Commit ff556eb0 authored by Casper's avatar Casper
Browse files

Create more detailed setup.py

parent 60296077
import os
import torch
from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
if CUDA_HOME is None:
raise RuntimeError(
f"Cannot find CUDA_HOME. CUDA must be available to build the package.")
# Collect the compute capabilities of all available GPUs.
compute_capabilities = set()
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(i)
if major < 8:
raise RuntimeError("GPUs with compute capability less than 8.0 are not supported.")
compute_capabilities.add(major * 10 + minor)
# Get environment variables
build_cuda_extension = os.environ.get('BUILD_CUDA_EXT', '1') == '1'
......@@ -22,6 +35,11 @@ if not torch_is_prebuilt:
ext_modules = []
if build_cuda_extension:
n_threads = min(os.cpu_count(), 8)
cxx_args = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17"]
nvcc_args = ["-O3", "-std=c++17", "--threads", n_threads]
ext_modules.append(
CUDAExtension(
name="awq_inference_engine",
......@@ -32,25 +50,36 @@ if build_cuda_extension:
"awq_cuda/position_embedding/pos_encoding_kernels.cu"
],
extra_compile_args={
"cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17"],
"nvcc": ["-O3", "-std=c++17"]
"cxx": cxx_args,
"nvcc": nvcc_args
},
)
)
setup(
name="awq",
name="autoawq",
version="0.1.0",
description="An efficient and accurate low-bit weight quantization(INT3/4) method for LLMs.",
author="Casper Hansen",
license="MIT",
description="AutoAWQ implements the AWQ algorithm for 4-bit quantization with a 2x speedup during inference.",
long_description=open("README.md", "r").read(),
long_description_content_type="text/markdown",
python_requires=">=3.8",
url="https://github.com/casper-hansen/AutoAWQ",
keywords=["awq", "autoawq", "quantization", "transformers"],
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
"Environment :: GPU :: NVIDIA CUDA :: 11.8",
"Environment :: GPU :: NVIDIA CUDA :: 12",
"License :: OSI Approved :: MIT License",
"Natural Language :: English",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: C++",
],
install_requires=dependencies,
packages=find_packages(exclude=["results*", "scripts*", "examples*"]),
packages=find_packages(exclude=["examples*"]),
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension}
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment