setup.py 4.59 KB
Newer Older
Casper Hansen's avatar
Casper Hansen committed
1
import os
Casper's avatar
Casper committed
2
import torch
3
from pathlib import Path
Casper Hansen's avatar
Casper Hansen committed
4
from setuptools import setup, find_packages
Casper's avatar
Casper committed
5
from distutils.sysconfig import get_python_lib
Casper's avatar
Casper committed
6
from torch.utils.cpp_extension import BuildExtension, CUDA_HOME, CUDAExtension
Casper's avatar
Casper committed
7

8
9
os.environ["CC"] = "g++"
os.environ["CXX"] = "g++"
Casper's avatar
Casper committed
10

11
common_setup_kwargs = {
Casper's avatar
Casper committed
12
    "version": "0.0.2",
13
14
15
16
17
18
19
20
21
    "name": "autoawq",
    "author": "Casper Hansen",
    "license": "MIT",
    "python_requires": ">=3.8.0",
    "description": "AutoAWQ implements the AWQ algorithm for 4-bit quantization with a 2x speedup during inference.",
    "long_description": (Path(__file__).parent / "README.md").read_text(encoding="UTF-8"),
    "long_description_content_type": "text/markdown",
    "url": "https://github.com/casper-hansen/AutoAWQ",
    "keywords": ["awq", "autoawq", "quantization", "transformers"],
Casper's avatar
Casper committed
22
    "platforms": ["linux", "windows"],
23
24
25
26
27
28
29
30
31
32
33
34
    "classifiers": [
        "Environment :: GPU :: NVIDIA CUDA :: 11.8",
        "Environment :: GPU :: NVIDIA CUDA :: 12",
        "License :: OSI Approved :: MIT License",
        "Natural Language :: English",
        "Programming Language :: Python :: 3.8",
        "Programming Language :: Python :: 3.9",
        "Programming Language :: Python :: 3.10",
        "Programming Language :: Python :: 3.11",
        "Programming Language :: C++",
    ]
}
Casper Hansen's avatar
Casper Hansen committed
35

36
37
38
39
40
41
42
43
44
45
requirements = [
    "torch>=2.0.0",
    "transformers>=4.32.0",
    "tokenizers>=0.12.1",
    "accelerate",
    "sentencepiece",
    "lm_eval",
    "texttable",
    "toml",
    "attributedict",
Casper's avatar
Casper committed
46
    "protobuf",
Casper Hansen's avatar
Casper Hansen committed
47
    "torchvision",
Casper Hansen's avatar
Casper Hansen committed
48
49
    "tabulate",
    "xformers"
Casper Hansen's avatar
Casper Hansen committed
50
51
]

Casper Hansen's avatar
Casper Hansen committed
52
53
def get_include_dirs():
    include_dirs = []
54

Casper Hansen's avatar
Casper Hansen committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include")
    if os.path.isdir(conda_cuda_include_dir):
        include_dirs.append(conda_cuda_include_dir)
    this_dir = os.path.dirname(os.path.abspath(__file__))
    include_dirs.append(this_dir)

    return include_dirs

def get_generator_flag():
    generator_flag = []
    torch_dir = torch.__path__[0]
    if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
        generator_flag = ["-DOLD_GENERATOR_PATH"]
    
    return generator_flag
Casper's avatar
Casper committed
70

Casper's avatar
Casper committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
def check_dependencies():
    if CUDA_HOME is None:
        raise RuntimeError(
            f"Cannot find CUDA_HOME. CUDA must be available to build the package.")

def get_compute_capabilities():
    # Collect the compute capabilities of all available GPUs.
    compute_capabilities = set()
    for i in range(torch.cuda.device_count()):
        major, minor = torch.cuda.get_device_capability(i)
        if major < 8:
            raise RuntimeError("GPUs with compute capability less than 8.0 are not supported.")
        compute_capabilities.add(major * 10 + minor)

    # figure out compute capability
    compute_capabilities = {80, 86, 89, 90}

    capability_flags = []
    for cap in compute_capabilities:
        capability_flags += ["-gencode", f"arch=compute_{cap},code=sm_{cap}"]

    return capability_flags

check_dependencies()
Casper Hansen's avatar
Casper Hansen committed
95
96
include_dirs = get_include_dirs()
generator_flags = get_generator_flag()
Casper's avatar
Casper committed
97
98
arch_flags = get_compute_capabilities()

Casper's avatar
Casper committed
99
100
101
102
103
104
105
if os.name == "nt":
    # Relaxed args on Windows
    extra_compile_args={
        "nvcc": arch_flags
    }
else:
    extra_compile_args={
Casper Hansen's avatar
Casper Hansen committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        "cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"],
        "nvcc": [
            "-O3", 
            "-std=c++17",
            "-DENABLE_BF16",
            "-U__CUDA_NO_HALF_OPERATORS__",
            "-U__CUDA_NO_HALF_CONVERSIONS__",
            "-U__CUDA_NO_BFLOAT16_OPERATORS__",
            "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
            "-U__CUDA_NO_BFLOAT162_OPERATORS__",
            "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
            "--expt-relaxed-constexpr",
            "--expt-extended-lambda",
            "--use_fast_math",
        ] + arch_flags + generator_flags
Casper's avatar
Casper committed
121
122
    }

123
extensions = [
qwopqwop200's avatar
qwopqwop200 committed
124
    CUDAExtension(
125
126
127
128
129
        "awq_inference_engine",
        [
            "awq_cuda/pybind.cpp",
            "awq_cuda/quantization/gemm_cuda_gen.cu",
            "awq_cuda/layernorm/layernorm.cu",
Casper Hansen's avatar
Casper Hansen committed
130
131
132
133
            "awq_cuda/position_embedding/pos_encoding_kernels.cu",
            "awq_cuda/quantization/gemv_cuda.cu",
            "awq_cuda/attention/ft_attention.cpp",
            "awq_cuda/attention/decoder_masked_multihead_attention.cu"
Casper's avatar
Casper committed
134
        ], extra_compile_args=extra_compile_args
Casper Hansen's avatar
Casper Hansen committed
135
    )
136
]
Casper Hansen's avatar
Casper Hansen committed
137

138
139
additional_setup_kwargs = {
    "ext_modules": extensions,
Casper's avatar
Casper committed
140
    "cmdclass": {'build_ext': BuildExtension}
141
}
Casper's avatar
Casper committed
142

143
common_setup_kwargs.update(additional_setup_kwargs)
Casper's avatar
Casper committed
144

Casper Hansen's avatar
Casper Hansen committed
145
setup(
146
147
    packages=find_packages(),
    install_requires=requirements,
Casper's avatar
Casper committed
148
    include_dirs=include_dirs,
149
    **common_setup_kwargs
Casper's avatar
Casper committed
150
)