setup.py 21.8 KB
Newer Older
1
2
# Copyright (c) 2023, Tri Dao.

Tri Dao's avatar
Tri Dao committed
3
4
5
import sys
import warnings
import os
6
7
import re
import ast
8
9
import glob
import shutil
Tri Dao's avatar
Tri Dao committed
10
from pathlib import Path
Tri Dao's avatar
Tri Dao committed
11
from packaging.version import parse, Version
12
import platform
Tri Dao's avatar
Tri Dao committed
13
14
15
16

from setuptools import setup, find_packages
import subprocess

Pierce Freeman's avatar
Pierce Freeman committed
17
18
import urllib.request
import urllib.error
Tri Dao's avatar
Tri Dao committed
19
20
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel

Tri Dao's avatar
Tri Dao committed
21
import torch
22
23
24
25
26
from torch.utils.cpp_extension import (
    BuildExtension,
    CppExtension,
    CUDAExtension,
    CUDA_HOME,
27
28
    ROCM_HOME,
    IS_HIP_EXTENSION,
29
)
Tri Dao's avatar
Tri Dao committed
30
31
32
33
34


with open("README.md", "r", encoding="utf-8") as fh:
    long_description = fh.read()

Tri Dao's avatar
Tri Dao committed
35
36
37
38

# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))

39
40
41
42
43
44
45
46
47
48
49
50
51
BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto")

if BUILD_TARGET == "auto":
    if IS_HIP_EXTENSION:
        IS_ROCM = True
    else:
        IS_ROCM = False
else:
    if BUILD_TARGET == "cuda":
        IS_ROCM = False
    elif BUILD_TARGET == "rocm":
        IS_ROCM = True

52
PACKAGE_NAME = "flash_attn"
Tri Dao's avatar
Tri Dao committed
53

54
55
56
BASE_WHEEL_URL = (
    "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
)
57
58
59
60
61

# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE"
SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
Tri Dao's avatar
Tri Dao committed
62
63
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE"
64
65


66
67
def get_platform():
    """
68
    Returns the platform name as used in wheel filenames.
69
    """
70
    if sys.platform.startswith("linux"):
71
        return f'linux_{platform.uname().machine}'
72
73
74
75
76
    elif sys.platform == "darwin":
        mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
        return f"macosx_{mac_version}_x86_64"
    elif sys.platform == "win32":
        return "win_amd64"
77
    else:
78
        raise ValueError("Unsupported platform: {}".format(sys.platform))
79

Tri Dao's avatar
Tri Dao committed
80
81
82
83
84

def get_cuda_bare_metal_version(cuda_dir):
    raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
    output = raw_output.split()
    release_idx = output.index("release") + 1
Tri Dao's avatar
Tri Dao committed
85
    bare_metal_version = parse(output[release_idx].split(",")[0])
Tri Dao's avatar
Tri Dao committed
86

Tri Dao's avatar
Tri Dao committed
87
    return raw_output, bare_metal_version
Tri Dao's avatar
Tri Dao committed
88
89


90
def check_if_cuda_home_none(global_option: str) -> None:
Tri Dao's avatar
Tri Dao committed
91
92
    if CUDA_HOME is not None:
        return
93
94
95
    # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
    # in that case.
    warnings.warn(
Tri Dao's avatar
Tri Dao committed
96
97
98
99
100
101
        f"{global_option} was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  "
        "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
        "only images whose names contain 'devel' will provide nvcc."
    )


102
103
104
105
106
107
108
109
110
111
def check_if_rocm_home_none(global_option: str) -> None:
    if ROCM_HOME is not None:
        return
    # warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary
    # in that case.
    warnings.warn(
        f"{global_option} was requested, but hipcc was not found."
    )


Tri Dao's avatar
Tri Dao committed
112
def append_nvcc_threads(nvcc_extra_args):
113
114
    nvcc_threads = os.getenv("NVCC_THREADS") or "4"
    return nvcc_extra_args + ["--threads", nvcc_threads]
Tri Dao's avatar
Tri Dao committed
115
116


117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def rename_cpp_to_cu(cpp_files):
    for entry in cpp_files:
        shutil.copy(entry, os.path.splitext(entry)[0] + ".cu")


def validate_and_update_archs(archs):
    # List of allowed architectures
    allowed_archs = ["native", "gfx90a", "gfx940", "gfx941", "gfx942"]

    # Validate if each element in archs is in allowed_archs
    assert all(
        arch in allowed_archs for arch in archs
    ), f"One of GPU archs of {archs} is invalid or not supported by Flash-Attention"


Tri Dao's avatar
Tri Dao committed
132
133
134
cmdclass = {}
ext_modules = []

Tri Dao's avatar
Tri Dao committed
135
136
# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp
# files included in the source distribution, in case the user compiles from source.
137
138
139
140
if IS_ROCM:
    subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"])
else:
    subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
Tri Dao's avatar
Tri Dao committed
141

142
if not SKIP_CUDA_BUILD and not IS_ROCM:
143
144
145
146
147
148
149
150
151
152
153
    print("\n\ntorch.__version__  = {}\n\n".format(torch.__version__))
    TORCH_MAJOR = int(torch.__version__.split(".")[0])
    TORCH_MINOR = int(torch.__version__.split(".")[1])

    # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
    # See https://github.com/pytorch/pytorch/pull/70650
    generator_flag = []
    torch_dir = torch.__path__[0]
    if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
        generator_flag = ["-DOLD_GENERATOR_PATH"]

154
    check_if_cuda_home_none("flash_attn")
155
156
    # Check, if CUDA11 is installed for compute capability 8.0
    cc_flag = []
157
158
159
    if CUDA_HOME is not None:
        _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
        if bare_metal_version < Version("11.6"):
160
161
162
163
            raise RuntimeError(
                "FlashAttention is only supported on CUDA 11.6 and above.  "
                "Note: make sure nvcc has a supported version by running nvcc -V."
            )
164
165
    # cc_flag.append("-gencode")
    # cc_flag.append("arch=compute_75,code=sm_75")
Tri Dao's avatar
Tri Dao committed
166
    cc_flag.append("-gencode")
167
    cc_flag.append("arch=compute_80,code=sm_80")
Tri Dao's avatar
Tri Dao committed
168
169
170
171
    if CUDA_HOME is not None:
        if bare_metal_version >= Version("11.8"):
            cc_flag.append("-gencode")
            cc_flag.append("arch=compute_90,code=sm_90")
172

Tri Dao's avatar
Tri Dao committed
173
174
175
176
177
    # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
    # torch._C._GLIBCXX_USE_CXX11_ABI
    # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
    if FORCE_CXX11_ABI:
        torch._C._GLIBCXX_USE_CXX11_ABI = True
178
179
    ext_modules.append(
        CUDAExtension(
180
            name="flash_attn_2_cuda",
181
            sources=[
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
                "csrc/flash_attn/flash_api.cpp",
                "csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
197
198
199
200
201
202
203
204
205
206
207
208
209
210
                "csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
211
212
213
214
215
216
217
218
219
220
221
222
223
224
                "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
Tri Dao's avatar
Tri Dao committed
225
226
227
228
229
230
231
232
233
234
235
236
237
238
                "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
239
240
241
242
243
244
245
246
247
248
249
250
251
252
                "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
253
254
255
256
257
258
259
260
261
262
263
264
265
266
            ],
            extra_compile_args={
                "cxx": ["-O3", "-std=c++17"] + generator_flag,
                "nvcc": append_nvcc_threads(
                    [
                        "-O3",
                        "-std=c++17",
                        "-U__CUDA_NO_HALF_OPERATORS__",
                        "-U__CUDA_NO_HALF_CONVERSIONS__",
                        "-U__CUDA_NO_HALF2_OPERATORS__",
                        "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
                        "--expt-relaxed-constexpr",
                        "--expt-extended-lambda",
                        "--use_fast_math",
Tri Dao's avatar
Tri Dao committed
267
                        # "--ptxas-options=-v",
268
                        # "--ptxas-options=-O2",
269
                        # "-lineinfo",
270
271
272
                        # "-DFLASHATTENTION_DISABLE_BACKWARD",
                        # "-DFLASHATTENTION_DISABLE_DROPOUT",
                        # "-DFLASHATTENTION_DISABLE_ALIBI",
Nicolas Patry's avatar
Nicolas Patry committed
273
                        # "-DFLASHATTENTION_DISABLE_SOFTCAP",
274
275
                        # "-DFLASHATTENTION_DISABLE_UNEVEN_K",
                        # "-DFLASHATTENTION_DISABLE_LOCAL",
276
277
278
279
280
281
                    ]
                    + generator_flag
                    + cc_flag
                ),
            },
            include_dirs=[
282
283
284
                Path(this_dir) / "csrc" / "flash_attn",
                Path(this_dir) / "csrc" / "flash_attn" / "src",
                Path(this_dir) / "csrc" / "cutlass" / "include",
285
286
            ],
        )
Tri Dao's avatar
Tri Dao committed
287
    )
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
elif not SKIP_CUDA_BUILD and IS_ROCM:
    ck_dir = "csrc/composable_kernel"

    #use codegen get code dispatch
    if not os.path.exists("./build"):
        os.makedirs("build")

    os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd --output_dir build --receipt 2")
    os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d bwd --output_dir build --receipt 2")

    print("\n\ntorch.__version__  = {}\n\n".format(torch.__version__))
    TORCH_MAJOR = int(torch.__version__.split(".")[0])
    TORCH_MINOR = int(torch.__version__.split(".")[1])

    # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
    # See https://github.com/pytorch/pytorch/pull/70650
    generator_flag = []
    torch_dir = torch.__path__[0]
    if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
        generator_flag = ["-DOLD_GENERATOR_PATH"]

    check_if_rocm_home_none("flash_attn")
    cc_flag = []

    archs = os.getenv("GPU_ARCHS", "native").split(";")
    validate_and_update_archs(archs)

    cc_flag = [f"--offload-arch={arch}" for arch in archs]

    # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
    # torch._C._GLIBCXX_USE_CXX11_ABI
    # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
    if FORCE_CXX11_ABI:
        torch._C._GLIBCXX_USE_CXX11_ABI = True

    sources = ["csrc/flash_attn_ck/flash_api.cpp",
               "csrc/flash_attn_ck/mha_bwd.cpp",
               "csrc/flash_attn_ck/mha_fwd.cpp",
               "csrc/flash_attn_ck/mha_varlen_bwd.cpp",
               "csrc/flash_attn_ck/mha_varlen_fwd.cpp"] + glob.glob(
        f"build/fmha_*wd*.cpp"
    )

    rename_cpp_to_cu(sources)

    renamed_sources = ["csrc/flash_attn_ck/flash_api.cu",
                       "csrc/flash_attn_ck/mha_bwd.cu",
                       "csrc/flash_attn_ck/mha_fwd.cu",
                       "csrc/flash_attn_ck/mha_varlen_bwd.cu",
                       "csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu")
    extra_compile_args = {
        "cxx": ["-O3", "-std=c++17"] + generator_flag,
        "nvcc":
            [
                "-O3","-std=c++17",
                "-mllvm", "-enable-post-misched=0",
                "-DCK_TILE_FMHA_FWD_FAST_EXP2=1",
                "-fgpu-flush-denormals-to-zero",
                "-DCK_ENABLE_BF16",
                "-DCK_ENABLE_BF8",
                "-DCK_ENABLE_FP16",
                "-DCK_ENABLE_FP32",
                "-DCK_ENABLE_FP64",
                "-DCK_ENABLE_FP8",
                "-DCK_ENABLE_INT8",
                "-DCK_USE_XDL",
                "-DUSE_PROF_API=1",
                "-D__HIP_PLATFORM_HCC__=1",
                # "-DFLASHATTENTION_DISABLE_BACKWARD",
            ]
            + generator_flag
            + cc_flag
        ,
    }

    include_dirs = [
        Path(this_dir) / "csrc" / "composable_kernel" / "include",
        Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include",
        Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha",
    ]

    ext_modules.append(
        CUDAExtension(
            name="flash_attn_2_cuda",
            sources=renamed_sources,
            extra_compile_args=extra_compile_args,
            include_dirs=include_dirs,
        )
    )
Tri Dao's avatar
Tri Dao committed
377

Tri Dao's avatar
Tri Dao committed
378

379
380
381
382
383
384
385
386
387
388
def get_package_version():
    with open(Path(this_dir) / "flash_attn" / "__init__.py", "r") as f:
        version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
    public_version = ast.literal_eval(version_match.group(1))
    local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION")
    if local_version:
        return f"{public_version}+{local_version}"
    else:
        return str(public_version)

Tri Dao's avatar
Tri Dao committed
389

390
391
392
393
394
395
396
397
def get_wheel_url():
    torch_version_raw = parse(torch.__version__)
    python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
    platform_name = get_platform()
    flash_version = get_package_version()
    torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
    cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()

398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
    if IS_ROCM:
        torch_hip_version = parse(torch.version.hip.split()[-1].rstrip('-').replace('-', '+'))
        hip_version = f"{torch_hip_version.major}{torch_hip_version.minor}"
        wheel_filename = f"{PACKAGE_NAME}-{flash_version}+rocm{hip_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
    else:
        # Determine the version numbers that will be used to determine the correct wheel
        # We're using the CUDA version used to build torch, not the one currently installed
        # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
        torch_cuda_version = parse(torch.version.cuda)
        # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
        # to save CI time. Minor versions should be compatible.
        torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
        # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
        cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"

        # Determine wheel URL based on CUDA version, torch version, python version and OS
        wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"

416
    wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename)
417

418
419
420
    return wheel_url, wheel_filename


421
class CachedWheelsCommand(_bdist_wheel):
Tri Dao's avatar
Tri Dao committed
422
423
424
425
426
427
    """
    The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
    find an existing wheel (which is currently the case for all flash attention installs). We use
    the environment parameters to detect whether there is already a pre-built version of a compatible
    wheel available and short-circuits the standard full build pipeline.
    """
428

Tri Dao's avatar
Tri Dao committed
429
    def run(self):
430
        if FORCE_BUILD:
Pierce Freeman's avatar
Pierce Freeman committed
431
            return super().run()
432

433
        wheel_url, wheel_filename = get_wheel_url()
434
435
436
        print("Guessing wheel URL: ", wheel_url)
        try:
            urllib.request.urlretrieve(wheel_url, wheel_filename)
437
438
439
440
441
442
443
444
445

            # Make the archive
            # Lifted from the root wheel processing command
            # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
            if not os.path.exists(self.dist_dir):
                os.makedirs(self.dist_dir)

            impl_tag, abi_tag, plat_tag = self.get_tag()
            archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
446

447
448
449
            wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
            print("Raw wheel path", wheel_path)
            os.rename(wheel_filename, wheel_path)
450
        except (urllib.error.HTTPError, urllib.error.URLError):
451
452
            print("Precompiled wheel not found. Building from source...")
            # If the wheel could not be downloaded, build from source
453
            super().run()
454
455


456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
class NinjaBuildExtension(BuildExtension):
    def __init__(self, *args, **kwargs) -> None:
        # do not override env MAX_JOBS if already exists
        if not os.environ.get("MAX_JOBS"):
            import psutil

            # calculate the maximum allowed NUM_JOBS based on cores
            max_num_jobs_cores = max(1, os.cpu_count() // 2)

            # calculate the maximum allowed NUM_JOBS based on free memory
            free_memory_gb = psutil.virtual_memory().available / (1024 ** 3)  # free memory in GB
            max_num_jobs_memory = int(free_memory_gb / 9)  # each JOB peak memory cost is ~8-9GB when threads = 4

            # pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation
            max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory))
            os.environ["MAX_JOBS"] = str(max_jobs)

        super().__init__(*args, **kwargs)


Tri Dao's avatar
Tri Dao committed
476
setup(
477
    name=PACKAGE_NAME,
478
    version=get_package_version(),
Tri Dao's avatar
Tri Dao committed
479
    packages=find_packages(
480
481
482
483
484
485
486
487
488
489
        exclude=(
            "build",
            "csrc",
            "include",
            "tests",
            "dist",
            "docs",
            "benchmarks",
            "flash_attn.egg-info",
        )
Tri Dao's avatar
Tri Dao committed
490
491
    ),
    author="Tri Dao",
492
    author_email="tri@tridao.me",
Tri Dao's avatar
Tri Dao committed
493
494
495
    description="Flash Attention: Fast and Memory-Efficient Exact Attention",
    long_description=long_description,
    long_description_content_type="text/markdown",
Tri Dao's avatar
Tri Dao committed
496
    url="https://github.com/Dao-AILab/flash-attention",
Tri Dao's avatar
Tri Dao committed
497
498
    classifiers=[
        "Programming Language :: Python :: 3",
499
        "License :: OSI Approved :: BSD License",
Phil Wang's avatar
Phil Wang committed
500
        "Operating System :: Unix",
Tri Dao's avatar
Tri Dao committed
501
    ],
Tri Dao's avatar
Tri Dao committed
502
    ext_modules=ext_modules,
503
    cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": NinjaBuildExtension}
504
505
506
    if ext_modules
    else {
        "bdist_wheel": CachedWheelsCommand,
507
    },
508
    python_requires=">=3.8",
Gustaf's avatar
Gustaf committed
509
510
511
512
    install_requires=[
        "torch",
        "einops",
    ],
513
    setup_requires=[
514
515
516
        "packaging",
        "psutil",
        "ninja",
517
    ],
518
)