setup.py 3.28 KB
Newer Older
Jiashi Li's avatar
Jiashi Li committed
1
2
3
4
5
6
7
8
9
10
import os
from pathlib import Path
from datetime import datetime
import subprocess

from setuptools import setup, find_packages

from torch.utils.cpp_extension import (
    BuildExtension,
    CUDAExtension,
程元's avatar
程元 committed
11
    IS_WINDOWS,
12
    CUDA_HOME
Jiashi Li's avatar
Jiashi Li committed
13
14
)

ljss's avatar
ljss committed
15

16
17
def is_flag_set(flag: str) -> bool:
    return os.getenv(flag, "FALSE").lower() in ["true", "1", "y", "yes"]
ljss's avatar
ljss committed
18

Sijia Chen's avatar
Sijia Chen committed
19
20
def get_features_args():
    features_args = []
21
    if is_flag_set("FLASH_MLA_DISABLE_FP16"):
Sijia Chen's avatar
Sijia Chen committed
22
23
        features_args.append("-DFLASH_MLA_DISABLE_FP16")
    return features_args
Jiashi Li's avatar
Jiashi Li committed
24

25
def get_arch_flags():
zhanghj2's avatar
zhanghj2 committed
26
27
28
    arch_flags = []
    arch_flags.append("--offload-arch=gfx938")
    return arch_flags
29

zhanghj2's avatar
zhanghj2 committed
30
31
32
# def get_nvcc_thread_args():
#     # nvcc_threads = os.getenv("NVCC_THREADS") or "32"
#     return ["--threads", nvcc_threads]
ljss's avatar
ljss committed
33

zhanghj2's avatar
zhanghj2 committed
34
# subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
Jiashi Li's avatar
Jiashi Li committed
35
36
37

this_dir = os.path.dirname(os.path.abspath(__file__))

zhanghj2's avatar
zhanghj2 committed
38
if False:
39
    cxx_args = ["/O2", "/std:c++20", "/DNDEBUG", "/W0"]
程元's avatar
程元 committed
40
else:
zhanghj2's avatar
zhanghj2 committed
41
42
    cxx_args = ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations", "-DDCU_ASM", "-Wno-return-type", ]

程元's avatar
程元 committed
43

Jiashi Li's avatar
Jiashi Li committed
44
45
46
ext_modules = []
ext_modules.append(
    CUDAExtension(
47
        name="flash_mla.cuda",
48
        sources=[
49
50
51
            # API
            "csrc/api/api.cpp",

zhanghj2's avatar
zhanghj2 committed
52
            # # Misc kernels for decoding
53
54
55
            "csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu",
            "csrc/smxx/decode/combine/combine.cu",

zhanghj2's avatar
zhanghj2 committed
56
            # # sm90 dense decode
57
58
59
            "csrc/sm90/decode/dense/instantiations/fp16.cu",
            "csrc/sm90/decode/dense/instantiations/bf16.cu",

zhanghj2's avatar
zhanghj2 committed
60
            # # sm90 sparse decode
61
62
63
64
65
            "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu",
            "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu",
            "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu",
            "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu",

zhanghj2's avatar
zhanghj2 committed
66
            # # sm90 sparse prefill
67
            "csrc/sm90/prefill/sparse/fwd.cu",
68
69
70
71
72
            "csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu",
            "csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu",
            "csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu",
            "csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu",

73
        ],
Jiashi Li's avatar
Jiashi Li committed
74
        extra_compile_args={
Sijia Chen's avatar
Sijia Chen committed
75
            "cxx": cxx_args + get_features_args(),
76
77
            "nvcc": [
                "-O3",
zhanghj2's avatar
zhanghj2 committed
78
                "-std=c++17",
79
                "-DNDEBUG",
zhanghj2's avatar
zhanghj2 committed
80
81
82
83
84
85
86
                "-DHIP_ENABLE_WARP_SYNC_BUILTINS",
                "-ffast-math",
                "-ftemplate-backtrace-limit=0",
                "-Rpass-analysis=kernel-resource-usage",
                "-DDCU_ASM",
                "-w"
            ] + get_features_args() + get_arch_flags()
Jiashi Li's avatar
Jiashi Li committed
87
88
        },
        include_dirs=[
89
            Path(this_dir) / "csrc",
90
            Path(this_dir) / "csrc" / "kerutils" / "include",   # TODO Remove me
91
            Path(this_dir) / "csrc" / "sm90",
zhanghj2's avatar
zhanghj2 committed
92
93
            "/public/home/zhanghj/work/dev/cutlass_3.2.1-mla/include",
            # "/public/home/zhanghj/work/dev/cutlass_3.2.1-mla/tools/util/include",
Jiashi Li's avatar
Jiashi Li committed
94
95
96
97
98
99
100
101
        ],
    )
)



setup(
    name="flash_mla",
zhanghj2's avatar
zhanghj2 committed
102
    version="1.0.0",
Jiashi Li's avatar
Jiashi Li committed
103
104
105
106
    packages=find_packages(include=['flash_mla']),
    ext_modules=ext_modules,
    cmdclass={"build_ext": BuildExtension},
)