import os import re import subprocess import sys from datetime import date import setuptools import torch from packaging import version as packaging_version from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension class CustomBuildExtension(BuildExtension): def build_extensions(self): for ext in self.extensions: if not "cxx" in ext.extra_compile_args: ext.extra_compile_args["cxx"] = [] if not "nvcc" in ext.extra_compile_args: ext.extra_compile_args["nvcc"] = [] if self.compiler.compiler_type == "msvc": ext.extra_compile_args["cxx"] += ext.extra_compile_args["msvc"] ext.extra_compile_args["nvcc"] += ext.extra_compile_args["nvcc_msvc"] else: ext.extra_compile_args["cxx"] += ext.extra_compile_args["gcc"] super().build_extensions() def get_sm_targets() -> list[str]: is_rocm = hasattr(torch.version, 'hip') and torch.version.hip is not None if is_rocm: # ========== ROCm / AMD 路径 ========== # 手动指定或自动检测 AMD gfx 架构 # 注意:ROCm 不使用 "SM",而是 "gfxXXXX" gfx_arch = os.getenv("AMDGPU_TARGETS", None) if gfx_arch is None: # 尝试从 PyTorch 获取(部分版本支持) try: # 示例:'gfx942' for MI300X props = torch.cuda.get_device_properties(0) # 在 ROCm 中,name 可能包含 gfx 信息,或需查表 # 这里保守起见,要求用户显式设置 raise NotImplementedError("Auto-detection of AMD arch not reliable. Set NUNCHAKU_AMD_ARCH=gfx942 etc.") except: raise RuntimeError( "Running on ROCm, but NUNCHAKU_AMD_ARCH not set. " "Please specify your AMD GPU architecture, e.g.: " "export NUNCHAKU_AMD_ARCH=gfx942 # for MI300X" ) return [gfx_arch] # 返回如 ["gfx942"] else: nvcc_path = os.path.join(CUDA_HOME, "bin/nvcc") if CUDA_HOME else "nvcc" try: nvcc_output = subprocess.check_output([nvcc_path, "--version"]).decode() match = re.search(r"release (\d+\.\d+), V(\d+\.\d+\.\d+)", nvcc_output) if match: nvcc_version = match.group(2) else: raise Exception("nvcc version not found") print(f"Found nvcc version: {nvcc_version}") except: raise Exception("nvcc not found") support_sm120 = packaging_version.parse(nvcc_version) >= packaging_version.parse("12.8") install_mode = os.getenv("NUNCHAKU_INSTALL_MODE", "FAST") if install_mode == "FAST": ret = [] for i in range(torch.cuda.device_count()): capability = torch.cuda.get_device_capability(i) sm = f"{capability[0]}{capability[1]}" if sm == "120" and support_sm120: sm = "120a" assert sm in ["75", "80", "86", "89", "92", "120a"], f"Unsupported SM {sm}" if sm not in ret: ret.append(sm) else: assert install_mode == "ALL" ret = ["75", "80", "86", "89", "92"] if support_sm120: ret.append("120a") return ret if __name__ == "__main__": fp = open("nunchaku/__version__.py", "r").read() version = eval(fp.strip().split()[-1]) torch_version = torch.__version__.split("+")[0] torch_major_minor_version = ".".join(torch_version.split(".")[:2]) if "dev" in version: version = version + date.today().strftime("%Y%m%d") # data version = version + "+torch" + torch_major_minor_version ROOT_DIR = os.path.dirname(__file__) INCLUDE_DIRS = [ "src", "third_party/cutlass/include", "third_party/json/include", "third_party/mio/include", "third_party/spdlog/include", # "third_party/Block-Sparse-Attention/csrc/block_sparse_attn", "third_party/flash_c_api/include", ] INCLUDE_DIRS = [os.path.join(ROOT_DIR, dir) for dir in INCLUDE_DIRS] DEBUG = False def ncond(s) -> list: if DEBUG: return [] else: return [s] def cond(s) -> list: if DEBUG: return [s] else: return [] #sm_targets = get_sm_targets() #print(f"Detected SM targets: {sm_targets}", file=sys.stderr) #assert len(sm_targets) > 0, "No SM targets found" GCC_FLAGS = ["-DENABLE_BF16=1", "-DBUILD_NUNCHAKU=1", "-fvisibility=hidden", "-g", "-std=c++2a", "-UNDEBUG", "-O1"] MSVC_FLAGS = ["/DENABLE_BF16=1", "/DBUILD_NUNCHAKU=1", "/std:c++2a", "/UNDEBUG", "/Zc:__cplusplus", "/FS"] NVCC_FLAGS = [ "-DDCU_ASM", "-DUSE_ROCM", "-DENABLE_BF16=1", "-DBUILD_NUNCHAKU=1", "-g", "-O1", "-std=c++2a", "-UNDEBUG", "-mllvm", "-nv-ptx-asm-transform=true", "-finline-asm-ptx", #"-Xcudafe", #"--diag_suppress=20208", # spdlog: 'long double' is treated as 'double' in device code *cond("-G"), "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", "-U__CUDA_NO_HALF2_CONVERSIONS__", "-U__CUDA_NO_BFLOAT16_OPERATORS__", "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "-U__CUDA_NO_BFLOAT162_OPERATORS__", "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", #f"--threads={len(sm_targets)}", f"--expt-relaxed-constexpr", "--expt-extended-lambda", #"--ptxas-options=--allow-expensive-optimizations=true", ] if os.getenv("NUNCHAKU_BUILD_WHEELS", "0") == "0": NVCC_FLAGS.append("--generate-line-info") #for target in sm_targets: # NVCC_FLAGS += ["-gencode", f"arch=compute_{target},code=sm_{target}"] NVCC_MSVC_FLAGS = ["-Xcompiler", "/Zc:__cplusplus", "-Xcompiler", "/FS", "-Xcompiler", "/bigobj"] nunchaku_extension = CUDAExtension( name="nunchaku._C", sources=[ "nunchaku/csrc/pybind.cpp", "src/interop/torch.cpp", "src/activation.cpp", "src/layernorm.cpp", "src/Linear.cpp", *ncond("src/FluxModel.cpp"), *ncond("src/SanaModel.cpp"), "src/Serialization.cpp", "src/Module.cpp", # *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim64_fp16_sm80.cu"), # *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim64_bf16_sm80.cu"), # *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim128_fp16_sm80.cu"), # *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_hdim128_bf16_sm80.cu"), # *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim64_fp16_sm80.cu"), # *ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim64_bf16_sm80.cu"), # *ncond( # "third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim128_fp16_sm80.cu" # ), # *ncond( # "third_party/Block-Sparse-Attention/csrc/block_sparse_attn/src/flash_fwd_block_hdim128_bf16_sm80.cu" # ), "src/kernels/activation_kernels.cu", "src/kernels/layernorm_kernels.cu", "src/kernels/misc_kernels.cu", "src/kernels/zgemm/gemm_w4a4.cu", "src/kernels/zgemm/gemm_w4a4_test.cu", "src/kernels/zgemm/gemm_w4a4_launch_fp16_int4.cu", "src/kernels/zgemm/gemm_w4a4_launch_fp16_int4_fasteri2f.cu", "src/kernels/zgemm/gemm_w4a4_launch_fp16_fp4.cu", "src/kernels/zgemm/gemm_w4a4_launch_bf16_int4.cu", "src/kernels/zgemm/gemm_w4a4_launch_bf16_fp4.cu", "src/kernels/zgemm/gemm_w8a8.cu", "src/kernels/zgemm/attention.cu", "src/kernels/dwconv.cu", "src/kernels/gemm_batched.cu", "src/kernels/gemm_f16.cu", "src/kernels/awq/gemm_awq.cu", "src/kernels/awq/gemv_awq.cu", #*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/flash_api.cpp"), #*ncond("third_party/Block-Sparse-Attention/csrc/block_sparse_attn/flash_api_adapter.cpp"), ], extra_compile_args={"gcc": GCC_FLAGS, "msvc": MSVC_FLAGS, "nvcc": NVCC_FLAGS, "nvcc_msvc": NVCC_MSVC_FLAGS}, include_dirs=INCLUDE_DIRS, libraries=["flash_atten_c"], library_dirs=["third_party/flash_c_api/lib"] ) setuptools.setup( name="nunchaku", version=version, packages=setuptools.find_packages(), ext_modules=[nunchaku_extension], cmdclass={"build_ext": CustomBuildExtension}, )