import os from pathlib import Path from datetime import datetime import subprocess from setuptools import setup, find_packages from torch.utils.cpp_extension import ( BuildExtension, CUDAExtension, IS_WINDOWS, ) def is_flag_set(flag: str) -> bool: return os.getenv(flag, "FALSE").lower() in ["true", "1", "y", "yes"] def get_features_args(): features_args = [] if is_flag_set("FLASH_MLA_DISABLE_FP16"): features_args.append("-DFLASH_MLA_DISABLE_FP16") return features_args def get_arch_flags(): DISABLE_SM100 = is_flag_set("FLASH_MLA_DISABLE_SM100") DISABLE_SM90 = is_flag_set("FLASH_MLA_DISABLE_SM90") arch_flags = [] if not DISABLE_SM100: arch_flags.extend(["-gencode", "arch=compute_100a,code=sm_100a"]) if not DISABLE_SM90: arch_flags.extend(["-gencode", "arch=compute_90a,code=sm_90a"]) return arch_flags def get_nvcc_thread_args(): nvcc_threads = os.getenv("NVCC_THREADS") or "32" return ["--threads", nvcc_threads] subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) this_dir = os.path.dirname(os.path.abspath(__file__)) if IS_WINDOWS: cxx_args = ["/O2", "/std:c++17", "/DNDEBUG", "/W0"] else: cxx_args = ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"] ext_modules = [] ext_modules.append( CUDAExtension( name="flash_mla.cuda", sources=[ "csrc/pybind.cpp", "csrc/smxx/get_mla_metadata.cu", "csrc/smxx/mla_combine.cu", "csrc/sm90/decode/dense/splitkv_mla.cu", "csrc/sm90/decode/sparse_fp8/splitkv_mla.cu", "csrc/sm90/prefill/sparse/fwd.cu", "csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu", "csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu", ], extra_compile_args={ "cxx": cxx_args + get_features_args(), "nvcc": [ "-O3", "-std=c++17", "-DNDEBUG", "-D_USE_MATH_DEFINES", "-Wno-deprecated-declarations", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", "--use_fast_math", "--ptxas-options=-v,--register-usage-level=10" ] + get_features_args() + get_arch_flags() + get_nvcc_thread_args(), }, include_dirs=[ Path(this_dir) / "csrc", Path(this_dir) / "csrc" / "sm90", Path(this_dir) / "csrc" / "cutlass" / "include", Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include", ], ) ) try: cmd = ['git', 'rev-parse', '--short', 'HEAD'] rev = '+' + subprocess.check_output(cmd).decode('ascii').rstrip() except Exception as _: now = datetime.now() date_time_str = now.strftime("%Y-%m-%d-%H-%M-%S") rev = '+' + date_time_str setup( name="flash_mla", version="1.0.0" + rev, packages=find_packages(include=['flash_mla']), ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension}, )