setup.py 3.44 KB
Newer Older
Casper Hansen's avatar
Casper Hansen committed
1
import os
Casper's avatar
Casper committed
2
import torch
Casper Hansen's avatar
Casper Hansen committed
3
from setuptools import setup, find_packages
Casper's avatar
Casper committed
4
from distutils.sysconfig import get_python_lib
Casper's avatar
Casper committed
5
6
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME

Casper's avatar
Casper committed
7
8
9
10
def check_dependencies():
    if CUDA_HOME is None:
        raise RuntimeError(
            f"Cannot find CUDA_HOME. CUDA must be available to build the package.")
Casper's avatar
Casper committed
11

Casper's avatar
Casper committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def get_compute_capabilities():
    # Collect the compute capabilities of all available GPUs.
    compute_capabilities = set()
    for i in range(torch.cuda.device_count()):
        major, minor = torch.cuda.get_device_capability(i)
        if major < 8:
            raise RuntimeError("GPUs with compute capability less than 8.0 are not supported.")
        compute_capabilities.add(major * 10 + minor)
    
    # figure out compute capability
    compute_capabilities = {80, 86, 89, 90}
    
    capability_flags = []
    for cap in compute_capabilities:
        capability_flags += ["-gencode", f"arch=compute_{cap},code=sm_{cap}"]
    
    return capability_flags
Casper Hansen's avatar
Casper Hansen committed
29
30
31
32
33
34
35

# Define dependencies
dependencies = [
    "accelerate", "sentencepiece", "tokenizers>=0.12.1",
    "transformers>=4.32.0", 
    "lm_eval", "texttable",
    "toml", "attributedict",
Casper's avatar
Casper committed
36
37
    "protobuf",
    "torch>=2.0.0", "torchvision"
Casper Hansen's avatar
Casper Hansen committed
38
39
]

Casper's avatar
Casper committed
40
41
# Get environment variables
build_cuda_extension = os.environ.get('BUILD_CUDA_EXT', '1') == '1'
Casper Hansen's avatar
Casper Hansen committed
42
43
44
45
46

# Setup CUDA extension
ext_modules = []

if build_cuda_extension:
47
48
49
50
    # num threads
    n_threads = str(min(os.cpu_count(), 8))

    # final args
Casper's avatar
Casper committed
51
    capability_flags = get_compute_capabilities()
Casper's avatar
Casper committed
52
    cxx_args = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17"]
53
    nvcc_args = ["-O3", "-std=c++17", "--threads", n_threads] + capability_flags
Casper's avatar
Casper committed
54

Casper Hansen's avatar
Casper Hansen committed
55
56
57
58
    ext_modules.append(
        CUDAExtension(
            name="awq_inference_engine",
            sources=[
Casper's avatar
Casper committed
59
60
61
62
                "awq_cuda/pybind.cpp",
                "awq_cuda/quantization/gemm_cuda_gen.cu",
                "awq_cuda/layernorm/layernorm.cu",
                "awq_cuda/position_embedding/pos_encoding_kernels.cu"
Casper Hansen's avatar
Casper Hansen committed
63
64
            ],
            extra_compile_args={
Casper's avatar
Casper committed
65
66
                "cxx": cxx_args,
                "nvcc": nvcc_args
Casper Hansen's avatar
Casper Hansen committed
67
68
69
70
            },
        )
    )

Casper's avatar
Casper committed
71
72
73
74
75
76
77
# Find directories to be included in setup
include_dirs = []
conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include")

if os.path.isdir(conda_cuda_include_dir):
    include_dirs.append(conda_cuda_include_dir)

Casper Hansen's avatar
Casper Hansen committed
78
setup(
Casper's avatar
Casper committed
79
    name="autoawq",
Casper Hansen's avatar
Casper Hansen committed
80
    version="0.1.0",
Casper's avatar
Casper committed
81
82
83
    author="Casper Hansen",
    license="MIT",
    description="AutoAWQ implements the AWQ algorithm for 4-bit quantization with a 2x speedup during inference.",
Casper Hansen's avatar
Casper Hansen committed
84
85
86
    long_description=open("README.md", "r").read(),
    long_description_content_type="text/markdown",
    python_requires=">=3.8",
Casper's avatar
Casper committed
87
88
    url="https://github.com/casper-hansen/AutoAWQ",
    keywords=["awq", "autoawq", "quantization", "transformers"],
Casper Hansen's avatar
Casper Hansen committed
89
    classifiers=[
Casper's avatar
Casper committed
90
91
92
93
94
95
96
97
98
        "Environment :: GPU :: NVIDIA CUDA :: 11.8",
        "Environment :: GPU :: NVIDIA CUDA :: 12",
        "License :: OSI Approved :: MIT License",
        "Natural Language :: English",
        "Programming Language :: Python :: 3.8",
        "Programming Language :: Python :: 3.9",
        "Programming Language :: Python :: 3.10",
        "Programming Language :: Python :: 3.11",
        "Programming Language :: C++",
Casper Hansen's avatar
Casper Hansen committed
99
100
    ],
    install_requires=dependencies,
Casper's avatar
Casper committed
101
    include_dirs=include_dirs,
Casper's avatar
Casper committed
102
    packages=find_packages(exclude=["examples*"]),
Casper Hansen's avatar
Casper Hansen committed
103
104
105
    ext_modules=ext_modules,
    cmdclass={"build_ext": BuildExtension}
)