Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
9dbc491a
Commit
9dbc491a
authored
May 26, 2022
by
Tri Dao
Browse files
Rename, add benchmarking script
parent
1fcbe6f0
Changes
34
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
42 additions
and
48 deletions
+42
-48
csrc/flash_attn/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu
csrc/flash_attn/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu
+0
-0
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
csrc/flash_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
+0
-0
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_loop.h
+0
-0
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_reload_recompute.h
csrc/flash_attn/src/fmha_dgrad_kernel_1xN_reload_recompute.h
+0
-0
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
csrc/flash_attn/src/fmha_fprop_fp16_kernel.sm80.cu
+0
-0
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
csrc/flash_attn/src/fmha_fprop_kernel_1xN.h
+0
-0
csrc/flash_attn/src/fmha_kernel.h
csrc/flash_attn/src/fmha_kernel.h
+0
-0
csrc/flash_attn/src/fmha_utils.h
csrc/flash_attn/src/fmha_utils.h
+0
-0
csrc/flash_attn/src/philox.cuh
csrc/flash_attn/src/philox.cuh
+0
-0
csrc/stream_attn/README.md
csrc/stream_attn/README.md
+0
-6
flash_attention.py
flash_attention.py
+7
-7
flash_attn_interface.py
flash_attn_interface.py
+13
-13
flash_blocksparse_attention.py
flash_blocksparse_attention.py
+9
-9
flash_blocksparse_attn_interface.py
flash_blocksparse_attn_interface.py
+13
-13
No files found.
csrc/
stream
_attn/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu
→
csrc/
flash
_attn/src/fmha_dgrad_fp16_512_64_kernel.sm80.cu
View file @
9dbc491a
File moved
csrc/
stream
_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
→
csrc/
flash
_attn/src/fmha_dgrad_fp16_kernel_loop.sm80.cu
View file @
9dbc491a
File moved
csrc/
stream
_attn/src/fmha_dgrad_kernel_1xN_loop.h
→
csrc/
flash
_attn/src/fmha_dgrad_kernel_1xN_loop.h
View file @
9dbc491a
File moved
csrc/
stream
_attn/src/fmha_dgrad_kernel_1xN_reload_recompute.h
→
csrc/
flash
_attn/src/fmha_dgrad_kernel_1xN_reload_recompute.h
View file @
9dbc491a
File moved
csrc/
stream
_attn/src/fmha_fprop_fp16_kernel.sm80.cu
→
csrc/
flash
_attn/src/fmha_fprop_fp16_kernel.sm80.cu
View file @
9dbc491a
File moved
csrc/
stream
_attn/src/fmha_fprop_kernel_1xN.h
→
csrc/
flash
_attn/src/fmha_fprop_kernel_1xN.h
View file @
9dbc491a
File moved
csrc/
stream
_attn/src/fmha_kernel.h
→
csrc/
flash
_attn/src/fmha_kernel.h
View file @
9dbc491a
File moved
csrc/
stream
_attn/src/fmha_utils.h
→
csrc/
flash
_attn/src/fmha_utils.h
View file @
9dbc491a
File moved
csrc/
stream
_attn/src/philox.cuh
→
csrc/
flash
_attn/src/philox.cuh
View file @
9dbc491a
File moved
csrc/stream_attn/README.md
deleted
100644 → 0
View file @
1fcbe6f0
Our implementation uses Apex's
[
FMHA
](
https://github.com/NVIDIA/apex/tree/master/apex/contrib/csrc/fmha
)
code
as a starting point.
We thank
[
Young-jun Ko
](
https://yjk21.github.io/
)
for the in-depth explanation of his FMHA implementation
and for his thoughtful answers to our questions about CUDA.
streaming
_attention.py
→
flash
_attention.py
View file @
9dbc491a
...
@@ -5,11 +5,11 @@ import torch.nn as nn
...
@@ -5,11 +5,11 @@ import torch.nn as nn
from
einops
import
rearrange
from
einops
import
rearrange
from
rotary
import
RotaryEmbedding
,
RotaryEmbedding2D
from
rotary
import
RotaryEmbedding
,
RotaryEmbedding2D
from
stream
_attn_interface
import
stream
_attn_func
from
flash
_attn_interface
import
flash
_attn_func
from
bert_padding
import
unpad_input
,
pad_input
,
index_first_axis
from
bert_padding
import
unpad_input
,
pad_input
,
index_first_axis
class
Streaming
Attention
(
nn
.
Module
):
class
Flash
Attention
(
nn
.
Module
):
"""Implement the scaled dot product attention with softmax.
"""Implement the scaled dot product attention with softmax.
Arguments
Arguments
---------
---------
...
@@ -49,7 +49,7 @@ class StreamingAttention(nn.Module):
...
@@ -49,7 +49,7 @@ class StreamingAttention(nn.Module):
max_s
=
seqlen
max_s
=
seqlen
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen
,
step
=
seqlen
,
dtype
=
torch
.
int32
,
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen
,
step
=
seqlen
,
dtype
=
torch
.
int32
,
device
=
qkv
.
device
)
device
=
qkv
.
device
)
output
=
stream
_attn_func
(
qkv
,
cu_seqlens
,
self
.
dropout_p
if
self
.
training
else
0.0
,
output
=
flash
_attn_func
(
qkv
,
cu_seqlens
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
)
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
)
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
else
:
else
:
...
@@ -58,7 +58,7 @@ class StreamingAttention(nn.Module):
...
@@ -58,7 +58,7 @@ class StreamingAttention(nn.Module):
x
=
rearrange
(
qkv
,
'b s three h d -> b s (three h d)'
)
x
=
rearrange
(
qkv
,
'b s three h d -> b s (three h d)'
)
x_unpad
,
indices
,
cu_seqlens
,
max_s
=
unpad_input
(
x
,
key_padding_mask_bool
)
x_unpad
,
indices
,
cu_seqlens
,
max_s
=
unpad_input
(
x
,
key_padding_mask_bool
)
x_unpad
=
rearrange
(
x_unpad
,
'nnz (three h d) -> nnz three h d'
,
three
=
3
,
h
=
nheads
)
x_unpad
=
rearrange
(
x_unpad
,
'nnz (three h d) -> nnz three h d'
,
three
=
3
,
h
=
nheads
)
output_unpad
=
stream
_attn_func
(
x_unpad
,
cu_seqlens
,
output_unpad
=
flash
_attn_func
(
x_unpad
,
cu_seqlens
,
self
.
dropout_p
if
self
.
training
else
0.0
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
)
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
)
output
=
rearrange
(
pad_input
(
rearrange
(
output_unpad
,
'nnz h d -> nnz (h d)'
),
output
=
rearrange
(
pad_input
(
rearrange
(
output_unpad
,
'nnz h d -> nnz (h d)'
),
...
@@ -66,14 +66,14 @@ class StreamingAttention(nn.Module):
...
@@ -66,14 +66,14 @@ class StreamingAttention(nn.Module):
'b s (h d) -> b s h d'
,
h
=
nheads
)
'b s (h d) -> b s h d'
,
h
=
nheads
)
else
:
else
:
assert
max_s
is
not
None
assert
max_s
is
not
None
output
=
stream
_attn_func
(
qkv
,
cu_seqlens
,
output
=
flash
_attn_func
(
qkv
,
cu_seqlens
,
self
.
dropout_p
if
self
.
training
else
0.0
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
)
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
)
return
output
,
None
return
output
,
None
class
Streaming
MHA
(
nn
.
Module
):
class
Flash
MHA
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
num_heads
,
bias
=
True
,
batch_first
=
True
,
attention_dropout
=
0.0
,
def
__init__
(
self
,
embed_dim
,
num_heads
,
bias
=
True
,
batch_first
=
True
,
attention_dropout
=
0.0
,
causal
=
False
,
use_rotary_emb
=
None
,
device
=
None
,
dtype
=
None
,
**
kwargs
)
->
None
:
causal
=
False
,
use_rotary_emb
=
None
,
device
=
None
,
dtype
=
None
,
**
kwargs
)
->
None
:
...
@@ -96,7 +96,7 @@ class StreamingMHA(nn.Module):
...
@@ -96,7 +96,7 @@ class StreamingMHA(nn.Module):
self
.
rotary_emb
=
RotaryEmbedding2D
(
self
.
head_dim
)
self
.
rotary_emb
=
RotaryEmbedding2D
(
self
.
head_dim
)
self
.
Wqkv
=
nn
.
Linear
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
Wqkv
=
nn
.
Linear
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
inner_attn
=
Streaming
Attention
(
attention_dropout
=
attention_dropout
,
**
factory_kwargs
)
self
.
inner_attn
=
Flash
Attention
(
attention_dropout
=
attention_dropout
,
**
factory_kwargs
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
def
forward
(
self
,
x
,
x_ignored_
,
x_ignored_1_
,
attn_mask
=
None
,
key_padding_mask
=
None
,
def
forward
(
self
,
x
,
x_ignored_
,
x_ignored_1_
,
attn_mask
=
None
,
key_padding_mask
=
None
,
...
...
stream
_attn_interface.py
→
flash
_attn_interface.py
View file @
9dbc491a
...
@@ -2,11 +2,11 @@
...
@@ -2,11 +2,11 @@
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
stream
_attn_cuda
import
flash
_attn_cuda
def
_
stream
_attn_forward
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
,
return_softmax
):
def
_
flash
_attn_forward
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
,
return_softmax
):
context
,
softmax_lse
,
*
rest
=
stream
_attn_cuda
.
fwd
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
context
,
softmax_lse
,
*
rest
=
flash
_attn_cuda
.
fwd
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
False
,
causal
,
return_softmax
,
None
)
False
,
causal
,
return_softmax
,
None
)
# if context.isnan().any() or softmax_lse.isnan().any():
# if context.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
# breakpoint()
...
@@ -14,16 +14,16 @@ def _stream_attn_forward(qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causa
...
@@ -14,16 +14,16 @@ def _stream_attn_forward(qkv, cu_seqlens, dropout_p, max_s, softmax_scale, causa
return
context
,
softmax_lse
,
S_dmask
return
context
,
softmax_lse
,
S_dmask
def
_
stream
_attn_backward
(
dout
,
qkv
,
out
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
dropout_p
,
max_s
,
def
_
flash
_attn_backward
(
dout
,
qkv
,
out
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
softmax_scale
,
causal
):
dqkv
,
dp
,
softmax_d
=
stream
_attn_cuda
.
bwd
(
dout
,
qkv
,
out
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
dropout_p
,
dqkv
,
dp
,
softmax_d
=
flash
_attn_cuda
.
bwd
(
dout
,
qkv
,
out
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
dropout_p
,
softmax_scale
,
max_s
,
False
,
causal
,
None
)
softmax_scale
,
max_s
,
False
,
causal
,
None
)
# if dqkv.isnan().any() or softmax_d.isnan().any():
# if dqkv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
# breakpoint()
return
dqkv
return
dqkv
class
Stream
AttnFun
(
torch
.
autograd
.
Function
):
class
Flash
AttnFun
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
...
@@ -31,7 +31,7 @@ class StreamAttnFun(torch.autograd.Function):
...
@@ -31,7 +31,7 @@ class StreamAttnFun(torch.autograd.Function):
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
context
,
softmax_lse
,
S_dmask
=
_
stream
_attn_forward
(
context
,
softmax_lse
,
S_dmask
=
_
flash
_attn_forward
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
False
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
False
)
)
ctx
.
save_for_backward
(
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
rng_state
)
ctx
.
save_for_backward
(
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
rng_state
)
...
@@ -48,7 +48,7 @@ class StreamAttnFun(torch.autograd.Function):
...
@@ -48,7 +48,7 @@ class StreamAttnFun(torch.autograd.Function):
cur_rng_state
=
torch
.
cuda
.
get_rng_state
()
cur_rng_state
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
rng_state
)
torch
.
cuda
.
set_rng_state
(
rng_state
)
# S_dmask is None, temporarily use another tensor just to get it running
# S_dmask is None, temporarily use another tensor just to get it running
dqkv
=
_
stream
_attn_backward
(
dqkv
=
_
flash
_attn_backward
(
dout
,
qkv
,
context
,
context
,
softmax_lse
,
cu_seqlens
,
ctx
.
dropout_p
,
dout
,
qkv
,
context
,
context
,
softmax_lse
,
cu_seqlens
,
ctx
.
dropout_p
,
ctx
.
max_s
,
ctx
.
softmax_scale
,
ctx
.
causal
ctx
.
max_s
,
ctx
.
softmax_scale
,
ctx
.
causal
)
)
...
@@ -59,7 +59,7 @@ class StreamAttnFun(torch.autograd.Function):
...
@@ -59,7 +59,7 @@ class StreamAttnFun(torch.autograd.Function):
# We duplicate code to return both the output and the softmax for testing
# We duplicate code to return both the output and the softmax for testing
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
class
Stream
AttnFunWithS
(
torch
.
autograd
.
Function
):
class
Flash
AttnFunWithS
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
...
@@ -67,7 +67,7 @@ class StreamAttnFunWithS(torch.autograd.Function):
...
@@ -67,7 +67,7 @@ class StreamAttnFunWithS(torch.autograd.Function):
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
context
,
softmax_lse
,
S_dmask
=
_
stream
_attn_forward
(
context
,
softmax_lse
,
S_dmask
=
_
flash
_attn_forward
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
True
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
True
)
)
ctx
.
save_for_backward
(
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
rng_state
)
ctx
.
save_for_backward
(
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
rng_state
)
...
@@ -83,7 +83,7 @@ class StreamAttnFunWithS(torch.autograd.Function):
...
@@ -83,7 +83,7 @@ class StreamAttnFunWithS(torch.autograd.Function):
if
rng_state
is
not
None
:
if
rng_state
is
not
None
:
cur_rng_state
=
torch
.
cuda
.
get_rng_state
()
cur_rng_state
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
rng_state
)
torch
.
cuda
.
set_rng_state
(
rng_state
)
dqkv
=
_
stream
_attn_backward
(
dqkv
=
_
flash
_attn_backward
(
dout
,
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
ctx
.
dropout_p
,
dout
,
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
ctx
.
dropout_p
,
ctx
.
max_s
,
ctx
.
softmax_scale
,
ctx
.
causal
ctx
.
max_s
,
ctx
.
softmax_scale
,
ctx
.
causal
)
)
...
@@ -92,9 +92,9 @@ class StreamAttnFunWithS(torch.autograd.Function):
...
@@ -92,9 +92,9 @@ class StreamAttnFunWithS(torch.autograd.Function):
return
dqkv
,
None
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
def
stream
_attn_func
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
=
None
,
causal
=
False
,
def
flash
_attn_func
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
):
return_attn_probs
=
False
):
"""dropout_p should be set to 0.0 during evaluation
"""dropout_p should be set to 0.0 during evaluation
"""
"""
func
=
Stream
AttnFun
if
not
return_attn_probs
else
Stream
AttnFunWithS
func
=
Flash
AttnFun
if
not
return_attn_probs
else
Flash
AttnFunWithS
return
func
.
apply
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
)
return
func
.
apply
(
qkv
,
cu_seqlens
,
dropout_p
,
max_s
,
softmax_scale
,
causal
)
streaming
_blocksparse_attention.py
→
flash
_blocksparse_attention.py
View file @
9dbc491a
...
@@ -6,11 +6,11 @@ from einops import rearrange
...
@@ -6,11 +6,11 @@ from einops import rearrange
import
hydra
import
hydra
from
stream
_blocksparse_attn_interface
import
stream
_blocksparse_attn_func
from
flash
_blocksparse_attn_interface
import
flash
_blocksparse_attn_func
from
stream
_blocksparse_attn_interface
import
convert_blockmask
from
flash
_blocksparse_attn_interface
import
convert_blockmask
from
bert_padding
import
unpad_input
,
pad_input
,
index_first_axis
from
bert_padding
import
unpad_input
,
pad_input
,
index_first_axis
class
Streaming
BlocksparseAttention
(
nn
.
Module
):
class
Flash
BlocksparseAttention
(
nn
.
Module
):
"""Implement the scaled dot product attention with softmax.
"""Implement the scaled dot product attention with softmax.
Arguments
Arguments
---------
---------
...
@@ -63,7 +63,7 @@ class StreamingBlocksparseAttention(nn.Module):
...
@@ -63,7 +63,7 @@ class StreamingBlocksparseAttention(nn.Module):
max_s
=
seqlen
max_s
=
seqlen
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen
,
step
=
seqlen
,
dtype
=
torch
.
int32
,
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen
,
step
=
seqlen
,
dtype
=
torch
.
int32
,
device
=
qkv
.
device
)
device
=
qkv
.
device
)
output
=
stream
_blocksparse_attn_func
(
output
=
flash
_blocksparse_attn_func
(
qkv
,
cu_seqlens
,
blockmask
,
self
.
dropout_p
if
self
.
training
else
0.0
,
qkv
,
cu_seqlens
,
blockmask
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
)
)
...
@@ -74,7 +74,7 @@ class StreamingBlocksparseAttention(nn.Module):
...
@@ -74,7 +74,7 @@ class StreamingBlocksparseAttention(nn.Module):
x
=
rearrange
(
qkv
,
'b s three h d -> b s (three h d)'
)
x
=
rearrange
(
qkv
,
'b s three h d -> b s (three h d)'
)
x_unpad
,
indices
,
cu_seqlens
,
max_s
=
unpad_input
(
x
,
key_padding_mask_bool
)
x_unpad
,
indices
,
cu_seqlens
,
max_s
=
unpad_input
(
x
,
key_padding_mask_bool
)
x_unpad
=
rearrange
(
x_unpad
,
'nnz (three h d) -> nnz three h d'
,
three
=
3
,
h
=
nheads
)
x_unpad
=
rearrange
(
x_unpad
,
'nnz (three h d) -> nnz three h d'
,
three
=
3
,
h
=
nheads
)
output_unpad
=
stream
_blocksparse_attn_func
(
output_unpad
=
flash
_blocksparse_attn_func
(
x_unpad
,
cu_seqlens
,
blockmask
,
self
.
dropout_p
if
self
.
training
else
0.0
,
x_unpad
,
cu_seqlens
,
blockmask
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
)
)
...
@@ -89,12 +89,12 @@ class StreamingBlocksparseAttention(nn.Module):
...
@@ -89,12 +89,12 @@ class StreamingBlocksparseAttention(nn.Module):
assert
seqlen_rounded
//
16
<=
self
.
layout
.
shape
[
0
],
seqlen_rounded
//
256
<=
self
.
layout
.
shape
[
1
]
assert
seqlen_rounded
//
16
<=
self
.
layout
.
shape
[
0
],
seqlen_rounded
//
256
<=
self
.
layout
.
shape
[
1
]
blockmask
=
self
.
layout
[:
seqlen_rounded
//
16
,
:
seqlen_rounded
//
256
]
blockmask
=
self
.
layout
[:
seqlen_rounded
//
16
,
:
seqlen_rounded
//
256
]
if
convert_mask
:
if
convert_mask
:
output
=
stream
_blocksparse_attn_func
(
output
=
flash
_blocksparse_attn_func
(
qkv
,
cu_seqlens
,
blockmask
,
self
.
dropout_p
if
self
.
training
else
0.0
,
qkv
,
cu_seqlens
,
blockmask
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
)
)
else
:
else
:
output
=
stream
_blocksparse_attn_func
(
output
=
flash
_blocksparse_attn_func
(
qkv
,
cu_seqlens
,
self
.
blockmask_converted
,
self
.
dropout_p
if
self
.
training
else
0.0
,
qkv
,
cu_seqlens
,
self
.
blockmask_converted
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
,
convert_mask
=
False
,
convert_mask
=
False
,
...
@@ -103,7 +103,7 @@ class StreamingBlocksparseAttention(nn.Module):
...
@@ -103,7 +103,7 @@ class StreamingBlocksparseAttention(nn.Module):
return
output
,
None
return
output
,
None
class
Streaming
BlocksparseMHA
(
nn
.
Module
):
class
Flash
BlocksparseMHA
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
num_heads
,
sparsity_config
,
bias
=
True
,
batch_first
=
True
,
def
__init__
(
self
,
embed_dim
,
num_heads
,
sparsity_config
,
bias
=
True
,
batch_first
=
True
,
attention_dropout
=
0.0
,
causal
=
False
,
max_seq_length
=
2048
,
attention_dropout
=
0.0
,
causal
=
False
,
max_seq_length
=
2048
,
...
@@ -120,7 +120,7 @@ class StreamingBlocksparseMHA(nn.Module):
...
@@ -120,7 +120,7 @@ class StreamingBlocksparseMHA(nn.Module):
assert
self
.
head_dim
in
[
16
,
32
,
64
],
"Only support head_dim == 16, 32, or 64"
assert
self
.
head_dim
in
[
16
,
32
,
64
],
"Only support head_dim == 16, 32, or 64"
self
.
Wqkv
=
nn
.
Linear
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
Wqkv
=
nn
.
Linear
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
inner_attn
=
Streaming
BlocksparseAttention
(
self
.
inner_attn
=
Flash
BlocksparseAttention
(
sparsity_config
,
attention_dropout
=
attention_dropout
,
sparsity_config
,
attention_dropout
=
attention_dropout
,
max_seq_length
=
max_seq_length
,
**
factory_kwargs
max_seq_length
=
max_seq_length
,
**
factory_kwargs
)
)
...
...
stream
_blocksparse_attn_interface.py
→
flash
_blocksparse_attn_interface.py
View file @
9dbc491a
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
stream
_attn_cuda
import
flash
_attn_cuda
def
convert_blockmask
(
blockmask
,
causal
):
def
convert_blockmask
(
blockmask
,
causal
):
...
@@ -40,9 +40,9 @@ def convert_blockmask(blockmask, causal):
...
@@ -40,9 +40,9 @@ def convert_blockmask(blockmask, causal):
return
nonzero_idx
.
T
.
contiguous
().
to
(
dtype
=
torch
.
int32
)
return
nonzero_idx
.
T
.
contiguous
().
to
(
dtype
=
torch
.
int32
)
def
_
stream
_blocksparse_attn_forward
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
def
_
flash
_blocksparse_attn_forward
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
,
return_softmax
):
causal
,
return_softmax
):
context
,
softmax_lse
,
*
rest
=
stream
_attn_cuda
.
fwd_block
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
context
,
softmax_lse
,
*
rest
=
flash
_attn_cuda
.
fwd_block
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
,
max_s
,
softmax_scale
,
causal
,
return_softmax
,
None
)
return_softmax
,
None
)
# if context.isnan().any() or softmax_lse.isnan().any():
# if context.isnan().any() or softmax_lse.isnan().any():
...
@@ -51,9 +51,9 @@ def _stream_blocksparse_attn_forward(qkv, cu_seqlens, blockmask, dropout_p, max_
...
@@ -51,9 +51,9 @@ def _stream_blocksparse_attn_forward(qkv, cu_seqlens, blockmask, dropout_p, max_
return
context
,
softmax_lse
,
S_dmask
return
context
,
softmax_lse
,
S_dmask
def
_
stream
_blocksparse_attn_backward
(
dout
,
qkv
,
out
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
def
_
flash
_blocksparse_attn_backward
(
dout
,
qkv
,
out
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
dropout_p
,
max_s
,
softmax_scale
,
causal
):
dqkv
,
dp
,
softmax_d
=
stream
_attn_cuda
.
bwd_block
(
dout
,
qkv
,
out
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
dqkv
,
dp
,
softmax_d
=
flash
_attn_cuda
.
bwd_block
(
dout
,
qkv
,
out
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
dropout_p
,
softmax_scale
,
max_s
,
blockmask
,
dropout_p
,
softmax_scale
,
max_s
,
causal
,
None
)
causal
,
None
)
# if dqkv.isnan().any() or softmax_d.isnan().any():
# if dqkv.isnan().any() or softmax_d.isnan().any():
...
@@ -61,7 +61,7 @@ def _stream_blocksparse_attn_backward(dout, qkv, out, S_dmask, softmax_lse, cu_s
...
@@ -61,7 +61,7 @@ def _stream_blocksparse_attn_backward(dout, qkv, out, S_dmask, softmax_lse, cu_s
return
dqkv
return
dqkv
class
Stream
BlocksparseAttnFun
(
torch
.
autograd
.
Function
):
class
Flash
BlocksparseAttnFun
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
...
@@ -69,7 +69,7 @@ class StreamBlocksparseAttnFun(torch.autograd.Function):
...
@@ -69,7 +69,7 @@ class StreamBlocksparseAttnFun(torch.autograd.Function):
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
context
,
softmax_lse
,
S_dmask
=
_
stream
_blocksparse_attn_forward
(
context
,
softmax_lse
,
S_dmask
=
_
flash
_blocksparse_attn_forward
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
=
causal
,
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
False
return_softmax
=
False
)
)
...
@@ -87,7 +87,7 @@ class StreamBlocksparseAttnFun(torch.autograd.Function):
...
@@ -87,7 +87,7 @@ class StreamBlocksparseAttnFun(torch.autograd.Function):
cur_rng_state
=
torch
.
cuda
.
get_rng_state
()
cur_rng_state
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
rng_state
)
torch
.
cuda
.
set_rng_state
(
rng_state
)
# S_dmask is None, temporarily use another tensor just to get it running
# S_dmask is None, temporarily use another tensor just to get it running
dqkv
=
_
stream
_blocksparse_attn_backward
(
dqkv
=
_
flash
_blocksparse_attn_backward
(
dout
,
qkv
,
context
,
context
,
softmax_lse
,
cu_seqlens
,
blockmask
,
ctx
.
dropout_p
,
dout
,
qkv
,
context
,
context
,
softmax_lse
,
cu_seqlens
,
blockmask
,
ctx
.
dropout_p
,
ctx
.
max_s
,
ctx
.
softmax_scale
,
ctx
.
causal
ctx
.
max_s
,
ctx
.
softmax_scale
,
ctx
.
causal
)
)
...
@@ -98,7 +98,7 @@ class StreamBlocksparseAttnFun(torch.autograd.Function):
...
@@ -98,7 +98,7 @@ class StreamBlocksparseAttnFun(torch.autograd.Function):
# We duplicate code to return both the output and the softmax for testing
# We duplicate code to return both the output and the softmax for testing
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
class
Stream
BlocksparseAttnFunWithS
(
torch
.
autograd
.
Function
):
class
Flash
BlocksparseAttnFunWithS
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
...
@@ -106,7 +106,7 @@ class StreamBlocksparseAttnFunWithS(torch.autograd.Function):
...
@@ -106,7 +106,7 @@ class StreamBlocksparseAttnFunWithS(torch.autograd.Function):
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
rng_state
=
torch
.
cuda
.
get_rng_state
()
if
dropout_p
>
0
else
None
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
context
,
softmax_lse
,
S_dmask
=
_
stream
_blocksparse_attn_forward
(
context
,
softmax_lse
,
S_dmask
=
_
flash
_blocksparse_attn_forward
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
=
causal
,
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
True
return_softmax
=
True
)
)
...
@@ -123,7 +123,7 @@ class StreamBlocksparseAttnFunWithS(torch.autograd.Function):
...
@@ -123,7 +123,7 @@ class StreamBlocksparseAttnFunWithS(torch.autograd.Function):
if
rng_state
is
not
None
:
if
rng_state
is
not
None
:
cur_rng_state
=
torch
.
cuda
.
get_rng_state
()
cur_rng_state
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
rng_state
)
torch
.
cuda
.
set_rng_state
(
rng_state
)
dqkv
=
_
stream
_blocksparse_attn_backward
(
dqkv
=
_
flash
_blocksparse_attn_backward
(
dout
,
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
ctx
.
dropout_p
,
dout
,
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
ctx
.
dropout_p
,
ctx
.
max_s
,
ctx
.
softmax_scale
,
ctx
.
causal
ctx
.
max_s
,
ctx
.
softmax_scale
,
ctx
.
causal
)
)
...
@@ -132,11 +132,11 @@ class StreamBlocksparseAttnFunWithS(torch.autograd.Function):
...
@@ -132,11 +132,11 @@ class StreamBlocksparseAttnFunWithS(torch.autograd.Function):
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
def
stream
_blocksparse_attn_func
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
=
None
,
def
flash
_blocksparse_attn_func
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
convert_mask
=
True
):
causal
=
False
,
return_attn_probs
=
False
,
convert_mask
=
True
):
"""dropout_p should be set to 0.0 during evaluation
"""dropout_p should be set to 0.0 during evaluation
"""
"""
func
=
Stream
BlocksparseAttnFun
if
not
return_attn_probs
else
Stream
BlocksparseAttnFunWithS
func
=
Flash
BlocksparseAttnFun
if
not
return_attn_probs
else
Flash
BlocksparseAttnFunWithS
if
convert_mask
:
if
convert_mask
:
blockmask
=
convert_blockmask
(
blockmask
,
causal
=
causal
)
blockmask
=
convert_blockmask
(
blockmask
,
causal
=
causal
)
return
func
.
apply
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
)
return
func
.
apply
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment