setup.py 4.45 KB
Newer Older
Jiashi Li's avatar
Jiashi Li committed
1
2
3
4
import os
from pathlib import Path
from datetime import datetime
import subprocess
zhanghj2's avatar
zhanghj2 committed
5
6
from typing import Optional
from get_version import get_version
Jiashi Li's avatar
Jiashi Li committed
7
8

from setuptools import setup, find_packages
zhanghj2's avatar
zhanghj2 committed
9
import torch
Jiashi Li's avatar
Jiashi Li committed
10
11
12
from torch.utils.cpp_extension import (
    BuildExtension,
    CUDAExtension,
程元's avatar
程元 committed
13
    IS_WINDOWS,
14
    CUDA_HOME
Jiashi Li's avatar
Jiashi Li committed
15
16
)

ljss's avatar
ljss committed
17

18
19
def is_flag_set(flag: str) -> bool:
    return os.getenv(flag, "FALSE").lower() in ["true", "1", "y", "yes"]
ljss's avatar
ljss committed
20

Sijia Chen's avatar
Sijia Chen committed
21
22
def get_features_args():
    features_args = []
23
    if is_flag_set("FLASH_MLA_DISABLE_FP16"):
Sijia Chen's avatar
Sijia Chen committed
24
25
        features_args.append("-DFLASH_MLA_DISABLE_FP16")
    return features_args
Jiashi Li's avatar
Jiashi Li committed
26

27
def get_arch_flags():
zhanghj2's avatar
zhanghj2 committed
28
    arch_flags = []
zhanghj2's avatar
zhanghj2 committed
29
    arch_flags.append("--offload-arch=gfx938;gfx936")
zhanghj2's avatar
zhanghj2 committed
30
    return arch_flags
31

zhanghj2's avatar
zhanghj2 committed
32
33
34
# def get_nvcc_thread_args():
#     # nvcc_threads = os.getenv("NVCC_THREADS") or "32"
#     return ["--threads", nvcc_threads]
ljss's avatar
ljss committed
35

zhanghj2's avatar
zhanghj2 committed
36
# subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
Jiashi Li's avatar
Jiashi Li committed
37
38
39

this_dir = os.path.dirname(os.path.abspath(__file__))

zhanghj2's avatar
zhanghj2 committed
40
if False:
41
    cxx_args = ["/O2", "/std:c++20", "/DNDEBUG", "/W0"]
程元's avatar
程元 committed
42
else:
zhanghj2's avatar
zhanghj2 committed
43
44
    cxx_args = ["-O3", "-std=c++17", "-DNDEBUG", "-Wno-deprecated-declarations", "-DDCU_ASM", "-Wno-return-type", ]

程元's avatar
程元 committed
45

Jiashi Li's avatar
Jiashi Li committed
46
47
48
ext_modules = []
ext_modules.append(
    CUDAExtension(
49
        name="flash_mla.cuda",
50
        sources=[
51
52
53
            # API
            "csrc/api/api.cpp",

zhanghj2's avatar
zhanghj2 committed
54
            # # Misc kernels for decoding
55
56
57
            "csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu",
            "csrc/smxx/decode/combine/combine.cu",

zhanghj2's avatar
zhanghj2 committed
58
            # # sm90 dense decode
59
60
61
            "csrc/sm90/decode/dense/instantiations/fp16.cu",
            "csrc/sm90/decode/dense/instantiations/bf16.cu",

zhanghj2's avatar
zhanghj2 committed
62
63
64
            ## sm90 dense qkvfp8 decode
            "csrc/sm90/decode/dense_qkvfp8/instantiations/fp8e4m3.cu",
            
zhanghj2's avatar
zhanghj2 committed
65
            # # sm90 sparse decode
zhanghj2's avatar
zhanghj2 committed
66
            "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h16.cu",
67
68
            "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h64.cu",
            "csrc/sm90/decode/sparse_fp8/instantiations/model1_persistent_h128.cu",
zhanghj2's avatar
zhanghj2 committed
69
            "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h16.cu",
70
71
72
            "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h64.cu",
            "csrc/sm90/decode/sparse_fp8/instantiations/v32_persistent_h128.cu",

zhanghj2's avatar
zhanghj2 committed
73
            # # sm90 sparse prefill
74
            "csrc/sm90/prefill/sparse/fwd.cu",
75
76
77
78
79
            "csrc/sm90/prefill/sparse/instantiations/phase1_k512.cu",
            "csrc/sm90/prefill/sparse/instantiations/phase1_k512_topklen.cu",
            "csrc/sm90/prefill/sparse/instantiations/phase1_k576.cu",
            "csrc/sm90/prefill/sparse/instantiations/phase1_k576_topklen.cu",

80
        ],
Jiashi Li's avatar
Jiashi Li committed
81
        extra_compile_args={
Sijia Chen's avatar
Sijia Chen committed
82
            "cxx": cxx_args + get_features_args(),
83
84
            "nvcc": [
                "-O3",
zhanghj2's avatar
zhanghj2 committed
85
                "-std=c++17",
86
                "-DNDEBUG",
zhanghj2's avatar
zhanghj2 committed
87
88
89
90
91
                "-DHIP_ENABLE_WARP_SYNC_BUILTINS",
                "-ffast-math",
                "-ftemplate-backtrace-limit=0",
                "-Rpass-analysis=kernel-resource-usage",
                "-DDCU_ASM",
zhanghj2's avatar
zhanghj2 committed
92
                "--save-temps",
zhanghj2's avatar
zhanghj2 committed
93
94
                "-w"
            ] + get_features_args() + get_arch_flags()
Jiashi Li's avatar
Jiashi Li committed
95
96
        },
        include_dirs=[
97
            Path(this_dir) / "csrc",
98
            Path(this_dir) / "csrc" / "kerutils" / "include",   # TODO Remove me
99
            Path(this_dir) / "csrc" / "sm90",
zhanghj2's avatar
zhanghj2 committed
100
            Path(this_dir) / "csrc" / "cutlass" / "include"
Jiashi Li's avatar
Jiashi Li committed
101
102
103
104
        ],
    )
)

zhanghj2's avatar
zhanghj2 committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def _find_rocm_home() -> Optional[str]:
    rocm_home = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH')
    if rocm_home is None:
        try:
            pipe_hipcc = subprocess.Popen(
                ["which hipcc | xargs readlink -f"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
            hipcc, _ = pipe_hipcc.communicate()
            rocm_home = os.path.dirname(os.path.dirname(hipcc.decode(*()).rstrip('\r\n')))
            if os.path.basename(rocm_home) == 'hip':
                rocm_home = os.path.dirname(rocm_home)
        except Exception:
            rocm_home = '/opt/rocm'
            if not os.path.exists(rocm_home):
                rocm_home = None
    if rocm_home and torch.version.hip is None:
        print(f"No ROCm runtime is found, using ROCM_HOME='{rocm_home}'")
    return rocm_home
ROCM_HOME = _find_rocm_home()
Jiashi Li's avatar
Jiashi Li committed
123
124

setup(
zhanghj2's avatar
zhanghj2 committed
125
126
    name="flash_mla",   
    version=get_version(ROCM_HOME),
Jiashi Li's avatar
Jiashi Li committed
127
128
129
130
    packages=find_packages(include=['flash_mla']),
    ext_modules=ext_modules,
    cmdclass={"build_ext": BuildExtension},
)