setup.py 2.63 KB
Newer Older
Jiashi Li's avatar
Jiashi Li committed
1
2
3
4
5
6
7
8
9
10
import os
from pathlib import Path
from datetime import datetime
import subprocess

from setuptools import setup, find_packages

from torch.utils.cpp_extension import (
    BuildExtension,
    CUDAExtension,
程元's avatar
程元 committed
11
    IS_WINDOWS,
Jiashi Li's avatar
Jiashi Li committed
12
13
)

ljss's avatar
ljss committed
14

Jiashi Li's avatar
Jiashi Li committed
15
16
17
18
def append_nvcc_threads(nvcc_extra_args):
    nvcc_threads = os.getenv("NVCC_THREADS") or "32"
    return nvcc_extra_args + ["--threads", nvcc_threads]

ljss's avatar
ljss committed
19

Sijia Chen's avatar
Sijia Chen committed
20
21
def get_features_args():
    features_args = []
22
    DISABLE_FP16 = os.getenv("FLASH_MLA_DISABLE_FP16", "FALSE") in ["TRUE", "1"]
Sijia Chen's avatar
Sijia Chen committed
23
24
25
    if DISABLE_FP16:
        features_args.append("-DFLASH_MLA_DISABLE_FP16")
    return features_args
Jiashi Li's avatar
Jiashi Li committed
26

ljss's avatar
ljss committed
27

Jiashi Li's avatar
Jiashi Li committed
28
29
30
31
32
33
34
35
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])

cc_flag = []
cc_flag.append("-gencode")
cc_flag.append("arch=compute_90a,code=sm_90a")

this_dir = os.path.dirname(os.path.abspath(__file__))

程元's avatar
程元 committed
36
37
38
39
40
if IS_WINDOWS:
    cxx_args = ["/O2", "/std:c++17", "/DNDEBUG", "/W0"]
else:
    cxx_args = ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations"]

Jiashi Li's avatar
Jiashi Li committed
41
42
43
44
ext_modules = []
ext_modules.append(
    CUDAExtension(
        name="flash_mla_cuda",
45
46
47
48
49
50
        sources=[
            "csrc/flash_api.cpp",
            "csrc/kernels/get_mla_metadata.cu",
            "csrc/kernels/mla_combine.cu",
            "csrc/kernels/splitkv_mla.cu",
        ],
Jiashi Li's avatar
Jiashi Li committed
51
        extra_compile_args={
Sijia Chen's avatar
Sijia Chen committed
52
            "cxx": cxx_args + get_features_args(),
Jiashi Li's avatar
Jiashi Li committed
53
54
55
56
57
            "nvcc": append_nvcc_threads(
                [
                    "-O3",
                    "-std=c++17",
                    "-DNDEBUG",
程元's avatar
程元 committed
58
                    "-D_USE_MATH_DEFINES",
Jiashi Li's avatar
Jiashi Li committed
59
60
61
62
63
64
65
66
67
68
69
                    "-Wno-deprecated-declarations",
                    "-U__CUDA_NO_HALF_OPERATORS__",
                    "-U__CUDA_NO_HALF_CONVERSIONS__",
                    "-U__CUDA_NO_HALF2_OPERATORS__",
                    "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
                    "--expt-relaxed-constexpr",
                    "--expt-extended-lambda",
                    "--use_fast_math",
                    "--ptxas-options=-v,--register-usage-level=10"
                ]
                + cc_flag
Sijia Chen's avatar
Sijia Chen committed
70
            ) + get_features_args(),
Jiashi Li's avatar
Jiashi Li committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        },
        include_dirs=[
            Path(this_dir) / "csrc",
            Path(this_dir) / "csrc" / "cutlass" / "include",
        ],
    )
)


try:
    cmd = ['git', 'rev-parse', '--short', 'HEAD']
    rev = '+' + subprocess.check_output(cmd).decode('ascii').rstrip()
except Exception as _:
    now = datetime.now()
    date_time_str = now.strftime("%Y-%m-%d-%H-%M-%S")
    rev = '+' + date_time_str


setup(
    name="flash_mla",
    version="1.0.0" + rev,
    packages=find_packages(include=['flash_mla']),
    ext_modules=ext_modules,
    cmdclass={"build_ext": BuildExtension},
)