setup.py 3.2 KB
Newer Older
Jiashi Li's avatar
Jiashi Li committed
1
2
3
4
5
6
7
8
9
10
import os
from pathlib import Path
from datetime import datetime
import subprocess

from setuptools import setup, find_packages

from torch.utils.cpp_extension import (
    BuildExtension,
    CUDAExtension,
程元's avatar
程元 committed
11
    IS_WINDOWS,
Jiashi Li's avatar
Jiashi Li committed
12
13
)

ljss's avatar
ljss committed
14

15
16
def is_flag_set(flag: str) -> bool:
    return os.getenv(flag, "FALSE").lower() in ["true", "1", "y", "yes"]
ljss's avatar
ljss committed
17

Sijia Chen's avatar
Sijia Chen committed
18
19
def get_features_args():
    features_args = []
20
    if is_flag_set("FLASH_MLA_DISABLE_FP16"):
Sijia Chen's avatar
Sijia Chen committed
21
22
        features_args.append("-DFLASH_MLA_DISABLE_FP16")
    return features_args
Jiashi Li's avatar
Jiashi Li committed
23

24
25
26
27
28
29
30
31
32
33
34
35
36
def get_arch_flags():
    DISABLE_SM100 = is_flag_set("FLASH_MLA_DISABLE_SM100")
    DISABLE_SM90 = is_flag_set("FLASH_MLA_DISABLE_SM90")
    arch_flags = []
    if not DISABLE_SM100:
        arch_flags.extend(["-gencode", "arch=compute_100a,code=sm_100a"])
    if not DISABLE_SM90:
        arch_flags.extend(["-gencode", "arch=compute_90a,code=sm_90a"])
    return arch_flags

def get_nvcc_thread_args():
    nvcc_threads = os.getenv("NVCC_THREADS") or "32"
    return ["--threads", nvcc_threads]
ljss's avatar
ljss committed
37

Jiashi Li's avatar
Jiashi Li committed
38
39
40
41
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])

this_dir = os.path.dirname(os.path.abspath(__file__))

程元's avatar
程元 committed
42
43
44
45
46
if IS_WINDOWS:
    cxx_args = ["/O2", "/std:c++17", "/DNDEBUG", "/W0"]
else:
    cxx_args = ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"]

Jiashi Li's avatar
Jiashi Li committed
47
48
49
ext_modules = []
ext_modules.append(
    CUDAExtension(
50
        name="flash_mla.cuda",
51
        sources=[
52
53
54
55
56
57
58
59
            "csrc/pybind.cpp",
            "csrc/smxx/get_mla_metadata.cu",
            "csrc/smxx/mla_combine.cu",
            "csrc/sm90/decode/dense/splitkv_mla.cu",
            "csrc/sm90/decode/sparse_fp8/splitkv_mla.cu",
            "csrc/sm90/prefill/sparse/fwd.cu",
            "csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu",
            "csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu",
60
        ],
Jiashi Li's avatar
Jiashi Li committed
61
        extra_compile_args={
Sijia Chen's avatar
Sijia Chen committed
62
            "cxx": cxx_args + get_features_args(),
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
            "nvcc": [
                "-O3",
                "-std=c++17",
                "-DNDEBUG",
                "-D_USE_MATH_DEFINES",
                "-Wno-deprecated-declarations",
                "-U__CUDA_NO_HALF_OPERATORS__",
                "-U__CUDA_NO_HALF_CONVERSIONS__",
                "-U__CUDA_NO_HALF2_OPERATORS__",
                "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
                "--expt-relaxed-constexpr",
                "--expt-extended-lambda",
                "--use_fast_math",
                "--ptxas-options=-v,--register-usage-level=10"
            ] + get_features_args() + get_arch_flags() + get_nvcc_thread_args(),
Jiashi Li's avatar
Jiashi Li committed
78
79
        },
        include_dirs=[
80
            Path(this_dir) / "csrc",
81
82
83
            Path(this_dir) / "csrc" / "sm90",
            Path(this_dir) / "csrc" / "cutlass" / "include",
            Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include",
Jiashi Li's avatar
Jiashi Li committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        ],
    )
)

try:
    cmd = ['git', 'rev-parse', '--short', 'HEAD']
    rev = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
except Exception as _:
    now = datetime.now()
    date_time_str = now.strftime("%Y-%m-%d-%H-%M-%S")
    rev = '+' + date_time_str


setup(
    name="flash_mla",
    version="1.0.0" + rev,
    packages=find_packages(include=['flash_mla']),
    ext_modules=ext_modules,
    cmdclass={"build_ext": BuildExtension},
)