setup.py 4.11 KB
Newer Older
Jiashi Li's avatar
Jiashi Li committed
1
2
3
4
5
6
7
8
9
10
import os
from pathlib import Path
from datetime import datetime
import subprocess

from setuptools import setup, find_packages

from torch.utils.cpp_extension import (
    BuildExtension,
    CUDAExtension,
程元's avatar
程元 committed
11
    IS_WINDOWS,
Jiashi Li's avatar
Jiashi Li committed
12
13
)

ljss's avatar
ljss committed
14

Jiashi Li's avatar
Jiashi Li committed
15
16
17
18
def append_nvcc_threads(nvcc_extra_args):
    nvcc_threads = os.getenv("NVCC_THREADS") or "32"
    return nvcc_extra_args + ["--threads", nvcc_threads]

ljss's avatar
ljss committed
19

Sijia Chen's avatar
Sijia Chen committed
20
21
def get_features_args():
    features_args = []
22
    DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") in ["TRUE", "1"]
Sijia Chen's avatar
Sijia Chen committed
23
24
25
    if DISABLE_FP16:
        features_args.append("-DFLASH_MLA_DISABLE_FP16")
    return features_args
Jiashi Li's avatar
Jiashi Li committed
26

ljss's avatar
ljss committed
27

Jiashi Li's avatar
Jiashi Li committed
28
29
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])

30
31
32
33
34
35
36
cc_flag_sm90 = []
cc_flag_sm90.append("-gencode")
cc_flag_sm90.append("arch=compute_90a,code=sm_90a")

cc_flag_sm100 = []
cc_flag_sm100.append("-gencode")
cc_flag_sm100.append("arch=compute_100a,code=sm_100a")
Jiashi Li's avatar
Jiashi Li committed
37
38
39

this_dir = os.path.dirname(os.path.abspath(__file__))

程元's avatar
程元 committed
40
41
42
43
44
if IS_WINDOWS:
    cxx_args = ["/O2", "/std:c++17", "/DNDEBUG", "/W0"]
else:
    cxx_args = ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"]

Jiashi Li's avatar
Jiashi Li committed
45
46
47
ext_modules = []
ext_modules.append(
    CUDAExtension(
48
        name="flash_mla_sm90",
49
        sources=[
50
51
52
53
            "csrc/sm90/flash_api.cpp",
            "csrc/sm90/kernels/get_mla_metadata.cu",
            "csrc/sm90/kernels/mla_combine.cu",
            "csrc/sm90/kernels/splitkv_mla.cu",
54
        ],
Jiashi Li's avatar
Jiashi Li committed
55
        extra_compile_args={
Sijia Chen's avatar
Sijia Chen committed
56
            "cxx": cxx_args + get_features_args(),
Jiashi Li's avatar
Jiashi Li committed
57
58
59
60
61
            "nvcc": append_nvcc_threads(
                [
                    "-O3",
                    "-std=c++17",
                    "-DNDEBUG",
程元's avatar
程元 committed
62
                    "-D_USE_MATH_DEFINES",
Jiashi Li's avatar
Jiashi Li committed
63
64
65
66
67
68
69
70
71
72
                    "-Wno-deprecated-declarations",
                    "-U__CUDA_NO_HALF_OPERATORS__",
                    "-U__CUDA_NO_HALF_CONVERSIONS__",
                    "-U__CUDA_NO_HALF2_OPERATORS__",
                    "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
                    "--expt-relaxed-constexpr",
                    "--expt-extended-lambda",
                    "--use_fast_math",
                    "--ptxas-options=-v,--register-usage-level=10"
                ]
73
                + cc_flag_sm90
Sijia Chen's avatar
Sijia Chen committed
74
            ) + get_features_args(),
Jiashi Li's avatar
Jiashi Li committed
75
76
        },
        include_dirs=[
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
            Path(this_dir) / "csrc" / "sm90",
            Path(this_dir) / "csrc" / "cutlass" / "include",
        ],
    )
)

ext_modules.append(
    CUDAExtension(
        name="flash_mla_sm100",
        sources=[
            "csrc/sm100/pybind.cu",
            "csrc/sm100/fmha_cutlass_fwd_sm100.cu",
            "csrc/sm100/fmha_cutlass_bwd_sm100.cu",
        ],
        extra_compile_args={
            "cxx": ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"],
            "nvcc": append_nvcc_threads(
                [
                    "-O3",
                    "-std=c++17",
                    "-DNDEBUG",
                    "-Wno-deprecated-declarations",
                    "-U__CUDA_NO_HALF_OPERATORS__",
                    "-U__CUDA_NO_HALF_CONVERSIONS__",
                    "-U__CUDA_NO_HALF2_OPERATORS__",
                    "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
                    "--expt-relaxed-constexpr",
                    "--expt-extended-lambda",
                    "--use_fast_math",
                    "-lineinfo",
                    "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage",
                ]
                + cc_flag_sm100
            ),
        },
        include_dirs=[
            Path(this_dir) / "csrc" / "sm100",
Jiashi Li's avatar
Jiashi Li committed
114
            Path(this_dir) / "csrc" / "cutlass" / "include",
115
            Path(this_dir) / "csrc" / "cutlass" / "tools" / "util" / "include",
Jiashi Li's avatar
Jiashi Li committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        ],
    )
)


try:
    cmd = ['git', 'rev-parse', '--short', 'HEAD']
    rev = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
except Exception as _:
    now = datetime.now()
    date_time_str = now.strftime("%Y-%m-%d-%H-%M-%S")
    rev = '+' + date_time_str


setup(
    name="flash_mla",
    version="1.0.0" + rev,
    packages=find_packages(include=['flash_mla']),
    ext_modules=ext_modules,
    cmdclass={"build_ext": BuildExtension},
)