setup.py 22.9 KB
Newer Older
1
2
# Copyright (c) 2023, Tri Dao.

Tri Dao's avatar
Tri Dao committed
3
4
5
import sys
import warnings
import os
6
7
import re
import ast
8
9
import glob
import shutil
Tri Dao's avatar
Tri Dao committed
10
from pathlib import Path
Tri Dao's avatar
Tri Dao committed
11
from packaging.version import parse, Version
12
import platform
Tri Dao's avatar
Tri Dao committed
13
14
15
16

from setuptools import setup, find_packages
import subprocess

Pierce Freeman's avatar
Pierce Freeman committed
17
18
import urllib.request
import urllib.error
Tri Dao's avatar
Tri Dao committed
19
20
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel

Tri Dao's avatar
Tri Dao committed
21
import torch
22
23
24
25
26
from torch.utils.cpp_extension import (
    BuildExtension,
    CppExtension,
    CUDAExtension,
    CUDA_HOME,
27
28
    ROCM_HOME,
    IS_HIP_EXTENSION,
29
)
Tri Dao's avatar
Tri Dao committed
30
31
32
33
34


with open("README.md", "r", encoding="utf-8") as fh:
    long_description = fh.read()

Tri Dao's avatar
Tri Dao committed
35
36
37
38

# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))

39
40
41
42
43
44
45
46
47
48
49
50
51
BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto")

if BUILD_TARGET == "auto":
    if IS_HIP_EXTENSION:
        IS_ROCM = True
    else:
        IS_ROCM = False
else:
    if BUILD_TARGET == "cuda":
        IS_ROCM = False
    elif BUILD_TARGET == "rocm":
        IS_ROCM = True

52
PACKAGE_NAME = "flash_attn"
Tri Dao's avatar
Tri Dao committed
53

54
55
56
BASE_WHEEL_URL = (
    "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
)
57
58
59
60
61

# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE"
SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
Tri Dao's avatar
Tri Dao committed
62
63
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE"
64
65


66
67
def get_platform():
    """
68
    Returns the platform name as used in wheel filenames.
69
    """
70
    if sys.platform.startswith("linux"):
71
        return f'linux_{platform.uname().machine}'
72
73
74
75
76
    elif sys.platform == "darwin":
        mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
        return f"macosx_{mac_version}_x86_64"
    elif sys.platform == "win32":
        return "win_amd64"
77
    else:
78
        raise ValueError("Unsupported platform: {}".format(sys.platform))
79

Tri Dao's avatar
Tri Dao committed
80
81
82
83
84

def get_cuda_bare_metal_version(cuda_dir):
    raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
    output = raw_output.split()
    release_idx = output.index("release") + 1
Tri Dao's avatar
Tri Dao committed
85
    bare_metal_version = parse(output[release_idx].split(",")[0])
Tri Dao's avatar
Tri Dao committed
86

Tri Dao's avatar
Tri Dao committed
87
    return raw_output, bare_metal_version
Tri Dao's avatar
Tri Dao committed
88
89


90
def check_if_cuda_home_none(global_option: str) -> None:
Tri Dao's avatar
Tri Dao committed
91
92
    if CUDA_HOME is not None:
        return
93
94
95
    # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
    # in that case.
    warnings.warn(
Tri Dao's avatar
Tri Dao committed
96
97
98
99
100
101
        f"{global_option} was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  "
        "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
        "only images whose names contain 'devel' will provide nvcc."
    )


102
103
104
105
106
107
108
109
110
111
def check_if_rocm_home_none(global_option: str) -> None:
    if ROCM_HOME is not None:
        return
    # warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary
    # in that case.
    warnings.warn(
        f"{global_option} was requested, but hipcc was not found."
    )


Tri Dao's avatar
Tri Dao committed
112
def append_nvcc_threads(nvcc_extra_args):
113
114
    nvcc_threads = os.getenv("NVCC_THREADS") or "4"
    return nvcc_extra_args + ["--threads", nvcc_threads]
Tri Dao's avatar
Tri Dao committed
115
116


117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def rename_cpp_to_cu(cpp_files):
    for entry in cpp_files:
        shutil.copy(entry, os.path.splitext(entry)[0] + ".cu")


def validate_and_update_archs(archs):
    # List of allowed architectures
    allowed_archs = ["native", "gfx90a", "gfx940", "gfx941", "gfx942"]

    # Validate if each element in archs is in allowed_archs
    assert all(
        arch in allowed_archs for arch in archs
    ), f"One of GPU archs of {archs} is invalid or not supported by Flash-Attention"


Tri Dao's avatar
Tri Dao committed
132
133
134
cmdclass = {}
ext_modules = []

Tri Dao's avatar
Tri Dao committed
135
136
# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp
# files included in the source distribution, in case the user compiles from source.
137
138
139
140
if IS_ROCM:
    subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"])
else:
    subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
Tri Dao's avatar
Tri Dao committed
141

142
if not SKIP_CUDA_BUILD and not IS_ROCM:
143
144
145
146
147
148
149
150
151
152
153
    print("\n\ntorch.__version__  = {}\n\n".format(torch.__version__))
    TORCH_MAJOR = int(torch.__version__.split(".")[0])
    TORCH_MINOR = int(torch.__version__.split(".")[1])

    # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
    # See https://github.com/pytorch/pytorch/pull/70650
    generator_flag = []
    torch_dir = torch.__path__[0]
    if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
        generator_flag = ["-DOLD_GENERATOR_PATH"]

154
    check_if_cuda_home_none("flash_attn")
155
156
    # Check, if CUDA11 is installed for compute capability 8.0
    cc_flag = []
157
158
159
    if CUDA_HOME is not None:
        _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
        if bare_metal_version < Version("11.6"):
160
161
162
163
            raise RuntimeError(
                "FlashAttention is only supported on CUDA 11.6 and above.  "
                "Note: make sure nvcc has a supported version by running nvcc -V."
            )
164
165
    # cc_flag.append("-gencode")
    # cc_flag.append("arch=compute_75,code=sm_75")
Tri Dao's avatar
Tri Dao committed
166
    cc_flag.append("-gencode")
167
    cc_flag.append("arch=compute_80,code=sm_80")
Tri Dao's avatar
Tri Dao committed
168
169
170
171
    if CUDA_HOME is not None:
        if bare_metal_version >= Version("11.8"):
            cc_flag.append("-gencode")
            cc_flag.append("arch=compute_90,code=sm_90")
172

Tri Dao's avatar
Tri Dao committed
173
174
175
176
177
    # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
    # torch._C._GLIBCXX_USE_CXX11_ABI
    # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
    if FORCE_CXX11_ABI:
        torch._C._GLIBCXX_USE_CXX11_ABI = True
178
179
    ext_modules.append(
        CUDAExtension(
180
            name="flash_attn_2_cuda",
181
            sources=[
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
                "csrc/flash_attn/flash_api.cpp",
                "csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
197
198
199
200
201
202
203
204
205
206
207
208
209
210
                "csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
211
212
213
214
215
216
217
218
219
220
221
222
223
224
                "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
225
226
227
228
229
230
231
232
233
234
235
236
237
238
                "csrc/flash_attn/src/flash_bwd_hdim32_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim32_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim64_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim64_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim96_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim96_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim128_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim128_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim160_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim160_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim192_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim192_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim256_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim256_bf16_causal_sm80.cu",
Tri Dao's avatar
Tri Dao committed
239
240
241
242
243
244
245
246
247
248
249
250
251
252
                "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
253
254
255
256
257
258
259
260
261
262
263
264
265
266
                "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
267
268
269
270
271
272
273
274
275
276
277
278
279
280
            ],
            extra_compile_args={
                "cxx": ["-O3", "-std=c++17"] + generator_flag,
                "nvcc": append_nvcc_threads(
                    [
                        "-O3",
                        "-std=c++17",
                        "-U__CUDA_NO_HALF_OPERATORS__",
                        "-U__CUDA_NO_HALF_CONVERSIONS__",
                        "-U__CUDA_NO_HALF2_OPERATORS__",
                        "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
                        "--expt-relaxed-constexpr",
                        "--expt-extended-lambda",
                        "--use_fast_math",
Tri Dao's avatar
Tri Dao committed
281
                        # "--ptxas-options=-v",
282
                        # "--ptxas-options=-O2",
283
                        # "-lineinfo",
284
285
286
                        # "-DFLASHATTENTION_DISABLE_BACKWARD",
                        # "-DFLASHATTENTION_DISABLE_DROPOUT",
                        # "-DFLASHATTENTION_DISABLE_ALIBI",
Nicolas Patry's avatar
Nicolas Patry committed
287
                        # "-DFLASHATTENTION_DISABLE_SOFTCAP",
288
289
                        # "-DFLASHATTENTION_DISABLE_UNEVEN_K",
                        # "-DFLASHATTENTION_DISABLE_LOCAL",
290
291
292
293
294
295
                    ]
                    + generator_flag
                    + cc_flag
                ),
            },
            include_dirs=[
296
297
298
                Path(this_dir) / "csrc" / "flash_attn",
                Path(this_dir) / "csrc" / "flash_attn" / "src",
                Path(this_dir) / "csrc" / "cutlass" / "include",
299
300
            ],
        )
Tri Dao's avatar
Tri Dao committed
301
    )
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
elif not SKIP_CUDA_BUILD and IS_ROCM:
    ck_dir = "csrc/composable_kernel"

    #use codegen get code dispatch
    if not os.path.exists("./build"):
        os.makedirs("build")

    os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd --output_dir build --receipt 2")
    os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d bwd --output_dir build --receipt 2")

    print("\n\ntorch.__version__  = {}\n\n".format(torch.__version__))
    TORCH_MAJOR = int(torch.__version__.split(".")[0])
    TORCH_MINOR = int(torch.__version__.split(".")[1])

    # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
    # See https://github.com/pytorch/pytorch/pull/70650
    generator_flag = []
    torch_dir = torch.__path__[0]
    if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
        generator_flag = ["-DOLD_GENERATOR_PATH"]

    check_if_rocm_home_none("flash_attn")
    cc_flag = []

    archs = os.getenv("GPU_ARCHS", "native").split(";")
    validate_and_update_archs(archs)

    cc_flag = [f"--offload-arch={arch}" for arch in archs]

    # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
    # torch._C._GLIBCXX_USE_CXX11_ABI
    # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
    if FORCE_CXX11_ABI:
        torch._C._GLIBCXX_USE_CXX11_ABI = True

    sources = ["csrc/flash_attn_ck/flash_api.cpp",
               "csrc/flash_attn_ck/mha_bwd.cpp",
               "csrc/flash_attn_ck/mha_fwd.cpp",
               "csrc/flash_attn_ck/mha_varlen_bwd.cpp",
               "csrc/flash_attn_ck/mha_varlen_fwd.cpp"] + glob.glob(
        f"build/fmha_*wd*.cpp"
    )

    rename_cpp_to_cu(sources)

    renamed_sources = ["csrc/flash_attn_ck/flash_api.cu",
                       "csrc/flash_attn_ck/mha_bwd.cu",
                       "csrc/flash_attn_ck/mha_fwd.cu",
                       "csrc/flash_attn_ck/mha_varlen_bwd.cu",
                       "csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu")
    extra_compile_args = {
        "cxx": ["-O3", "-std=c++17"] + generator_flag,
        "nvcc":
            [
                "-O3","-std=c++17",
                "-mllvm", "-enable-post-misched=0",
                "-DCK_TILE_FMHA_FWD_FAST_EXP2=1",
                "-fgpu-flush-denormals-to-zero",
                "-DCK_ENABLE_BF16",
                "-DCK_ENABLE_BF8",
                "-DCK_ENABLE_FP16",
                "-DCK_ENABLE_FP32",
                "-DCK_ENABLE_FP64",
                "-DCK_ENABLE_FP8",
                "-DCK_ENABLE_INT8",
                "-DCK_USE_XDL",
                "-DUSE_PROF_API=1",
                "-D__HIP_PLATFORM_HCC__=1",
                # "-DFLASHATTENTION_DISABLE_BACKWARD",
            ]
            + generator_flag
            + cc_flag
        ,
    }

    include_dirs = [
        Path(this_dir) / "csrc" / "composable_kernel" / "include",
        Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include",
        Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha",
    ]

    ext_modules.append(
        CUDAExtension(
            name="flash_attn_2_cuda",
            sources=renamed_sources,
            extra_compile_args=extra_compile_args,
            include_dirs=include_dirs,
        )
    )
Tri Dao's avatar
Tri Dao committed
391

Tri Dao's avatar
Tri Dao committed
392

393
394
395
396
397
398
399
400
401
402
def get_package_version():
    with open(Path(this_dir) / "flash_attn" / "__init__.py", "r") as f:
        version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
    public_version = ast.literal_eval(version_match.group(1))
    local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION")
    if local_version:
        return f"{public_version}+{local_version}"
    else:
        return str(public_version)

Tri Dao's avatar
Tri Dao committed
403

404
405
406
407
408
409
410
411
def get_wheel_url():
    torch_version_raw = parse(torch.__version__)
    python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
    platform_name = get_platform()
    flash_version = get_package_version()
    torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
    cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()

412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
    if IS_ROCM:
        torch_hip_version = parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+'))
        hip_version = f"{torch_hip_version.major}{torch_hip_version.minor}"
        wheel_filename = f"{PACKAGE_NAME}-{flash_version}+rocm{hip_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
    else:
        # Determine the version numbers that will be used to determine the correct wheel
        # We're using the CUDA version used to build torch, not the one currently installed
        # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
        torch_cuda_version = parse(torch.version.cuda)
        # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
        # to save CI time. Minor versions should be compatible.
        torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
        # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
        cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"

        # Determine wheel URL based on CUDA version, torch version, python version and OS
        wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"

430
    wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename)
431

432
433
434
    return wheel_url, wheel_filename


435
class CachedWheelsCommand(_bdist_wheel):
Tri Dao's avatar
Tri Dao committed
436
437
438
439
440
441
    """
    The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
    find an existing wheel (which is currently the case for all flash attention installs). We use
    the environment parameters to detect whether there is already a pre-built version of a compatible
    wheel available and short-circuits the standard full build pipeline.
    """
442

Tri Dao's avatar
Tri Dao committed
443
    def run(self):
444
        if FORCE_BUILD:
Pierce Freeman's avatar
Pierce Freeman committed
445
            return super().run()
446

447
        wheel_url, wheel_filename = get_wheel_url()
448
449
450
        print("Guessing wheel URL: ", wheel_url)
        try:
            urllib.request.urlretrieve(wheel_url, wheel_filename)
451
452
453
454
455
456
457
458
459

            # Make the archive
            # Lifted from the root wheel processing command
            # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
            if not os.path.exists(self.dist_dir):
                os.makedirs(self.dist_dir)

            impl_tag, abi_tag, plat_tag = self.get_tag()
            archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
460

461
462
463
            wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
            print("Raw wheel path", wheel_path)
            os.rename(wheel_filename, wheel_path)
464
        except (urllib.error.HTTPError, urllib.error.URLError):
465
466
            print("Precompiled wheel not found. Building from source...")
            # If the wheel could not be downloaded, build from source
467
            super().run()
468
469


470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
class NinjaBuildExtension(BuildExtension):
    def __init__(self, *args, **kwargs) -> None:
        # do not override env MAX_JOBS if already exists
        if not os.environ.get("MAX_JOBS"):
            import psutil

            # calculate the maximum allowed NUM_JOBS based on cores
            max_num_jobs_cores = max(1, os.cpu_count() // 2)

            # calculate the maximum allowed NUM_JOBS based on free memory
            free_memory_gb = psutil.virtual_memory().available / (1024 ** 3)  # free memory in GB
            max_num_jobs_memory = int(free_memory_gb / 9)  # each JOB peak memory cost is ~8-9GB when threads = 4

            # pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation
            max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory))
            os.environ["MAX_JOBS"] = str(max_jobs)

        super().__init__(*args, **kwargs)


Tri Dao's avatar
Tri Dao committed
490
setup(
491
    name=PACKAGE_NAME,
492
    version=get_package_version(),
Tri Dao's avatar
Tri Dao committed
493
    packages=find_packages(
494
495
496
497
498
499
500
501
502
503
        exclude=(
            "build",
            "csrc",
            "include",
            "tests",
            "dist",
            "docs",
            "benchmarks",
            "flash_attn.egg-info",
        )
Tri Dao's avatar
Tri Dao committed
504
505
    ),
    author="Tri Dao",
506
    author_email="tri@tridao.me",
Tri Dao's avatar
Tri Dao committed
507
508
509
    description="Flash Attention: Fast and Memory-Efficient Exact Attention",
    long_description=long_description,
    long_description_content_type="text/markdown",
Tri Dao's avatar
Tri Dao committed
510
    url="https://github.com/Dao-AILab/flash-attention",
Tri Dao's avatar
Tri Dao committed
511
512
    classifiers=[
        "Programming Language :: Python :: 3",
513
        "License :: OSI Approved :: BSD License",
Phil Wang's avatar
Phil Wang committed
514
        "Operating System :: Unix",
Tri Dao's avatar
Tri Dao committed
515
    ],
Tri Dao's avatar
Tri Dao committed
516
    ext_modules=ext_modules,
517
    cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": NinjaBuildExtension}
518
519
520
    if ext_modules
    else {
        "bdist_wheel": CachedWheelsCommand,
521
    },
522
    python_requires=">=3.8",
Gustaf's avatar
Gustaf committed
523
524
525
526
    install_requires=[
        "torch",
        "einops",
    ],
527
    setup_requires=[
528
529
530
        "packaging",
        "psutil",
        "ninja",
531
    ],
532
)