Commit 0c04943f authored by Tri Dao's avatar Tri Dao
Browse files

Require CUDA 11.6+, clean up setup.py

parent 798858f9
...@@ -29,7 +29,7 @@ Please cite and credit FlashAttention if you use it. ...@@ -29,7 +29,7 @@ Please cite and credit FlashAttention if you use it.
## Installation and features ## Installation and features
Requirements: Requirements:
- CUDA 11.4 and above. - CUDA 11.6 and above.
- PyTorch 1.12 and above. - PyTorch 1.12 and above.
We recommend the We recommend the
......
...@@ -64,28 +64,12 @@ def get_cuda_bare_metal_version(cuda_dir): ...@@ -64,28 +64,12 @@ def get_cuda_bare_metal_version(cuda_dir):
return raw_output, bare_metal_version return raw_output, bare_metal_version
def check_cuda_torch_binary_vs_bare_metal(cuda_dir): def check_if_cuda_home_none(global_option: str) -> None:
raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir)
torch_binary_version = parse(torch.version.cuda)
print("\nCompiling cuda extensions with")
print(raw_output + "from " + cuda_dir + "/bin\n")
if (bare_metal_version != torch_binary_version):
raise RuntimeError(
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
"Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda)
+ "In some cases, a minor-version mismatch will not cause later errors: "
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
"You can try commenting out this check (at your own risk)."
)
def raise_if_cuda_home_none(global_option: str) -> None:
if CUDA_HOME is not None: if CUDA_HOME is not None:
return return
raise RuntimeError( # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
# in that case.
warnings.warn(
f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
"only images whose names contain 'devel' will provide nvcc." "only images whose names contain 'devel' will provide nvcc."
...@@ -117,16 +101,18 @@ if not SKIP_CUDA_BUILD: ...@@ -117,16 +101,18 @@ if not SKIP_CUDA_BUILD:
if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")):
generator_flag = ["-DOLD_GENERATOR_PATH"] generator_flag = ["-DOLD_GENERATOR_PATH"]
raise_if_cuda_home_none("flash_attn") check_if_cuda_home_none("flash_attn")
# Check, if CUDA11 is installed for compute capability 8.0 # Check, if CUDA11 is installed for compute capability 8.0
cc_flag = [] cc_flag = []
if CUDA_HOME is not None:
_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
if bare_metal_version < Version("11.4"): if bare_metal_version < Version("11.6"):
raise RuntimeError("FlashAttention is only supported on CUDA 11.4 and above") raise RuntimeError("FlashAttention is only supported on CUDA 11.6 and above")
# cc_flag.append("-gencode") # cc_flag.append("-gencode")
# cc_flag.append("arch=compute_75,code=sm_75") # cc_flag.append("arch=compute_75,code=sm_75")
cc_flag.append("-gencode") cc_flag.append("-gencode")
cc_flag.append("arch=compute_80,code=sm_80") cc_flag.append("arch=compute_80,code=sm_80")
if CUDA_HOME is not None:
if bare_metal_version >= Version("11.8"): if bare_metal_version >= Version("11.8"):
cc_flag.append("-gencode") cc_flag.append("-gencode")
cc_flag.append("arch=compute_90,code=sm_90") cc_flag.append("arch=compute_90,code=sm_90")
...@@ -231,17 +217,7 @@ def get_package_version(): ...@@ -231,17 +217,7 @@ def get_package_version():
return str(public_version) return str(public_version)
class CachedWheelsCommand(_bdist_wheel): def get_wheel_url():
"""
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
find an existing wheel (which is currently the case for all flash attention installs). We use
the environment parameters to detect whether there is already a pre-built version of a compatible
wheel available and short-circuits the standard full build pipeline.
"""
def run(self):
if FORCE_BUILD:
return super().run()
# Determine the version numbers that will be used to determine the correct wheel # Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build torch, not the one currently installed # We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
...@@ -261,8 +237,22 @@ class CachedWheelsCommand(_bdist_wheel): ...@@ -261,8 +237,22 @@ class CachedWheelsCommand(_bdist_wheel):
tag_name=f"v{flash_version}", tag_name=f"v{flash_version}",
wheel_name=wheel_filename wheel_name=wheel_filename
) )
print("Guessing wheel URL: ", wheel_url) return wheel_url, wheel_filename
class CachedWheelsCommand(_bdist_wheel):
"""
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
find an existing wheel (which is currently the case for all flash attention installs). We use
the environment parameters to detect whether there is already a pre-built version of a compatible
wheel available and short-circuits the standard full build pipeline.
"""
def run(self):
if FORCE_BUILD:
return super().run()
wheel_url, wheel_filename = get_wheel_url()
print("Guessing wheel URL: ", wheel_url)
try: try:
urllib.request.urlretrieve(wheel_url, wheel_filename) urllib.request.urlretrieve(wheel_url, wheel_filename)
......
...@@ -12,7 +12,7 @@ from flash_attn import ( ...@@ -12,7 +12,7 @@ from flash_attn import (
flash_attn_varlen_kvpacked_func, flash_attn_varlen_kvpacked_func,
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_qkvpacked_func,
) )
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import _get_block_size from flash_attn.flash_attn_interface import _get_block_size
MAX_HEADDIM_SM8x = 192 MAX_HEADDIM_SM8x = 192
...@@ -1376,7 +1376,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): ...@@ -1376,7 +1376,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80]) # @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [128]) # @pytest.mark.parametrize("d", [64])
@pytest.mark.parametrize("swap_sq_sk", [False, True]) @pytest.mark.parametrize("swap_sq_sk", [False, True])
# @pytest.mark.parametrize("swap_sq_sk", [False]) # @pytest.mark.parametrize("swap_sq_sk", [False])
@pytest.mark.parametrize( @pytest.mark.parametrize(
...@@ -1384,6 +1384,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): ...@@ -1384,6 +1384,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
[ [
(3, 1024), (3, 1024),
(1, 339), (1, 339),
(64, 800),
(3, 799), (3, 799),
(64, 2048), (64, 2048),
(16, 20000), (16, 20000),
...@@ -1394,11 +1395,6 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype): ...@@ -1394,11 +1395,6 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
) )
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype): def test_flash_attn_splitkv(seqlen_q, seqlen_k, swap_sq_sk, d, causal, dtype):
if (
max(seqlen_q, seqlen_k) >= 2048
and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30
):
pytest.skip() # Reference implementation OOM
if swap_sq_sk: if swap_sq_sk:
seqlen_q, seqlen_k = seqlen_k, seqlen_q seqlen_q, seqlen_k = seqlen_k, seqlen_q
device = "cuda" device = "cuda"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment