setup.py 6.77 KB
Newer Older
Jiashi Li's avatar
Jiashi Li committed
1
2
3
4
import os
from pathlib import Path
from datetime import datetime
import subprocess
zhanghj2's avatar
zhanghj2 committed
5
6
from typing import Optional
from get_version import get_version
Jiashi Li's avatar
Jiashi Li committed
7
8

from setuptools import setup, find_packages
zhanghj2's avatar
zhanghj2 committed
9
import torch
Jiashi Li's avatar
Jiashi Li committed
10
11
12
from torch.utils.cpp_extension import (
    BuildExtension,
    CUDAExtension,
程元's avatar
程元 committed
13
    IS_WINDOWS,
14
    CUDA_HOME
Jiashi Li's avatar
Jiashi Li committed
15
16
)

ljss's avatar
ljss committed
17

18
19
def is_flag_set(flag: str) -> bool:
    return os.getenv(flag, "FALSE").lower() in ["true", "1", "y", "yes"]
ljss's avatar
ljss committed
20

Sijia Chen's avatar
Sijia Chen committed
21
def get_features_args():
zhanghj2's avatar
zhanghj2 committed
22
    bf16_type = os.getenv("FLASH_MLA_BF16_TYPE", "1")
23
24
25
    assert bf16_type == "0" or bf16_type == "1", "bf16_type must be 0 or 1"
    bf16_mode_names = {"0": "round_toward_zero", "1": "round_half_ulp_truncate"}
    print(f"Using BFloat16 rounding mode: {bf16_mode_names.get(bf16_type, 'unknown')}")
Sijia Chen's avatar
Sijia Chen committed
26
    features_args = []
27
    if is_flag_set("FLASH_MLA_DISABLE_FP16"):
Sijia Chen's avatar
Sijia Chen committed
28
        features_args.append("-DFLASH_MLA_DISABLE_FP16")
29
    features_args.append(f"-DFLASH_MLA_BF16_TYPE={bf16_type}")
Sijia Chen's avatar
Sijia Chen committed
30
    return features_args
Jiashi Li's avatar
Jiashi Li committed
31

32
def get_arch_flags():
zhanghj2's avatar
zhanghj2 committed
33
    arch_flags = []
zhanghj2's avatar
zhanghj2 committed
34
    arch_flags.append("--offload-arch=gfx938,gfx936,gfx928")
zhanghj2's avatar
zhanghj2 committed
35
    return arch_flags
36

zhanghj2's avatar
zhanghj2 committed
37
38
39
# def get_nvcc_thread_args():
#     # nvcc_threads = os.getenv("NVCC_THREADS") or "32"
#     return ["--threads", nvcc_threads]
ljss's avatar
ljss committed
40

Jiashi Li's avatar
Jiashi Li committed
41
this_dir = os.path.dirname(os.path.abspath(__file__))
zhanghj2's avatar
zhanghj2 committed
42
os.environ['PYTORCH_NVCC'] = 'aicc'
Jiashi Li's avatar
Jiashi Li committed
43

shenzhe's avatar
shenzhe committed
44
45
46
47
48
49
50
51
52
53
54
55
56
cutlass_dir = Path(this_dir) / "csrc" / "cutlass" / "cutlass_3.2.1"
if not (cutlass_dir / "include").exists():
    subprocess.run(
        ["git", "submodule", "update", "--init", "--recursive", "csrc/cutlass/cutlass_3.2.1"],
        cwd=this_dir,
        check=True,
    )
if not (cutlass_dir / "include").exists():
    raise RuntimeError(
        "CUTLASS dependency is missing. Run "
        "`git submodule update --init --recursive csrc/cutlass/cutlass_3.2.1` first."
    )

zhanghj2's avatar
zhanghj2 committed
57
if False:
58
    cxx_args = ["/O2", "/std:c++20", "/DNDEBUG", "/W0"]
程元's avatar
程元 committed
59
else:
zhanghj2's avatar
zhanghj2 committed
60
    cxx_args = ["-O3", "-std=c++20", "-DNDEBUG", "-Wno-deprecated-declarations", "-DDCU_ASM", "-Wno-return-type", ]
zhanghj2's avatar
zhanghj2 committed
61

zhanghj2's avatar
zhanghj2 committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
aicc_flags = [
    "-mcode-object-version=5",
    "-mllvm=-support-768-vgprs=true",
    "-mllvm=-disable-machine-sink",
    "-mllvm=-disable-code-sink",
    "-mllvm=-amdgpu-enable-rewrite-partial-reg-uses=false",
    "-mllvm=-allow-gvn-convergent-call=true",
    "-mllvm=-disallow-uniform-vmed3-combine=true",
    "-mllvm=-hcu-pre-emit-load-store-opt=false",
    "-mllvm=-amdgpu-early-inline-all=true",
    "-mllvm=-amdgpu-function-calls=false",
    "-fno-finite-math-only",
    "--gpu-max-threads-per-block=256"
]

程元's avatar
程元 committed
77

Jiashi Li's avatar
Jiashi Li committed
78
79
80
ext_modules = []
ext_modules.append(
    CUDAExtension(
81
        name="flash_mla.cuda",
82
        sources=[
83
84
85
            # API
            "csrc/api/api.cpp",

zhanghj2's avatar
zhanghj2 committed
86
            # # Misc kernels for decoding
zhanghj2's avatar
zhanghj2 committed
87
88
            "csrc/gfx9/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu",
            "csrc/gfx9/decode/combine/combine.cu",
89

zhanghj2's avatar
zhanghj2 committed
90
91
92
            # # gfx93 dense decode
            "csrc/gfx93/decode/dense/instantiations/fp16.cu",
            "csrc/gfx93/decode/dense/instantiations/bf16.cu",
93

zhanghj2's avatar
zhanghj2 committed
94
95
            ## gfx93 dense qkvfp8 decode
            "csrc/gfx93/decode/dense_qkvfp8/instantiations/fp8e4m3.cu",
zhanghj2's avatar
zhanghj2 committed
96

zhanghj2's avatar
zhanghj2 committed
97
98
            ## gfx93 dense kvfp8 decode
            "csrc/gfx93/decode/dense_kvfp8/instantiations/kvfp8.cu",
zhanghj2's avatar
zhanghj2 committed
99
            
zhanghj2's avatar
zhanghj2 committed
100
101
102
103
104
105
106
            # # gfx93 sparse decode
            "csrc/gfx93/decode/sparse_fp8/instantiations/model1_persistent_h16.cu",
            "csrc/gfx93/decode/sparse_fp8/instantiations/model1_persistent_h64.cu",
            "csrc/gfx93/decode/sparse_fp8/instantiations/model1_persistent_h128.cu",
            "csrc/gfx93/decode/sparse_fp8/instantiations/v32_persistent_h16.cu",
            "csrc/gfx93/decode/sparse_fp8/instantiations/v32_persistent_h64.cu",
            "csrc/gfx93/decode/sparse_fp8/instantiations/v32_persistent_h128.cu",
shenzhe's avatar
shenzhe committed
107
            "csrc/gfx93/decode/sparse_bf16_dsa/fwd.cu",
zhanghj2's avatar
zhanghj2 committed
108
109
110
111
112
113
114

            # # gfx93 sparse prefill
            "csrc/gfx93/prefill/sparse/fwd.cu",
            "csrc/gfx93/prefill/sparse/instantiations/phase1_k512.cu",
            "csrc/gfx93/prefill/sparse/instantiations/phase1_k512_topklen.cu",
            "csrc/gfx93/prefill/sparse/instantiations/phase1_k576.cu",
            "csrc/gfx93/prefill/sparse/instantiations/phase1_k576_topklen.cu",
shenzhe's avatar
shenzhe committed
115
            "csrc/gfx93/prefill/sparse/dsa_mls/fwd.cu",
116

zhanghj2's avatar
zhanghj2 committed
117
118
119
120
121
122
            "csrc/extension/flash_fwd_mla_bf16_gfx936.cu",
            "csrc/extension/flash_fwd_mla_fp16_gfx936.cu",
            "csrc/extension/flash_fwd_mla_fp8_gfx938.cu",
            "csrc/extension/flash_fwd_mla_fp8_qbf16_gfx938.cu",
            "csrc/extension/flash_fwd_mla_metadata.cu",

123
        ],
Jiashi Li's avatar
Jiashi Li committed
124
        extra_compile_args={
Sijia Chen's avatar
Sijia Chen committed
125
            "cxx": cxx_args + get_features_args(),
126
127
            "nvcc": [
                "-O3",
zhanghj2's avatar
zhanghj2 committed
128
                "-std=c++20",
129
                "-DNDEBUG",
zhanghj2's avatar
zhanghj2 committed
130
131
132
133
134
                "-DHIP_ENABLE_WARP_SYNC_BUILTINS",
                "-ffast-math",
                "-ftemplate-backtrace-limit=0",
                "-Rpass-analysis=kernel-resource-usage",
                "-DDCU_ASM",
zhanghj2's avatar
zhanghj2 committed
135
                # "--save-temps",
zhanghj2's avatar
zhanghj2 committed
136
                "-w",
zhanghj2's avatar
zhanghj2 committed
137
138
139
                # "-mllvm -enable-num-vgprs-512=true",
                # "-mllvm -allow-cse-cross-bb-convergent-call=true",
                # "-mllvm -full-vectorize-slp=true",
zhanghj2's avatar
zhanghj2 committed
140
            ] + get_features_args() + get_arch_flags() + aicc_flags
Jiashi Li's avatar
Jiashi Li committed
141
142
        },
        include_dirs=[
143
            Path(this_dir) / "csrc",
144
            Path(this_dir) / "csrc" / "kerutils" / "include",   # TODO Remove me
zhanghj2's avatar
zhanghj2 committed
145
            Path(this_dir) / "csrc" / "gfx93",
shenzhe's avatar
shenzhe committed
146
147
            Path(this_dir) / "csrc" / "cutlass" / "cutlass_3.2.1" / "include",
            Path(this_dir) / "csrc" / "gfx93" / "prefill" / "sparse" / "dsa_mls" / "legacy" / "include",
Jiashi Li's avatar
Jiashi Li committed
148
149
150
151
        ],
    )
)

zhanghj2's avatar
zhanghj2 committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def _find_rocm_home() -> Optional[str]:
    rocm_home = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH')
    if rocm_home is None:
        try:
            pipe_hipcc = subprocess.Popen(
                ["which hipcc | xargs readlink -f"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
            hipcc, _ = pipe_hipcc.communicate()
            rocm_home = os.path.dirname(os.path.dirname(hipcc.decode(*()).rstrip('\r\n')))
            if os.path.basename(rocm_home) == 'hip':
                rocm_home = os.path.dirname(rocm_home)
        except Exception:
            rocm_home = '/opt/rocm'
            if not os.path.exists(rocm_home):
                rocm_home = None
    if rocm_home and torch.version.hip is None:
        print(f"No ROCm runtime is found, using ROCM_HOME='{rocm_home}'")
    return rocm_home
ROCM_HOME = _find_rocm_home()
Jiashi Li's avatar
Jiashi Li committed
170

171
172
173
174
pytorch_dep = 'torch'
if os.getenv('PYTORCH_VERSION'):
    pytorch_dep += "==" + os.getenv('PYTORCH_VERSION')

Jiashi Li's avatar
Jiashi Li committed
175
setup(
zhanghj2's avatar
zhanghj2 committed
176
177
    name="flash_mla",   
    version=get_version(ROCM_HOME),
Jiashi Li's avatar
Jiashi Li committed
178
179
    packages=find_packages(include=['flash_mla']),
    ext_modules=ext_modules,
zhanghj2's avatar
zhanghj2 committed
180
    package_data={"flash_mla":["asm/*.co"]},
Jiashi Li's avatar
Jiashi Li committed
181
    cmdclass={"build_ext": BuildExtension},
182
    install_requires=[pytorch_dep],
Jiashi Li's avatar
Jiashi Li committed
183
)