setup.py 8.9 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
import os
2
3
4
import re
import subprocess
import sys
5
from datetime import date
Zhekai Zhang's avatar
Zhekai Zhang committed
6
7

import setuptools
8
9
import torch
from packaging import version as packaging_version
Muyang Li's avatar
Muyang Li committed
10
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
Zhekai Zhang's avatar
Zhekai Zhang committed
11

muyangli's avatar
muyangli committed
12

sxtyzhangzk's avatar
sxtyzhangzk committed
13
14
15
16
17
class CustomBuildExtension(BuildExtension):
    def build_extensions(self):
        for ext in self.extensions:
            if not "cxx" in ext.extra_compile_args:
                ext.extra_compile_args["cxx"] = []
sxtyzhangzk's avatar
sxtyzhangzk committed
18
19
            if not "nvcc" in ext.extra_compile_args:
                ext.extra_compile_args["nvcc"] = []
sxtyzhangzk's avatar
sxtyzhangzk committed
20
21
            if self.compiler.compiler_type == "msvc":
                ext.extra_compile_args["cxx"] += ext.extra_compile_args["msvc"]
sxtyzhangzk's avatar
sxtyzhangzk committed
22
                ext.extra_compile_args["nvcc"] += ext.extra_compile_args["nvcc_msvc"]
sxtyzhangzk's avatar
sxtyzhangzk committed
23
24
25
26
            else:
                ext.extra_compile_args["cxx"] += ext.extra_compile_args["gcc"]
        super().build_extensions()

27

28
def get_sm_targets() -> list[str]:
limm's avatar
limm committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    is_rocm = hasattr(torch.version, 'hip') and torch.version.hip is not None

    if is_rocm:
        # ========== ROCm / AMD 路径 ==========
        # 手动指定或自动检测 AMD gfx 架构
        # 注意:ROCm 不使用 "SM",而是 "gfxXXXX"
        gfx_arch = os.getenv("AMDGPU_TARGETS", None)
        if gfx_arch is None:
            # 尝试从 PyTorch 获取(部分版本支持)
            try:
                # 示例:'gfx942' for MI300X
                props = torch.cuda.get_device_properties(0)
                # 在 ROCm 中,name 可能包含 gfx 信息,或需查表
                # 这里保守起见,要求用户显式设置
                raise NotImplementedError("Auto-detection of AMD arch not reliable. Set NUNCHAKU_AMD_ARCH=gfx942 etc.")
            except:
                raise RuntimeError(
                    "Running on ROCm, but NUNCHAKU_AMD_ARCH not set. "
                    "Please specify your AMD GPU architecture, e.g.: "
                    "export NUNCHAKU_AMD_ARCH=gfx942  # for MI300X"
                )
        return [gfx_arch]  # 返回如 ["gfx942"]
51
    else:
limm's avatar
limm committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        nvcc_path = os.path.join(CUDA_HOME, "bin/nvcc") if CUDA_HOME else "nvcc"
        try:
            nvcc_output = subprocess.check_output([nvcc_path, "--version"]).decode()
            match = re.search(r"release (\d+\.\d+), V(\d+\.\d+\.\d+)", nvcc_output)
            if match:
                nvcc_version = match.group(2)
            else:
                raise Exception("nvcc version not found")
            print(f"Found nvcc version: {nvcc_version}")
        except:
            raise Exception("nvcc not found")

        support_sm120 = packaging_version.parse(nvcc_version) >= packaging_version.parse("12.8")

        install_mode = os.getenv("NUNCHAKU_INSTALL_MODE", "FAST")
        if install_mode == "FAST":
            ret = []
            for i in range(torch.cuda.device_count()):
                capability = torch.cuda.get_device_capability(i)
                sm = f"{capability[0]}{capability[1]}"
                if sm == "120" and support_sm120:
                    sm = "120a"
                assert sm in ["75", "80", "86", "89", "92", "120a"], f"Unsupported SM {sm}"
                if sm not in ret:
                    ret.append(sm)
        else:
            assert install_mode == "ALL"
            ret = ["75", "80", "86", "89", "92"]
            if support_sm120:
                ret.append("120a")
        return ret
83
84


Zhekai Zhang's avatar
Zhekai Zhang committed
85
86
87
88
if __name__ == "__main__":
    fp = open("nunchaku/__version__.py", "r").read()
    version = eval(fp.strip().split()[-1])

89
90
    torch_version = torch.__version__.split("+")[0]
    torch_major_minor_version = ".".join(torch_version.split(".")[:2])
91
92
    if "dev" in version:
        version = version + date.today().strftime("%Y%m%d")  # data
93
94
    version = version + "+torch" + torch_major_minor_version

Zhekai Zhang's avatar
Zhekai Zhang committed
95
96
    ROOT_DIR = os.path.dirname(__file__)

fengzch's avatar
fengzch committed
97
98
99
100
101
102
103
    ignores = [
        "third_party/cutlass/*",
        "third_party/json/*",
        "third_party/mio/*",
        "third_party/spdlog/*",
    ]

Zhekai Zhang's avatar
Zhekai Zhang committed
104
105
106
107
108
109
    INCLUDE_DIRS = [
        "src",
        "third_party/cutlass/include",
        "third_party/json/include",
        "third_party/mio/include",
        "third_party/spdlog/include",
Zhekai Zhang's avatar
Zhekai Zhang committed
110
        "third_party/Block-Sparse-Attention/csrc/block_sparse_attn",
Zhekai Zhang's avatar
Zhekai Zhang committed
111
112
    ]

Samuel Tesfai's avatar
Samuel Tesfai committed
113
    INCLUDE_DIRS = [os.path.join(ROOT_DIR, dir) for dir in INCLUDE_DIRS]
Zhekai Zhang's avatar
Zhekai Zhang committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128

    DEBUG = False

    def ncond(s) -> list:
        if DEBUG:
            return []
        else:
            return [s]

    def cond(s) -> list:
        if DEBUG:
            return [s]
        else:
            return []

limm's avatar
limm committed
129
130
    #sm_targets = get_sm_targets()
    #print(f"Detected SM targets: {sm_targets}", file=sys.stderr)
muyangli's avatar
muyangli committed
131

limm's avatar
limm committed
132
    #assert len(sm_targets) > 0, "No SM targets found"
muyangli's avatar
muyangli committed
133

fengzch's avatar
fengzch committed
134
135
    GCC_FLAGS = ["-w", "-DENABLE_BF16=1", "-DBUILD_NUNCHAKU=1", "-fvisibility=hidden", "-g", "-std=c++2a", "-UNDEBUG", "-Og"]
    MSVC_FLAGS = ["-w", "/DENABLE_BF16=1", "/DBUILD_NUNCHAKU=1", "/std:c++2a", "/UNDEBUG", "/Zc:__cplusplus", "/FS"]
Zhekai Zhang's avatar
Zhekai Zhang committed
136
    NVCC_FLAGS = [
fengzch's avatar
fengzch committed
137
        "-w",
fengzch's avatar
fengzch committed
138
        "-DDCU_ASM",
sxtyzhangzk's avatar
sxtyzhangzk committed
139
        "-DENABLE_BF16=1",
Zhekai Zhang's avatar
Zhekai Zhang committed
140
141
        "-DBUILD_NUNCHAKU=1",
        "-g",
limm's avatar
limm committed
142
        "-std=c++2a",
Zhekai Zhang's avatar
Zhekai Zhang committed
143
        "-UNDEBUG",
limm's avatar
limm committed
144
145
146
147
148
        "-mllvm",
        "-nv-ptx-asm-transform=true",
        "-finline-asm-ptx",
        #"-Xcudafe",
        #"--diag_suppress=20208",  # spdlog: 'long double' is treated as 'double' in device code
Zhekai Zhang's avatar
Zhekai Zhang committed
149
150
151
152
153
154
155
156
157
        *cond("-G"),
        "-U__CUDA_NO_HALF_OPERATORS__",
        "-U__CUDA_NO_HALF_CONVERSIONS__",
        "-U__CUDA_NO_HALF2_OPERATORS__",
        "-U__CUDA_NO_HALF2_CONVERSIONS__",
        "-U__CUDA_NO_BFLOAT16_OPERATORS__",
        "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
        "-U__CUDA_NO_BFLOAT162_OPERATORS__",
        "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
limm's avatar
limm committed
158
159
        #f"--threads={len(sm_targets)}",
        f"--expt-relaxed-constexpr",
Zhekai Zhang's avatar
Zhekai Zhang committed
160
        "--expt-extended-lambda",
limm's avatar
limm committed
161
        #"--ptxas-options=--allow-expensive-optimizations=true",
Zhekai Zhang's avatar
Zhekai Zhang committed
162
    ]
163

muyangli's avatar
muyangli committed
164
165
166
    if os.getenv("NUNCHAKU_BUILD_WHEELS", "0") == "0":
        NVCC_FLAGS.append("--generate-line-info")

limm's avatar
limm committed
167
168
    #for target in sm_targets:
    #    NVCC_FLAGS += ["-gencode", f"arch=compute_{target},code=sm_{target}"]
169

Zhekai Zhang's avatar
Zhekai Zhang committed
170
    NVCC_MSVC_FLAGS = ["-Xcompiler", "/Zc:__cplusplus", "-Xcompiler", "/FS", "-Xcompiler", "/bigobj"]
Zhekai Zhang's avatar
Zhekai Zhang committed
171
172
173
174
175
176
177
178
179
180

    nunchaku_extension = CUDAExtension(
        name="nunchaku._C",
        sources=[
            "nunchaku/csrc/pybind.cpp",
            "src/interop/torch.cpp",
            "src/activation.cpp",
            "src/layernorm.cpp",
            "src/Linear.cpp",
            *ncond("src/FluxModel.cpp"),
muyangli's avatar
muyangli committed
181
            *ncond("src/SanaModel.cpp"),
Zhekai Zhang's avatar
Zhekai Zhang committed
182
            "src/Serialization.cpp",
183
            "src/Module.cpp",
fengzch's avatar
fengzch committed
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
            *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim64_fp16_sm80.cu"),
            *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim64_bf16_sm80.cu"),
            *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim128_fp16_sm80.cu"),
            *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim128_bf16_sm80.cu"),
            *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim64_fp16_sm80.cu"),
            *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim64_bf16_sm80.cu"),
            *ncond(
                "third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim128_fp16_sm80.cu"
            ),
            *ncond(
                "third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim128_bf16_sm80.cu"
            ),
            "src/kernels/activation_kernels.cu",
            "src/kernels/layernorm_kernels.cu",
            "src/kernels/misc_kernels.cu",
            "src/kernels/zgemm/gemm_w4a4.cu",
            "src/kernels/zgemm/gemm_w4a4_test.cu",
            "src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.cu",
            "src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.cu",
            "src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.cu",
            "src/kernels/zgemm/gemm_w4a4_launch_bf16_int4.cu",
            "src/kernels/zgemm/gemm_w4a4_launch_bf16_fp4.cu",
            "src/kernels/zgemm/gemm_w8a8.cu",
            "src/kernels/zgemm/attention.cu",
            "src/kernels/dwconv.cu",
            "src/kernels/gemm_batched.cu",
            "src/kernels/gemm_f16.cu",
            "src/kernels/awq/gemm_awq.cu",
Zhekai Zhang's avatar
Zhekai Zhang committed
212
            "src/kernels/awq/gemv_awq.cu",
Zhekai Zhang's avatar
Zhekai Zhang committed
213
214
            *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/flash_api.cpp"),
            *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/flash_api_adapter.cpp"),
Zhekai Zhang's avatar
Zhekai Zhang committed
215
        ],
muyangli's avatar
muyangli committed
216
        extra_compile_args={"gcc": GCC_FLAGS, "msvc": MSVC_FLAGS, "nvcc": NVCC_FLAGS, "nvcc_msvc": NVCC_MSVC_FLAGS},
sxtyzhangzk's avatar
sxtyzhangzk committed
217
        include_dirs=INCLUDE_DIRS,
fengzch's avatar
fengzch committed
218
        ignores=ignores,
Zhekai Zhang's avatar
Zhekai Zhang committed
219
220
221
222
223
224
225
    )

    setuptools.setup(
        name="nunchaku",
        version=version,
        packages=setuptools.find_packages(),
        ext_modules=[nunchaku_extension],
sxtyzhangzk's avatar
sxtyzhangzk committed
226
        cmdclass={"build_ext": CustomBuildExtension},
Zhekai Zhang's avatar
Zhekai Zhang committed
227
    )