Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
0c04943f
Commit
0c04943f
authored
Sep 03, 2023
by
Tri Dao
Browse files
Require CUDA 11.6+, clean up setup.py
parent
798858f9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
55 deletions
+41
-55
README.md
README.md
+1
-1
setup.py
setup.py
+37
-47
tests/test_flash_attn.py
tests/test_flash_attn.py
+3
-7
No files found.
README.md
View file @
0c04943f
...
@@ -29,7 +29,7 @@ Please cite and credit FlashAttention if you use it.
...
@@ -29,7 +29,7 @@ Please cite and credit FlashAttention if you use it.
## Installation and features
## Installation and features
Requirements:
Requirements:
-
CUDA 11.
4
and above.
-
CUDA 11.
6
and above.
-
PyTorch 1.12 and above.
-
PyTorch 1.12 and above.
We recommend the
We recommend the
...
...
setup.py
View file @
0c04943f
...
@@ -64,28 +64,12 @@ def get_cuda_bare_metal_version(cuda_dir):
...
@@ -64,28 +64,12 @@ def get_cuda_bare_metal_version(cuda_dir):
return
raw_output
,
bare_metal_version
return
raw_output
,
bare_metal_version
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
def
check_if_cuda_home_none
(
global_option
:
str
)
->
None
:
raw_output
,
bare_metal_version
=
get_cuda_bare_metal_version
(
cuda_dir
)
torch_binary_version
=
parse
(
torch
.
version
.
cuda
)
print
(
"
\n
Compiling cuda extensions with"
)
print
(
raw_output
+
"from "
+
cuda_dir
+
"/bin
\n
"
)
if
(
bare_metal_version
!=
torch_binary_version
):
raise
RuntimeError
(
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
"Pytorch binaries were compiled with Cuda {}.
\n
"
.
format
(
torch
.
version
.
cuda
)
+
"In some cases, a minor-version mismatch will not cause later errors: "
"https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. "
"You can try commenting out this check (at your own risk)."
)
def
raise_if_cuda_home_none
(
global_option
:
str
)
->
None
:
if
CUDA_HOME
is
not
None
:
if
CUDA_HOME
is
not
None
:
return
return
raise
RuntimeError
(
# warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
# in that case.
warnings
.
warn
(
f
"
{
global_option
}
was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
f
"
{
global_option
}
was requested, but nvcc was not found. Are you sure your environment has nvcc available? "
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
"If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
"only images whose names contain 'devel' will provide nvcc."
"only images whose names contain 'devel' will provide nvcc."
...
@@ -117,19 +101,21 @@ if not SKIP_CUDA_BUILD:
...
@@ -117,19 +101,21 @@ if not SKIP_CUDA_BUILD:
if
os
.
path
.
exists
(
os
.
path
.
join
(
torch_dir
,
"include"
,
"ATen"
,
"CUDAGeneratorImpl.h"
)):
if
os
.
path
.
exists
(
os
.
path
.
join
(
torch_dir
,
"include"
,
"ATen"
,
"CUDAGeneratorImpl.h"
)):
generator_flag
=
[
"-DOLD_GENERATOR_PATH"
]
generator_flag
=
[
"-DOLD_GENERATOR_PATH"
]
raise
_if_cuda_home_none
(
"flash_attn"
)
check
_if_cuda_home_none
(
"flash_attn"
)
# Check, if CUDA11 is installed for compute capability 8.0
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
cc_flag
=
[]
_
,
bare_metal_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
CUDA_HOME
is
not
None
:
if
bare_metal_version
<
Version
(
"11.4"
):
_
,
bare_metal_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
raise
RuntimeError
(
"FlashAttention is only supported on CUDA 11.4 and above"
)
if
bare_metal_version
<
Version
(
"11.6"
):
raise
RuntimeError
(
"FlashAttention is only supported on CUDA 11.6 and above"
)
# cc_flag.append("-gencode")
# cc_flag.append("-gencode")
# cc_flag.append("arch=compute_75,code=sm_75")
# cc_flag.append("arch=compute_75,code=sm_75")
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_80,code=sm_80"
)
cc_flag
.
append
(
"arch=compute_80,code=sm_80"
)
if
bare_metal_version
>=
Version
(
"11.8"
):
if
CUDA_HOME
is
not
None
:
cc_flag
.
append
(
"-gencode"
)
if
bare_metal_version
>=
Version
(
"11.8"
):
cc_flag
.
append
(
"arch=compute_90,code=sm_90"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_90,code=sm_90"
)
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
# torch._C._GLIBCXX_USE_CXX11_ABI
# torch._C._GLIBCXX_USE_CXX11_ABI
...
@@ -231,6 +217,29 @@ def get_package_version():
...
@@ -231,6 +217,29 @@ def get_package_version():
return
str
(
public_version
)
return
str
(
public_version
)
def
get_wheel_url
():
# Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version
=
parse
(
torch
.
version
.
cuda
)
torch_version_raw
=
parse
(
torch
.
__version__
)
python_version
=
f
"cp
{
sys
.
version_info
.
major
}{
sys
.
version_info
.
minor
}
"
platform_name
=
get_platform
()
flash_version
=
get_package_version
()
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
cuda_version
=
f
"
{
torch_cuda_version
.
major
}{
torch_cuda_version
.
minor
}
"
torch_version
=
f
"
{
torch_version_raw
.
major
}
.
{
torch_version_raw
.
minor
}
"
cxx11_abi
=
str
(
torch
.
_C
.
_GLIBCXX_USE_CXX11_ABI
).
upper
()
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename
=
f
'
{
PACKAGE_NAME
}
-
{
flash_version
}
+cu
{
cuda_version
}
torch
{
torch_version
}
cxx11abi
{
cxx11_abi
}
-
{
python_version
}
-
{
python_version
}
-
{
platform_name
}
.whl'
wheel_url
=
BASE_WHEEL_URL
.
format
(
tag_name
=
f
"v
{
flash_version
}
"
,
wheel_name
=
wheel_filename
)
return
wheel_url
,
wheel_filename
class
CachedWheelsCommand
(
_bdist_wheel
):
class
CachedWheelsCommand
(
_bdist_wheel
):
"""
"""
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
...
@@ -242,27 +251,8 @@ class CachedWheelsCommand(_bdist_wheel):
...
@@ -242,27 +251,8 @@ class CachedWheelsCommand(_bdist_wheel):
if
FORCE_BUILD
:
if
FORCE_BUILD
:
return
super
().
run
()
return
super
().
run
()
# Determine the version numbers that will be used to determine the correct wheel
wheel_url
,
wheel_filename
=
get_wheel_url
()
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version
=
parse
(
torch
.
version
.
cuda
)
torch_version_raw
=
parse
(
torch
.
__version__
)
python_version
=
f
"cp
{
sys
.
version_info
.
major
}{
sys
.
version_info
.
minor
}
"
platform_name
=
get_platform
()
flash_version
=
get_package_version
()
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
cuda_version
=
f
"
{
torch_cuda_version
.
major
}{
torch_cuda_version
.
minor
}
"
torch_version
=
f
"
{
torch_version_raw
.
major
}
.
{
torch_version_raw
.
minor
}
"
cxx11_abi
=
str
(
torch
.
_C
.
_GLIBCXX_USE_CXX11_ABI
).
upper
()
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename
=
f
'
{
PACKAGE_NAME
}
-
{
flash_version
}
+cu
{
cuda_version
}
torch
{
torch_version
}
cxx11abi
{
cxx11_abi
}
-
{
python_version
}
-
{
python_version
}
-
{
platform_name
}
.whl'
wheel_url
=
BASE_WHEEL_URL
.
format
(
tag_name
=
f
"v
{
flash_version
}
"
,
wheel_name
=
wheel_filename
)
print
(
"Guessing wheel URL: "
,
wheel_url
)
print
(
"Guessing wheel URL: "
,
wheel_url
)
try
:
try
:
urllib
.
request
.
urlretrieve
(
wheel_url
,
wheel_filename
)
urllib
.
request
.
urlretrieve
(
wheel_url
,
wheel_filename
)
...
...
tests/test_flash_attn.py
View file @
0c04943f
...
@@ -12,7 +12,7 @@ from flash_attn import (
...
@@ -12,7 +12,7 @@ from flash_attn import (
flash_attn_varlen_kvpacked_func
,
flash_attn_varlen_kvpacked_func
,
flash_attn_varlen_qkvpacked_func
,
flash_attn_varlen_qkvpacked_func
,
)
)
from
flash_attn.bert_padding
import
index_first_axis
,
pad_input
,
unpad_input
from
flash_attn.bert_padding
import
pad_input
,
unpad_input
from
flash_attn.flash_attn_interface
import
_get_block_size
from
flash_attn.flash_attn_interface
import
_get_block_size
MAX_HEADDIM_SM8x
=
192
MAX_HEADDIM_SM8x
=
192
...
@@ -1376,7 +1376,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
...
@@ -1376,7 +1376,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [
128
])
# @pytest.mark.parametrize("d", [
64
])
@
pytest
.
mark
.
parametrize
(
"swap_sq_sk"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"swap_sq_sk"
,
[
False
,
True
])
# @pytest.mark.parametrize("swap_sq_sk", [False])
# @pytest.mark.parametrize("swap_sq_sk", [False])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
@@ -1384,6 +1384,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
...
@@ -1384,6 +1384,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
[
[
(
3
,
1024
),
(
3
,
1024
),
(
1
,
339
),
(
1
,
339
),
(
64
,
800
),
(
3
,
799
),
(
3
,
799
),
(
64
,
2048
),
(
64
,
2048
),
(
16
,
20000
),
(
16
,
20000
),
...
@@ -1394,11 +1395,6 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
...
@@ -1394,11 +1395,6 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, dtype):
)
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
def
test_flash_attn_splitkv
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
causal
,
dtype
):
def
test_flash_attn_splitkv
(
seqlen_q
,
seqlen_k
,
swap_sq_sk
,
d
,
causal
,
dtype
):
if
(
max
(
seqlen_q
,
seqlen_k
)
>=
2048
and
torch
.
cuda
.
get_device_properties
(
"cuda"
).
total_memory
<=
16
*
2
**
30
):
pytest
.
skip
()
# Reference implementation OOM
if
swap_sq_sk
:
if
swap_sq_sk
:
seqlen_q
,
seqlen_k
=
seqlen_k
,
seqlen_q
seqlen_q
,
seqlen_k
=
seqlen_k
,
seqlen_q
device
=
"cuda"
device
=
"cuda"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment