setup.py 8.9 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
import os
2
3
4
import re
import subprocess
import sys
5
from datetime import date
Zhekai Zhang's avatar
Zhekai Zhang committed
6
7

import setuptools
8
9
import torch
from packaging import version as packaging_version
Muyang Li's avatar
Muyang Li committed
10
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension
Zhekai Zhang's avatar
Zhekai Zhang committed
11

muyangli's avatar
muyangli committed
12

sxtyzhangzk's avatar
sxtyzhangzk committed
13
14
15
16
17
class CustomBuildExtension(BuildExtension):
    def build_extensions(self):
        for ext in self.extensions:
            if not "cxx" in ext.extra_compile_args:
                ext.extra_compile_args["cxx"] = []
sxtyzhangzk's avatar
sxtyzhangzk committed
18
19
            if not "nvcc" in ext.extra_compile_args:
                ext.extra_compile_args["nvcc"] = []
sxtyzhangzk's avatar
sxtyzhangzk committed
20
21
            if self.compiler.compiler_type == "msvc":
                ext.extra_compile_args["cxx"] += ext.extra_compile_args["msvc"]
sxtyzhangzk's avatar
sxtyzhangzk committed
22
                ext.extra_compile_args["nvcc"] += ext.extra_compile_args["nvcc_msvc"]
sxtyzhangzk's avatar
sxtyzhangzk committed
23
24
25
26
            else:
                ext.extra_compile_args["cxx"] += ext.extra_compile_args["gcc"]
        super().build_extensions()

27

28
def get_sm_targets() -> list[str]:
limm's avatar
limm committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    is_rocm = hasattr(torch.version, 'hip') and torch.version.hip is not None

    if is_rocm:
        # ========== ROCm / AMD 路径 ==========
        # 手动指定或自动检测 AMD gfx 架构
        # 注意:ROCm 不使用 "SM",而是 "gfxXXXX"
        gfx_arch = os.getenv("AMDGPU_TARGETS", None)
        if gfx_arch is None:
            # 尝试从 PyTorch 获取(部分版本支持)
            try:
                # 示例:'gfx942' for MI300X
                props = torch.cuda.get_device_properties(0)
                # 在 ROCm 中,name 可能包含 gfx 信息,或需查表
                # 这里保守起见,要求用户显式设置
                raise NotImplementedError("Auto-detection of AMD arch not reliable. Set NUNCHAKU_AMD_ARCH=gfx942 etc.")
            except:
                raise RuntimeError(
                    "Running on ROCm, but NUNCHAKU_AMD_ARCH not set. "
                    "Please specify your AMD GPU architecture, e.g.: "
                    "export NUNCHAKU_AMD_ARCH=gfx942  # for MI300X"
                )
        return [gfx_arch]  # 返回如 ["gfx942"]
51
    else:
limm's avatar
limm committed
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        nvcc_path = os.path.join(CUDA_HOME, "bin/nvcc") if CUDA_HOME else "nvcc"
        try:
            nvcc_output = subprocess.check_output([nvcc_path, "--version"]).decode()
            match = re.search(r"release (\d+\.\d+), V(\d+\.\d+\.\d+)", nvcc_output)
            if match:
                nvcc_version = match.group(2)
            else:
                raise Exception("nvcc version not found")
            print(f"Found nvcc version: {nvcc_version}")
        except:
            raise Exception("nvcc not found")

        support_sm120 = packaging_version.parse(nvcc_version) >= packaging_version.parse("12.8")

        install_mode = os.getenv("NUNCHAKU_INSTALL_MODE", "FAST")
        if install_mode == "FAST":
            ret = []
            for i in range(torch.cuda.device_count()):
                capability = torch.cuda.get_device_capability(i)
                sm = f"{capability[0]}{capability[1]}"
                if sm == "120" and support_sm120:
                    sm = "120a"
                assert sm in ["75", "80", "86", "89", "92", "120a"], f"Unsupported SM {sm}"
                if sm not in ret:
                    ret.append(sm)
        else:
            assert install_mode == "ALL"
            ret = ["75", "80", "86", "89", "92"]
            if support_sm120:
                ret.append("120a")
        return ret
83
84


Zhekai Zhang's avatar
Zhekai Zhang committed
85
86
87
88
if __name__ == "__main__":
    fp = open("nunchaku/__version__.py", "r").read()
    version = eval(fp.strip().split()[-1])

89
90
    torch_version = torch.__version__.split("+")[0]
    torch_major_minor_version = ".".join(torch_version.split(".")[:2])
91
92
    if "dev" in version:
        version = version + date.today().strftime("%Y%m%d")  # data
93
94
    version = version + "+torch" + torch_major_minor_version

Zhekai Zhang's avatar
Zhekai Zhang committed
95
96
97
98
99
100
101
102
    ROOT_DIR = os.path.dirname(__file__)

    INCLUDE_DIRS = [
        "src",
        "third_party/cutlass/include",
        "third_party/json/include",
        "third_party/mio/include",
        "third_party/spdlog/include",
103
        # "third_party/Block-Sparse-Attention/csrc/block_sparse_attn",
fengzch's avatar
fengzch committed
104
        "third_party/flash_c_api/include",
Zhekai Zhang's avatar
Zhekai Zhang committed
105
106
    ]

Samuel Tesfai's avatar
Samuel Tesfai committed
107
    INCLUDE_DIRS = [os.path.join(ROOT_DIR, dir) for dir in INCLUDE_DIRS]
Zhekai Zhang's avatar
Zhekai Zhang committed
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

    DEBUG = False

    def ncond(s) -> list:
        if DEBUG:
            return []
        else:
            return [s]

    def cond(s) -> list:
        if DEBUG:
            return [s]
        else:
            return []

limm's avatar
limm committed
123
124
    #sm_targets = get_sm_targets()
    #print(f"Detected SM targets: {sm_targets}", file=sys.stderr)
muyangli's avatar
muyangli committed
125

limm's avatar
limm committed
126
    #assert len(sm_targets) > 0, "No SM targets found"
muyangli's avatar
muyangli committed
127

fengzch's avatar
fengzch committed
128
    GCC_FLAGS = ["-DENABLE_BF16=1", "-DBUILD_NUNCHAKU=1", "-fvisibility=hidden", "-g", "-std=c++2a", "-UNDEBUG", "-O1"]
fengzch's avatar
fengzch committed
129
    MSVC_FLAGS = ["/DENABLE_BF16=1", "/DBUILD_NUNCHAKU=1", "/std:c++2a", "/UNDEBUG", "/Zc:__cplusplus", "/FS"]
Zhekai Zhang's avatar
Zhekai Zhang committed
130
    NVCC_FLAGS = [
fengzch's avatar
fengzch committed
131
        "-DDCU_ASM",
fengzch's avatar
fengzch committed
132
        "-DUSE_ROCM",
sxtyzhangzk's avatar
sxtyzhangzk committed
133
        "-DENABLE_BF16=1",
Zhekai Zhang's avatar
Zhekai Zhang committed
134
135
        "-DBUILD_NUNCHAKU=1",
        "-g",
fengzch's avatar
fengzch committed
136
        "-O1",
limm's avatar
limm committed
137
        "-std=c++2a",
Zhekai Zhang's avatar
Zhekai Zhang committed
138
        "-UNDEBUG",
limm's avatar
limm committed
139
140
141
142
143
        "-mllvm",
        "-nv-ptx-asm-transform=true",
        "-finline-asm-ptx",
        #"-Xcudafe",
        #"--diag_suppress=20208",  # spdlog: 'long double' is treated as 'double' in device code
Zhekai Zhang's avatar
Zhekai Zhang committed
144
145
146
147
148
149
150
151
152
        *cond("-G"),
        "-U__CUDA_NO_HALF_OPERATORS__",
        "-U__CUDA_NO_HALF_CONVERSIONS__",
        "-U__CUDA_NO_HALF2_OPERATORS__",
        "-U__CUDA_NO_HALF2_CONVERSIONS__",
        "-U__CUDA_NO_BFLOAT16_OPERATORS__",
        "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
        "-U__CUDA_NO_BFLOAT162_OPERATORS__",
        "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
limm's avatar
limm committed
153
154
        #f"--threads={len(sm_targets)}",
        f"--expt-relaxed-constexpr",
Zhekai Zhang's avatar
Zhekai Zhang committed
155
        "--expt-extended-lambda",
limm's avatar
limm committed
156
        #"--ptxas-options=--allow-expensive-optimizations=true",
Zhekai Zhang's avatar
Zhekai Zhang committed
157
    ]
158

muyangli's avatar
muyangli committed
159
160
161
    if os.getenv("NUNCHAKU_BUILD_WHEELS", "0") == "0":
        NVCC_FLAGS.append("--generate-line-info")

limm's avatar
limm committed
162
163
    #for target in sm_targets:
    #    NVCC_FLAGS += ["-gencode", f"arch=compute_{target},code=sm_{target}"]
164

Zhekai Zhang's avatar
Zhekai Zhang committed
165
    NVCC_MSVC_FLAGS = ["-Xcompiler", "/Zc:__cplusplus", "-Xcompiler", "/FS", "-Xcompiler", "/bigobj"]
Zhekai Zhang's avatar
Zhekai Zhang committed
166
167
168
169
170
171
172
173
174
175

    nunchaku_extension = CUDAExtension(
        name="nunchaku._C",
        sources=[
            "nunchaku/csrc/pybind.cpp",
            "src/interop/torch.cpp",
            "src/activation.cpp",
            "src/layernorm.cpp",
            "src/Linear.cpp",
            *ncond("src/FluxModel.cpp"),
muyangli's avatar
muyangli committed
176
            *ncond("src/SanaModel.cpp"),
Zhekai Zhang's avatar
Zhekai Zhang committed
177
            "src/Serialization.cpp",
178
            "src/Module.cpp",
fengzch's avatar
fengzch committed
179
180
181
182
183
184
185
186
187
188
189
190
            # *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim64_fp16_sm80.cu"),
            # *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim64_bf16_sm80.cu"),
            # *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim128_fp16_sm80.cu"),
            # *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim128_bf16_sm80.cu"),
            # *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim64_fp16_sm80.cu"),
            # *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim64_bf16_sm80.cu"),
            # *ncond(
            #     "third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim128_fp16_sm80.cu"
            # ),
            # *ncond(
            #     "third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim128_bf16_sm80.cu"
            # ),
fengzch's avatar
fengzch committed
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
            "src/kernels/activation_kernels.cu",
            "src/kernels/layernorm_kernels.cu",
            "src/kernels/misc_kernels.cu",
            "src/kernels/zgemm/gemm_w4a4.cu",
            "src/kernels/zgemm/gemm_w4a4_test.cu",
            "src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.cu",
            "src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.cu",
            "src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.cu",
            "src/kernels/zgemm/gemm_w4a4_launch_bf16_int4.cu",
            "src/kernels/zgemm/gemm_w4a4_launch_bf16_fp4.cu",
            "src/kernels/zgemm/gemm_w8a8.cu",
            "src/kernels/zgemm/attention.cu",
            "src/kernels/dwconv.cu",
            "src/kernels/gemm_batched.cu",
            "src/kernels/gemm_f16.cu",
            "src/kernels/awq/gemm_awq.cu",
Zhekai Zhang's avatar
Zhekai Zhang committed
207
            "src/kernels/awq/gemv_awq.cu",
fengzch's avatar
fengzch committed
208
209
            #*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/flash_api.cpp"),
            #*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/flash_api_adapter.cpp"),
Zhekai Zhang's avatar
Zhekai Zhang committed
210
        ],
muyangli's avatar
muyangli committed
211
        extra_compile_args={"gcc": GCC_FLAGS, "msvc": MSVC_FLAGS, "nvcc": NVCC_FLAGS, "nvcc_msvc": NVCC_MSVC_FLAGS},
sxtyzhangzk's avatar
sxtyzhangzk committed
212
        include_dirs=INCLUDE_DIRS,
fengzch's avatar
fengzch committed
213
214
        libraries=["flash_atten_c"],
        library_dirs=["third_party/flash_c_api/lib"]
Zhekai Zhang's avatar
Zhekai Zhang committed
215
216
217
218
219
220
221
    )

    setuptools.setup(
        name="nunchaku",
        version=version,
        packages=setuptools.find_packages(),
        ext_modules=[nunchaku_extension],
sxtyzhangzk's avatar
sxtyzhangzk committed
222
        cmdclass={"build_ext": CustomBuildExtension},
Zhekai Zhang's avatar
Zhekai Zhang committed
223
    )