setup.py 3.14 KB
Newer Older
Casper Hansen's avatar
Casper Hansen committed
1
import os
Casper's avatar
Casper committed
2
import torch
3
from pathlib import Path
Casper Hansen's avatar
Casper Hansen committed
4
from setuptools import setup, find_packages
Casper's avatar
Casper committed
5
from distutils.sysconfig import get_python_lib
qwopqwop200's avatar
qwopqwop200 committed
6
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
Casper's avatar
Casper committed
7

8
9
os.environ["CC"] = "g++"
os.environ["CXX"] = "g++"
Casper's avatar
Casper committed
10

11
12
13
14
15
16
17
18
19
20
21
common_setup_kwargs = {
    "version": "0.0.1",
    "name": "autoawq",
    "author": "Casper Hansen",
    "license": "MIT",
    "python_requires": ">=3.8.0",
    "description": "AutoAWQ implements the AWQ algorithm for 4-bit quantization with a 2x speedup during inference.",
    "long_description": (Path(__file__).parent / "README.md").read_text(encoding="UTF-8"),
    "long_description_content_type": "text/markdown",
    "url": "https://github.com/casper-hansen/AutoAWQ",
    "keywords": ["awq", "autoawq", "quantization", "transformers"],
qwopqwop200's avatar
qwopqwop200 committed
22
    "platforms": ["windows", "linux"],
23
24
25
26
27
28
29
30
31
32
33
34
    "classifiers": [
        "Environment :: GPU :: NVIDIA CUDA :: 11.8",
        "Environment :: GPU :: NVIDIA CUDA :: 12",
        "License :: OSI Approved :: MIT License",
        "Natural Language :: English",
        "Programming Language :: Python :: 3.8",
        "Programming Language :: Python :: 3.9",
        "Programming Language :: Python :: 3.10",
        "Programming Language :: Python :: 3.11",
        "Programming Language :: C++",
    ]
}
Casper Hansen's avatar
Casper Hansen committed
35

36
37
38
39
40
41
42
43
44
45
requirements = [
    "torch>=2.0.0",
    "transformers>=4.32.0",
    "tokenizers>=0.12.1",
    "accelerate",
    "sentencepiece",
    "lm_eval",
    "texttable",
    "toml",
    "attributedict",
Casper's avatar
Casper committed
46
    "protobuf",
47
    "torchvision"
Casper Hansen's avatar
Casper Hansen committed
48
49
]

50
include_dirs = []
51

52
53
54
conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include")
if os.path.isdir(conda_cuda_include_dir):
    include_dirs.append(conda_cuda_include_dir)
Casper's avatar
Casper committed
55

Casper's avatar
Casper committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def check_dependencies():
    if CUDA_HOME is None:
        raise RuntimeError(
            f"Cannot find CUDA_HOME. CUDA must be available to build the package.")

def get_compute_capabilities():
    # Collect the compute capabilities of all available GPUs.
    compute_capabilities = set()
    for i in range(torch.cuda.device_count()):
        major, minor = torch.cuda.get_device_capability(i)
        if major < 8:
            raise RuntimeError("GPUs with compute capability less than 8.0 are not supported.")
        compute_capabilities.add(major * 10 + minor)

    # figure out compute capability
    compute_capabilities = {80, 86, 89, 90}

    capability_flags = []
    for cap in compute_capabilities:
        capability_flags += ["-gencode", f"arch=compute_{cap},code=sm_{cap}"]

    return capability_flags

check_dependencies()
arch_flags = get_compute_capabilities()

82
extensions = [
qwopqwop200's avatar
qwopqwop200 committed
83
    CUDAExtension(
84
85
86
87
88
89
        "awq_inference_engine",
        [
            "awq_cuda/pybind.cpp",
            "awq_cuda/quantization/gemm_cuda_gen.cu",
            "awq_cuda/layernorm/layernorm.cu",
            "awq_cuda/position_embedding/pos_encoding_kernels.cu"
qwopqwop200's avatar
qwopqwop200 committed
90
        ]
Casper Hansen's avatar
Casper Hansen committed
91
    )
92
]
Casper Hansen's avatar
Casper Hansen committed
93

94
95
additional_setup_kwargs = {
    "ext_modules": extensions,
Casper's avatar
Casper committed
96
    "cmdclass": {'build_ext': BuildExtension}
97
}
Casper's avatar
Casper committed
98

99
common_setup_kwargs.update(additional_setup_kwargs)
Casper's avatar
Casper committed
100

Casper Hansen's avatar
Casper Hansen committed
101
setup(
102
103
    packages=find_packages(),
    install_requires=requirements,
Casper's avatar
Casper committed
104
    include_dirs=include_dirs,
105
    **common_setup_kwargs
qwopqwop200's avatar
qwopqwop200 committed
106
)