setup.py 14.1 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py
import sys
import warnings
import os
5
6
import re
import ast
Tri Dao's avatar
Tri Dao committed
7
from pathlib import Path
Tri Dao's avatar
Tri Dao committed
8
from packaging.version import parse, Version
9
import platform
Tri Dao's avatar
Tri Dao committed
10
11
12
13

from setuptools import setup, find_packages
import subprocess

Pierce Freeman's avatar
Pierce Freeman committed
14
15
import urllib.request
import urllib.error
Tri Dao's avatar
Tri Dao committed
16
17
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel

Tri Dao's avatar
Tri Dao committed
18
19
20
21
22
23
24
import torch
from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME


with open("README.md", "r", encoding="utf-8") as fh:
    long_description = fh.read()

Tri Dao's avatar
Tri Dao committed
25
26
27
28

# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))

29
PACKAGE_NAME = "flash_attn"
Tri Dao's avatar
Tri Dao committed
30

31
BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
32
33
34
35
36

# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE"
SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
Tri Dao's avatar
Tri Dao committed
37
38
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE"
39
40


41
42
def get_platform():
    """
43
    Returns the platform name as used in wheel filenames.
44
45
46
47
    """
    if sys.platform.startswith('linux'):
        return 'linux_x86_64'
    elif sys.platform == 'darwin':
48
49
        mac_version = '.'.join(platform.mac_ver()[0].split('.')[:2])
        return f'macosx_{mac_version}_x86_64'
50
51
52
53
54
    elif sys.platform == 'win32':
        return 'win_amd64'
    else:
        raise ValueError('Unsupported platform: {}'.format(sys.platform))

Tri Dao's avatar
Tri Dao committed
55
56
57
58
59

def get_cuda_bare_metal_version(cuda_dir):
    raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
    output = raw_output.split()
    release_idx = output.index("release") + 1
Tri Dao's avatar
Tri Dao committed
60
    bare_metal_version = parse(output[release_idx].split(",")[0])
Tri Dao's avatar
Tri Dao committed
61

Tri Dao's avatar
Tri Dao committed
62
    return raw_output, bare_metal_version
Tri Dao's avatar
Tri Dao committed
63
64
65


def check_cuda_torch_binary_vs_bare_metal(cuda_dir):
Tri Dao's avatar
Tri Dao committed
66
67
    raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir)
    torch_binary_version = parse(torch.version.cuda)
Tri Dao's avatar
Tri Dao committed
68
69
70
71

    print("\nCompiling cuda extensions with")
    print(raw_output + "from " + cuda_dir + "/bin\n")

Tri Dao's avatar
Tri Dao committed
72
    if (bare_metal_version != torch_binary_version):
Tri Dao's avatar
Tri Dao committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        raise RuntimeError(
            "Cuda extensions are being compiled with a version of Cuda that does "
            "not match the version used to compile Pytorch binaries.  "
            "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda)
            + "In some cases, a minor-version mismatch will not cause later errors:  "
            "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798.  "
            "You can try commenting out this check (at your own risk)."
        )


def raise_if_cuda_home_none(global_option: str) -> None:
    if CUDA_HOME is not None:
        return
    raise RuntimeError(
        f"{global_option} was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  "
        "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
        "only images whose names contain 'devel' will provide nvcc."
    )


def append_nvcc_threads(nvcc_extra_args):
Tri Dao's avatar
Tri Dao committed
94
95
    _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
    if bare_metal_version >= Version("11.2"):
Tri Dao's avatar
Tri Dao committed
96
97
98
99
100
101
102
103
104
105
106
        return nvcc_extra_args + ["--threads", "4"]
    return nvcc_extra_args


if not torch.cuda.is_available():
    # https://github.com/NVIDIA/apex/issues/486
    # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(),
    # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command).
    print(
        "\nWarning: Torch did not find available GPUs on this system.\n",
        "If your intention is to cross-compile, this is not an error.\n"
Tri Dao's avatar
Tri Dao committed
107
108
        "By default, FlashAttention will cross-compile for Ampere (compute capability 8.0, 8.6, "
        "8.9), and, if the CUDA version is >= 11.8, Hopper (compute capability 9.0).\n"
Tri Dao's avatar
Tri Dao committed
109
110
111
        "If you wish to cross-compile for a single specific architecture,\n"
        'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n',
    )
Tri Dao's avatar
Tri Dao committed
112
113
114
    if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None:
        _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
        if bare_metal_version >= Version("11.8"):
Tri Dao's avatar
Tri Dao committed
115
116
117
            os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;9.0"
        elif bare_metal_version >= Version("11.4"):
            os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
Tri Dao's avatar
Tri Dao committed
118
        else:
Tri Dao's avatar
Tri Dao committed
119
            os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
Tri Dao's avatar
Tri Dao committed
120

Tri Dao's avatar
Tri Dao committed
121
122
123
cmdclass = {}
ext_modules = []

Tri Dao's avatar
Tri Dao committed
124
125
126
127
# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp
# files included in the source distribution, in case the user compiles from source.
subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
if not SKIP_CUDA_BUILD:
    print("\n\ntorch.__version__  = {}\n\n".format(torch.__version__))
    TORCH_MAJOR = int(torch.__version__.split(".")[0])
    TORCH_MINOR = int(torch.__version__.split(".")[1])

    # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
    # See https://github.com/pytorch/pytorch/pull/70650
    generator_flag = []
    torch_dir = torch.__path__[0]
    if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
        generator_flag = ["-DOLD_GENERATOR_PATH"]

    raise_if_cuda_home_none("flash_attn")
    # Check, if CUDA11 is installed for compute capability 8.0
    cc_flag = []
    _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
Tri Dao's avatar
Tri Dao committed
144
145
    if bare_metal_version < Version("11.4"):
        raise RuntimeError("FlashAttention is only supported on CUDA 11.4 and above")
146
147
    # cc_flag.append("-gencode")
    # cc_flag.append("arch=compute_75,code=sm_75")
Tri Dao's avatar
Tri Dao committed
148
    cc_flag.append("-gencode")
149
150
151
152
153
    cc_flag.append("arch=compute_80,code=sm_80")
    if bare_metal_version >= Version("11.8"):
        cc_flag.append("-gencode")
        cc_flag.append("arch=compute_90,code=sm_90")

Tri Dao's avatar
Tri Dao committed
154
155
156
157
158
    # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
    # torch._C._GLIBCXX_USE_CXX11_ABI
    # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
    if FORCE_CXX11_ABI:
        torch._C._GLIBCXX_USE_CXX11_ABI = True
159
160
    ext_modules.append(
        CUDAExtension(
161
            name="flash_attn_2_cuda",
162
            sources=[
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
                "csrc/flash_attn/flash_api.cpp",
                "csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu",
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
            ],
            extra_compile_args={
                "cxx": ["-O3", "-std=c++17"] + generator_flag,
                "nvcc": append_nvcc_threads(
                    [
                        "-O3",
                        "-std=c++17",
                        "-U__CUDA_NO_HALF_OPERATORS__",
                        "-U__CUDA_NO_HALF_CONVERSIONS__",
                        "-U__CUDA_NO_HALF2_OPERATORS__",
                        "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
                        "--expt-relaxed-constexpr",
                        "--expt-extended-lambda",
                        "--use_fast_math",
                        "--ptxas-options=-v",
211
                        # "--ptxas-options=-O2",
212
213
214
215
216
217
218
219
220
                        "-lineinfo"
                    ]
                    + generator_flag
                    + cc_flag
                ),
            },
            include_dirs=[
                Path(this_dir) / 'csrc' / 'flash_attn',
                Path(this_dir) / 'csrc' / 'flash_attn' / 'src',
221
                Path(this_dir) / 'csrc' / 'cutlass' / 'include',
222
223
            ],
        )
Tri Dao's avatar
Tri Dao committed
224
    )
Tri Dao's avatar
Tri Dao committed
225

Tri Dao's avatar
Tri Dao committed
226

227
228
229
230
231
232
233
234
235
236
def get_package_version():
    with open(Path(this_dir) / "flash_attn" / "__init__.py", "r") as f:
        version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
    public_version = ast.literal_eval(version_match.group(1))
    local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION")
    if local_version:
        return f"{public_version}+{local_version}"
    else:
        return str(public_version)

Tri Dao's avatar
Tri Dao committed
237

238
class CachedWheelsCommand(_bdist_wheel):
Tri Dao's avatar
Tri Dao committed
239
240
241
242
243
244
245
    """
    The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
    find an existing wheel (which is currently the case for all flash attention installs). We use
    the environment parameters to detect whether there is already a pre-built version of a compatible
    wheel available and short-circuits the standard full build pipeline.
    """
    def run(self):
246
        if FORCE_BUILD:
Pierce Freeman's avatar
Pierce Freeman committed
247
            return super().run()
248
249
250
251

        raise_if_cuda_home_none("flash_attn")

        # Determine the version numbers that will be used to determine the correct wheel
Tri Dao's avatar
Tri Dao committed
252
253
254
        # We're using the CUDA version used to build torch, not the one currently installed
        # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
        torch_cuda_version = parse(torch.version.cuda)
255
256
257
258
        torch_version_raw = parse(torch.__version__)
        python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
        platform_name = get_platform()
        flash_version = get_package_version()
Tri Dao's avatar
Tri Dao committed
259
260
261
262
        # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
        cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
        torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
        cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
263
264

        # Determine wheel URL based on CUDA version, torch version, python version and OS
Tri Dao's avatar
Tri Dao committed
265
        wheel_filename = f'{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl'
266
267
268
269
270
        wheel_url = BASE_WHEEL_URL.format(
            tag_name=f"v{flash_version}",
            wheel_name=wheel_filename
        )
        print("Guessing wheel URL: ", wheel_url)
271

272
273
        try:
            urllib.request.urlretrieve(wheel_url, wheel_filename)
274
275
276
277
278
279
280
281
282

            # Make the archive
            # Lifted from the root wheel processing command
            # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
            if not os.path.exists(self.dist_dir):
                os.makedirs(self.dist_dir)

            impl_tag, abi_tag, plat_tag = self.get_tag()
            archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
283

284
285
286
            wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
            print("Raw wheel path", wheel_path)
            os.rename(wheel_filename, wheel_path)
287
288
289
        except urllib.error.HTTPError:
            print("Precompiled wheel not found. Building from source...")
            # If the wheel could not be downloaded, build from source
290
            super().run()
291
292


Tri Dao's avatar
Tri Dao committed
293
setup(
294
    name=PACKAGE_NAME,
295
    version=get_package_version(),
Tri Dao's avatar
Tri Dao committed
296
297
298
299
    packages=find_packages(
        exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",)
    ),
    author="Tri Dao",
Tri Dao's avatar
Tri Dao committed
300
    author_email="trid@cs.stanford.edu",
Tri Dao's avatar
Tri Dao committed
301
302
303
    description="Flash Attention: Fast and Memory-Efficient Exact Attention",
    long_description=long_description,
    long_description_content_type="text/markdown",
Tri Dao's avatar
Tri Dao committed
304
    url="https://github.com/Dao-AILab/flash-attention",
Tri Dao's avatar
Tri Dao committed
305
306
    classifiers=[
        "Programming Language :: Python :: 3",
307
        "License :: OSI Approved :: BSD License",
Phil Wang's avatar
Phil Wang committed
308
        "Operating System :: Unix",
Tri Dao's avatar
Tri Dao committed
309
    ],
Tri Dao's avatar
Tri Dao committed
310
    ext_modules=ext_modules,
311
    cmdclass={
312
        'bdist_wheel': CachedWheelsCommand,
313
314
        "build_ext": BuildExtension
    } if ext_modules else {
315
        'bdist_wheel': CachedWheelsCommand,
316
    },
Gustaf's avatar
Gustaf committed
317
318
319
320
    python_requires=">=3.7",
    install_requires=[
        "torch",
        "einops",
Pavel Shvets's avatar
Pavel Shvets committed
321
        "packaging",
322
        "ninja",
Gustaf's avatar
Gustaf committed
323
    ],
Tri Dao's avatar
Tri Dao committed
324
)