setup.py 17.9 KB
Newer Older
1
2
# Copyright (c) 2023, Tri Dao.

Tri Dao's avatar
Tri Dao committed
3
4
5
import sys
import warnings
import os
6
7
import re
import ast
8
9
import glob
import shutil
Tri Dao's avatar
Tri Dao committed
10
from pathlib import Path
Tri Dao's avatar
Tri Dao committed
11
from packaging.version import parse, Version
12
import platform
Tri Dao's avatar
Tri Dao committed
13
14
15
16

from setuptools import setup, find_packages
import subprocess

Pierce Freeman's avatar
Pierce Freeman committed
17
18
import urllib.request
import urllib.error
Tri Dao's avatar
Tri Dao committed
19
20
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel

Tri Dao's avatar
Tri Dao committed
21
import torch
22
23
24
25
26
from torch.utils.cpp_extension import (
    BuildExtension,
    CppExtension,
    CUDAExtension,
    CUDA_HOME,
27
28
    ROCM_HOME,
    IS_HIP_EXTENSION,
29
)
Tri Dao's avatar
Tri Dao committed
30
31
32
33
34


with open("README.md", "r", encoding="utf-8") as fh:
    long_description = fh.read()

Tri Dao's avatar
Tri Dao committed
35
36
37
38

# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))

39
40
41
42
43
44
45
46
47
48
49
50
51
BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto")

if BUILD_TARGET == "auto":
    if IS_HIP_EXTENSION:
        IS_ROCM = True
    else:
        IS_ROCM = False
else:
    if BUILD_TARGET == "cuda":
        IS_ROCM = False
    elif BUILD_TARGET == "rocm":
        IS_ROCM = True

Woosuk Kwon's avatar
Woosuk Kwon committed
52
PACKAGE_NAME = "vllm_flash_attn"
Tri Dao's avatar
Tri Dao committed
53

54
55
56
BASE_WHEEL_URL = (
    "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}"
)
57
58
59

# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
Woosuk Kwon's avatar
Woosuk Kwon committed
60
61
FORCE_BUILD = True
SKIP_CUDA_BUILD = False
Tri Dao's avatar
Tri Dao committed
62
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
Woosuk Kwon's avatar
Woosuk Kwon committed
63
FORCE_CXX11_ABI = torch._C._GLIBCXX_USE_CXX11_ABI
64
65


66
67
def get_platform():
    """
68
    Returns the platform name as used in wheel filenames.
69
    """
70
    if sys.platform.startswith("linux"):
71
        return f'linux_{platform.uname().machine}'
72
73
74
75
76
    elif sys.platform == "darwin":
        mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
        return f"macosx_{mac_version}_x86_64"
    elif sys.platform == "win32":
        return "win_amd64"
77
    else:
78
        raise ValueError("Unsupported platform: {}".format(sys.platform))
79

Tri Dao's avatar
Tri Dao committed
80
81
82
83
84

def get_cuda_bare_metal_version(cuda_dir):
    raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
    output = raw_output.split()
    release_idx = output.index("release") + 1
Tri Dao's avatar
Tri Dao committed
85
    bare_metal_version = parse(output[release_idx].split(",")[0])
Tri Dao's avatar
Tri Dao committed
86

Tri Dao's avatar
Tri Dao committed
87
    return raw_output, bare_metal_version
Tri Dao's avatar
Tri Dao committed
88
89


90
def check_if_cuda_home_none(global_option: str) -> None:
Tri Dao's avatar
Tri Dao committed
91
92
    if CUDA_HOME is not None:
        return
93
94
95
    # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
    # in that case.
    warnings.warn(
Tri Dao's avatar
Tri Dao committed
96
97
98
99
100
101
        f"{global_option} was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  "
        "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
        "only images whose names contain 'devel' will provide nvcc."
    )


102
103
104
105
106
107
108
109
110
111
def check_if_rocm_home_none(global_option: str) -> None:
    if ROCM_HOME is not None:
        return
    # warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary
    # in that case.
    warnings.warn(
        f"{global_option} was requested, but hipcc was not found."
    )


Tri Dao's avatar
Tri Dao committed
112
def append_nvcc_threads(nvcc_extra_args):
113
114
    nvcc_threads = os.getenv("NVCC_THREADS") or "4"
    return nvcc_extra_args + ["--threads", nvcc_threads]
Tri Dao's avatar
Tri Dao committed
115
116


117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
def rename_cpp_to_cu(cpp_files):
    for entry in cpp_files:
        shutil.copy(entry, os.path.splitext(entry)[0] + ".cu")


def validate_and_update_archs(archs):
    # List of allowed architectures
    allowed_archs = ["native", "gfx90a", "gfx940", "gfx941", "gfx942"]

    # Validate if each element in archs is in allowed_archs
    assert all(
        arch in allowed_archs for arch in archs
    ), f"One of GPU archs of {archs} is invalid or not supported by Flash-Attention"


Tri Dao's avatar
Tri Dao committed
132
133
134
cmdclass = {}
ext_modules = []

Tri Dao's avatar
Tri Dao committed
135
136
# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp
# files included in the source distribution, in case the user compiles from source.
137
138
139
140
if IS_ROCM:
    subprocess.run(["git", "submodule", "update", "--init", "csrc/composable_kernel"])
else:
    subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"])
Tri Dao's avatar
Tri Dao committed
141

142
if not SKIP_CUDA_BUILD and not IS_ROCM:
143
144
145
146
147
148
149
150
151
152
153
    print("\n\ntorch.__version__  = {}\n\n".format(torch.__version__))
    TORCH_MAJOR = int(torch.__version__.split(".")[0])
    TORCH_MINOR = int(torch.__version__.split(".")[1])

    # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
    # See https://github.com/pytorch/pytorch/pull/70650
    generator_flag = []
    torch_dir = torch.__path__[0]
    if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
        generator_flag = ["-DOLD_GENERATOR_PATH"]

Woosuk Kwon's avatar
Woosuk Kwon committed
154
    check_if_cuda_home_none(PACKAGE_NAME)
155
156
    # Check, if CUDA11 is installed for compute capability 8.0
    cc_flag = []
157
158
159
    if CUDA_HOME is not None:
        _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
        if bare_metal_version < Version("11.6"):
160
161
162
163
            raise RuntimeError(
                "FlashAttention is only supported on CUDA 11.6 and above.  "
                "Note: make sure nvcc has a supported version by running nvcc -V."
            )
164
165
    # cc_flag.append("-gencode")
    # cc_flag.append("arch=compute_75,code=sm_75")
Tri Dao's avatar
Tri Dao committed
166
    cc_flag.append("-gencode")
167
    cc_flag.append("arch=compute_80,code=sm_80")
Tri Dao's avatar
Tri Dao committed
168
169
170
171
    if CUDA_HOME is not None:
        if bare_metal_version >= Version("11.8"):
            cc_flag.append("-gencode")
            cc_flag.append("arch=compute_90,code=sm_90")
172

Tri Dao's avatar
Tri Dao committed
173
174
175
176
177
    # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
    # torch._C._GLIBCXX_USE_CXX11_ABI
    # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
    if FORCE_CXX11_ABI:
        torch._C._GLIBCXX_USE_CXX11_ABI = True
178
179
    ext_modules.append(
        CUDAExtension(
Woosuk Kwon's avatar
Woosuk Kwon committed
180
            name="vllm_flash_attn_2_cuda",
181
            sources=[
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
                "csrc/flash_attn/flash_api.cpp",
                "csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu",
197
198
199
200
201
202
203
204
205
206
207
208
209
210
                "csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu",
Tri Dao's avatar
Tri Dao committed
211
212
213
214
215
216
217
218
219
220
221
222
223
224
                "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu",
225
226
227
228
229
230
231
232
233
234
235
236
237
238
                "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu",
                "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu",
239
240
241
242
243
244
245
246
247
248
249
250
251
252
            ],
            extra_compile_args={
                "cxx": ["-O3", "-std=c++17"] + generator_flag,
                "nvcc": append_nvcc_threads(
                    [
                        "-O3",
                        "-std=c++17",
                        "-U__CUDA_NO_HALF_OPERATORS__",
                        "-U__CUDA_NO_HALF_CONVERSIONS__",
                        "-U__CUDA_NO_HALF2_OPERATORS__",
                        "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
                        "--expt-relaxed-constexpr",
                        "--expt-extended-lambda",
                        "--use_fast_math",
Tri Dao's avatar
Tri Dao committed
253
                        # "--ptxas-options=-v",
254
                        # "--ptxas-options=-O2",
255
                        # "-lineinfo",
256
                        # "-DFLASHATTENTION_DISABLE_BACKWARD",
Woosuk Kwon's avatar
Woosuk Kwon committed
257
                        "-DFLASHATTENTION_DISABLE_DROPOUT",
258
                        # "-DFLASHATTENTION_DISABLE_ALIBI",
Nicolas Patry's avatar
Nicolas Patry committed
259
                        # "-DFLASHATTENTION_DISABLE_SOFTCAP",
Woosuk Kwon's avatar
Woosuk Kwon committed
260
                        "-DFLASHATTENTION_DISABLE_UNEVEN_K",
261
                        # "-DFLASHATTENTION_DISABLE_LOCAL",
262
263
264
265
266
267
                    ]
                    + generator_flag
                    + cc_flag
                ),
            },
            include_dirs=[
268
269
270
                Path(this_dir) / "csrc" / "flash_attn",
                Path(this_dir) / "csrc" / "flash_attn" / "src",
                Path(this_dir) / "csrc" / "cutlass" / "include",
271
272
            ],
        )
Tri Dao's avatar
Tri Dao committed
273
    )
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
elif not SKIP_CUDA_BUILD and IS_ROCM:
    ck_dir = "csrc/composable_kernel"

    #use codegen get code dispatch
    if not os.path.exists("./build"):
        os.makedirs("build")

    os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d fwd --output_dir build --receipt 2")
    os.system(f"{sys.executable} {ck_dir}/example/ck_tile/01_fmha/generate.py -d bwd --output_dir build --receipt 2")

    print("\n\ntorch.__version__  = {}\n\n".format(torch.__version__))
    TORCH_MAJOR = int(torch.__version__.split(".")[0])
    TORCH_MINOR = int(torch.__version__.split(".")[1])

    # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h
    # See https://github.com/pytorch/pytorch/pull/70650
    generator_flag = []
    torch_dir = torch.__path__[0]
    if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
        generator_flag = ["-DOLD_GENERATOR_PATH"]

    check_if_rocm_home_none("flash_attn")
    cc_flag = []

    archs = os.getenv("GPU_ARCHS", "native").split(";")
    validate_and_update_archs(archs)

    cc_flag = [f"--offload-arch={arch}" for arch in archs]

    # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
    # torch._C._GLIBCXX_USE_CXX11_ABI
    # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
    if FORCE_CXX11_ABI:
        torch._C._GLIBCXX_USE_CXX11_ABI = True

    sources = ["csrc/flash_attn_ck/flash_api.cpp",
               "csrc/flash_attn_ck/mha_bwd.cpp",
               "csrc/flash_attn_ck/mha_fwd.cpp",
               "csrc/flash_attn_ck/mha_varlen_bwd.cpp",
               "csrc/flash_attn_ck/mha_varlen_fwd.cpp"] + glob.glob(
        f"build/fmha_*wd*.cpp"
    )

    rename_cpp_to_cu(sources)

    renamed_sources = ["csrc/flash_attn_ck/flash_api.cu",
                       "csrc/flash_attn_ck/mha_bwd.cu",
                       "csrc/flash_attn_ck/mha_fwd.cu",
                       "csrc/flash_attn_ck/mha_varlen_bwd.cu",
                       "csrc/flash_attn_ck/mha_varlen_fwd.cu"] + glob.glob(f"build/fmha_*wd*.cu")
    extra_compile_args = {
        "cxx": ["-O3", "-std=c++17"] + generator_flag,
        "nvcc":
            [
                "-O3","-std=c++17",
                "-mllvm", "-enable-post-misched=0",
                "-DCK_TILE_FMHA_FWD_FAST_EXP2=1",
                "-fgpu-flush-denormals-to-zero",
                "-DCK_ENABLE_BF16",
                "-DCK_ENABLE_BF8",
                "-DCK_ENABLE_FP16",
                "-DCK_ENABLE_FP32",
                "-DCK_ENABLE_FP64",
                "-DCK_ENABLE_FP8",
                "-DCK_ENABLE_INT8",
                "-DCK_USE_XDL",
                "-DUSE_PROF_API=1",
                "-D__HIP_PLATFORM_HCC__=1",
                # "-DFLASHATTENTION_DISABLE_BACKWARD",
            ]
            + generator_flag
            + cc_flag
        ,
    }

    include_dirs = [
        Path(this_dir) / "csrc" / "composable_kernel" / "include",
        Path(this_dir) / "csrc" / "composable_kernel" / "library" / "include",
        Path(this_dir) / "csrc" / "composable_kernel" / "example" / "ck_tile" / "01_fmha",
    ]

    ext_modules.append(
        CUDAExtension(
            name="flash_attn_2_cuda",
            sources=renamed_sources,
            extra_compile_args=extra_compile_args,
            include_dirs=include_dirs,
        )
    )
Tri Dao's avatar
Tri Dao committed
363

Tri Dao's avatar
Tri Dao committed
364

365
def get_package_version():
Woosuk Kwon's avatar
Woosuk Kwon committed
366
    with open(Path(this_dir) / PACKAGE_NAME / "__init__.py", "r") as f:
367
368
369
370
371
372
373
374
        version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
    public_version = ast.literal_eval(version_match.group(1))
    local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION")
    if local_version:
        return f"{public_version}+{local_version}"
    else:
        return str(public_version)

Tri Dao's avatar
Tri Dao committed
375

376
class CachedWheelsCommand(_bdist_wheel):
Tri Dao's avatar
Tri Dao committed
377
378
379
380
381
382
    """
    The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
    find an existing wheel (which is currently the case for all flash attention installs). We use
    the environment parameters to detect whether there is already a pre-built version of a compatible
    wheel available and short-circuits the standard full build pipeline.
    """
383

Tri Dao's avatar
Tri Dao committed
384
    def run(self):
385
        if FORCE_BUILD:
Pierce Freeman's avatar
Pierce Freeman committed
386
            return super().run()
387
388


389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
class NinjaBuildExtension(BuildExtension):
    def __init__(self, *args, **kwargs) -> None:
        # do not override env MAX_JOBS if already exists
        if not os.environ.get("MAX_JOBS"):
            import psutil

            # calculate the maximum allowed NUM_JOBS based on cores
            max_num_jobs_cores = max(1, os.cpu_count() // 2)

            # calculate the maximum allowed NUM_JOBS based on free memory
            free_memory_gb = psutil.virtual_memory().available / (1024 ** 3)  # free memory in GB
            max_num_jobs_memory = int(free_memory_gb / 9)  # each JOB peak memory cost is ~8-9GB when threads = 4

            # pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation
            max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory))
            os.environ["MAX_JOBS"] = str(max_jobs)

        super().__init__(*args, **kwargs)


Sage Moore's avatar
Sage Moore committed
409
PYTORCH_VERSION = "2.4.0"
Woosuk Kwon's avatar
Woosuk Kwon committed
410
411
CUDA_VERSION = "12.1"

Tri Dao's avatar
Tri Dao committed
412
setup(
Woosuk Kwon's avatar
Woosuk Kwon committed
413
    name="vllm-flash-attn",
414
    version=get_package_version(),
Tri Dao's avatar
Tri Dao committed
415
    packages=find_packages(
416
417
418
419
420
421
422
423
        exclude=(
            "build",
            "csrc",
            "include",
            "tests",
            "dist",
            "docs",
            "benchmarks",
Woosuk Kwon's avatar
Woosuk Kwon committed
424
            f"{PACKAGE_NAME}.egg-info",
425
        )
Tri Dao's avatar
Tri Dao committed
426
    ),
Woosuk Kwon's avatar
Woosuk Kwon committed
427
428
    author="vLLM Team",
    description="Forward-only flash-attn",
Woosuk Kwon's avatar
Woosuk Kwon committed
429
    long_description=f"Forward-only flash-attn package built for PyTorch {PYTORCH_VERSION} and CUDA {CUDA_VERSION}",
Woosuk Kwon's avatar
Woosuk Kwon committed
430
    url="https://github.com/vllm-project/flash-attention.git",
Tri Dao's avatar
Tri Dao committed
431
432
    classifiers=[
        "Programming Language :: Python :: 3",
433
        "License :: OSI Approved :: BSD License",
Phil Wang's avatar
Phil Wang committed
434
        "Operating System :: Unix",
Tri Dao's avatar
Tri Dao committed
435
    ],
Tri Dao's avatar
Tri Dao committed
436
    ext_modules=ext_modules,
437
    cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": NinjaBuildExtension}
438
439
440
    if ext_modules
    else {
        "bdist_wheel": CachedWheelsCommand,
441
    },
442
    python_requires=">=3.8",
Woosuk Kwon's avatar
Woosuk Kwon committed
443
    install_requires=[f"torch == {PYTORCH_VERSION}"],
Woosuk Kwon's avatar
Woosuk Kwon committed
444
    setup_requires=["psutil"],
445
)