Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
0c04943f
Commit
0c04943f
authored
Sep 03, 2023
by
Tri Dao
Browse files
Require CUDA 11.6+, clean up setup.py
parent
798858f9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
55 deletions
+41
-55
README.md
README.md
+1
-1
setup.py
setup.py
+37
-47
tests/test_flash_attn.py
tests/test_flash_attn.py
+3
-7
No files found.
README.md
View file @
0c04943f
...
...
@@ -29,7 +29,7 @@ Please cite and credit FlashAttention if you use it.
## Installation and features
Requirements:
-
CUDA 11.
4
and above.
-
CUDA 11.
6
and above.
-
PyTorch 1.12 and above.
We recommend the
...
...
setup.py
View file @
0c04943f
...
...
@@ -64,28 +64,12 @@ def get_cuda_bare_metal_version(cuda_dir):
return
raw_output
,
bare_metal_version
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
raw_output
,
bare_metal_version
=
get_cuda_bare_metal_version
(
cuda_dir
)
torch_binary_version
=
parse
(
torch
.
version
.
cuda
)
print
(
"
\n
Compiling cuda extensions with"
)
print
(
raw_output
+
"from "
+
cuda_dir
+
"/bin
\n
"
)
if
(
bare_metal_version
!=
torch_binary_version
):
raise
RuntimeError
(
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
"Pytorch binaries were compiled with Cuda {}.
\n
"
.
format
(
torch
.
version
.
cuda
)
+
"In some cases, a minor-version mismatch will not cause later errors: "
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
"You can try commenting out this check (at your own risk)."
)
def
raise_if_cuda_home_none
(
global_option
:
str
)
->
None
:
def
check_if_cuda_home_none
(
global_option
:
str
)
->
None
:
if
CUDA_HOME
is
not
None
:
return
raise
RuntimeError
(
# warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
# in that case.
warnings
.
warn
(
f
"
{
global_option
}
was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
"only images whose names contain 'devel' will provide nvcc."
...
...
@@ -117,19 +101,21 @@ if not SKIP_CUDA_BUILD:
if
os
.
path
.
exists
(
os
.
path
.
join
(
torch_dir
,
"include"
,
"ATen"
,
"CUDAGeneratorImpl.h"
)):
generator_flag
=
[
"-DOLD_GENERATOR_PATH"
]
raise
_if_cuda_home_none
(
"flash_attn"
)
check
_if_cuda_home_none
(
"flash_attn"
)
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
_
,
bare_metal_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
bare_metal_version
<
Version
(
"11.4"
):
raise
RuntimeError
(
"FlashAttention is only supported on CUDA 11.4 and above"
)
if
CUDA_HOME
is
not
None
:
_
,
bare_metal_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
bare_metal_version
<
Version
(
"11.6"
):
raise
RuntimeError
(
"FlashAttention is only supported on CUDA 11.6 and above"
)
# cc_flag.append("-gencode")
# cc_flag.append("arch=compute_75,code=sm_75")
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_80,code=sm_80"
)
if
bare_metal_version
>=
Version
(
"11.8"
):
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_90,code=sm_90"
)
if
CUDA_HOME
is
not
None
:
if
bare_metal_version
>=
Version
(
"11.8"
):
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_90,code=sm_90"
)
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
# torch._C._GLIBCXX_USE_CXX11_ABI
...
...
@@ -231,6 +217,29 @@ def get_package_version():
return
str
(
public_version
)
def
get_wheel_url
():
# Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version
=
parse
(
torch
.
version
.
cuda
)
torch_version_raw
=
parse
(
torch
.
__version__
)
python_version
=
f
"cp
{
sys
.
version_info
.
major
}{
sys
.
version_info
.
minor
}
"
platform_name
=
get_platform
()
flash_version
=
get_package_version
()
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
cuda_version
=
f
"
{
torch_cuda_version
.
major
}{
torch_cuda_version
.
minor
}
"
torch_version
=
f
"
{
torch_version_raw
.
major
}
.
{
torch_version_raw
.
minor
}
"
cxx11_abi
=
str
(
torch
.
_C
.
_GLIBCXX_USE_CXX11_ABI
).
upper
()
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename
=
f
'
{
PACKAGE_NAME
}
-
{
flash_version
}
+cu
{
cuda_version
}
torch
{
torch_version
}
cxx11abi
{
cxx11_abi
}
-
{
python_version
}
-
{
python_version
}
-
{
platform_name
}
.whl'
wheel_url
=
BASE_WHEEL_URL
.
format
(
tag_name
=
f
"v
{
flash_version
}
"
,
wheel_name
=
wheel_filename
)
return
wheel_url
,
wheel_filename
class
CachedWheelsCommand
(
_bdist_wheel
):
"""
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
...
...
@@ -242,27 +251,8 @@ class CachedWheelsCommand(_bdist_wheel):
if
FORCE_BUILD
:
return
super
().
run
()
# Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version
=
parse
(
torch
.
version
.
cuda
)
torch_version_raw
=
parse
(
torch
.
__version__
)
python_version
=
f
"cp
{
sys
.
version_info
.
major
}{
sys
.
version_info
.
minor
}
"
platform_name
=
get_platform
()
flash_version
=
get_package_version
()
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
cuda_version
=
f
"
{
torch_cuda_version
.
major
}{
torch_cuda_version
.
minor
}
"
torch_version
=
f
"
{
torch_version_raw
.
major
}
.
{
torch_version_raw
.
minor
}
"
cxx11_abi
=
str
(
torch
.
_C
.
_GLIBCXX_USE_CXX11_ABI
).
upper
()
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename
=
f
'
{
PACKAGE_NAME
}
-
{
flash_version
}
+cu
{
cuda_version
}
torch
{
torch_version
}
cxx11abi
{
cxx11_abi
}
-
{
python_version
}
-
{
python_version
}
-
{
platform_name
}
.whl'
wheel_url
=
BASE_WHEEL_URL
.
format
(
tag_name
=
f
"v
{
flash_version
}
"
,
wheel_name
=
wheel_filename
)
wheel_url
,
wheel_filename
=
get_wheel_url
()
print
(
"Guessing wheel URL: "
,
wheel_url
)
try
:
urllib
.
request
.
urlretrieve
(
wheel_url
,
wheel_filename
)
...
...
tests/test_flash_attn.py
View file @
0c04943f
...
...
@@ -12,7 +12,7 @@ from flash_attn import (
flash_attn_varlen_kvpacked_func
,
flash_attn_varlen_qkvpacked_func
,
)
from
flash_attn.bert_padding
import
index_first_axis
,
pad_input
,
unpad_input
from
flash_attn.bert_padding
import
pad_input
,
unpad_input
from
flash_attn.flash_attn_interface
import
_get_block_size
MAX_HEADDIM_SM8x
=
192
...
...
@@ -1376,7 +1376,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [
128
])
# @pytest.mark.parametrize("d", [
64
])
@
pytest
.
mark
.
parametrize
(
"swap_sq_sk"
,
[
False
,
True
])
# @pytest.mark.parametrize("swap_sq_sk", [False])
@
pytest
.
mark
.
parametrize
(
...
...
@@ -1384,6 +1384,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
[
(
3
,
1024
),
(
1
,
339
),
(
64
,
800
),
(
3
,
799
),
(
64
,
2048
),
(
16
,
20000
),
...
...
@@ -1394,11 +1395,6 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def
test_flash_attn_splitkv
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
causal
,
dtype
):
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
):
pytest
.
skip
()
# Reference implementation OOM
if
swap_sq_sk
:
seqlen_q
,
seqlen_k
=
seqlen_k
,
seqlen_q
device
=
"cuda"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment