Unverified Commit f018d2b7 authored by Ilyas Moutawwakil's avatar Ilyas Moutawwakil Committed by GitHub
Browse files

AMD ROCM Support (#315)

parent 8c78db4d
......@@ -32,8 +32,8 @@ jobs:
const script = require('.github/workflows/scripts/github_create_release.js')
await script(github, context, core)
build_wheels:
name: Build AWQ
build_cuda_wheels:
name: Build AWQ with CUDA
runs-on: ${{ matrix.os }}
needs: release
......@@ -114,6 +114,119 @@ jobs:
python setup.py sdist bdist_wheel
- name: Upload Assets
uses: shogo82148/actions-upload-release-asset@v1
with:
upload_url: ${{ needs.release.outputs.upload_url }}
asset_path: ./dist/*.whl
build_rocm_wheels:
name: Build AWQ with ROCm
runs-on: ${{ matrix.os }}
needs: release
strategy:
matrix:
os: [ubuntu-20.04]
python: ["3.8", "3.9", "3.10", "3.11"]
rocm: ["5.6.1", "5.7.1"]
defaults:
run:
shell: bash
env:
ROCM_VERSION: ${{ matrix.rocm }}
steps:
- uses: actions/checkout@v3
- name: Free Disk Space
run: |
df -h
echo "Removing large packages"
sudo apt-get remove -y '^dotnet-.*'
sudo apt-get remove -y 'php.*'
sudo apt-get remove -y azure-cli google-chrome-stable firefox powershell mono-devel
df -h
sudo apt-get autoremove -y >/dev/null 2>&1
sudo apt-get clean
sudo apt-get autoremove -y >/dev/null 2>&1
sudo apt-get autoclean -y >/dev/null 2>&1
df -h
echo "https://github.com/actions/virtual-environments/issues/709"
sudo rm -rf "$AGENT_TOOLSDIRECTORY"
df -h
echo "remove big /usr/local"
sudo rm -rf "/usr/local/share/boost"
sudo rm -rf /usr/local/lib/android >/dev/null 2>&1
df -h
sudo rm -rf /usr/share/dotnet/sdk > /dev/null 2>&1
sudo rm -rf /usr/share/dotnet/shared > /dev/null 2>&1
sudo rm -rf /usr/share/swift > /dev/null 2>&1
df -h
- uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python }}
- name: Setup Mamba
uses: conda-incubator/setup-miniconda@v2.2.0
with:
activate-environment: "build"
python-version: ${{ matrix.python }}
mamba-version: "*"
use-mamba: false
channels: conda-forge,defaults
channel-priority: true
add-pip-as-python-dependency: true
auto-activate-base: false
- name: Set up ROCm
run: |
echo "Using python:"
python --version
which python
if [[ "${{ matrix.rocm }}" == "5.4.2" ]]; then
export ROCM_DL_FILE=amdgpu-install_5.4.50402-1_all.deb
elif [[ "${{ matrix.rocm }}" == "5.6.1" ]]; then
export ROCM_DL_FILE=amdgpu-install_5.6.50601-1_all.deb
elif [[ "${{ matrix.rocm }}" == "5.7.1" ]]; then
export ROCM_DL_FILE=amdgpu-install_5.7.50701-1_all.deb
else
echo Unknown rocm version
exit 1
fi
curl -O https://repo.radeon.com/amdgpu-install/${{ matrix.rocm }}/ubuntu/focal/$ROCM_DL_FILE
sudo dpkg -i $ROCM_DL_FILE
sudo DEBIAN_FRONTEND=noninteractive amdgpu-install --usecase=rocm --no-dkms --no-32 -y
- name: Install Dependencies
run: |
sudo apt-get update
sudo apt-get install -y --no-install-recommends rocsparse-dev rocthrust-dev rocblas-dev hipblas-dev hipsparse-dev
python -m pip install --upgrade build setuptools wheel
if [[ "${{ matrix.rocm }}" == "5.7.1" ]]; then
echo "Using PyTorch nightly"
python -m pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm5.7
elif [[ "${{ matrix.rocm }}" == "5.6.1" ]]; then
echo "Using PyTorch stable"
python -m pip install torch --index-url https://download.pytorch.org/whl/rocm5.6
else
echo Unknown rocm version for python install
exit 1
fi
- name: Build Wheel
run: |
echo "Using python for build:"
python --version
which python
ROCM_VERSION=${{ matrix.rocm }} python setup.py sdist bdist_wheel
- name: Upload Assets
uses: shogo82148/actions-upload-release-asset@v1
with:
......
......@@ -30,9 +30,11 @@ AutoAWQ is an easy-to-use package for 4-bit quantized models. AutoAWQ speeds up
### Prerequisites
- Your GPU(s) must be of Compute Capability 7.5. Turing and later architectures are supported.
- Your CUDA version must be CUDA 11.8 or later.
- Requires installing [AutoAWQ kernels](https://github.com/casper-hansen/AutoAWQ_kernels).
- NVIDIA:
- Your NVIDIA GPU(s) must be of Compute Capability 7.5. Turing and later architectures are supported.
- Your CUDA version must be CUDA 11.8 or later.
- AMD:
- Your ROCm version must be ROCm 5.6 or later.
### Install from PyPi
......@@ -42,13 +44,21 @@ To install the newest AutoAWQ from PyPi, you need CUDA 12.1 installed.
pip install autoawq
```
If you cannot use CUDA 12.1, you can still use CUDA 11.8 and install the wheel from the [latest release](https://github.com/casper-hansen/AutoAWQ/releases).
### Build from source
For CUDA 11.8, ROCm 5.6, and ROCm 5.7, you can install wheels from the [release page](https://github.com/casper-hansen/AutoAWQ/releases/latest):
```
pip install https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.6/autoawq-0.1.6+cu118-cp310-cp310-linux_x86_64.whl
pip install autoawq@https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.0/autoawq-0.2.0+cu118-cp310-cp310-linux_x86_64.whl
```
### Build from source
Or from the main branch directly:
```
pip install autoawq@https://github.com/casper-hansen/AutoAWQ.git
```
Or by cloning the repository and installing from source:
```
git clone https://github.com/casper-hansen/AutoAWQ
......@@ -56,12 +66,16 @@ cd AutoAWQ
pip install -e .
```
All three methods will install the latest and correct kernels for your system from [AutoAWQ_Kernels](https://github.com/casper-hansen/AutoAWQ_kernels/releases).
If your system is not supported (i.e. not on the release page), you can build the kernels yourself by following the instructions in [AutoAWQ_Kernels](https://github.com/casper-hansen/AutoAWQ_kernels/releases) and then install AutoAWQ from source.
## Supported models
The detailed support list:
| Models | Sizes |
| ---------| ----------------------------|
| -------- | --------------------------- |
| LLaMA-2 | 7B/13B/70B |
| LLaMA | 7B/13B/30B/65B |
| Mistral | 7B |
......@@ -195,40 +209,40 @@ These benchmarks showcase the speed and memory usage of processing context (pref
- Command: `python examples/benchmark.py --model_path <hf_model> --batch_size 1`
- 🟢 for GEMV, 🔵 for GEMM, 🔴 for avoid using
| Model Name | Size | Version | Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory (VRAM) |
|------------|----------|------------------|------------|----------------|---------------|------------------|-----------------|------------------|
| Vicuna | 7B | 🟢GEMV | 1 | 64 | 64 | 639.65 | 198.848 | 4.50 GB (19.05%) |
| Vicuna | 7B | 🟢GEMV | 1 | 2048 | 2048 | 1123.63 | 133.191 | 6.15 GB (26.02%) |
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| Mistral | 7B | 🔵GEMM | 1 | 64 | 64 | 1093.35 | 156.317 | 4.35 GB (18.41%) |
| Mistral | 7B | 🔵GEMM | 1 | 2048 | 2048 | 3897.02 | 114.355 | 5.55 GB (23.48%) |
| Mistral | 7B | 🔵GEMM | 8 | 64 | 64 | 4199.18 | 1185.25 | 4.35 GB (18.41%) |
| Mistral | 7B | 🔵GEMM | 8 | 2048 | 2048 | 3661.46 | 829.754 | 16.82 GB (71.12%)|
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| Mistral | 7B | 🟢GEMV | 1 | 64 | 64 | 531.99 | 188.29 | 4.28 GB (18.08%) |
| Mistral | 7B | 🟢GEMV | 1 | 2048 | 2048 | 903.83 | 130.66 | 5.55 GB (23.48%) |
| Mistral | 7B | 🔴GEMV | 8 | 64 | 64 | 897.87 | 486.46 | 4.33 GB (18.31%) |
| Mistral | 7B | 🔴GEMV | 8 | 2048 | 2048 | 884.22 | 411.893 | 16.82 GB (71.12%)|
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| TinyLlama | 1B | 🟢GEMV | 1 | 64 | 64 | 1088.63 | 548.993 | 0.86 GB (3.62%) |
| TinyLlama | 1B | 🟢GEMV | 1 | 2048 | 2048 | 5178.98 | 431.468 | 2.10 GB (8.89%) |
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| Llama 2 | 13B | 🔵GEMM | 1 | 64 | 64 | 820.34 | 96.74 | 8.47 GB (35.83%) |
| Llama 2 | 13B | 🔵GEMM | 1 | 2048 | 2048 | 2279.41 | 73.8213 | 10.28 GB (43.46%)|
| Llama 2 | 13B | 🔵GEMM | 3 | 64 | 64 | 1593.88 | 286.249 | 8.57 GB (36.24%) |
| Llama 2 | 13B | 🔵GEMM | 3 | 2048 | 2048 | 2226.7 | 189.573 | 16.90 GB (71.47%)|
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| MPT | 7B | 🔵GEMM | 1 | 64 | 64 | 1079.06 | 161.344 | 3.67 GB (15.51%) |
| MPT | 7B | 🔵GEMM | 1 | 2048 | 2048 | 4069.78 | 114.982 | 5.87 GB (24.82%) |
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| Falcon | 7B | 🔵GEMM | 1 | 64 | 64 | 1139.93 | 133.585 | 4.47 GB (18.92%) |
| Falcon | 7B | 🔵GEMM | 1 | 2048 | 2048 | 2850.97 | 115.73 | 6.83 GB (28.88%) |
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| CodeLlama | 34B | 🔵GEMM | 1 | 64 | 64 | 681.74 | 41.01 | 19.05 GB (80.57%)|
| CodeLlama | 34B | 🔵GEMM | 1 | 2048 | 2048 | 1072.36 | 35.8316 | 20.26 GB (85.68%)|
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| DeepSeek | 33B | 🔵GEMM | 1 | 64 | 64 | 1160.18 | 40.29 | 18.92 GB (80.00%)|
| DeepSeek | 33B | 🔵GEMM | 1 | 2048 | 2048 | 1012.1 | 34.0093 | 19.87 GB (84.02%)|
| Model Name | Size | Version | Batch Size | Prefill Length | Decode Length | Prefill tokens/s | Decode tokens/s | Memory (VRAM) |
| ---------- | ---- | ------- | ---------- | -------------- | ------------- | ---------------- | --------------- | ----------------- |
| Vicuna | 7B | 🟢GEMV | 1 | 64 | 64 | 639.65 | 198.848 | 4.50 GB (19.05%) |
| Vicuna | 7B | 🟢GEMV | 1 | 2048 | 2048 | 1123.63 | 133.191 | 6.15 GB (26.02%) |
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| Mistral | 7B | 🔵GEMM | 1 | 64 | 64 | 1093.35 | 156.317 | 4.35 GB (18.41%) |
| Mistral | 7B | 🔵GEMM | 1 | 2048 | 2048 | 3897.02 | 114.355 | 5.55 GB (23.48%) |
| Mistral | 7B | 🔵GEMM | 8 | 64 | 64 | 4199.18 | 1185.25 | 4.35 GB (18.41%) |
| Mistral | 7B | 🔵GEMM | 8 | 2048 | 2048 | 3661.46 | 829.754 | 16.82 GB (71.12%) |
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| Mistral | 7B | 🟢GEMV | 1 | 64 | 64 | 531.99 | 188.29 | 4.28 GB (18.08%) |
| Mistral | 7B | 🟢GEMV | 1 | 2048 | 2048 | 903.83 | 130.66 | 5.55 GB (23.48%) |
| Mistral | 7B | 🔴GEMV | 8 | 64 | 64 | 897.87 | 486.46 | 4.33 GB (18.31%) |
| Mistral | 7B | 🔴GEMV | 8 | 2048 | 2048 | 884.22 | 411.893 | 16.82 GB (71.12%) |
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| TinyLlama | 1B | 🟢GEMV | 1 | 64 | 64 | 1088.63 | 548.993 | 0.86 GB (3.62%) |
| TinyLlama | 1B | 🟢GEMV | 1 | 2048 | 2048 | 5178.98 | 431.468 | 2.10 GB (8.89%) |
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| Llama 2 | 13B | 🔵GEMM | 1 | 64 | 64 | 820.34 | 96.74 | 8.47 GB (35.83%) |
| Llama 2 | 13B | 🔵GEMM | 1 | 2048 | 2048 | 2279.41 | 73.8213 | 10.28 GB (43.46%) |
| Llama 2 | 13B | 🔵GEMM | 3 | 64 | 64 | 1593.88 | 286.249 | 8.57 GB (36.24%) |
| Llama 2 | 13B | 🔵GEMM | 3 | 2048 | 2048 | 2226.7 | 189.573 | 16.90 GB (71.47%) |
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| MPT | 7B | 🔵GEMM | 1 | 64 | 64 | 1079.06 | 161.344 | 3.67 GB (15.51%) |
| MPT | 7B | 🔵GEMM | 1 | 2048 | 2048 | 4069.78 | 114.982 | 5.87 GB (24.82%) |
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| Falcon | 7B | 🔵GEMM | 1 | 64 | 64 | 1139.93 | 133.585 | 4.47 GB (18.92%) |
| Falcon | 7B | 🔵GEMM | 1 | 2048 | 2048 | 2850.97 | 115.73 | 6.83 GB (28.88%) |
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| CodeLlama | 34B | 🔵GEMM | 1 | 64 | 64 | 681.74 | 41.01 | 19.05 GB (80.57%) |
| CodeLlama | 34B | 🔵GEMM | 1 | 2048 | 2048 | 1072.36 | 35.8316 | 20.26 GB (85.68%) |
| ... | ... | ... | ... | ... | ... | ... | ... | ... |
| DeepSeek | 33B | 🔵GEMM | 1 | 64 | 64 | 1160.18 | 40.29 | 18.92 GB (80.00%) |
| DeepSeek | 33B | 🔵GEMM | 1 | 2048 | 2048 | 1012.1 | 34.0093 | 19.87 GB (84.02%) |
## Reference
......
......@@ -301,17 +301,16 @@ class BaseAWQForCausalLM(nn.Module):
# Dispath to devices
if fuse_layers:
self.fuse_layers(model)
if use_exllama:
# creates q4 handle
model = exllama_post_init(model)
elif use_exllama_v2:
# creates q4 handle and allocates scratch spaces wrt max_input_len and
# max_batch_size, which are hardcoded for now but might be worth interfacing
# creates q4 handle and allocates scratch spaces wrt max_input_len and max_batch_size
model = exllamav2_post_init(
model,
max_input_len=max_new_tokens,
max_batch_size=int(os.getenv("AWQ_BATCH_SIZE", 1))
max_input_len=max_new_tokens or 2048,
max_batch_size=int(os.getenv("AWQ_BATCH_SIZE", 1)),
)
return self(
......
......@@ -3,10 +3,12 @@ from torch import nn
try:
import awq_ext # with CUDA kernels
AWQ_INSTALLED = True
except:
AWQ_INSTALLED = False
class FasterTransformerRMSNorm(nn.Module):
def __init__(self, weight, eps=1e-6):
super().__init__()
......@@ -14,6 +16,12 @@ class FasterTransformerRMSNorm(nn.Module):
self.variance_epsilon = eps
def forward(self, x):
assert AWQ_INSTALLED, (
"AWQ kernels could not be loaded. "
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
)
output = torch.empty_like(x)
awq_ext.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon)
return output
return output
......@@ -4,9 +4,10 @@ from awq.utils.packing_utils import unpack_reorder_pack
try:
import exl_ext # with CUDA kernels (AutoAWQ_kernels)
AWQ_INSTALLED = True
EXL_INSTALLED = True
except:
AWQ_INSTALLED = False
EXL_INSTALLED = False
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
......@@ -103,6 +104,10 @@ class WQLinear_Exllama(nn.Module):
"module.post_init() must be called before module.forward(). "
"Use exllama_post_init() on the whole model."
)
assert EXL_INSTALLED, (
"Exllama kernels could not be loaded. "
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
)
input_dtype = x.dtype
out_shape = x.shape[:-1] + (self.out_features,)
......
......@@ -5,9 +5,11 @@ from awq.utils.packing_utils import unpack_reorder_pack
try:
import exlv2_ext # with CUDA kernels (AutoAWQ_kernels)
AWQ_INSTALLED = True
EXLV2_INSTALLED = True
except:
AWQ_INSTALLED = False
EXLV2_INSTALLED = False
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
......@@ -131,6 +133,11 @@ class WQLinear_ExllamaV2(nn.Module):
"module.post_init() must be called before module.forward(). "
"Use exllamav2_post_init() on the whole model."
)
assert EXLV2_INSTALLED, (
"Exllama kernels could not be loaded. "
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
)
input_dtype = x.dtype
out_shape = x.shape[:-1] + (self.out_features,)
......
......@@ -4,7 +4,8 @@ from awq.utils.utils import get_best_device
from awq.utils.packing_utils import dequantize_gemm
try:
import awq_ext # with CUDA kernels
import awq_ext # with CUDA kernels (AutoAWQ_kernels)
AWQ_INSTALLED = True
except:
AWQ_INSTALLED = False
......@@ -125,7 +126,7 @@ class WQLinear_GEMM(nn.Module):
if "mps" in best_device:
zeros = zeros.to("cpu")
qzeros = torch.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * awq_linear.w_bit),
dtype=torch.int32,
......@@ -153,7 +154,7 @@ class WQLinear_GEMM(nn.Module):
x = x.half()
if AWQ_INSTALLED:
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0]*x.shape[1] >= 1024
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024
if FP16_MATMUL_HEURISTIC_CONDITION:
out = awq_ext.dequantize_weights_cuda(
......@@ -163,12 +164,16 @@ class WQLinear_GEMM(nn.Module):
0,
0,
0,
False
False,
)
out = torch.matmul(x, out)
else:
out = awq_ext.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
x.reshape(-1, x.shape[-1]),
self.qweight,
self.scales,
self.qzeros,
8,
)
else:
out = dequantize_gemm(
......@@ -176,7 +181,7 @@ class WQLinear_GEMM(nn.Module):
self.qzeros,
self.scales,
self.w_bit,
self.group_size
self.group_size,
)
out = torch.matmul(x, out)
......
......@@ -3,6 +3,7 @@ import torch.nn as nn
try:
import awq_ext # with CUDA kernels
AWQ_INSTALLED = True
except:
AWQ_INSTALLED = False
......@@ -158,6 +159,11 @@ class WQLinear_GEMV(nn.Module):
@torch.no_grad()
def forward(self, x):
assert AWQ_INSTALLED, (
"AWQ kernels could not be loaded. "
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
)
out_shape = x.shape[:-1] + (self.out_features,)
inputs = x.reshape(-1, x.shape[-1])
......
import os
import sys
import torch
import platform
import requests
import importlib.util
from pathlib import Path
from setuptools import setup, find_packages
os.environ["CC"] = "g++"
os.environ["CXX"] = "g++"
def get_latest_kernels_version(repo):
"""
Get the latest version of the kernels from the github repo.
"""
response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest")
data = response.json()
tag_name = data["tag_name"]
version = tag_name.replace("v", "")
return version
def get_kernels_whl_url(
gpu_system_version,
release_version,
python_version,
platform,
architecture,
):
"""
Get the url for the kernels wheel file.
"""
return f"https://github.com/casper-hansen/AutoAWQ_kernels/releases/download/v{release_version}/autoawq_kernels-{release_version}+{gpu_system_version}-cp{python_version}-cp{python_version}-{platform}_{architecture}.whl"
AUTOAWQ_VERSION = "0.1.8"
PYPI_BUILD = os.getenv("PYPI_BUILD", "0") == "1"
HAS_CUDA = torch.cuda.is_available()
if not PYPI_BUILD and HAS_CUDA:
try:
CUDA_VERSION = "".join(os.environ.get("CUDA_VERSION", torch.version.cuda).split("."))[:3]
CUDA_VERSION = os.getenv("CUDA_VERSION", None) or torch.version.cuda
if CUDA_VERSION:
CUDA_VERSION = "".join(CUDA_VERSION.split("."))[:3]
ROCM_VERSION = os.getenv("ROCM_VERSION", None) or torch.version.hip
if ROCM_VERSION:
if ROCM_VERSION.startswith("5.6"):
ROCM_VERSION = "5.6.1"
elif ROCM_VERSION.startswith("5.7"):
ROCM_VERSION = "5.7.1"
ROCM_VERSION = "".join(ROCM_VERSION.split("."))[:3]
if not PYPI_BUILD:
if CUDA_VERSION:
AUTOAWQ_VERSION += f"+cu{CUDA_VERSION}"
except Exception as ex:
raise RuntimeError("Your system must have an Nvidia GPU for installing AutoAWQ")
elif ROCM_VERSION:
AUTOAWQ_VERSION += f"+rocm{ROCM_VERSION}"
else:
raise RuntimeError(
"Your system must have either Nvidia or AMD GPU to build this package."
)
common_setup_kwargs = {
"version": AUTOAWQ_VERSION,
......@@ -25,7 +64,9 @@ common_setup_kwargs = {
"license": "MIT",
"python_requires": ">=3.8.0",
"description": "AutoAWQ implements the AWQ algorithm for 4-bit quantization with a 2x speedup during inference.",
"long_description": (Path(__file__).parent / "README.md").read_text(encoding="UTF-8"),
"long_description": (Path(__file__).parent / "README.md").read_text(
encoding="UTF-8"
),
"long_description_content_type": "text/markdown",
"url": "https://github.com/casper-hansen/AutoAWQ",
"keywords": ["awq", "autoawq", "quantization", "transformers"],
......@@ -40,7 +81,7 @@ common_setup_kwargs = {
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: C++",
]
],
}
requirements = [
......@@ -51,21 +92,44 @@ requirements = [
"datasets",
]
# CUDA kernels
if platform.system().lower() != "darwin" and HAS_CUDA:
requirements.append("autoawq-kernels")
try:
importlib.metadata.version("autoawq-kernels")
KERNELS_INSTALLED = True
except importlib.metadata.PackageNotFoundError:
KERNELS_INSTALLED = False
# kernels can be downloaded from pypi for cuda+121 only
# for everything else, we need to download the wheels from github
if not KERNELS_INSTALLED and (CUDA_VERSION or ROCM_VERSION):
if CUDA_VERSION.startswith("12"):
requirements.append("autoawq-kernels")
elif CUDA_VERSION.startswith("11") or ROCM_VERSION in ["561", "571"]:
gpu_system_version = (
f"cu{CUDA_VERSION}" if CUDA_VERSION else f"rocm{ROCM_VERSION}"
)
kernels_version = get_latest_kernels_version("casper-hansen/AutoAWQ_kernels")
python_version = "".join(platform.python_version_tuple()[:2])
platform_name = platform.system().lower()
architecture = platform.machine().lower()
latest_rocm_kernels_wheels = get_kernels_whl_url(
gpu_system_version,
kernels_version,
python_version,
platform_name,
architecture,
)
requirements.append(f"autoawq-kernels@{latest_rocm_kernels_wheels}")
else:
raise RuntimeError(
"Your system have a GPU with an unsupported CUDA or ROCm version. "
"Please install the kernels manually from https://github.com/casper-hansen/AutoAWQ_kernels"
)
setup(
packages=find_packages(),
install_requires=requirements,
extras_require={
"eval": [
"lm_eval>=0.4.0",
"tabulate",
"protobuf",
"evaluate",
"scipy"
],
"eval": ["lm_eval>=0.4.0", "tabulate", "protobuf", "evaluate", "scipy"],
},
**common_setup_kwargs
**common_setup_kwargs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment