Unverified Commit f018d2b7 authored by Ilyas Moutawwakil's avatar Ilyas Moutawwakil Committed by GitHub
Browse files

AMD ROCM Support (#315)

parent 8c78db4d
...@@ -32,8 +32,8 @@ jobs: ...@@ -32,8 +32,8 @@ jobs:
const script = require('.github/workflows/scripts/github_create_release.js') const script = require('.github/workflows/scripts/github_create_release.js')
await script(github, context, core) await script(github, context, core)
build_wheels: build_cuda_wheels:
name: Build AWQ name: Build AWQ with CUDA
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
needs: release needs: release
...@@ -119,3 +119,116 @@ jobs: ...@@ -119,3 +119,116 @@ jobs:
with: with:
upload_url: ${{ needs.release.outputs.upload_url }} upload_url: ${{ needs.release.outputs.upload_url }}
asset_path: ./dist/*.whl asset_path: ./dist/*.whl
build_rocm_wheels:
name: Build AWQ with ROCm
runs-on: ${{ matrix.os }}
needs: release
strategy:
matrix:
os: [ubuntu-20.04]
python: ["3.8", "3.9", "3.10", "3.11"]
rocm: ["5.6.1", "5.7.1"]
defaults:
run:
shell: bash
env:
ROCM_VERSION: ${{ matrix.rocm }}
steps:
- uses: actions/checkout@v3
- name: Free Disk Space
run: |
df -h
echo "Removing large packages"
sudo apt-get remove -y '^dotnet-.*'
sudo apt-get remove -y 'php.*'
sudo apt-get remove -y azure-cli google-chrome-stable firefox powershell mono-devel
df -h
sudo apt-get autoremove -y >/dev/null 2>&1
sudo apt-get clean
sudo apt-get autoremove -y >/dev/null 2>&1
sudo apt-get autoclean -y >/dev/null 2>&1
df -h
echo "https://github.com/actions/virtual-environments/issues/709"
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
df -h
echo "remove big /usr/local"
sudo rm -rf "/usr/local/share/boost"
sudo rm -rf /usr/local/lib/android >/dev/null 2>&1
df -h
sudo rm -rf /usr/share/dotnet/sdk > /dev/null 2>&1
sudo rm -rf /usr/share/dotnet/shared > /dev/null 2>&1
sudo rm -rf /usr/share/swift > /dev/null 2>&1
df -h
- uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python }}
- name: Setup Mamba
uses: conda-incubator/setup-miniconda@v2.2.0
with:
activate-environment: "build"
python-version: ${{ matrix.python }}
mamba-version: "*"
use-mamba: false
channels: conda-forge,defaults
channel-priority: true
add-pip-as-python-dependency: true
auto-activate-base: false
- name: Set up ROCm
run: |
echo "Using python:"
python --version
which python
if [[ "${{ matrix.rocm }}" == "5.4.2" ]]; then
export ROCM_DL_FILE=amdgpu-install_5.4.50402-1_all.deb
elif [[ "${{ matrix.rocm }}" == "5.6.1" ]]; then
export ROCM_DL_FILE=amdgpu-install_5.6.50601-1_all.deb
elif [[ "${{ matrix.rocm }}" == "5.7.1" ]]; then
export ROCM_DL_FILE=amdgpu-install_5.7.50701-1_all.deb
else
echo Unknown rocm version
exit 1
fi
curl -O https://repo.radeon.com/amdgpu-install/${{ matrix.rocm }}/ubuntu/focal/$ROCM_DL_FILE
sudo dpkg -i $ROCM_DL_FILE
sudo DEBIAN_FRONTEND=noninteractive amdgpu-install --usecase=rocm --no-dkms --no-32 -y
- name: Install Dependencies
run: |
sudo apt-get update
sudo apt-get install -y --no-install-recommends rocsparse-dev rocthrust-dev rocblas-dev hipblas-dev hipsparse-dev
python -m pip install --upgrade build setuptools wheel
if [[ "${{ matrix.rocm }}" == "5.7.1" ]]; then
echo "Using PyTorch nightly"
python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm5.7
elif [[ "${{ matrix.rocm }}" == "5.6.1" ]]; then
echo "Using PyTorch stable"
python -m pip install torch --index-url https://download.pytorch.org/whl/rocm5.6
else
echo Unknown rocm version for python install
exit 1
fi
- name: Build Wheel
run: |
echo "Using python for build:"
python --version
which python
ROCM_VERSION=${{ matrix.rocm }} python setup.py sdist bdist_wheel
- name: Upload Assets
uses: shogo82148/actions-upload-release-asset@v1
with:
upload_url: ${{ needs.release.outputs.upload_url }}
asset_path: ./dist/*.whl
\ No newline at end of file
...@@ -30,9 +30,11 @@ AutoAWQ is an easy-to-use package for 4-bit quantized models. AutoAWQ speeds up ...@@ -30,9 +30,11 @@ AutoAWQ is an easy-to-use package for 4-bit quantized models. AutoAWQ speeds up
### Prerequisites ### Prerequisites
- Your GPU(s) must be of Compute Capability 7.5. Turing and later architectures are supported. - NVIDIA:
- Your CUDA version must be CUDA 11.8 or later. - Your NVIDIA GPU(s) must be of Compute Capability 7.5. Turing and later architectures are supported.
- Requires installing [AutoAWQ kernels](https://github.com/casper-hansen/AutoAWQ_kernels). - Your CUDA version must be CUDA 11.8 or later.
- AMD:
- Your ROCm version must be ROCm 5.6 or later.
### Install from PyPi ### Install from PyPi
...@@ -42,13 +44,21 @@ To install the newest AutoAWQ from PyPi, you need CUDA 12.1 installed. ...@@ -42,13 +44,21 @@ To install the newest AutoAWQ from PyPi, you need CUDA 12.1 installed.
pip install autoawq pip install autoawq
``` ```
If you cannot use CUDA 12.1, you can still use CUDA 11.8 and install the wheel from the [latest release](https://github.com/casper-hansen/AutoAWQ/releases). ### Build from source
For CUDA 11.8, ROCm 5.6, and ROCm 5.7, you can install wheels from the [release page](https://github.com/casper-hansen/AutoAWQ/releases/latest):
``` ```
pip install https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.6/autoawq-0.1.6+cu118-cp310-cp310-linux_x86_64.whl pip install autoawq@https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.0/autoawq-0.2.0+cu118-cp310-cp310-linux_x86_64.whl
``` ```
### Build from source Or from the main branch directly:
```
pip install autoawq@https://github.com/casper-hansen/AutoAWQ.git
```
Or by cloning the repository and installing from source:
``` ```
git clone https://github.com/casper-hansen/AutoAWQ git clone https://github.com/casper-hansen/AutoAWQ
...@@ -56,12 +66,16 @@ cd AutoAWQ ...@@ -56,12 +66,16 @@ cd AutoAWQ
pip install -e . pip install -e .
``` ```
All three methods will install the latest and correct kernels for your system from [AutoAWQ_Kernels](https://github.com/casper-hansen/AutoAWQ_kernels/releases).
If your system is not supported (i.e. not on the release page), you can build the kernels yourself by following the instructions in [AutoAWQ_Kernels](https://github.com/casper-hansen/AutoAWQ_kernels/releases) and then install AutoAWQ from source.
## Supported models ## Supported models
The detailed support list: The detailed support list:
| Models | Sizes | | Models | Sizes |
| ---------| ----------------------------| | -------- | --------------------------- |
| LLaMA-2 | 7B/13B/70B | | LLaMA-2 | 7B/13B/70B |
| LLaMA | 7B/13B/30B/65B | | LLaMA | 7B/13B/30B/65B |
| Mistral | 7B | | Mistral | 7B |
...@@ -196,27 +210,27 @@ These benchmarks showcase the speed and memory usage of processing context (pref ...@@ -196,27 +210,27 @@ These benchmarks showcase the speed and memory usage of processing context (pref
- 🟢 for GEMV, 🔵 for GEMM, 🔴 for avoid using - 🟢 for GEMV, 🔵 for GEMM, 🔴 for avoid using
| Model Name | Size | Version | Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory (VRAM) | | Model Name | Size | Version | Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory (VRAM) |
|------------|----------|------------------|------------|----------------|---------------|------------------|-----------------|------------------| | ---------- | ---- | ------- | ---------- | -------------- | ------------- | ---------------- | --------------- | ----------------- |
| Vicuna | 7B | 🟢GEMV | 1 | 64 | 64 | 639.65 | 198.848 | 4.50 GB (19.05%) | | Vicuna | 7B | 🟢GEMV | 1 | 64 | 64 | 639.65 | 198.848 | 4.50 GB (19.05%) |
| Vicuna | 7B | 🟢GEMV | 1 | 2048 | 2048 | 1123.63 | 133.191 | 6.15 GB (26.02%) | | Vicuna | 7B | 🟢GEMV | 1 | 2048 | 2048 | 1123.63 | 133.191 | 6.15 GB (26.02%) |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| Mistral | 7B | 🔵GEMM | 1 | 64 | 64 | 1093.35 | 156.317 | 4.35 GB (18.41%) | | Mistral | 7B | 🔵GEMM | 1 | 64 | 64 | 1093.35 | 156.317 | 4.35 GB (18.41%) |
| Mistral | 7B | 🔵GEMM | 1 | 2048 | 2048 | 3897.02 | 114.355 | 5.55 GB (23.48%) | | Mistral | 7B | 🔵GEMM | 1 | 2048 | 2048 | 3897.02 | 114.355 | 5.55 GB (23.48%) |
| Mistral | 7B | 🔵GEMM | 8 | 64 | 64 | 4199.18 | 1185.25 | 4.35 GB (18.41%) | | Mistral | 7B | 🔵GEMM | 8 | 64 | 64 | 4199.18 | 1185.25 | 4.35 GB (18.41%) |
| Mistral | 7B | 🔵GEMM | 8 | 2048 | 2048 | 3661.46 | 829.754 | 16.82 GB (71.12%)| | Mistral | 7B | 🔵GEMM | 8 | 2048 | 2048 | 3661.46 | 829.754 | 16.82 GB (71.12%) |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| Mistral | 7B | 🟢GEMV | 1 | 64 | 64 | 531.99 | 188.29 | 4.28 GB (18.08%) | | Mistral | 7B | 🟢GEMV | 1 | 64 | 64 | 531.99 | 188.29 | 4.28 GB (18.08%) |
| Mistral | 7B | 🟢GEMV | 1 | 2048 | 2048 | 903.83 | 130.66 | 5.55 GB (23.48%) | | Mistral | 7B | 🟢GEMV | 1 | 2048 | 2048 | 903.83 | 130.66 | 5.55 GB (23.48%) |
| Mistral | 7B | 🔴GEMV | 8 | 64 | 64 | 897.87 | 486.46 | 4.33 GB (18.31%) | | Mistral | 7B | 🔴GEMV | 8 | 64 | 64 | 897.87 | 486.46 | 4.33 GB (18.31%) |
| Mistral | 7B | 🔴GEMV | 8 | 2048 | 2048 | 884.22 | 411.893 | 16.82 GB (71.12%)| | Mistral | 7B | 🔴GEMV | 8 | 2048 | 2048 | 884.22 | 411.893 | 16.82 GB (71.12%) |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| TinyLlama | 1B | 🟢GEMV | 1 | 64 | 64 | 1088.63 | 548.993 | 0.86 GB (3.62%) | | TinyLlama | 1B | 🟢GEMV | 1 | 64 | 64 | 1088.63 | 548.993 | 0.86 GB (3.62%) |
| TinyLlama | 1B | 🟢GEMV | 1 | 2048 | 2048 | 5178.98 | 431.468 | 2.10 GB (8.89%) | | TinyLlama | 1B | 🟢GEMV | 1 | 2048 | 2048 | 5178.98 | 431.468 | 2.10 GB (8.89%) |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| Llama 2 | 13B | 🔵GEMM | 1 | 64 | 64 | 820.34 | 96.74 | 8.47 GB (35.83%) | | Llama 2 | 13B | 🔵GEMM | 1 | 64 | 64 | 820.34 | 96.74 | 8.47 GB (35.83%) |
| Llama 2 | 13B | 🔵GEMM | 1 | 2048 | 2048 | 2279.41 | 73.8213 | 10.28 GB (43.46%)| | Llama 2 | 13B | 🔵GEMM | 1 | 2048 | 2048 | 2279.41 | 73.8213 | 10.28 GB (43.46%) |
| Llama 2 | 13B | 🔵GEMM | 3 | 64 | 64 | 1593.88 | 286.249 | 8.57 GB (36.24%) | | Llama 2 | 13B | 🔵GEMM | 3 | 64 | 64 | 1593.88 | 286.249 | 8.57 GB (36.24%) |
| Llama 2 | 13B | 🔵GEMM | 3 | 2048 | 2048 | 2226.7 | 189.573 | 16.90 GB (71.47%)| | Llama 2 | 13B | 🔵GEMM | 3 | 2048 | 2048 | 2226.7 | 189.573 | 16.90 GB (71.47%) |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| MPT | 7B | 🔵GEMM | 1 | 64 | 64 | 1079.06 | 161.344 | 3.67 GB (15.51%) | | MPT | 7B | 🔵GEMM | 1 | 64 | 64 | 1079.06 | 161.344 | 3.67 GB (15.51%) |
| MPT | 7B | 🔵GEMM | 1 | 2048 | 2048 | 4069.78 | 114.982 | 5.87 GB (24.82%) | | MPT | 7B | 🔵GEMM | 1 | 2048 | 2048 | 4069.78 | 114.982 | 5.87 GB (24.82%) |
...@@ -224,11 +238,11 @@ These benchmarks showcase the speed and memory usage of processing context (pref ...@@ -224,11 +238,11 @@ These benchmarks showcase the speed and memory usage of processing context (pref
| Falcon | 7B | 🔵GEMM | 1 | 64 | 64 | 1139.93 | 133.585 | 4.47 GB (18.92%) | | Falcon | 7B | 🔵GEMM | 1 | 64 | 64 | 1139.93 | 133.585 | 4.47 GB (18.92%) |
| Falcon | 7B | 🔵GEMM | 1 | 2048 | 2048 | 2850.97 | 115.73 | 6.83 GB (28.88%) | | Falcon | 7B | 🔵GEMM | 1 | 2048 | 2048 | 2850.97 | 115.73 | 6.83 GB (28.88%) |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| CodeLlama | 34B | 🔵GEMM | 1 | 64 | 64 | 681.74 | 41.01 | 19.05 GB (80.57%)| | CodeLlama | 34B | 🔵GEMM | 1 | 64 | 64 | 681.74 | 41.01 | 19.05 GB (80.57%) |
| CodeLlama | 34B | 🔵GEMM | 1 | 2048 | 2048 | 1072.36 | 35.8316 | 20.26 GB (85.68%)| | CodeLlama | 34B | 🔵GEMM | 1 | 2048 | 2048 | 1072.36 | 35.8316 | 20.26 GB (85.68%) |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| DeepSeek | 33B | 🔵GEMM | 1 | 64 | 64 | 1160.18 | 40.29 | 18.92 GB (80.00%)| | DeepSeek | 33B | 🔵GEMM | 1 | 64 | 64 | 1160.18 | 40.29 | 18.92 GB (80.00%) |
| DeepSeek | 33B | 🔵GEMM | 1 | 2048 | 2048 | 1012.1 | 34.0093 | 19.87 GB (84.02%)| | DeepSeek | 33B | 🔵GEMM | 1 | 2048 | 2048 | 1012.1 | 34.0093 | 19.87 GB (84.02%) |
## Reference ## Reference
......
...@@ -306,12 +306,11 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -306,12 +306,11 @@ class BaseAWQForCausalLM(nn.Module):
# creates q4 handle # creates q4 handle
model = exllama_post_init(model) model = exllama_post_init(model)
elif use_exllama_v2: elif use_exllama_v2:
# creates q4 handle and allocates scratch spaces wrt max_input_len and # creates q4 handle and allocates scratch spaces wrt max_input_len and max_batch_size
# max_batch_size, which are hardcoded for now but might be worth interfacing
model = exllamav2_post_init( model = exllamav2_post_init(
model, model,
max_input_len=max_new_tokens, max_input_len=max_new_tokens or 2048,
max_batch_size=int(os.getenv("AWQ_BATCH_SIZE", 1)) max_batch_size=int(os.getenv("AWQ_BATCH_SIZE", 1)),
) )
return self( return self(
......
...@@ -3,10 +3,12 @@ from torch import nn ...@@ -3,10 +3,12 @@ from torch import nn
try: try:
import awq_ext # with CUDA kernels import awq_ext # with CUDA kernels
AWQ_INSTALLED = True AWQ_INSTALLED = True
except: except:
AWQ_INSTALLED = False AWQ_INSTALLED = False
class FasterTransformerRMSNorm(nn.Module): class FasterTransformerRMSNorm(nn.Module):
def __init__(self, weight, eps=1e-6): def __init__(self, weight, eps=1e-6):
super().__init__() super().__init__()
...@@ -14,6 +16,12 @@ class FasterTransformerRMSNorm(nn.Module): ...@@ -14,6 +16,12 @@ class FasterTransformerRMSNorm(nn.Module):
self.variance_epsilon = eps self.variance_epsilon = eps
def forward(self, x): def forward(self, x):
assert AWQ_INSTALLED, (
"AWQ kernels could not be loaded. "
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
)
output = torch.empty_like(x) output = torch.empty_like(x)
awq_ext.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon) awq_ext.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon)
return output return output
...@@ -4,9 +4,10 @@ from awq.utils.packing_utils import unpack_reorder_pack ...@@ -4,9 +4,10 @@ from awq.utils.packing_utils import unpack_reorder_pack
try: try:
import exl_ext # with CUDA kernels (AutoAWQ_kernels) import exl_ext # with CUDA kernels (AutoAWQ_kernels)
AWQ_INSTALLED = True
EXL_INSTALLED = True
except: except:
AWQ_INSTALLED = False EXL_INSTALLED = False
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta") none_tensor = torch.empty((1, 1), device="meta")
...@@ -103,6 +104,10 @@ class WQLinear_Exllama(nn.Module): ...@@ -103,6 +104,10 @@ class WQLinear_Exllama(nn.Module):
"module.post_init() must be called before module.forward(). " "module.post_init() must be called before module.forward(). "
"Use exllama_post_init() on the whole model." "Use exllama_post_init() on the whole model."
) )
assert EXL_INSTALLED, (
"Exllama kernels could not be loaded. "
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
)
input_dtype = x.dtype input_dtype = x.dtype
out_shape = x.shape[:-1] + (self.out_features,) out_shape = x.shape[:-1] + (self.out_features,)
......
...@@ -5,9 +5,11 @@ from awq.utils.packing_utils import unpack_reorder_pack ...@@ -5,9 +5,11 @@ from awq.utils.packing_utils import unpack_reorder_pack
try: try:
import exlv2_ext # with CUDA kernels (AutoAWQ_kernels) import exlv2_ext # with CUDA kernels (AutoAWQ_kernels)
AWQ_INSTALLED = True
EXLV2_INSTALLED = True
except: except:
AWQ_INSTALLED = False EXLV2_INSTALLED = False
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta") none_tensor = torch.empty((1, 1), device="meta")
...@@ -131,6 +133,11 @@ class WQLinear_ExllamaV2(nn.Module): ...@@ -131,6 +133,11 @@ class WQLinear_ExllamaV2(nn.Module):
"module.post_init() must be called before module.forward(). " "module.post_init() must be called before module.forward(). "
"Use exllamav2_post_init() on the whole model." "Use exllamav2_post_init() on the whole model."
) )
assert EXLV2_INSTALLED, (
"Exllama kernels could not be loaded. "
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
)
input_dtype = x.dtype input_dtype = x.dtype
out_shape = x.shape[:-1] + (self.out_features,) out_shape = x.shape[:-1] + (self.out_features,)
......
...@@ -4,7 +4,8 @@ from awq.utils.utils import get_best_device ...@@ -4,7 +4,8 @@ from awq.utils.utils import get_best_device
from awq.utils.packing_utils import dequantize_gemm from awq.utils.packing_utils import dequantize_gemm
try: try:
import awq_ext # with CUDA kernels import awq_ext # with CUDA kernels (AutoAWQ_kernels)
AWQ_INSTALLED = True AWQ_INSTALLED = True
except: except:
AWQ_INSTALLED = False AWQ_INSTALLED = False
...@@ -153,7 +154,7 @@ class WQLinear_GEMM(nn.Module): ...@@ -153,7 +154,7 @@ class WQLinear_GEMM(nn.Module):
x = x.half() x = x.half()
if AWQ_INSTALLED: if AWQ_INSTALLED:
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0]*x.shape[1] >= 1024 FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024
if FP16_MATMUL_HEURISTIC_CONDITION: if FP16_MATMUL_HEURISTIC_CONDITION:
out = awq_ext.dequantize_weights_cuda( out = awq_ext.dequantize_weights_cuda(
...@@ -163,12 +164,16 @@ class WQLinear_GEMM(nn.Module): ...@@ -163,12 +164,16 @@ class WQLinear_GEMM(nn.Module):
0, 0,
0, 0,
0, 0,
False False,
) )
out = torch.matmul(x, out) out = torch.matmul(x, out)
else: else:
out = awq_ext.gemm_forward_cuda( out = awq_ext.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8 x.reshape(-1, x.shape[-1]),
self.qweight,
self.scales,
self.qzeros,
8,
) )
else: else:
out = dequantize_gemm( out = dequantize_gemm(
...@@ -176,7 +181,7 @@ class WQLinear_GEMM(nn.Module): ...@@ -176,7 +181,7 @@ class WQLinear_GEMM(nn.Module):
self.qzeros, self.qzeros,
self.scales, self.scales,
self.w_bit, self.w_bit,
self.group_size self.group_size,
) )
out = torch.matmul(x, out) out = torch.matmul(x, out)
......
...@@ -3,6 +3,7 @@ import torch.nn as nn ...@@ -3,6 +3,7 @@ import torch.nn as nn
try: try:
import awq_ext # with CUDA kernels import awq_ext # with CUDA kernels
AWQ_INSTALLED = True AWQ_INSTALLED = True
except: except:
AWQ_INSTALLED = False AWQ_INSTALLED = False
...@@ -158,6 +159,11 @@ class WQLinear_GEMV(nn.Module): ...@@ -158,6 +159,11 @@ class WQLinear_GEMV(nn.Module):
@torch.no_grad() @torch.no_grad()
def forward(self, x): def forward(self, x):
assert AWQ_INSTALLED, (
"AWQ kernels could not be loaded. "
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
)
out_shape = x.shape[:-1] + (self.out_features,) out_shape = x.shape[:-1] + (self.out_features,)
inputs = x.reshape(-1, x.shape[-1]) inputs = x.reshape(-1, x.shape[-1])
......
import os import os
import sys
import torch import torch
import platform import platform
import requests
import importlib.util
from pathlib import Path from pathlib import Path
from setuptools import setup, find_packages from setuptools import setup, find_packages
os.environ["CC"] = "g++"
os.environ["CXX"] = "g++" def get_latest_kernels_version(repo):
"""
Get the latest version of the kernels from the github repo.
"""
response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest")
data = response.json()
tag_name = data["tag_name"]
version = tag_name.replace("v", "")
return version
def get_kernels_whl_url(
gpu_system_version,
release_version,
python_version,
platform,
architecture,
):
"""
Get the url for the kernels wheel file.
"""
return f"https://github.com/casper-hansen/AutoAWQ_kernels/releases/download/v{release_version}/autoawq_kernels-{release_version}+{gpu_system_version}-cp{python_version}-cp{python_version}-{platform}_{architecture}.whl"
AUTOAWQ_VERSION = "0.1.8" AUTOAWQ_VERSION = "0.1.8"
PYPI_BUILD = os.getenv("PYPI_BUILD", "0") == "1" PYPI_BUILD = os.getenv("PYPI_BUILD", "0") == "1"
HAS_CUDA = torch.cuda.is_available()
if not PYPI_BUILD and HAS_CUDA: CUDA_VERSION = os.getenv("CUDA_VERSION", None) or torch.version.cuda
try: if CUDA_VERSION:
CUDA_VERSION = "".join(os.environ.get("CUDA_VERSION", torch.version.cuda).split("."))[:3] CUDA_VERSION = "".join(CUDA_VERSION.split("."))[:3]
ROCM_VERSION = os.getenv("ROCM_VERSION", None) or torch.version.hip
if ROCM_VERSION:
if ROCM_VERSION.startswith("5.6"):
ROCM_VERSION = "5.6.1"
elif ROCM_VERSION.startswith("5.7"):
ROCM_VERSION = "5.7.1"
ROCM_VERSION = "".join(ROCM_VERSION.split("."))[:3]
if not PYPI_BUILD:
if CUDA_VERSION:
AUTOAWQ_VERSION += f"+cu{CUDA_VERSION}" AUTOAWQ_VERSION += f"+cu{CUDA_VERSION}"
except Exception as ex: elif ROCM_VERSION:
raise RuntimeError("Your system must have an Nvidia GPU for installing AutoAWQ") AUTOAWQ_VERSION += f"+rocm{ROCM_VERSION}"
else:
raise RuntimeError(
"Your system must have either Nvidia or AMD GPU to build this package."
)
common_setup_kwargs = { common_setup_kwargs = {
"version": AUTOAWQ_VERSION, "version": AUTOAWQ_VERSION,
...@@ -25,7 +64,9 @@ common_setup_kwargs = { ...@@ -25,7 +64,9 @@ common_setup_kwargs = {
"license": "MIT", "license": "MIT",
"python_requires": ">=3.8.0", "python_requires": ">=3.8.0",
"description": "AutoAWQ implements the AWQ algorithm for 4-bit quantization with a 2x speedup during inference.", "description": "AutoAWQ implements the AWQ algorithm for 4-bit quantization with a 2x speedup during inference.",
"long_description": (Path(__file__).parent / "README.md").read_text(encoding="UTF-8"), "long_description": (Path(__file__).parent / "README.md").read_text(
encoding="UTF-8"
),
"long_description_content_type": "text/markdown", "long_description_content_type": "text/markdown",
"url": "https://github.com/casper-hansen/AutoAWQ", "url": "https://github.com/casper-hansen/AutoAWQ",
"keywords": ["awq", "autoawq", "quantization", "transformers"], "keywords": ["awq", "autoawq", "quantization", "transformers"],
...@@ -40,7 +81,7 @@ common_setup_kwargs = { ...@@ -40,7 +81,7 @@ common_setup_kwargs = {
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.11",
"Programming Language :: C++", "Programming Language :: C++",
] ],
} }
requirements = [ requirements = [
...@@ -51,21 +92,44 @@ requirements = [ ...@@ -51,21 +92,44 @@ requirements = [
"datasets", "datasets",
] ]
# CUDA kernels try:
if platform.system().lower() != "darwin" and HAS_CUDA: importlib.metadata.version("autoawq-kernels")
KERNELS_INSTALLED = True
except importlib.metadata.PackageNotFoundError:
KERNELS_INSTALLED = False
# kernels can be downloaded from pypi for cuda+121 only
# for everything else, we need to download the wheels from github
if not KERNELS_INSTALLED and (CUDA_VERSION or ROCM_VERSION):
if CUDA_VERSION.startswith("12"):
requirements.append("autoawq-kernels") requirements.append("autoawq-kernels")
elif CUDA_VERSION.startswith("11") or ROCM_VERSION in ["561", "571"]:
gpu_system_version = (
f"cu{CUDA_VERSION}" if CUDA_VERSION else f"rocm{ROCM_VERSION}"
)
kernels_version = get_latest_kernels_version("casper-hansen/AutoAWQ_kernels")
python_version = "".join(platform.python_version_tuple()[:2])
platform_name = platform.system().lower()
architecture = platform.machine().lower()
latest_rocm_kernels_wheels = get_kernels_whl_url(
gpu_system_version,
kernels_version,
python_version,
platform_name,
architecture,
)
requirements.append(f"autoawq-kernels@{latest_rocm_kernels_wheels}")
else:
raise RuntimeError(
"Your system have a GPU with an unsupported CUDA or ROCm version. "
"Please install the kernels manually from https://github.com/casper-hansen/AutoAWQ_kernels"
)
setup( setup(
packages=find_packages(), packages=find_packages(),
install_requires=requirements, install_requires=requirements,
extras_require={ extras_require={
"eval": [ "eval": ["lm_eval>=0.4.0", "tabulate", "protobuf", "evaluate", "scipy"],
"lm_eval>=0.4.0",
"tabulate",
"protobuf",
"evaluate",
"scipy"
],
}, },
**common_setup_kwargs **common_setup_kwargs,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment