Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
1b18f1b7
Commit
1b18f1b7
authored
Mar 15, 2023
by
Tri Dao
Browse files
Support H100
parent
318e2f1b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
50 additions
and
38 deletions
+50
-38
README.md
README.md
+9
-7
csrc/flash_attn/fmha_api.cpp
csrc/flash_attn/fmha_api.cpp
+14
-8
csrc/flash_attn/src/fmha_bwd_hdim64.cu
csrc/flash_attn/src/fmha_bwd_hdim64.cu
+2
-2
setup.py
setup.py
+25
-21
No files found.
README.md
View file @
1b18f1b7
...
...
@@ -62,9 +62,10 @@ PYTHONPATH=$PWD python benchmarks/benchmark_flash_attention.py
```
FlashAttention currently supports:
1.
Turing or Ampere GPUs (e.g., A100, RTX 3090, T4, RTX 2080).
2.
fp16 and bf16 (bf16 requires Ampere GPUs).
3.
Head dimensions that are multiples of 8, up to 128 (e.g., 8, 16, 24, ..., 128). Head dim > 64 backward requires A100.
1.
Turing, Ampere, Ada, or Hopper GPUs (e.g., H100, A100, RTX 3090, T4, RTX 2080).
2.
fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
3.
Head dimensions that are multiples of 8, up to 128 (e.g., 8, 16, 24, ...,
128). Head dim > 64 backward requires A100 or H100.
Our tentative roadmap:
1.
~~[Jun 2022] Make package pip-installable~~[Done, thanks to lucidrains].
...
...
@@ -74,10 +75,11 @@ Our tentative roadmap:
5.
~~[Jul 2022] Implement cross-attention~~[Done].
6.
~~[Jul 2022] Support head dimension 128~~[Done].
7.
~~[Aug 2022] Fuse rotary embedding~~[Done].
8.
[Apr 2023] Refactor to use Cutlass 3.x.
9.
[May 2023] Support attention bias (e.g. ALiBi, relative positional encoding).
10.
[Jun 2023] Support SM70 GPUs (V100).
11.
[Jun 2023] Support SM90 GPUs (H100).
8.
~~[Mar 2023] Support SM90 GPUs (H100)~~[Done].
9.
[Apr 2023] Refactor to use Cutlass 3.x.
10.
[May 2023] Support attention bias (e.g. ALiBi, relative positional encoding).
11.
[Jun 2023] Support SM70 GPUs (V100).
12.
[Jun 2023] Support fp8 (H100).
## How to use FlashAttention
...
...
csrc/flash_attn/fmha_api.cpp
View file @
1b18f1b7
...
...
@@ -207,13 +207,14 @@ mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
bool
is_sm75
=
dprops
->
major
==
7
&&
dprops
->
minor
==
5
;
bool
is_sm80
=
dprops
->
major
==
8
&&
dprops
->
minor
==
0
;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
;
TORCH_CHECK
(
is_sm8x
||
is_sm75
);
bool
is_sm90
=
dprops
->
major
==
9
&&
dprops
->
minor
==
0
;
TORCH_CHECK
(
is_sm90
||
is_sm8x
||
is_sm75
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
bool
is_dropout
=
p_dropout
>
0.0
;
Launch_params
<
FMHA_fprop_params
>
launch_params
(
dprops
,
stream
,
is_dropout
,
return_softmax
);
auto
q_dtype
=
q
.
dtype
();
TORCH_CHECK
(
q_dtype
==
torch
::
kFloat16
||
(
is_sm8x
&&
q_dtype
==
torch
::
kBFloat16
));
TORCH_CHECK
(
q_dtype
==
torch
::
kFloat16
||
(
(
is_sm8x
||
is_sm90
)
&&
q_dtype
==
torch
::
kBFloat16
));
TORCH_CHECK
(
k
.
dtype
()
==
q_dtype
);
TORCH_CHECK
(
v
.
dtype
()
==
q_dtype
);
TORCH_CHECK
(
out
.
dtype
()
==
q_dtype
);
...
...
@@ -358,14 +359,15 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
bool
is_sm75
=
dprops
->
major
==
7
&&
dprops
->
minor
==
5
;
bool
is_sm80
=
dprops
->
major
==
8
&&
dprops
->
minor
==
0
;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
;
TORCH_CHECK
(
is_sm8x
||
is_sm75
);
bool
is_sm90
=
dprops
->
major
==
9
&&
dprops
->
minor
==
0
;
TORCH_CHECK
(
is_sm90
||
is_sm8x
||
is_sm75
);
auto
launch
=
&
run_fmha_bwd
;
bool
is_dropout
=
p_dropout
>
0.0
;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
q_dtype
=
q
.
dtype
();
TORCH_CHECK
(
q_dtype
==
torch
::
kFloat16
||
(
is_sm8x
&&
q_dtype
==
torch
::
kBFloat16
));
TORCH_CHECK
(
q_dtype
==
torch
::
kFloat16
||
(
(
is_sm8x
||
is_sm90
)
&&
q_dtype
==
torch
::
kBFloat16
));
TORCH_CHECK
(
k
.
dtype
()
==
q_dtype
);
TORCH_CHECK
(
v
.
dtype
()
==
q_dtype
);
TORCH_CHECK
(
out
.
dtype
()
==
q_dtype
);
...
...
@@ -406,7 +408,7 @@ mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
TORCH_CHECK
(
batch_size
>
0
);
TORCH_CHECK
((
head_size
%
8
==
0
)
&&
(
head_size
<=
128
));
if
(
head_size
>
64
)
{
// TODO: eventually we should support SM86 and SM70 with d=128 as well
TORCH_CHECK
(
is_sm80
);
TORCH_CHECK
(
is_sm80
||
is_sm90
);
}
CHECK_SHAPE
(
q
,
total_q
,
num_heads
,
head_size
);
...
...
@@ -518,7 +520,10 @@ mha_fwd_block(const at::Tensor &q, // total_q x num_heads x head_size, t
c10
::
optional
<
at
::
Generator
>
gen_
)
{
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
TORCH_CHECK
(
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
);
bool
is_sm80
=
dprops
->
major
==
8
&&
dprops
->
minor
==
0
;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
;
bool
is_sm90
=
dprops
->
major
==
9
&&
dprops
->
minor
==
0
;
TORCH_CHECK
(
is_sm8x
||
is_sm90
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
bool
is_dropout
=
p_dropout
>
0.0
;
Launch_params
<
FMHA_fprop_params
>
launch_params
(
dprops
,
stream
,
is_dropout
,
return_softmax
);
...
...
@@ -648,7 +653,8 @@ mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size
auto
dprops
=
at
::
cuda
::
getCurrentDeviceProperties
();
bool
is_sm80
=
dprops
->
major
==
8
&&
dprops
->
minor
==
0
;
bool
is_sm8x
=
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
;
TORCH_CHECK
(
dprops
->
major
==
8
&&
dprops
->
minor
>=
0
);
bool
is_sm90
=
dprops
->
major
==
9
&&
dprops
->
minor
==
0
;
TORCH_CHECK
(
is_sm8x
||
is_sm90
);
auto
launch
=
&
run_fmha_block_dgrad_fp16_sm80
;
bool
is_dropout
=
p_dropout
>
0.0
;
...
...
@@ -698,7 +704,7 @@ mha_bwd_block(const at::Tensor &dout, // total x num_heads, x head_size
TORCH_CHECK
(
batch_size
>
0
);
TORCH_CHECK
(
head_size
==
16
||
head_size
==
32
||
head_size
==
64
||
head_size
==
128
);
if
(
head_size
==
128
)
{
// TODO: eventually we should support SM86 and SM70 with d=128 as well
TORCH_CHECK
(
is_sm80
);
TORCH_CHECK
(
is_sm80
||
is_sm90
);
}
CHECK_SHAPE
(
q
,
total_q
,
num_heads
,
head_size
);
...
...
csrc/flash_attn/src/fmha_bwd_hdim64.cu
View file @
1b18f1b7
...
...
@@ -11,10 +11,10 @@ void run_fmha_bwd_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream, const b
using
Kernel_traits
=
FMHA_kernel_traits
<
128
,
64
,
16
,
1
,
8
,
0x08u
,
elem_type
>
;
run_fmha_bwd_loop
<
Kernel_traits
>
(
params
,
stream
,
configure
);
}
else
if
(
params
.
seqlen_k
>=
256
)
{
if
(
dprops
->
major
==
8
&&
dprops
->
minor
==
0
)
{
if
(
(
dprops
->
major
==
8
&&
dprops
->
minor
==
0
)
||
(
dprops
->
major
==
9
&&
dprops
->
minor
==
0
))
{
// Don't share smem for K & V, and don't keep V in registers
// This speeds things up by 2-3% by avoiding register spills, but it
// uses more shared memory, which is fine on A100 but not other GPUs.
// uses more shared memory, which is fine on A100
and H100
but not other GPUs.
// For other GPUs, we keep V in registers.
using
Kernel_traits
=
FMHA_kernel_traits
<
256
,
64
,
16
,
1
,
8
,
0x100u
,
elem_type
>
;
run_fmha_bwd_loop
<
Kernel_traits
>
(
params
,
stream
,
configure
);
...
...
setup.py
View file @
1b18f1b7
...
...
@@ -3,6 +3,7 @@ import sys
import
warnings
import
os
from
pathlib
import
Path
from
packaging.version
import
parse
,
Version
from
setuptools
import
setup
,
find_packages
import
subprocess
...
...
@@ -23,22 +24,19 @@ def get_cuda_bare_metal_version(cuda_dir):
raw_output
=
subprocess
.
check_output
([
cuda_dir
+
"/bin/nvcc"
,
"-V"
],
universal_newlines
=
True
)
output
=
raw_output
.
split
()
release_idx
=
output
.
index
(
"release"
)
+
1
release
=
output
[
release_idx
].
split
(
"."
)
bare_metal_major
=
release
[
0
]
bare_metal_minor
=
release
[
1
][
0
]
bare_metal_version
=
parse
(
output
[
release_idx
].
split
(
","
)[
0
])
return
raw_output
,
bare_metal_
major
,
bare_metal_minor
return
raw_output
,
bare_metal_
version
def
check_cuda_torch_binary_vs_bare_metal
(
cuda_dir
):
raw_output
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
cuda_dir
)
torch_binary_major
=
torch
.
version
.
cuda
.
split
(
"."
)[
0
]
torch_binary_minor
=
torch
.
version
.
cuda
.
split
(
"."
)[
1
]
raw_output
,
bare_metal_version
=
get_cuda_bare_metal_version
(
cuda_dir
)
torch_binary_version
=
parse
(
torch
.
version
.
cuda
)
print
(
"
\n
Compiling cuda extensions with"
)
print
(
raw_output
+
"from "
+
cuda_dir
+
"/bin
\n
"
)
if
(
bare_metal_
major
!=
torch_binary_major
)
or
(
bare_metal_minor
!=
torch_binary_
minor
):
if
(
bare_metal_
version
!=
torch_binary_
version
):
raise
RuntimeError
(
"Cuda extensions are being compiled with a version of Cuda that does "
"not match the version used to compile Pytorch binaries. "
...
...
@@ -60,8 +58,8 @@ def raise_if_cuda_home_none(global_option: str) -> None:
def
append_nvcc_threads
(
nvcc_extra_args
):
_
,
bare_metal_
major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_
major
)
>=
11
and
int
(
bare_metal_minor
)
>=
2
:
_
,
bare_metal_
version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
bare_metal_
version
>=
Version
(
"11.2"
)
:
return
nvcc_extra_args
+
[
"--threads"
,
"4"
]
return
nvcc_extra_args
...
...
@@ -73,20 +71,23 @@ if not torch.cuda.is_available():
print
(
"
\n
Warning: Torch did not find available GPUs on this system.
\n
"
,
"If your intention is to cross-compile, this is not an error.
\n
"
"By default,
We
cross-compile for
Volta
(compute capabilit
y 7
.0
)
, "
"Turing (compute capability 7.5),
\n
"
"By default,
Apex will
cross-compile for
Pascal
(compute capabilit
ies 6
.0,
6.1, 6.2),
\n
"
"
Volta (compute capability 7.0),
Turing (compute capability 7.5),
\n
"
"and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).
\n
"
"If you wish to cross-compile for a single specific architecture,
\n
"
'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.
\n
'
,
)
if
os
.
environ
.
get
(
"TORCH_CUDA_ARCH_LIST"
,
None
)
is
None
:
_
,
bare_metal_major
,
bare_metal_minor
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_major
)
==
11
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"7.0;7.5;8.0"
if
int
(
bare_metal_minor
)
>
0
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"7.0;7.5;8.0;8.6"
if
os
.
environ
.
get
(
"TORCH_CUDA_ARCH_LIST"
,
None
)
is
None
and
CUDA_HOME
is
not
None
:
_
,
bare_metal_version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
bare_metal_version
>=
Version
(
"11.8"
):
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0"
elif
bare_metal_version
>=
Version
(
"11.1"
):
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0;8.6"
elif
bare_metal_version
==
Version
(
"11.0"
):
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5;8.0"
else
:
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"7.0;7.5"
os
.
environ
[
"TORCH_CUDA_ARCH_LIST"
]
=
"6.0;6.1;6.2;7.0;7.5"
print
(
"
\n\n
torch.__version__ = {}
\n\n
"
.
format
(
torch
.
__version__
))
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
...
...
@@ -105,13 +106,16 @@ if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.
raise_if_cuda_home_none
(
"flash_attn"
)
# Check, if CUDA11 is installed for compute capability 8.0
cc_flag
=
[]
_
,
bare_metal_
major
,
_
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
int
(
bare_metal_
major
)
<
11
:
_
,
bare_metal_
version
=
get_cuda_bare_metal_version
(
CUDA_HOME
)
if
bare_metal_
version
<
Version
(
"11.0"
)
:
raise
RuntimeError
(
"FlashAttention is only supported on CUDA 11"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_75,code=sm_75"
)
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_80,code=sm_80"
)
if
bare_metal_version
>=
Version
(
"11.8"
):
cc_flag
.
append
(
"-gencode"
)
cc_flag
.
append
(
"arch=compute_90,code=sm_90"
)
subprocess
.
run
([
"git"
,
"submodule"
,
"update"
,
"--init"
,
"csrc/flash_attn/cutlass"
])
ext_modules
.
append
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment