setup.py 2.63 KB
Newer Older
Jiashi Li's avatar
Jiashi Li committed
1
2
3
4
5
6
7
8
9
10
import os
from pathlib import Path
from datetime import datetime
import subprocess

from setuptools import setup, find_packages

from torch.utils.cpp_extension import (
    BuildExtension,
    CUDAExtension,
程元's avatar
程元 committed
11
    IS_WINDOWS,
Jiashi Li's avatar
Jiashi Li committed
12
13
14
15
16
17
)

def append_nvcc_threads(nvcc_extra_args):
    nvcc_threads = os.getenv("NVCC_THREADS") or "32"
    return nvcc_extra_args + ["--threads", nvcc_threads]

Sijia Chen's avatar
Sijia Chen committed
18
19
def get_features_args():
    features_args = []
20
    DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") in ["TRUE", "1"]
Sijia Chen's avatar
Sijia Chen committed
21
22
23
    if DISABLE_FP16:
        features_args.append("-DFLASH_MLA_DISABLE_FP16")
    return features_args
Jiashi Li's avatar
Jiashi Li committed
24

ljss's avatar
ljss committed
25

Jiashi Li's avatar
Jiashi Li committed
26
27
28
29
30
31
32
33
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])

cc_flag = []
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90a,code=sm_90a")

this_dir = os.path.dirname(os.path.abspath(__file__))

程元's avatar
程元 committed
34
35
36
37
38
if IS_WINDOWS:
    cxx_args = ["/O2", "/std:c++17", "/DNDEBUG", "/W0"]
else:
    cxx_args = ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"]

Jiashi Li's avatar
Jiashi Li committed
39
40
41
42
ext_modules = []
ext_modules.append(
    CUDAExtension(
        name="flash_mla_cuda",
43
44
45
46
47
48
        sources=[
            "csrc/flash_api.cpp",
            "csrc/kernels/get_mla_metadata.cu",
            "csrc/kernels/mla_combine.cu",
            "csrc/kernels/splitkv_mla.cu",
        ],
Jiashi Li's avatar
Jiashi Li committed
49
        extra_compile_args={
Sijia Chen's avatar
Sijia Chen committed
50
            "cxx": cxx_args + get_features_args(),
Jiashi Li's avatar
Jiashi Li committed
51
52
53
54
55
            "nvcc": append_nvcc_threads(
                [
                    "-O3",
                    "-std=c++17",
                    "-DNDEBUG",
程元's avatar
程元 committed
56
                    "-D_USE_MATH_DEFINES",
Jiashi Li's avatar
Jiashi Li committed
57
58
59
60
61
62
63
64
65
66
67
                    "-Wno-deprecated-declarations",
                    "-U__CUDA_NO_HALF_OPERATORS__",
                    "-U__CUDA_NO_HALF_CONVERSIONS__",
                    "-U__CUDA_NO_HALF2_OPERATORS__",
                    "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
                    "--expt-relaxed-constexpr",
                    "--expt-extended-lambda",
                    "--use_fast_math",
                    "--ptxas-options=-v,--register-usage-level=10"
                ]
                + cc_flag
Sijia Chen's avatar
Sijia Chen committed
68
            ) + get_features_args(),
Jiashi Li's avatar
Jiashi Li committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        },
        include_dirs=[
            Path(this_dir) / "csrc",
            Path(this_dir) / "csrc" / "cutlass" / "include",
        ],
    )
)


try:
    cmd = ['git', 'rev-parse', '--short', 'HEAD']
    rev = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
except Exception as _:
    now = datetime.now()
    date_time_str = now.strftime("%Y-%m-%d-%H-%M-%S")
    rev = '+' + date_time_str


setup(
    name="flash_mla",
    version="1.0.0" + rev,
    packages=find_packages(include=['flash_mla']),
    ext_modules=ext_modules,
    cmdclass={"build_ext": BuildExtension},
)