Unverified Commit 79b6fbd8 authored by Casper's avatar Casper Committed by GitHub
Browse files

Support Fused Mixtral on multi-GPU (#352)

parent 74053104
...@@ -111,6 +111,7 @@ jobs: ...@@ -111,6 +111,7 @@ jobs:
if ( $env:CUDA_VERSION -eq $env:PYPI_CUDA_VERSION ){ if ( $env:CUDA_VERSION -eq $env:PYPI_CUDA_VERSION ){
$env:PYPI_BUILD = 1 $env:PYPI_BUILD = 1
} }
$env:PYPI_FORCE_TAGS = 1
python setup.py sdist bdist_wheel python setup.py sdist bdist_wheel
...@@ -223,7 +224,7 @@ jobs: ...@@ -223,7 +224,7 @@ jobs:
python --version python --version
which python which python
ROCM_VERSION=${{ matrix.rocm }} python setup.py sdist bdist_wheel ROCM_VERSION=${{ matrix.rocm }} PYPI_FORCE_TAGS=1 python setup.py sdist bdist_wheel
- name: Upload Assets - name: Upload Assets
uses: shogo82148/actions-upload-release-asset@v1 uses: shogo82148/actions-upload-release-asset@v1
......
...@@ -127,7 +127,7 @@ class MixtralFuser: ...@@ -127,7 +127,7 @@ class MixtralFuser:
) )
sparse_moe = module.block_sparse_moe sparse_moe = module.block_sparse_moe
if isinstance(sparse_moe.experts[0].w1, WQLinear_GEMM) and torch.cuda.device_count() == 1: if isinstance(sparse_moe.experts[0].w1, WQLinear_GEMM):
fused_w1w3s = [ fused_w1w3s = [
fuse_linears( fuse_linears(
[ [
......
...@@ -132,6 +132,18 @@ if not KERNELS_INSTALLED and (CUDA_VERSION or ROCM_VERSION): ...@@ -132,6 +132,18 @@ if not KERNELS_INSTALLED and (CUDA_VERSION or ROCM_VERSION):
"Please install the kernels manually from https://github.com/casper-hansen/AutoAWQ_kernels" "Please install the kernels manually from https://github.com/casper-hansen/AutoAWQ_kernels"
) )
force_extension = os.getenv("PYPI_FORCE_TAGS", "0")
if force_extension == "1":
# NOTE: We create an empty CUDAExtension because torch helps us with
# creating the right boilerplate to enable correct targeting of
# the autoawq-kernels package
common_setup_kwargs["ext_modules"] = [
CUDAExtension(
name="test_kernel",
sources=[],
)
]
setup( setup(
packages=find_packages(), packages=find_packages(),
install_requires=requirements, install_requires=requirements,
...@@ -139,14 +151,5 @@ setup( ...@@ -139,14 +151,5 @@ setup(
"eval": ["lm_eval>=0.4.0", "tabulate", "protobuf", "evaluate", "scipy"], "eval": ["lm_eval>=0.4.0", "tabulate", "protobuf", "evaluate", "scipy"],
"dev": ["black", "mkdocstrings-python", "mkdocs-material", "griffe-typingdoc"] "dev": ["black", "mkdocstrings-python", "mkdocs-material", "griffe-typingdoc"]
}, },
# NOTE: We create an empty CUDAExtension because torch helps us with
# creating the right boilerplate to enable correct targeting of
# the autoawq-kernels package
ext_modules=[
CUDAExtension(
name="__build_artifact_for_awq_kernel_targeting",
sources=[],
)
],
**common_setup_kwargs, **common_setup_kwargs,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment