setup.py 5.87 KB
Newer Older
Jiashi Li's avatar
Jiashi Li committed
1
2
3
4
5
6
7
8
9
10
import os
from pathlib import Path
from datetime import datetime
import subprocess

from setuptools import setup, find_packages

from torch.utils.cpp_extension import (
    BuildExtension,
    CUDAExtension,
程元's avatar
程元 committed
11
    IS_WINDOWS,
12
    CUDA_HOME
Jiashi Li's avatar
Jiashi Li committed
13
14
)

ljss's avatar
ljss committed
15

16
17
def is_flag_set(flag: str) -> bool:
    return os.getenv(flag, "FALSE").lower() in ["true", "1", "y", "yes"]
ljss's avatar
ljss committed
18

Sijia Chen's avatar
Sijia Chen committed
19
20
def get_features_args():
    features_args = []
21
    if is_flag_set("FLASH_MLA_DISABLE_FP16"):
Sijia Chen's avatar
Sijia Chen committed
22
23
        features_args.append("-DFLASH_MLA_DISABLE_FP16")
    return features_args
Jiashi Li's avatar
Jiashi Li committed
24

25
def get_arch_flags():
26
27
28
29
30
31
32
33
34
35
    # Check NVCC Version
    # NOTE The "CUDA_HOME" here is not necessarily from the `CUDA_HOME` environment variable. For more details, see `torch/utils/cpp_extension.py`
    assert CUDA_HOME is not None, "PyTorch must be compiled with CUDA support"
    nvcc_version = subprocess.check_output(
        [os.path.join(CUDA_HOME, "bin", "nvcc"), '--version'], stderr=subprocess.STDOUT
    ).decode('utf-8')
    nvcc_version_number = nvcc_version.split('release ')[1].split(',')[0].strip()
    major, minor = map(int, nvcc_version_number.split('.'))
    print(f'Compiling using NVCC {major}.{minor}')

36
37
    DISABLE_SM100 = is_flag_set("FLASH_MLA_DISABLE_SM100")
    DISABLE_SM90 = is_flag_set("FLASH_MLA_DISABLE_SM90")
38
    if major < 12 or (major == 12 and minor <= 8):
39
        assert DISABLE_SM100, "sm100 compilation for Flash MLA requires NVCC 12.9 or higher. Please set FLASH_MLA_DISABLE_SM100=1 to disable sm100 compilation, or update your environment."    # TODO Implement this
40

41
42
    arch_flags = []
    if not DISABLE_SM100:
43
        arch_flags.extend(["-gencode", "arch=compute_100f,code=sm_100f"])
44
45
46
47
48
49
50
    if not DISABLE_SM90:
        arch_flags.extend(["-gencode", "arch=compute_90a,code=sm_90a"])
    return arch_flags

def get_nvcc_thread_args():
    nvcc_threads = os.getenv("NVCC_THREADS") or "32"
    return ["--threads", nvcc_threads]
ljss's avatar
ljss committed
51

Jiashi Li's avatar
Jiashi Li committed
52
53
54
55
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])

this_dir = os.path.dirname(os.path.abspath(__file__))

程元's avatar
程元 committed
56
if IS_WINDOWS:
57
    cxx_args = ["/O2", "/std:c++20", "/DNDEBUG", "/W0"]
程元's avatar
程元 committed
58
else:
59
    cxx_args = ["-O3", "-std=c++20", "-DNDEBUG", "-Wno-deprecated-declarations"]
程元's avatar
程元 committed
60

Jiashi Li's avatar
Jiashi Li committed
61
62
63
ext_modules = []
ext_modules.append(
    CUDAExtension(
64
        name="flash_mla.cuda",
65
        sources=[
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
            # API
            "csrc/api/api.cpp",

            # Misc kernels for decoding
            "csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu",
            "csrc/smxx/decode/combine/combine.cu",

            # sm90 dense decode
            "csrc/sm90/decode/dense/instantiations/fp16.cu",
            "csrc/sm90/decode/dense/instantiations/bf16.cu",

            # sm90 sparse decode
            "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu",
            "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu",
            "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu",
            "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu",

            # sm90 sparse prefill
84
            "csrc/sm90/prefill/sparse/fwd.cu",
85
86
87
88
89
90
            "csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu",
            "csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu",
            "csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu",
            "csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu",

            # sm100 dense prefill & backward
91
92
            "csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu",
            "csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu",
93
94
95
96
97
98
99
100
101
102
103
104

            # sm100 sparse prefill
            "csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k512.cu",
            "csrc/sm100/prefill/sparse/fwd/head64/instantiations/phase1_k576.cu",
            "csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k512.cu",
            "csrc/sm100/prefill/sparse/fwd/head128/instantiations/phase1_k576.cu",
            "csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_prefill_k512.cu",

            # sm100 sparse decode
            "csrc/sm100/decode/head64/instantiations/v32.cu",
            "csrc/sm100/decode/head64/instantiations/model1.cu",
            "csrc/sm100/prefill/sparse/fwd_for_small_topk/head128/instantiations/phase1_decode_k512.cu",
105
        ],
Jiashi Li's avatar
Jiashi Li committed
106
        extra_compile_args={
Sijia Chen's avatar
Sijia Chen committed
107
            "cxx": cxx_args + get_features_args(),
108
109
            "nvcc": [
                "-O3",
110
                "-std=c++20",
111
112
113
114
115
116
117
118
119
120
                "-DNDEBUG",
                "-D_USE_MATH_DEFINES",
                "-Wno-deprecated-declarations",
                "-U__CUDA_NO_HALF_OPERATORS__",
                "-U__CUDA_NO_HALF_CONVERSIONS__",
                "-U__CUDA_NO_HALF2_OPERATORS__",
                "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
                "--expt-relaxed-constexpr",
                "--expt-extended-lambda",
                "--use_fast_math",
121
122
123
                "--ptxas-options=-v,--register-usage-level=10,--warn-on-spills,--warn-on-local-memory-usage,--warn-on-double-precision-use",
                "-lineinfo",
                "--source-in-ptx",
124
            ] + get_features_args() + get_arch_flags() + get_nvcc_thread_args(),
Jiashi Li's avatar
Jiashi Li committed
125
126
        },
        include_dirs=[
127
            Path(this_dir) / "csrc",
128
            Path(this_dir) / "csrc" / "kerutils" / "include",   # TODO Remove me
129
130
131
            Path(this_dir) / "csrc" / "sm90",
            Path(this_dir) / "csrc" / "cutlass" / "include",
            Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include",
Jiashi Li's avatar
Jiashi Li committed
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        ],
    )
)

try:
    cmd = ['git', 'rev-parse', '--short', 'HEAD']
    rev = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
except Exception as _:
    now = datetime.now()
    date_time_str = now.strftime("%Y-%m-%d-%H-%M-%S")
    rev = '+' + date_time_str


setup(
    name="flash_mla",
    version="1.0.0" + rev,
    packages=find_packages(include=['flash_mla']),
    ext_modules=ext_modules,
    cmdclass={"build_ext": BuildExtension},
)