Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
f1a73d07
"vscode:/vscode.git/clone" did not exist on "79292ff3e06219bc67d06932cee98a5ad2ce5c04"
Commit
f1a73d07
authored
Aug 18, 2023
by
Tri Dao
Browse files
Run isort and black on python files
parent
cbb4cf5f
Changes
34
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2219 additions
and
1011 deletions
+2219
-1011
flash_attn/__init__.py
flash_attn/__init__.py
+8
-6
flash_attn/bert_padding.py
flash_attn/bert_padding.py
+23
-18
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+336
-88
flash_attn/flash_attn_triton.py
flash_attn/flash_attn_triton.py
+533
-205
flash_attn/flash_attn_triton_og.py
flash_attn/flash_attn_triton_og.py
+134
-45
flash_attn/flash_blocksparse_attention.py
flash_attn/flash_blocksparse_attention.py
+105
-44
flash_attn/flash_blocksparse_attn_interface.py
flash_attn/flash_blocksparse_attn_interface.py
+84
-26
flash_attn/fused_softmax.py
flash_attn/fused_softmax.py
+10
-14
flash_attn/layers/patch_embed.py
flash_attn/layers/patch_embed.py
+30
-19
flash_attn/layers/rotary.py
flash_attn/layers/rotary.py
+171
-79
flash_attn/losses/cross_entropy.py
flash_attn/losses/cross_entropy.py
+46
-27
flash_attn/models/bert.py
flash_attn/models/bert.py
+244
-155
flash_attn/models/falcon.py
flash_attn/models/falcon.py
+58
-37
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+20
-10
flash_attn/models/gpt_neox.py
flash_attn/models/gpt_neox.py
+52
-32
flash_attn/models/gptj.py
flash_attn/models/gptj.py
+36
-25
flash_attn/models/llama.py
flash_attn/models/llama.py
+106
-70
flash_attn/models/opt.py
flash_attn/models/opt.py
+51
-37
flash_attn/models/vit.py
flash_attn/models/vit.py
+4
-3
flash_attn/modules/block.py
flash_attn/modules/block.py
+168
-71
No files found.
flash_attn/__init__.py
View file @
f1a73d07
__version__
=
"2.0.8"
__version__
=
"2.0.8"
from
flash_attn.flash_attn_interface
import
flash_attn_func
from
flash_attn.flash_attn_interface
import
(
from
flash_attn.flash_attn_interface
import
flash_attn_kvpacked_func
flash_attn_func
,
from
flash_attn.flash_attn_interface
import
flash_attn_qkvpacked_func
flash_attn_kvpacked_func
,
from
flash_attn.flash_attn_interface
import
flash_attn_varlen_qkvpacked_func
flash_attn_qkvpacked_func
,
from
flash_attn.flash_attn_interface
import
flash_attn_varlen_kvpacked_func
flash_attn_varlen_func
,
from
flash_attn.flash_attn_interface
import
flash_attn_varlen_func
flash_attn_varlen_kvpacked_func
,
flash_attn_varlen_qkvpacked_func
,
)
flash_attn/bert_padding.py
View file @
f1a73d07
...
@@ -2,12 +2,10 @@
...
@@ -2,12 +2,10 @@
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
class
IndexFirstAxis
(
torch
.
autograd
.
Function
):
class
IndexFirstAxis
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input
,
indices
):
def
forward
(
ctx
,
input
,
indices
):
ctx
.
save_for_backward
(
indices
)
ctx
.
save_for_backward
(
indices
)
...
@@ -16,20 +14,24 @@ class IndexFirstAxis(torch.autograd.Function):
...
@@ -16,20 +14,24 @@ class IndexFirstAxis(torch.autograd.Function):
second_dim
=
other_shape
.
numel
()
second_dim
=
other_shape
.
numel
()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
# return input[indices]
return
torch
.
gather
(
rearrange
(
input
,
'b ... -> b (...)'
),
0
,
return
torch
.
gather
(
repeat
(
indices
,
'z -> z d'
,
d
=
second_dim
)).
reshape
(
-
1
,
*
other_shape
)
rearrange
(
input
,
"b ... -> b (...)"
),
0
,
repeat
(
indices
,
"z -> z d"
,
d
=
second_dim
)
).
reshape
(
-
1
,
*
other_shape
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
indices
,
=
ctx
.
saved_tensors
(
indices
,
)
=
ctx
.
saved_tensors
assert
grad_output
.
ndim
>=
2
assert
grad_output
.
ndim
>=
2
other_shape
=
grad_output
.
shape
[
1
:]
other_shape
=
grad_output
.
shape
[
1
:]
grad_output
=
rearrange
(
grad_output
,
'b ... -> b (...)'
)
grad_output
=
rearrange
(
grad_output
,
"b ... -> b (...)"
)
grad_input
=
torch
.
zeros
([
ctx
.
first_axis_dim
,
grad_output
.
shape
[
1
]],
grad_input
=
torch
.
zeros
(
device
=
grad_output
.
device
,
dtype
=
grad_output
.
dtype
)
[
ctx
.
first_axis_dim
,
grad_output
.
shape
[
1
]],
device
=
grad_output
.
device
,
dtype
=
grad_output
.
dtype
,
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# grad_input[indices] = grad_output
# grad_input[indices] = grad_output
grad_input
.
scatter_
(
0
,
repeat
(
indices
,
'
z -> z d
'
,
d
=
grad_output
.
shape
[
1
]),
grad_output
)
grad_input
.
scatter_
(
0
,
repeat
(
indices
,
"
z -> z d
"
,
d
=
grad_output
.
shape
[
1
]),
grad_output
)
return
grad_input
.
reshape
(
ctx
.
first_axis_dim
,
*
other_shape
),
None
return
grad_input
.
reshape
(
ctx
.
first_axis_dim
,
*
other_shape
),
None
...
@@ -37,14 +39,14 @@ index_first_axis = IndexFirstAxis.apply
...
@@ -37,14 +39,14 @@ index_first_axis = IndexFirstAxis.apply
class
IndexPutFirstAxis
(
torch
.
autograd
.
Function
):
class
IndexPutFirstAxis
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
values
,
indices
,
first_axis_dim
):
def
forward
(
ctx
,
values
,
indices
,
first_axis_dim
):
ctx
.
save_for_backward
(
indices
)
ctx
.
save_for_backward
(
indices
)
assert
indices
.
ndim
==
1
assert
indices
.
ndim
==
1
assert
values
.
ndim
>=
2
assert
values
.
ndim
>=
2
output
=
torch
.
zeros
(
first_axis_dim
,
*
values
.
shape
[
1
:],
device
=
values
.
device
,
output
=
torch
.
zeros
(
dtype
=
values
.
dtype
)
first_axis_dim
,
*
values
.
shape
[
1
:],
device
=
values
.
device
,
dtype
=
values
.
dtype
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output
[
indices
]
=
values
output
[
indices
]
=
values
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
...
@@ -52,7 +54,7 @@ class IndexPutFirstAxis(torch.autograd.Function):
...
@@ -52,7 +54,7 @@ class IndexPutFirstAxis(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
indices
,
=
ctx
.
saved_tensors
(
indices
,
)
=
ctx
.
saved_tensors
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
grad_values
=
grad_output
[
indices
]
grad_values
=
grad_output
[
indices
]
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
...
@@ -63,7 +65,6 @@ index_put_first_axis = IndexPutFirstAxis.apply
...
@@ -63,7 +65,6 @@ index_put_first_axis = IndexPutFirstAxis.apply
class
IndexFirstAxisResidual
(
torch
.
autograd
.
Function
):
class
IndexFirstAxisResidual
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input
,
indices
):
def
forward
(
ctx
,
input
,
indices
):
ctx
.
save_for_backward
(
indices
)
ctx
.
save_for_backward
(
indices
)
...
@@ -79,7 +80,7 @@ class IndexFirstAxisResidual(torch.autograd.Function):
...
@@ -79,7 +80,7 @@ class IndexFirstAxisResidual(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
,
grad_residual
):
def
backward
(
ctx
,
grad_output
,
grad_residual
):
indices
,
=
ctx
.
saved_tensors
(
indices
,
)
=
ctx
.
saved_tensors
assert
grad_output
.
ndim
>=
2
assert
grad_output
.
ndim
>=
2
other_shape
=
grad_output
.
shape
[
1
:]
other_shape
=
grad_output
.
shape
[
1
:]
assert
grad_residual
.
shape
[
1
:]
==
other_shape
assert
grad_residual
.
shape
[
1
:]
==
other_shape
...
@@ -113,8 +114,12 @@ def unpad_input(hidden_states, attention_mask):
...
@@ -113,8 +114,12 @@ def unpad_input(hidden_states, attention_mask):
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
# so we write custom forward and backward to make it a bit faster.
return
(
index_first_axis
(
rearrange
(
hidden_states
,
'b s ... -> (b s) ...'
),
indices
),
indices
,
return
(
cu_seqlens
,
max_seqlen_in_batch
)
index_first_axis
(
rearrange
(
hidden_states
,
"b s ... -> (b s) ..."
),
indices
),
indices
,
cu_seqlens
,
max_seqlen_in_batch
,
)
def
pad_input
(
hidden_states
,
indices
,
batch
,
seqlen
):
def
pad_input
(
hidden_states
,
indices
,
batch
,
seqlen
):
...
@@ -129,4 +134,4 @@ def pad_input(hidden_states, indices, batch, seqlen):
...
@@ -129,4 +134,4 @@ def pad_input(hidden_states, indices, batch, seqlen):
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output[indices] = hidden_states
# output[indices] = hidden_states
output
=
index_put_first_axis
(
hidden_states
,
indices
,
batch
*
seqlen
)
output
=
index_put_first_axis
(
hidden_states
,
indices
,
batch
*
seqlen
)
return
rearrange
(
output
,
'
(b s) ... -> b s ...
'
,
b
=
batch
)
return
rearrange
(
output
,
"
(b s) ... -> b s ...
"
,
b
=
batch
)
flash_attn/flash_attn_interface.py
View file @
f1a73d07
import
flash_attn_2_cuda
as
flash_attn_cuda
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
flash_attn_2_cuda
as
flash_attn_cuda
from
einops
import
rearrange
from
einops
import
rearrange
...
@@ -45,40 +44,109 @@ def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softma
...
@@ -45,40 +44,109 @@ def _flash_attn_forward(q, k, v, dropout_p, softmax_scale, causal, return_softma
return
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
return
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
def
_flash_attn_varlen_forward
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
def
_flash_attn_varlen_forward
(
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
,
):
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
flash_attn_cuda
.
varlen_fwd
(
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
flash_attn_cuda
.
varlen_fwd
(
q
,
k
,
v
,
None
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
q
,
softmax_scale
,
False
,
causal
,
return_softmax
,
None
k
,
v
,
None
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
False
,
causal
,
return_softmax
,
None
,
)
)
# if out.isnan().any() or softmax_lse.isnan().any():
# if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
# breakpoint()
return
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
return
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
def
_flash_attn_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
def
_flash_attn_backward
(
dropout_p
,
softmax_scale
,
causal
,
rng_state
=
None
):
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
dropout_p
,
softmax_scale
,
causal
,
rng_state
=
None
):
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
# dq, dk, dv are allocated by us so they should already be contiguous
# dq, dk, dv are allocated by us so they should already be contiguous
dout
,
q
,
k
,
v
,
out
=
[
maybe_contiguous
(
x
)
for
x
in
(
dout
,
q
,
k
,
v
,
out
)]
dout
,
q
,
k
,
v
,
out
=
[
maybe_contiguous
(
x
)
for
x
in
(
dout
,
q
,
k
,
v
,
out
)]
dq
,
dk
,
dv
,
softmax_d
,
=
flash_attn_cuda
.
bwd
(
dq
,
dk
,
dv
,
softmax_d
,
=
flash_attn_cuda
.
bwd
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
dropout_p
,
dout
,
softmax_scale
,
causal
,
None
,
rng_state
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
dropout_p
,
softmax_scale
,
causal
,
None
,
rng_state
,
)
)
return
dq
,
dk
,
dv
,
softmax_d
return
dq
,
dk
,
dv
,
softmax_d
def
_flash_attn_varlen_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
def
_flash_attn_varlen_backward
(
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dout
,
dropout_p
,
softmax_scale
,
causal
,
rng_state
=
None
):
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
rng_state
=
None
,
):
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
maybe_contiguous
=
lambda
x
:
x
.
contiguous
()
if
x
.
stride
(
-
1
)
!=
1
else
x
# dq, dk, dv are allocated by us so they should already be contiguous
# dq, dk, dv are allocated by us so they should already be contiguous
dout
,
q
,
k
,
v
,
out
=
[
maybe_contiguous
(
x
)
for
x
in
(
dout
,
q
,
k
,
v
,
out
)]
dout
,
q
,
k
,
v
,
out
=
[
maybe_contiguous
(
x
)
for
x
in
(
dout
,
q
,
k
,
v
,
out
)]
dq
,
dk
,
dv
,
softmax_d
,
=
flash_attn_cuda
.
varlen_bwd
(
dq
,
dk
,
dv
,
softmax_d
,
=
flash_attn_cuda
.
varlen_bwd
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
cu_seqlens_q
,
cu_seqlens_k
,
dout
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
False
,
causal
,
None
,
rng_state
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
False
,
causal
,
None
,
rng_state
,
)
)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
# breakpoint()
...
@@ -86,14 +154,18 @@ def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
...
@@ -86,14 +154,18 @@ def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
class
FlashAttnQKVPackedFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnQKVPackedFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
qkv
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
def
forward
(
ctx
,
qkv
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_forward
(
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_forward
(
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
],
dropout_p
,
softmax_scale
,
qkv
[:,
:,
0
],
causal
=
causal
,
return_softmax
=
return_softmax
and
dropout_p
>
0
qkv
[:,
:,
1
],
qkv
[:,
:,
2
],
dropout_p
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
)
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
dropout_p
=
dropout_p
...
@@ -107,22 +179,41 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
...
@@ -107,22 +179,41 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
qkv_shape
=
q
.
shape
[:
-
2
]
+
(
3
,
*
q
.
shape
[
-
2
:])
qkv_shape
=
q
.
shape
[:
-
2
]
+
(
3
,
*
q
.
shape
[
-
2
:])
dqkv
=
torch
.
empty
(
qkv_shape
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
dqkv
=
torch
.
empty
(
qkv_shape
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
_flash_attn_backward
(
_flash_attn_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dqkv
[:,
:,
0
],
dqkv
[:,
:,
1
],
dqkv
[:,
:,
2
],
dout
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
rng_state
=
rng_state
q
,
k
,
v
,
out
,
softmax_lse
,
dqkv
[:,
:,
0
],
dqkv
[:,
:,
1
],
dqkv
[:,
:,
2
],
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
rng_state
=
rng_state
,
)
)
dqkv
=
dqkv
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dqkv
=
dqkv
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
return
dqkv
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
class
FlashAttnVarlenQKVPackedFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnVarlenQKVPackedFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_varlen_forward
(
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_varlen_forward
(
qkv
[:,
0
],
qkv
[:,
1
],
qkv
[:,
2
],
cu_seqlens
,
cu_seqlens
,
max_seqlen
,
max_seqlen
,
qkv
[:,
0
],
dropout_p
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
return_softmax
and
dropout_p
>
0
qkv
[:,
1
],
qkv
[:,
2
],
cu_seqlens
,
cu_seqlens
,
max_seqlen
,
max_seqlen
,
dropout_p
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
)
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
cu_seqlens
,
rng_state
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
cu_seqlens
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
dropout_p
=
dropout_p
...
@@ -137,23 +228,41 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
...
@@ -137,23 +228,41 @@ class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
qkv_shape
=
q
.
shape
[:
-
2
]
+
(
3
,
*
q
.
shape
[
-
2
:])
qkv_shape
=
q
.
shape
[:
-
2
]
+
(
3
,
*
q
.
shape
[
-
2
:])
dqkv
=
torch
.
empty
(
qkv_shape
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
dqkv
=
torch
.
empty
(
qkv_shape
,
dtype
=
q
.
dtype
,
device
=
q
.
device
)
_flash_attn_varlen_backward
(
_flash_attn_varlen_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dqkv
[:,
0
],
dqkv
[:,
1
],
dqkv
[:,
2
],
dout
,
cu_seqlens
,
cu_seqlens
,
ctx
.
max_seqlen
,
ctx
.
max_seqlen
,
q
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
rng_state
=
rng_state
k
,
v
,
out
,
softmax_lse
,
dqkv
[:,
0
],
dqkv
[:,
1
],
dqkv
[:,
2
],
cu_seqlens
,
cu_seqlens
,
ctx
.
max_seqlen
,
ctx
.
max_seqlen
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
rng_state
=
rng_state
,
)
)
dqkv
=
dqkv
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dqkv
=
dqkv
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnKVPackedFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnKVPackedFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
q
,
kv
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
def
forward
(
ctx
,
q
,
kv
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_forward
(
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_forward
(
q
,
kv
[:,
:,
0
],
kv
[:,
:,
1
],
dropout_p
,
softmax_scale
,
causal
=
causal
,
q
,
return_softmax
=
return_softmax
and
dropout_p
>
0
kv
[:,
:,
0
],
kv
[:,
:,
1
],
dropout_p
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
)
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
dropout_p
=
dropout_p
...
@@ -168,28 +277,58 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
...
@@ -168,28 +277,58 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
kv_shape
=
k
.
shape
[:
-
2
]
+
(
2
,
*
k
.
shape
[
-
2
:])
kv_shape
=
k
.
shape
[:
-
2
]
+
(
2
,
*
k
.
shape
[
-
2
:])
dkv
=
torch
.
empty
(
kv_shape
,
dtype
=
k
.
dtype
,
device
=
k
.
device
)
dkv
=
torch
.
empty
(
kv_shape
,
dtype
=
k
.
dtype
,
device
=
k
.
device
)
_flash_attn_backward
(
_flash_attn_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dout
,
dq
,
dkv
[:,
:,
0
],
dkv
[:,
:,
1
],
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
q
,
rng_state
=
rng_state
k
,
v
,
out
,
softmax_lse
,
dq
,
dkv
[:,
:,
0
],
dkv
[:,
:,
1
],
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
rng_state
=
rng_state
,
)
)
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dkv
=
dkv
[...,
:
dout
.
shape
[
-
1
]]
dkv
=
dkv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dkv
,
None
,
None
,
None
,
None
return
dq
,
dkv
,
None
,
None
,
None
,
None
class
FlashAttnVarlenKVPackedFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnVarlenKVPackedFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
def
forward
(
softmax_scale
,
causal
,
return_softmax
):
ctx
,
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
,
):
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_varlen_forward
(
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_varlen_forward
(
q
,
kv
[:,
0
],
kv
[:,
1
],
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
,
dropout_p
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
return_softmax
and
dropout_p
>
0
kv
[:,
0
],
kv
[:,
1
],
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
cu_seqlens_q
,
cu_seqlens_k
,
rng_state
)
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
cu_seqlens_q
,
cu_seqlens_k
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
dropout_p
=
dropout_p
ctx
.
max_seqlen_q
=
max_seqlen_q
ctx
.
max_seqlen_q
=
max_seqlen_q
ctx
.
max_seqlen_k
=
max_seqlen_k
ctx
.
max_seqlen_k
=
max_seqlen_k
...
@@ -204,24 +343,42 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
...
@@ -204,24 +343,42 @@ class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
kv_shape
=
k
.
shape
[:
-
2
]
+
(
2
,
*
k
.
shape
[
-
2
:])
kv_shape
=
k
.
shape
[:
-
2
]
+
(
2
,
*
k
.
shape
[
-
2
:])
dkv
=
torch
.
empty
(
kv_shape
,
dtype
=
k
.
dtype
,
device
=
k
.
device
)
dkv
=
torch
.
empty
(
kv_shape
,
dtype
=
k
.
dtype
,
device
=
k
.
device
)
_flash_attn_varlen_backward
(
_flash_attn_varlen_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dkv
[:,
0
],
dkv
[:,
1
],
dout
,
cu_seqlens_q
,
cu_seqlens_k
,
ctx
.
max_seqlen_q
,
ctx
.
max_seqlen_k
,
q
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
rng_state
=
rng_state
k
,
v
,
out
,
softmax_lse
,
dq
,
dkv
[:,
0
],
dkv
[:,
1
],
cu_seqlens_q
,
cu_seqlens_k
,
ctx
.
max_seqlen_q
,
ctx
.
max_seqlen_k
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
rng_state
=
rng_state
,
)
)
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dkv
=
dkv
[...,
:
dout
.
shape
[
-
1
]]
dkv
=
dkv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dkv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
def
forward
(
ctx
,
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
):
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_forward
(
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_forward
(
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
=
causal
,
q
,
return_softmax
=
return_softmax
and
dropout_p
>
0
k
,
v
,
dropout_p
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
)
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
dropout_p
=
dropout_p
...
@@ -234,29 +391,60 @@ class FlashAttnFunc(torch.autograd.Function):
...
@@ -234,29 +391,60 @@ class FlashAttnFunc(torch.autograd.Function):
q
,
k
,
v
,
out
,
softmax_lse
,
rng_state
=
ctx
.
saved_tensors
q
,
k
,
v
,
out
,
softmax_lse
,
rng_state
=
ctx
.
saved_tensors
dq
,
dk
,
dv
=
torch
.
empty_like
(
q
),
torch
.
empty_like
(
k
),
torch
.
empty_like
(
v
)
dq
,
dk
,
dv
=
torch
.
empty_like
(
q
),
torch
.
empty_like
(
k
),
torch
.
empty_like
(
v
)
_flash_attn_backward
(
_flash_attn_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dout
,
dq
,
dk
,
dv
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
q
,
rng_state
=
rng_state
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
rng_state
=
rng_state
,
)
)
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dk
=
dk
[...,
:
dout
.
shape
[
-
1
]]
dk
=
dk
[...,
:
dout
.
shape
[
-
1
]]
dv
=
dv
[...,
:
dout
.
shape
[
-
1
]]
dv
=
dv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
FlashAttnVarlenFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnVarlenFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
def
forward
(
softmax_scale
,
causal
,
return_softmax
):
ctx
,
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
return_softmax
,
):
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
q
.
shape
[
-
1
]
**
(
-
0.5
)
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_varlen_forward
(
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
,
rng_state
=
_flash_attn_varlen_forward
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
,
dropout_p
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
return_softmax
and
dropout_p
>
0
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
return_softmax
and
dropout_p
>
0
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
cu_seqlens_q
,
cu_seqlens_k
,
rng_state
)
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
out_padded
,
softmax_lse
,
cu_seqlens_q
,
cu_seqlens_k
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
dropout_p
=
dropout_p
ctx
.
max_seqlen_q
=
max_seqlen_q
ctx
.
max_seqlen_q
=
max_seqlen_q
ctx
.
max_seqlen_k
=
max_seqlen_k
ctx
.
max_seqlen_k
=
max_seqlen_k
...
@@ -269,18 +457,33 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
...
@@ -269,18 +457,33 @@ class FlashAttnVarlenFunc(torch.autograd.Function):
q
,
k
,
v
,
out
,
softmax_lse
,
cu_seqlens_q
,
cu_seqlens_k
,
rng_state
=
ctx
.
saved_tensors
q
,
k
,
v
,
out
,
softmax_lse
,
cu_seqlens_q
,
cu_seqlens_k
,
rng_state
=
ctx
.
saved_tensors
dq
,
dk
,
dv
=
torch
.
empty_like
(
q
),
torch
.
empty_like
(
k
),
torch
.
empty_like
(
v
)
dq
,
dk
,
dv
=
torch
.
empty_like
(
q
),
torch
.
empty_like
(
k
),
torch
.
empty_like
(
v
)
_flash_attn_varlen_backward
(
_flash_attn_varlen_backward
(
dout
,
q
,
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
cu_seqlens_q
,
cu_seqlens_k
,
dout
,
ctx
.
max_seqlen_q
,
ctx
.
max_seqlen_k
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
q
,
rng_state
=
rng_state
k
,
v
,
out
,
softmax_lse
,
dq
,
dk
,
dv
,
cu_seqlens_q
,
cu_seqlens_k
,
ctx
.
max_seqlen_q
,
ctx
.
max_seqlen_k
,
ctx
.
dropout_p
,
ctx
.
softmax_scale
,
ctx
.
causal
,
rng_state
=
rng_state
,
)
)
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dq
=
dq
[...,
:
dout
.
shape
[
-
1
]]
# We could have padded the head dimension
dk
=
dk
[...,
:
dout
.
shape
[
-
1
]]
dk
=
dk
[...,
:
dout
.
shape
[
-
1
]]
dv
=
dv
[...,
:
dout
.
shape
[
-
1
]]
dv
=
dv
[...,
:
dout
.
shape
[
-
1
]]
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
dq
,
dk
,
dv
,
None
,
None
,
None
,
None
,
None
,
None
,
None
,
None
def
flash_attn_qkvpacked_func
(
qkv
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
def
flash_attn_qkvpacked_func
(
return_attn_probs
=
False
):
qkv
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
):
"""dropout_p should be set to 0.0 during evaluation
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
...
@@ -309,8 +512,9 @@ def flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=Fal
...
@@ -309,8 +512,9 @@ def flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=Fal
return
FlashAttnQKVPackedFunc
.
apply
(
qkv
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
return
FlashAttnQKVPackedFunc
.
apply
(
qkv
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
def
flash_attn_kvpacked_func
(
q
,
kv
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
def
flash_attn_kvpacked_func
(
return_attn_probs
=
False
):
q
,
kv
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
):
"""dropout_p should be set to 0.0 during evaluation
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
If K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
...
@@ -342,8 +546,9 @@ def flash_attn_kvpacked_func(q, kv, dropout_p=0.0, softmax_scale=None, causal=Fa
...
@@ -342,8 +546,9 @@ def flash_attn_kvpacked_func(q, kv, dropout_p=0.0, softmax_scale=None, causal=Fa
return
FlashAttnKVPackedFunc
.
apply
(
q
,
kv
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
return
FlashAttnKVPackedFunc
.
apply
(
q
,
kv
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
def
flash_attn_func
(
q
,
k
,
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
def
flash_attn_func
(
return_attn_probs
=
False
):
q
,
k
,
v
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
):
"""dropout_p should be set to 0.0 during evaluation
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
...
@@ -373,8 +578,15 @@ def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
...
@@ -373,8 +578,15 @@ def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False,
return
FlashAttnFunc
.
apply
(
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
return
FlashAttnFunc
.
apply
(
q
,
k
,
v
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
)
def
flash_attn_varlen_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
def
flash_attn_varlen_qkvpacked_func
(
causal
=
False
,
return_attn_probs
=
False
):
qkv
,
cu_seqlens
,
max_seqlen
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
):
"""dropout_p should be set to 0.0 during evaluation
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
...
@@ -408,9 +620,18 @@ def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0,
...
@@ -408,9 +620,18 @@ def flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens, max_seqlen, dropout_p=0.0,
)
)
def
flash_attn_varlen_kvpacked_func
(
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
def
flash_attn_varlen_kvpacked_func
(
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
q
,
return_attn_probs
=
False
):
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
):
"""dropout_p should be set to 0.0 during evaluation
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
If K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
...
@@ -446,14 +667,32 @@ def flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqle
...
@@ -446,14 +667,32 @@ def flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_k, max_seqle
pattern (negative means that location was dropped, nonnegative means it was kept).
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
"""
return
FlashAttnVarlenKVPackedFunc
.
apply
(
return
FlashAttnVarlenKVPackedFunc
.
apply
(
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
,
)
)
def
flash_attn_varlen_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
def
flash_attn_varlen_func
(
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
q
,
return_attn_probs
=
False
):
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
):
"""dropout_p should be set to 0.0 during evaluation
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
...
@@ -487,6 +726,15 @@ def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, ma
...
@@ -487,6 +726,15 @@ def flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, ma
pattern (negative means that location was dropped, nonnegative means it was kept).
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
"""
return
FlashAttnVarlenFunc
.
apply
(
return
FlashAttnVarlenFunc
.
apply
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
q
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
max_seqlen_q
,
max_seqlen_k
,
dropout_p
,
softmax_scale
,
causal
,
return_attn_probs
,
)
)
flash_attn/flash_attn_triton.py
View file @
f1a73d07
...
@@ -42,7 +42,6 @@ than CUDA forward + backward.
...
@@ -42,7 +42,6 @@ than CUDA forward + backward.
import
math
import
math
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
...
@@ -65,21 +64,44 @@ import triton.language as tl
...
@@ -65,21 +64,44 @@ import triton.language as tl
)
)
@
triton
.
jit
@
triton
.
jit
def
_fwd_kernel
(
def
_fwd_kernel
(
Q
,
K
,
V
,
Bias
,
Out
,
Q
,
Lse
,
TMP
,
# NOTE: TMP is a scratchpad buffer to workaround a compiler bug
K
,
V
,
Bias
,
Out
,
Lse
,
TMP
,
# NOTE: TMP is a scratchpad buffer to workaround a compiler bug
softmax_scale
,
softmax_scale
,
stride_qb
,
stride_qh
,
stride_qm
,
stride_qb
,
stride_kb
,
stride_kh
,
stride_kn
,
stride_qh
,
stride_vb
,
stride_vh
,
stride_vn
,
stride_qm
,
stride_bb
,
stride_bh
,
stride_bm
,
stride_kb
,
stride_ob
,
stride_oh
,
stride_om
,
stride_kh
,
nheads
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
headdim
,
stride_kn
,
CACHE_KEY_SEQLEN_Q
,
CACHE_KEY_SEQLEN_K
,
stride_vb
,
stride_vh
,
stride_vn
,
stride_bb
,
stride_bh
,
stride_bm
,
stride_ob
,
stride_oh
,
stride_om
,
nheads
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
headdim
,
CACHE_KEY_SEQLEN_Q
,
CACHE_KEY_SEQLEN_K
,
BIAS_TYPE
:
tl
.
constexpr
,
BIAS_TYPE
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
BLOCK_HEADDIM
:
tl
.
constexpr
,
BLOCK_HEADDIM
:
tl
.
constexpr
,
EVEN_M
:
tl
.
constexpr
,
EVEN_N
:
tl
.
constexpr
,
EVEN_HEADDIM
:
tl
.
constexpr
,
EVEN_M
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
EVEN_N
:
tl
.
constexpr
,
EVEN_HEADDIM
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
):
start_m
=
tl
.
program_id
(
0
)
start_m
=
tl
.
program_id
(
0
)
off_hb
=
tl
.
program_id
(
1
)
off_hb
=
tl
.
program_id
(
1
)
...
@@ -96,13 +118,24 @@ def _fwd_kernel(
...
@@ -96,13 +118,24 @@ def _fwd_kernel(
# Adding parenthesis around indexing might use int32 math instead of int64 math?
# Adding parenthesis around indexing might use int32 math instead of int64 math?
# https://github.com/openai/triton/issues/741
# https://github.com/openai/triton/issues/741
# I'm seeing a tiny bit of difference (5-7us)
# I'm seeing a tiny bit of difference (5-7us)
q_ptrs
=
Q
+
off_b
*
stride_qb
+
off_h
*
stride_qh
+
(
offs_m
[:,
None
]
*
stride_qm
+
offs_d
[
None
,
:])
q_ptrs
=
(
k_ptrs
=
K
+
off_b
*
stride_kb
+
off_h
*
stride_kh
+
(
offs_n
[:,
None
]
*
stride_kn
+
offs_d
[
None
,
:])
Q
+
off_b
*
stride_qb
+
off_h
*
stride_qh
+
(
offs_m
[:,
None
]
*
stride_qm
+
offs_d
[
None
,
:])
v_ptrs
=
V
+
off_b
*
stride_vb
+
off_h
*
stride_vh
+
(
offs_n
[:,
None
]
*
stride_vn
+
offs_d
[
None
,
:])
)
if
BIAS_TYPE
==
'vector'
:
k_ptrs
=
(
K
+
off_b
*
stride_kb
+
off_h
*
stride_kh
+
(
offs_n
[:,
None
]
*
stride_kn
+
offs_d
[
None
,
:])
)
v_ptrs
=
(
V
+
off_b
*
stride_vb
+
off_h
*
stride_vh
+
(
offs_n
[:,
None
]
*
stride_vn
+
offs_d
[
None
,
:])
)
if
BIAS_TYPE
==
"vector"
:
b_ptrs
=
Bias
+
off_b
*
stride_bb
+
off_h
*
stride_bh
+
offs_n
b_ptrs
=
Bias
+
off_b
*
stride_bb
+
off_h
*
stride_bh
+
offs_n
elif
BIAS_TYPE
==
'matrix'
:
elif
BIAS_TYPE
==
"matrix"
:
b_ptrs
=
Bias
+
off_b
*
stride_bb
+
off_h
*
stride_bh
+
(
offs_m
[:,
None
]
*
stride_bm
+
offs_n
[
None
,
:])
b_ptrs
=
(
Bias
+
off_b
*
stride_bb
+
off_h
*
stride_bh
+
(
offs_m
[:,
None
]
*
stride_bm
+
offs_n
[
None
,
:])
)
# initialize pointer to m and l
# initialize pointer to m and l
t_ptrs
=
TMP
+
off_hb
*
seqlen_q_rounded
+
offs_m
t_ptrs
=
TMP
+
off_hb
*
seqlen_q_rounded
+
offs_m
lse_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
lse_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
...
@@ -120,8 +153,9 @@ def _fwd_kernel(
...
@@ -120,8 +153,9 @@ def _fwd_kernel(
if
EVEN_HEADDIM
:
if
EVEN_HEADDIM
:
q
=
tl
.
load
(
q_ptrs
,
mask
=
offs_m
[:,
None
]
<
seqlen_q
,
other
=
0.0
)
q
=
tl
.
load
(
q_ptrs
,
mask
=
offs_m
[:,
None
]
<
seqlen_q
,
other
=
0.0
)
else
:
else
:
q
=
tl
.
load
(
q_ptrs
,
mask
=
(
offs_m
[:,
None
]
<
seqlen_q
)
&
(
offs_d
[
None
,
:]
<
headdim
),
q
=
tl
.
load
(
other
=
0.0
)
q_ptrs
,
mask
=
(
offs_m
[:,
None
]
<
seqlen_q
)
&
(
offs_d
[
None
,
:]
<
headdim
),
other
=
0.0
)
# loop over k, v and update accumulator
# loop over k, v and update accumulator
end_n
=
seqlen_k
if
not
IS_CAUSAL
else
tl
.
minimum
((
start_m
+
1
)
*
BLOCK_M
,
seqlen_k
)
end_n
=
seqlen_k
if
not
IS_CAUSAL
else
tl
.
minimum
((
start_m
+
1
)
*
BLOCK_M
,
seqlen_k
)
for
start_n
in
range
(
0
,
end_n
,
BLOCK_N
):
for
start_n
in
range
(
0
,
end_n
,
BLOCK_N
):
...
@@ -134,12 +168,17 @@ def _fwd_kernel(
...
@@ -134,12 +168,17 @@ def _fwd_kernel(
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kn
,
mask
=
offs_d
[
None
,
:]
<
headdim
,
other
=
0.0
)
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kn
,
mask
=
offs_d
[
None
,
:]
<
headdim
,
other
=
0.0
)
else
:
else
:
if
EVEN_HEADDIM
:
if
EVEN_HEADDIM
:
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kn
,
mask
=
(
start_n
+
offs_n
)[:,
None
]
<
seqlen_k
,
k
=
tl
.
load
(
other
=
0.0
)
k_ptrs
+
start_n
*
stride_kn
,
mask
=
(
start_n
+
offs_n
)[:,
None
]
<
seqlen_k
,
other
=
0.0
,
)
else
:
else
:
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kn
,
k
=
tl
.
load
(
k_ptrs
+
start_n
*
stride_kn
,
mask
=
((
start_n
+
offs_n
)[:,
None
]
<
seqlen_k
)
&
(
offs_d
[
None
,
:]
<
headdim
),
mask
=
((
start_n
+
offs_n
)[:,
None
]
<
seqlen_k
)
&
(
offs_d
[
None
,
:]
<
headdim
),
other
=
0.0
)
other
=
0.0
,
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
,
trans_b
=
True
)
qk
+=
tl
.
dot
(
q
,
k
,
trans_b
=
True
)
# Trying to combine the two masks seem to make the result wrong
# Trying to combine the two masks seem to make the result wrong
...
@@ -147,21 +186,25 @@ def _fwd_kernel(
...
@@ -147,21 +186,25 @@ def _fwd_kernel(
qk
+=
tl
.
where
((
start_n
+
offs_n
)[
None
,
:]
<
seqlen_k
,
0
,
float
(
"-inf"
))
qk
+=
tl
.
where
((
start_n
+
offs_n
)[
None
,
:]
<
seqlen_k
,
0
,
float
(
"-inf"
))
if
IS_CAUSAL
:
if
IS_CAUSAL
:
qk
+=
tl
.
where
(
offs_m
[:,
None
]
>=
(
start_n
+
offs_n
)[
None
,
:],
0
,
float
(
"-inf"
))
qk
+=
tl
.
where
(
offs_m
[:,
None
]
>=
(
start_n
+
offs_n
)[
None
,
:],
0
,
float
(
"-inf"
))
if
BIAS_TYPE
!=
'
none
'
:
if
BIAS_TYPE
!=
"
none
"
:
if
BIAS_TYPE
==
'
vector
'
:
if
BIAS_TYPE
==
"
vector
"
:
if
EVEN_N
:
if
EVEN_N
:
bias
=
tl
.
load
(
b_ptrs
+
start_n
).
to
(
tl
.
float32
)
bias
=
tl
.
load
(
b_ptrs
+
start_n
).
to
(
tl
.
float32
)
else
:
else
:
bias
=
tl
.
load
(
b_ptrs
+
start_n
,
mask
=
(
start_n
+
offs_n
)
<
seqlen_k
,
other
=
0.0
).
to
(
tl
.
float32
)
bias
=
tl
.
load
(
b_ptrs
+
start_n
,
mask
=
(
start_n
+
offs_n
)
<
seqlen_k
,
other
=
0.0
).
to
(
tl
.
float32
)
bias
=
bias
[
None
,
:]
bias
=
bias
[
None
,
:]
elif
BIAS_TYPE
==
'
matrix
'
:
elif
BIAS_TYPE
==
"
matrix
"
:
if
EVEN_M
&
EVEN_N
:
if
EVEN_M
&
EVEN_N
:
bias
=
tl
.
load
(
b_ptrs
+
start_n
).
to
(
tl
.
float32
)
bias
=
tl
.
load
(
b_ptrs
+
start_n
).
to
(
tl
.
float32
)
else
:
else
:
bias
=
tl
.
load
(
b_ptrs
+
start_n
,
bias
=
tl
.
load
(
b_ptrs
+
start_n
,
mask
=
(
offs_m
[:,
None
]
<
seqlen_q
)
mask
=
(
offs_m
[:,
None
]
<
seqlen_q
)
&
((
start_n
+
offs_n
)[
None
,
:]
<
seqlen_k
),
&
((
start_n
+
offs_n
)[
None
,
:]
<
seqlen_k
),
other
=
0.0
).
to
(
tl
.
float32
)
other
=
0.0
,
).
to
(
tl
.
float32
)
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
# to multiply with softmax_scale here.
# to multiply with softmax_scale here.
...
@@ -189,12 +232,17 @@ def _fwd_kernel(
...
@@ -189,12 +232,17 @@ def _fwd_kernel(
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vn
,
mask
=
offs_d
[
None
,
:]
<
headdim
,
other
=
0.0
)
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vn
,
mask
=
offs_d
[
None
,
:]
<
headdim
,
other
=
0.0
)
else
:
else
:
if
EVEN_HEADDIM
:
if
EVEN_HEADDIM
:
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vn
,
mask
=
(
start_n
+
offs_n
)[:,
None
]
<
seqlen_k
,
v
=
tl
.
load
(
other
=
0.0
)
v_ptrs
+
start_n
*
stride_vn
,
mask
=
(
start_n
+
offs_n
)[:,
None
]
<
seqlen_k
,
other
=
0.0
,
)
else
:
else
:
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vn
,
v
=
tl
.
load
(
v_ptrs
+
start_n
*
stride_vn
,
mask
=
((
start_n
+
offs_n
)[:,
None
]
<
seqlen_k
)
&
(
offs_d
[
None
,
:]
<
headdim
),
mask
=
((
start_n
+
offs_n
)[:,
None
]
<
seqlen_k
)
&
(
offs_d
[
None
,
:]
<
headdim
),
other
=
0.0
)
other
=
0.0
,
)
p
=
p
.
to
(
v
.
dtype
)
p
=
p
.
to
(
v
.
dtype
)
acc_o
+=
tl
.
dot
(
p
,
v
)
acc_o
+=
tl
.
dot
(
p
,
v
)
...
@@ -216,7 +264,12 @@ def _fwd_kernel(
...
@@ -216,7 +264,12 @@ def _fwd_kernel(
tl
.
store
(
lse_ptrs
,
lse_i
)
tl
.
store
(
lse_ptrs
,
lse_i
)
# initialize pointers to output
# initialize pointers to output
offs_d
=
tl
.
arange
(
0
,
BLOCK_HEADDIM
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_HEADDIM
)
out_ptrs
=
Out
+
off_b
*
stride_ob
+
off_h
*
stride_oh
+
(
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
None
,
:])
out_ptrs
=
(
Out
+
off_b
*
stride_ob
+
off_h
*
stride_oh
+
(
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
None
,
:])
)
if
EVEN_M
:
if
EVEN_M
:
if
EVEN_HEADDIM
:
if
EVEN_HEADDIM
:
tl
.
store
(
out_ptrs
,
acc_o
)
tl
.
store
(
out_ptrs
,
acc_o
)
...
@@ -226,17 +279,28 @@ def _fwd_kernel(
...
@@ -226,17 +279,28 @@ def _fwd_kernel(
if
EVEN_HEADDIM
:
if
EVEN_HEADDIM
:
tl
.
store
(
out_ptrs
,
acc_o
,
mask
=
offs_m
[:,
None
]
<
seqlen_q
)
tl
.
store
(
out_ptrs
,
acc_o
,
mask
=
offs_m
[:,
None
]
<
seqlen_q
)
else
:
else
:
tl
.
store
(
out_ptrs
,
acc_o
,
tl
.
store
(
mask
=
(
offs_m
[:,
None
]
<
seqlen_q
)
&
(
offs_d
[
None
,
:]
<
headdim
))
out_ptrs
,
acc_o
,
mask
=
(
offs_m
[:,
None
]
<
seqlen_q
)
&
(
offs_d
[
None
,
:]
<
headdim
)
)
@
triton
.
jit
@
triton
.
jit
def
_bwd_preprocess_do_o_dot
(
def
_bwd_preprocess_do_o_dot
(
Out
,
DO
,
Delta
,
Out
,
stride_ob
,
stride_oh
,
stride_om
,
DO
,
stride_dob
,
stride_doh
,
stride_dom
,
Delta
,
nheads
,
seqlen_q
,
seqlen_q_rounded
,
headdim
,
stride_ob
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_HEADDIM
:
tl
.
constexpr
,
stride_oh
,
stride_om
,
stride_dob
,
stride_doh
,
stride_dom
,
nheads
,
seqlen_q
,
seqlen_q_rounded
,
headdim
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_HEADDIM
:
tl
.
constexpr
,
):
):
start_m
=
tl
.
program_id
(
0
)
start_m
=
tl
.
program_id
(
0
)
off_hb
=
tl
.
program_id
(
1
)
off_hb
=
tl
.
program_id
(
1
)
...
@@ -246,10 +310,20 @@ def _bwd_preprocess_do_o_dot(
...
@@ -246,10 +310,20 @@ def _bwd_preprocess_do_o_dot(
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_HEADDIM
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_HEADDIM
)
# load
# load
o
=
tl
.
load
(
Out
+
off_b
*
stride_ob
+
off_h
*
stride_oh
+
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
None
,
:],
o
=
tl
.
load
(
mask
=
(
offs_m
[:,
None
]
<
seqlen_q
)
&
(
offs_d
[
None
,
:]
<
headdim
),
other
=
0.0
).
to
(
tl
.
float32
)
Out
+
off_b
*
stride_ob
+
off_h
*
stride_oh
+
offs_m
[:,
None
]
*
stride_om
+
offs_d
[
None
,
:],
do
=
tl
.
load
(
DO
+
off_b
*
stride_dob
+
off_h
*
stride_doh
+
offs_m
[:,
None
]
*
stride_dom
+
offs_d
[
None
,
:],
mask
=
(
offs_m
[:,
None
]
<
seqlen_q
)
&
(
offs_d
[
None
,
:]
<
headdim
),
mask
=
(
offs_m
[:,
None
]
<
seqlen_q
)
&
(
offs_d
[
None
,
:]
<
headdim
),
other
=
0.0
).
to
(
tl
.
float32
)
other
=
0.0
,
).
to
(
tl
.
float32
)
do
=
tl
.
load
(
DO
+
off_b
*
stride_dob
+
off_h
*
stride_doh
+
offs_m
[:,
None
]
*
stride_dom
+
offs_d
[
None
,
:],
mask
=
(
offs_m
[:,
None
]
<
seqlen_q
)
&
(
offs_d
[
None
,
:]
<
headdim
),
other
=
0.0
,
).
to
(
tl
.
float32
)
delta
=
tl
.
sum
(
o
*
do
,
axis
=
1
)
delta
=
tl
.
sum
(
o
*
do
,
axis
=
1
)
# write-back
# write-back
tl
.
store
(
Delta
+
off_hb
*
seqlen_q_rounded
+
offs_m
,
delta
)
tl
.
store
(
Delta
+
off_hb
*
seqlen_q_rounded
+
offs_m
,
delta
)
...
@@ -257,8 +331,17 @@ def _bwd_preprocess_do_o_dot(
...
@@ -257,8 +331,17 @@ def _bwd_preprocess_do_o_dot(
@
triton
.
jit
@
triton
.
jit
def
_bwd_store_dk_dv
(
def
_bwd_store_dk_dv
(
dk_ptrs
,
dv_ptrs
,
dk
,
dv
,
offs_n
,
offs_d
,
seqlen_k
,
headdim
,
dk_ptrs
,
EVEN_M
:
tl
.
constexpr
,
EVEN_N
:
tl
.
constexpr
,
EVEN_HEADDIM
:
tl
.
constexpr
,
dv_ptrs
,
dk
,
dv
,
offs_n
,
offs_d
,
seqlen_k
,
headdim
,
EVEN_M
:
tl
.
constexpr
,
EVEN_N
:
tl
.
constexpr
,
EVEN_HEADDIM
:
tl
.
constexpr
,
):
):
# [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
# [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False,
# if we just call tl.store(dv_ptrs), there's a race condition
# if we just call tl.store(dv_ptrs), there's a race condition
...
@@ -281,19 +364,37 @@ def _bwd_store_dk_dv(
...
@@ -281,19 +364,37 @@ def _bwd_store_dk_dv(
@
triton
.
jit
@
triton
.
jit
def
_bwd_kernel_one_col_block
(
def
_bwd_kernel_one_col_block
(
start_n
,
start_n
,
Q
,
K
,
V
,
Bias
,
Q
,
DO
,
DQ
,
DK
,
DV
,
K
,
LSE
,
D
,
V
,
Bias
,
DO
,
DQ
,
DK
,
DV
,
LSE
,
D
,
softmax_scale
,
softmax_scale
,
stride_qm
,
stride_kn
,
stride_vn
,
stride_bm
,
stride_qm
,
stride_dom
,
stride_dqm
,
stride_dkn
,
stride_dvn
,
stride_kn
,
seqlen_q
,
seqlen_k
,
headdim
,
stride_vn
,
stride_bm
,
stride_dom
,
stride_dqm
,
stride_dkn
,
stride_dvn
,
seqlen_q
,
seqlen_k
,
headdim
,
ATOMIC_ADD
:
tl
.
constexpr
,
ATOMIC_ADD
:
tl
.
constexpr
,
BIAS_TYPE
:
tl
.
constexpr
,
BIAS_TYPE
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
BLOCK_HEADDIM
:
tl
.
constexpr
,
BLOCK_HEADDIM
:
tl
.
constexpr
,
EVEN_M
:
tl
.
constexpr
,
EVEN_N
:
tl
.
constexpr
,
EVEN_HEADDIM
:
tl
.
constexpr
,
EVEN_M
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
EVEN_N
:
tl
.
constexpr
,
EVEN_HEADDIM
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
):
# We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
# We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N)
begin_m
=
0
if
not
IS_CAUSAL
else
((
start_n
*
BLOCK_N
)
//
BLOCK_M
)
*
BLOCK_M
begin_m
=
0
if
not
IS_CAUSAL
else
((
start_n
*
BLOCK_N
)
//
BLOCK_M
)
*
BLOCK_M
...
@@ -308,9 +409,9 @@ def _bwd_kernel_one_col_block(
...
@@ -308,9 +409,9 @@ def _bwd_kernel_one_col_block(
v_ptrs
=
V
+
(
offs_n
[:,
None
]
*
stride_vn
+
offs_d
[
None
,
:])
v_ptrs
=
V
+
(
offs_n
[:,
None
]
*
stride_vn
+
offs_d
[
None
,
:])
do_ptrs
=
DO
+
(
offs_qm
[:,
None
]
*
stride_dom
+
offs_d
[
None
,
:])
do_ptrs
=
DO
+
(
offs_qm
[:,
None
]
*
stride_dom
+
offs_d
[
None
,
:])
dq_ptrs
=
DQ
+
(
offs_qm
[:,
None
]
*
stride_dqm
+
offs_d
[
None
,
:])
dq_ptrs
=
DQ
+
(
offs_qm
[:,
None
]
*
stride_dqm
+
offs_d
[
None
,
:])
if
BIAS_TYPE
==
'
vector
'
:
if
BIAS_TYPE
==
"
vector
"
:
b_ptrs
=
Bias
+
offs_n
b_ptrs
=
Bias
+
offs_n
elif
BIAS_TYPE
==
'
matrix
'
:
elif
BIAS_TYPE
==
"
matrix
"
:
b_ptrs
=
Bias
+
(
offs_qm
[:,
None
]
*
stride_bm
+
offs_n
[
None
,
:])
b_ptrs
=
Bias
+
(
offs_qm
[:,
None
]
*
stride_bm
+
offs_n
[
None
,
:])
# initialize dv and dk
# initialize dv and dk
dv
=
tl
.
zeros
([
BLOCK_N
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
dv
=
tl
.
zeros
([
BLOCK_N
,
BLOCK_HEADDIM
],
dtype
=
tl
.
float32
)
...
@@ -322,8 +423,19 @@ def _bwd_kernel_one_col_block(
...
@@ -322,8 +423,19 @@ def _bwd_kernel_one_col_block(
if
begin_m
>=
seqlen_q
:
if
begin_m
>=
seqlen_q
:
dv_ptrs
=
DV
+
(
offs_n
[:,
None
]
*
stride_dvn
+
offs_d
[
None
,
:])
dv_ptrs
=
DV
+
(
offs_n
[:,
None
]
*
stride_dvn
+
offs_d
[
None
,
:])
dk_ptrs
=
DK
+
(
offs_n
[:,
None
]
*
stride_dkn
+
offs_d
[
None
,
:])
dk_ptrs
=
DK
+
(
offs_n
[:,
None
]
*
stride_dkn
+
offs_d
[
None
,
:])
_bwd_store_dk_dv
(
dk_ptrs
,
dv_ptrs
,
dk
,
dv
,
offs_n
,
offs_d
,
seqlen_k
,
headdim
,
_bwd_store_dk_dv
(
EVEN_M
=
EVEN_M
,
EVEN_N
=
EVEN_N
,
EVEN_HEADDIM
=
EVEN_HEADDIM
)
dk_ptrs
,
dv_ptrs
,
dk
,
dv
,
offs_n
,
offs_d
,
seqlen_k
,
headdim
,
EVEN_M
=
EVEN_M
,
EVEN_N
=
EVEN_N
,
EVEN_HEADDIM
=
EVEN_HEADDIM
,
)
return
return
# k and v stay in SRAM throughout
# k and v stay in SRAM throughout
# [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
# [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False,
...
@@ -340,10 +452,12 @@ def _bwd_kernel_one_col_block(
...
@@ -340,10 +452,12 @@ def _bwd_kernel_one_col_block(
k
=
tl
.
load
(
k_ptrs
,
mask
=
offs_n
[:,
None
]
<
seqlen_k
,
other
=
0.0
)
k
=
tl
.
load
(
k_ptrs
,
mask
=
offs_n
[:,
None
]
<
seqlen_k
,
other
=
0.0
)
v
=
tl
.
load
(
v_ptrs
,
mask
=
offs_n
[:,
None
]
<
seqlen_k
,
other
=
0.0
)
v
=
tl
.
load
(
v_ptrs
,
mask
=
offs_n
[:,
None
]
<
seqlen_k
,
other
=
0.0
)
else
:
else
:
k
=
tl
.
load
(
k_ptrs
,
mask
=
(
offs_n
[:,
None
]
<
seqlen_k
)
&
(
offs_d
[
None
,
:]
<
headdim
),
k
=
tl
.
load
(
other
=
0.0
)
k_ptrs
,
mask
=
(
offs_n
[:,
None
]
<
seqlen_k
)
&
(
offs_d
[
None
,
:]
<
headdim
),
other
=
0.0
v
=
tl
.
load
(
v_ptrs
,
mask
=
(
offs_n
[:,
None
]
<
seqlen_k
)
&
(
offs_d
[
None
,
:]
<
headdim
),
)
other
=
0.0
)
v
=
tl
.
load
(
v_ptrs
,
mask
=
(
offs_n
[:,
None
]
<
seqlen_k
)
&
(
offs_d
[
None
,
:]
<
headdim
),
other
=
0.0
)
# loop over rows
# loop over rows
num_block_m
=
tl
.
cdiv
(
seqlen_q
,
BLOCK_M
)
num_block_m
=
tl
.
cdiv
(
seqlen_q
,
BLOCK_M
)
for
start_m
in
range
(
begin_m
,
num_block_m
*
BLOCK_M
,
BLOCK_M
):
for
start_m
in
range
(
begin_m
,
num_block_m
*
BLOCK_M
,
BLOCK_M
):
...
@@ -357,8 +471,11 @@ def _bwd_kernel_one_col_block(
...
@@ -357,8 +471,11 @@ def _bwd_kernel_one_col_block(
if
EVEN_HEADDIM
:
if
EVEN_HEADDIM
:
q
=
tl
.
load
(
q_ptrs
,
mask
=
offs_m_curr
[:,
None
]
<
seqlen_q
,
other
=
0.0
)
q
=
tl
.
load
(
q_ptrs
,
mask
=
offs_m_curr
[:,
None
]
<
seqlen_q
,
other
=
0.0
)
else
:
else
:
q
=
tl
.
load
(
q_ptrs
,
mask
=
(
offs_m_curr
[:,
None
]
<
seqlen_q
)
q
=
tl
.
load
(
&
(
offs_d
[
None
,
:]
<
headdim
),
other
=
0.0
)
q_ptrs
,
mask
=
(
offs_m_curr
[:,
None
]
<
seqlen_q
)
&
(
offs_d
[
None
,
:]
<
headdim
),
other
=
0.0
,
)
# recompute p = softmax(qk, dim=-1).T
# recompute p = softmax(qk, dim=-1).T
qk
=
tl
.
dot
(
q
,
k
,
trans_b
=
True
)
qk
=
tl
.
dot
(
q
,
k
,
trans_b
=
True
)
# Trying to combine the two masks seem to make the result wrong
# Trying to combine the two masks seem to make the result wrong
...
@@ -366,29 +483,30 @@ def _bwd_kernel_one_col_block(
...
@@ -366,29 +483,30 @@ def _bwd_kernel_one_col_block(
qk
=
tl
.
where
(
offs_n
[
None
,
:]
<
seqlen_k
,
qk
,
float
(
"-inf"
))
qk
=
tl
.
where
(
offs_n
[
None
,
:]
<
seqlen_k
,
qk
,
float
(
"-inf"
))
if
IS_CAUSAL
:
if
IS_CAUSAL
:
qk
=
tl
.
where
(
offs_m_curr
[:,
None
]
>=
(
offs_n
[
None
,
:]),
qk
,
float
(
"-inf"
))
qk
=
tl
.
where
(
offs_m_curr
[:,
None
]
>=
(
offs_n
[
None
,
:]),
qk
,
float
(
"-inf"
))
if
BIAS_TYPE
!=
'
none
'
:
if
BIAS_TYPE
!=
"
none
"
:
tl
.
debug_barrier
()
# Race condition otherwise
tl
.
debug_barrier
()
# Race condition otherwise
if
BIAS_TYPE
==
'
vector
'
:
if
BIAS_TYPE
==
"
vector
"
:
if
EVEN_N
:
if
EVEN_N
:
bias
=
tl
.
load
(
b_ptrs
).
to
(
tl
.
float32
)
bias
=
tl
.
load
(
b_ptrs
).
to
(
tl
.
float32
)
else
:
else
:
bias
=
tl
.
load
(
b_ptrs
,
mask
=
offs_n
<
seqlen_k
,
other
=
0.0
).
to
(
tl
.
float32
)
bias
=
tl
.
load
(
b_ptrs
,
mask
=
offs_n
<
seqlen_k
,
other
=
0.0
).
to
(
tl
.
float32
)
bias
=
bias
[
None
,
:]
bias
=
bias
[
None
,
:]
elif
BIAS_TYPE
==
'
matrix
'
:
elif
BIAS_TYPE
==
"
matrix
"
:
if
EVEN_M
&
EVEN_N
:
if
EVEN_M
&
EVEN_N
:
bias
=
tl
.
load
(
b_ptrs
).
to
(
tl
.
float32
)
bias
=
tl
.
load
(
b_ptrs
).
to
(
tl
.
float32
)
else
:
else
:
bias
=
tl
.
load
(
b_ptrs
,
bias
=
tl
.
load
(
mask
=
(
offs_m_curr
[:,
None
]
<
seqlen_q
)
b_ptrs
,
&
(
offs_n
[
None
,
:]
<
seqlen_k
),
mask
=
(
offs_m_curr
[:,
None
]
<
seqlen_q
)
&
(
offs_n
[
None
,
:]
<
seqlen_k
),
other
=
0.0
).
to
(
tl
.
float32
)
other
=
0.0
,
).
to
(
tl
.
float32
)
qk
=
qk
*
softmax_scale
+
bias
qk
=
qk
*
softmax_scale
+
bias
# There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
# There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong.
# Also wrong for headdim=64.
# Also wrong for headdim=64.
if
not
(
EVEN_M
&
EVEN_HEADDIM
):
if
not
(
EVEN_M
&
EVEN_HEADDIM
):
tl
.
debug_barrier
()
tl
.
debug_barrier
()
lse_i
=
tl
.
load
(
LSE
+
offs_m_curr
)
lse_i
=
tl
.
load
(
LSE
+
offs_m_curr
)
if
BIAS_TYPE
==
'
none
'
:
if
BIAS_TYPE
==
"
none
"
:
p
=
tl
.
exp
(
qk
*
softmax_scale
-
lse_i
[:,
None
])
p
=
tl
.
exp
(
qk
*
softmax_scale
-
lse_i
[:,
None
])
else
:
else
:
p
=
tl
.
exp
(
qk
-
lse_i
[:,
None
])
p
=
tl
.
exp
(
qk
-
lse_i
[:,
None
])
...
@@ -401,8 +519,11 @@ def _bwd_kernel_one_col_block(
...
@@ -401,8 +519,11 @@ def _bwd_kernel_one_col_block(
do
=
tl
.
load
(
do_ptrs
)
do
=
tl
.
load
(
do_ptrs
)
else
:
else
:
# [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
# [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask.
do
=
tl
.
load
(
do_ptrs
,
mask
=
(
offs_m_curr
[:,
None
]
<
seqlen_q
)
do
=
tl
.
load
(
&
(
offs_d
[
None
,
:]
<
headdim
),
other
=
0.0
)
do_ptrs
,
mask
=
(
offs_m_curr
[:,
None
]
<
seqlen_q
)
&
(
offs_d
[
None
,
:]
<
headdim
),
other
=
0.0
,
)
# if EVEN_M:
# if EVEN_M:
# if EVEN_HEADDIM:
# if EVEN_HEADDIM:
# do = tl.load(do_ptrs)
# do = tl.load(do_ptrs)
...
@@ -434,7 +555,9 @@ def _bwd_kernel_one_col_block(
...
@@ -434,7 +555,9 @@ def _bwd_kernel_one_col_block(
# compute dk = dot(ds.T, q)
# compute dk = dot(ds.T, q)
dk
+=
tl
.
dot
(
ds
,
q
,
trans_a
=
True
)
dk
+=
tl
.
dot
(
ds
,
q
,
trans_a
=
True
)
# compute dq
# compute dq
if
not
(
EVEN_M
&
EVEN_HEADDIM
):
# Otherewise there's a race condition when BIAS_TYPE='matrix'
if
not
(
EVEN_M
&
EVEN_HEADDIM
):
# Otherewise there's a race condition when BIAS_TYPE='matrix'
tl
.
debug_barrier
()
tl
.
debug_barrier
()
if
not
ATOMIC_ADD
:
if
not
ATOMIC_ADD
:
if
EVEN_M
&
EVEN_HEADDIM
:
# Race condition if we just do EVEN_M
if
EVEN_M
&
EVEN_HEADDIM
:
# Race condition if we just do EVEN_M
...
@@ -443,19 +566,33 @@ def _bwd_kernel_one_col_block(
...
@@ -443,19 +566,33 @@ def _bwd_kernel_one_col_block(
tl
.
store
(
dq_ptrs
,
dq
,
eviction_policy
=
"evict_last"
)
tl
.
store
(
dq_ptrs
,
dq
,
eviction_policy
=
"evict_last"
)
else
:
else
:
if
EVEN_HEADDIM
:
if
EVEN_HEADDIM
:
dq
=
tl
.
load
(
dq_ptrs
,
mask
=
offs_m_curr
[:,
None
]
<
seqlen_q
,
other
=
0.0
,
dq
=
tl
.
load
(
eviction_policy
=
"evict_last"
)
dq_ptrs
,
mask
=
offs_m_curr
[:,
None
]
<
seqlen_q
,
other
=
0.0
,
eviction_policy
=
"evict_last"
,
)
dq
+=
tl
.
dot
(
ds
,
k
)
dq
+=
tl
.
dot
(
ds
,
k
)
tl
.
store
(
dq_ptrs
,
dq
,
mask
=
offs_m_curr
[:,
None
]
<
seqlen_q
,
tl
.
store
(
eviction_policy
=
"evict_last"
)
dq_ptrs
,
dq
,
mask
=
offs_m_curr
[:,
None
]
<
seqlen_q
,
eviction_policy
=
"evict_last"
,
)
else
:
else
:
dq
=
tl
.
load
(
dq_ptrs
,
dq
=
tl
.
load
(
dq_ptrs
,
mask
=
(
offs_m_curr
[:,
None
]
<
seqlen_q
)
&
(
offs_d
[
None
,
:]
<
headdim
),
mask
=
(
offs_m_curr
[:,
None
]
<
seqlen_q
)
&
(
offs_d
[
None
,
:]
<
headdim
),
other
=
0.0
,
eviction_policy
=
"evict_last"
)
other
=
0.0
,
eviction_policy
=
"evict_last"
,
)
dq
+=
tl
.
dot
(
ds
,
k
)
dq
+=
tl
.
dot
(
ds
,
k
)
tl
.
store
(
dq_ptrs
,
dq
,
tl
.
store
(
dq_ptrs
,
dq
,
mask
=
(
offs_m_curr
[:,
None
]
<
seqlen_q
)
&
(
offs_d
[
None
,
:]
<
headdim
),
mask
=
(
offs_m_curr
[:,
None
]
<
seqlen_q
)
&
(
offs_d
[
None
,
:]
<
headdim
),
eviction_policy
=
"evict_last"
)
eviction_policy
=
"evict_last"
,
)
else
:
# If we're parallelizing across the seqlen_k dimension
else
:
# If we're parallelizing across the seqlen_k dimension
dq
=
tl
.
dot
(
ds
,
k
)
dq
=
tl
.
dot
(
ds
,
k
)
if
EVEN_M
&
EVEN_HEADDIM
:
# Race condition if we just do EVEN_M
if
EVEN_M
&
EVEN_HEADDIM
:
# Race condition if we just do EVEN_M
...
@@ -464,19 +601,33 @@ def _bwd_kernel_one_col_block(
...
@@ -464,19 +601,33 @@ def _bwd_kernel_one_col_block(
if
EVEN_HEADDIM
:
if
EVEN_HEADDIM
:
tl
.
atomic_add
(
dq_ptrs
,
dq
,
mask
=
offs_m_curr
[:,
None
]
<
seqlen_q
)
tl
.
atomic_add
(
dq_ptrs
,
dq
,
mask
=
offs_m_curr
[:,
None
]
<
seqlen_q
)
else
:
else
:
tl
.
atomic_add
(
dq_ptrs
,
dq
,
tl
.
atomic_add
(
mask
=
(
offs_m_curr
[:,
None
]
<
seqlen_q
)
&
(
offs_d
[
None
,
:]
<
headdim
))
dq_ptrs
,
dq
,
mask
=
(
offs_m_curr
[:,
None
]
<
seqlen_q
)
&
(
offs_d
[
None
,
:]
<
headdim
),
)
# increment pointers
# increment pointers
dq_ptrs
+=
BLOCK_M
*
stride_dqm
dq_ptrs
+=
BLOCK_M
*
stride_dqm
q_ptrs
+=
BLOCK_M
*
stride_qm
q_ptrs
+=
BLOCK_M
*
stride_qm
do_ptrs
+=
BLOCK_M
*
stride_dom
do_ptrs
+=
BLOCK_M
*
stride_dom
if
BIAS_TYPE
==
'
matrix
'
:
if
BIAS_TYPE
==
"
matrix
"
:
b_ptrs
+=
BLOCK_M
*
stride_bm
b_ptrs
+=
BLOCK_M
*
stride_bm
# write-back
# write-back
dv_ptrs
=
DV
+
(
offs_n
[:,
None
]
*
stride_dvn
+
offs_d
[
None
,
:])
dv_ptrs
=
DV
+
(
offs_n
[:,
None
]
*
stride_dvn
+
offs_d
[
None
,
:])
dk_ptrs
=
DK
+
(
offs_n
[:,
None
]
*
stride_dkn
+
offs_d
[
None
,
:])
dk_ptrs
=
DK
+
(
offs_n
[:,
None
]
*
stride_dkn
+
offs_d
[
None
,
:])
_bwd_store_dk_dv
(
dk_ptrs
,
dv_ptrs
,
dk
,
dv
,
offs_n
,
offs_d
,
seqlen_k
,
headdim
,
_bwd_store_dk_dv
(
EVEN_M
=
EVEN_M
,
EVEN_N
=
EVEN_N
,
EVEN_HEADDIM
=
EVEN_HEADDIM
)
dk_ptrs
,
dv_ptrs
,
dk
,
dv
,
offs_n
,
offs_d
,
seqlen_k
,
headdim
,
EVEN_M
=
EVEN_M
,
EVEN_N
=
EVEN_N
,
EVEN_HEADDIM
=
EVEN_HEADDIM
,
)
def
init_to_zero
(
name
):
def
init_to_zero
(
name
):
...
@@ -485,8 +636,18 @@ def init_to_zero(name):
...
@@ -485,8 +636,18 @@ def init_to_zero(name):
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
configs
=
[
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"SEQUENCE_PARALLEL"
:
False
},
num_warps
=
8
,
num_stages
=
1
,
pre_hook
=
init_to_zero
(
'DQ'
)),
triton
.
Config
(
triton
.
Config
({
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"SEQUENCE_PARALLEL"
:
True
},
num_warps
=
8
,
num_stages
=
1
,
pre_hook
=
init_to_zero
(
'DQ'
)),
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"SEQUENCE_PARALLEL"
:
False
},
num_warps
=
8
,
num_stages
=
1
,
pre_hook
=
init_to_zero
(
"DQ"
),
),
triton
.
Config
(
{
"BLOCK_M"
:
128
,
"BLOCK_N"
:
128
,
"SEQUENCE_PARALLEL"
:
True
},
num_warps
=
8
,
num_stages
=
1
,
pre_hook
=
init_to_zero
(
"DQ"
),
),
# Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now
# Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now
# # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
# # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4*
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')),
...
@@ -494,7 +655,7 @@ def init_to_zero(name):
...
@@ -494,7 +655,7 @@ def init_to_zero(name):
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
# triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')),
],
],
key
=
[
'
CACHE_KEY_SEQLEN_Q
'
,
'
CACHE_KEY_SEQLEN_K
'
,
'
BIAS_TYPE
'
,
'
IS_CAUSAL
'
,
'
BLOCK_HEADDIM
'
],
key
=
[
"
CACHE_KEY_SEQLEN_Q
"
,
"
CACHE_KEY_SEQLEN_K
"
,
"
BIAS_TYPE
"
,
"
IS_CAUSAL
"
,
"
BLOCK_HEADDIM
"
],
)
)
@
triton
.
heuristics
(
@
triton
.
heuristics
(
{
{
...
@@ -505,26 +666,57 @@ def init_to_zero(name):
...
@@ -505,26 +666,57 @@ def init_to_zero(name):
)
)
@
triton
.
jit
@
triton
.
jit
def
_bwd_kernel
(
def
_bwd_kernel
(
Q
,
K
,
V
,
Bias
,
Q
,
DO
,
DQ
,
DK
,
DV
,
K
,
LSE
,
D
,
V
,
Bias
,
DO
,
DQ
,
DK
,
DV
,
LSE
,
D
,
softmax_scale
,
softmax_scale
,
stride_qb
,
stride_qh
,
stride_qm
,
stride_qb
,
stride_kb
,
stride_kh
,
stride_kn
,
stride_qh
,
stride_vb
,
stride_vh
,
stride_vn
,
stride_qm
,
stride_bb
,
stride_bh
,
stride_bm
,
stride_kb
,
stride_dob
,
stride_doh
,
stride_dom
,
stride_kh
,
stride_dqb
,
stride_dqh
,
stride_dqm
,
stride_kn
,
stride_dkb
,
stride_dkh
,
stride_dkn
,
stride_vb
,
stride_dvb
,
stride_dvh
,
stride_dvn
,
stride_vh
,
nheads
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
headdim
,
stride_vn
,
CACHE_KEY_SEQLEN_Q
,
CACHE_KEY_SEQLEN_K
,
stride_bb
,
stride_bh
,
stride_bm
,
stride_dob
,
stride_doh
,
stride_dom
,
stride_dqb
,
stride_dqh
,
stride_dqm
,
stride_dkb
,
stride_dkh
,
stride_dkn
,
stride_dvb
,
stride_dvh
,
stride_dvn
,
nheads
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
headdim
,
CACHE_KEY_SEQLEN_Q
,
CACHE_KEY_SEQLEN_K
,
BIAS_TYPE
:
tl
.
constexpr
,
BIAS_TYPE
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
BLOCK_HEADDIM
:
tl
.
constexpr
,
BLOCK_HEADDIM
:
tl
.
constexpr
,
SEQUENCE_PARALLEL
:
tl
.
constexpr
,
SEQUENCE_PARALLEL
:
tl
.
constexpr
,
EVEN_M
:
tl
.
constexpr
,
EVEN_N
:
tl
.
constexpr
,
EVEN_HEADDIM
:
tl
.
constexpr
,
EVEN_M
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
EVEN_N
:
tl
.
constexpr
,
EVEN_HEADDIM
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
):
off_hb
=
tl
.
program_id
(
1
)
off_hb
=
tl
.
program_id
(
1
)
off_b
=
off_hb
//
nheads
off_b
=
off_hb
//
nheads
...
@@ -537,7 +729,7 @@ def _bwd_kernel(
...
@@ -537,7 +729,7 @@ def _bwd_kernel(
DQ
+=
off_b
*
stride_dqb
+
off_h
*
stride_dqh
DQ
+=
off_b
*
stride_dqb
+
off_h
*
stride_dqh
DK
+=
off_b
*
stride_dkb
+
off_h
*
stride_dkh
DK
+=
off_b
*
stride_dkb
+
off_h
*
stride_dkh
DV
+=
off_b
*
stride_dvb
+
off_h
*
stride_dvh
DV
+=
off_b
*
stride_dvb
+
off_h
*
stride_dvh
if
BIAS_TYPE
!=
'
none
'
:
if
BIAS_TYPE
!=
"
none
"
:
Bias
+=
off_b
*
stride_bb
+
off_h
*
stride_bh
Bias
+=
off_b
*
stride_bb
+
off_h
*
stride_bh
# pointer to row-wise quantities in value-like data
# pointer to row-wise quantities in value-like data
D
+=
off_hb
*
seqlen_q_rounded
D
+=
off_hb
*
seqlen_q_rounded
...
@@ -547,37 +739,73 @@ def _bwd_kernel(
...
@@ -547,37 +739,73 @@ def _bwd_kernel(
for
start_n
in
range
(
0
,
num_block_n
):
for
start_n
in
range
(
0
,
num_block_n
):
_bwd_kernel_one_col_block
(
_bwd_kernel_one_col_block
(
start_n
,
start_n
,
Q
,
K
,
V
,
Bias
,
Q
,
DO
,
DQ
,
DK
,
DV
,
K
,
LSE
,
D
,
V
,
Bias
,
DO
,
DQ
,
DK
,
DV
,
LSE
,
D
,
softmax_scale
,
softmax_scale
,
stride_qm
,
stride_kn
,
stride_vn
,
stride_bm
,
stride_qm
,
stride_dom
,
stride_dqm
,
stride_dkn
,
stride_dvn
,
stride_kn
,
seqlen_q
,
seqlen_k
,
headdim
,
stride_vn
,
stride_bm
,
stride_dom
,
stride_dqm
,
stride_dkn
,
stride_dvn
,
seqlen_q
,
seqlen_k
,
headdim
,
ATOMIC_ADD
=
False
,
ATOMIC_ADD
=
False
,
BIAS_TYPE
=
BIAS_TYPE
,
BIAS_TYPE
=
BIAS_TYPE
,
IS_CAUSAL
=
IS_CAUSAL
,
IS_CAUSAL
=
IS_CAUSAL
,
BLOCK_HEADDIM
=
BLOCK_HEADDIM
,
BLOCK_HEADDIM
=
BLOCK_HEADDIM
,
EVEN_M
=
EVEN_M
,
EVEN_N
=
EVEN_N
,
EVEN_HEADDIM
=
EVEN_HEADDIM
,
EVEN_M
=
EVEN_M
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
EVEN_N
=
EVEN_N
,
EVEN_HEADDIM
=
EVEN_HEADDIM
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
)
)
else
:
else
:
start_n
=
tl
.
program_id
(
0
)
start_n
=
tl
.
program_id
(
0
)
_bwd_kernel_one_col_block
(
_bwd_kernel_one_col_block
(
start_n
,
start_n
,
Q
,
K
,
V
,
Bias
,
Q
,
DO
,
DQ
,
DK
,
DV
,
K
,
LSE
,
D
,
V
,
Bias
,
DO
,
DQ
,
DK
,
DV
,
LSE
,
D
,
softmax_scale
,
softmax_scale
,
stride_qm
,
stride_kn
,
stride_vn
,
stride_bm
,
stride_qm
,
stride_dom
,
stride_dqm
,
stride_dkn
,
stride_dvn
,
stride_kn
,
seqlen_q
,
seqlen_k
,
headdim
,
stride_vn
,
stride_bm
,
stride_dom
,
stride_dqm
,
stride_dkn
,
stride_dvn
,
seqlen_q
,
seqlen_k
,
headdim
,
ATOMIC_ADD
=
True
,
ATOMIC_ADD
=
True
,
BIAS_TYPE
=
BIAS_TYPE
,
BIAS_TYPE
=
BIAS_TYPE
,
IS_CAUSAL
=
IS_CAUSAL
,
IS_CAUSAL
=
IS_CAUSAL
,
BLOCK_HEADDIM
=
BLOCK_HEADDIM
,
BLOCK_HEADDIM
=
BLOCK_HEADDIM
,
EVEN_M
=
EVEN_M
,
EVEN_N
=
EVEN_N
,
EVEN_HEADDIM
=
EVEN_HEADDIM
,
EVEN_M
=
EVEN_M
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
EVEN_N
=
EVEN_N
,
EVEN_HEADDIM
=
EVEN_HEADDIM
,
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
)
)
...
@@ -587,14 +815,14 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
...
@@ -587,14 +815,14 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
_
,
seqlen_k
,
_
,
_
=
k
.
shape
_
,
seqlen_k
,
_
,
_
=
k
.
shape
assert
k
.
shape
==
(
batch
,
seqlen_k
,
nheads
,
d
)
assert
k
.
shape
==
(
batch
,
seqlen_k
,
nheads
,
d
)
assert
v
.
shape
==
(
batch
,
seqlen_k
,
nheads
,
d
)
assert
v
.
shape
==
(
batch
,
seqlen_k
,
nheads
,
d
)
assert
d
<=
128
,
'
FlashAttention only support head dimensions up to 128
'
assert
d
<=
128
,
"
FlashAttention only support head dimensions up to 128
"
assert
q
.
dtype
==
k
.
dtype
==
v
.
dtype
,
'
All tensors must have the same type
'
assert
q
.
dtype
==
k
.
dtype
==
v
.
dtype
,
"
All tensors must have the same type
"
assert
q
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
],
'
Only support fp16 and bf16
'
assert
q
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
],
"
Only support fp16 and bf16
"
assert
q
.
is_cuda
and
k
.
is_cuda
and
v
.
is_cuda
assert
q
.
is_cuda
and
k
.
is_cuda
and
v
.
is_cuda
softmax_scale
=
softmax_scale
or
1.0
/
math
.
sqrt
(
d
)
softmax_scale
=
softmax_scale
or
1.0
/
math
.
sqrt
(
d
)
has_bias
=
bias
is
not
None
has_bias
=
bias
is
not
None
bias_type
=
'
none
'
bias_type
=
"
none
"
if
has_bias
:
if
has_bias
:
assert
bias
.
dtype
in
[
q
.
dtype
,
torch
.
float
]
assert
bias
.
dtype
in
[
q
.
dtype
,
torch
.
float
]
assert
bias
.
is_cuda
assert
bias
.
is_cuda
...
@@ -602,12 +830,13 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
...
@@ -602,12 +830,13 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
if
bias
.
stride
(
-
1
)
!=
1
:
if
bias
.
stride
(
-
1
)
!=
1
:
bias
=
bias
.
contiguous
()
bias
=
bias
.
contiguous
()
if
bias
.
shape
[
2
:]
==
(
1
,
seqlen_k
):
if
bias
.
shape
[
2
:]
==
(
1
,
seqlen_k
):
bias_type
=
'
vector
'
bias_type
=
"
vector
"
elif
bias
.
shape
[
2
:]
==
(
seqlen_q
,
seqlen_k
):
elif
bias
.
shape
[
2
:]
==
(
seqlen_q
,
seqlen_k
):
bias_type
=
'
matrix
'
bias_type
=
"
matrix
"
else
:
else
:
raise
RuntimeError
(
'Last 2 dimensions of bias must be (1, seqlen_k)'
raise
RuntimeError
(
' or (seqlen_q, seqlen_k)'
)
"Last 2 dimensions of bias must be (1, seqlen_k)"
" or (seqlen_q, seqlen_k)"
)
bias
=
bias
.
expand
(
batch
,
nheads
,
seqlen_q
,
seqlen_k
)
bias
=
bias
.
expand
(
batch
,
nheads
,
seqlen_q
,
seqlen_k
)
bias_strides
=
(
bias
.
stride
(
0
),
bias
.
stride
(
1
),
bias
.
stride
(
2
))
if
has_bias
else
(
0
,
0
,
0
)
bias_strides
=
(
bias
.
stride
(
0
),
bias
.
stride
(
1
),
bias
.
stride
(
2
))
if
has_bias
else
(
0
,
0
,
0
)
...
@@ -621,27 +850,50 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
...
@@ -621,27 +850,50 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
num_warps
=
4
if
d
<=
64
else
8
num_warps
=
4
if
d
<=
64
else
8
grid
=
lambda
META
:
(
triton
.
cdiv
(
seqlen_q
,
META
[
"BLOCK_M"
]),
batch
*
nheads
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
seqlen_q
,
META
[
"BLOCK_M"
]),
batch
*
nheads
)
_fwd_kernel
[
grid
](
_fwd_kernel
[
grid
](
q
,
k
,
v
,
bias
,
o
,
q
,
lse
,
tmp
,
k
,
v
,
bias
,
o
,
lse
,
tmp
,
softmax_scale
,
softmax_scale
,
q
.
stride
(
0
),
q
.
stride
(
2
),
q
.
stride
(
1
),
q
.
stride
(
0
),
k
.
stride
(
0
),
k
.
stride
(
2
),
k
.
stride
(
1
),
q
.
stride
(
2
),
v
.
stride
(
0
),
v
.
stride
(
2
),
v
.
stride
(
1
),
q
.
stride
(
1
),
k
.
stride
(
0
),
k
.
stride
(
2
),
k
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
2
),
v
.
stride
(
1
),
*
bias_strides
,
*
bias_strides
,
o
.
stride
(
0
),
o
.
stride
(
2
),
o
.
stride
(
1
),
o
.
stride
(
0
),
nheads
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
d
,
o
.
stride
(
2
),
seqlen_q
//
32
,
seqlen_k
//
32
,
# key for triton cache (limit number of compilations)
o
.
stride
(
1
),
nheads
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
d
,
seqlen_q
//
32
,
seqlen_k
//
32
,
# key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
bias_type
,
causal
,
BLOCK_HEADDIM
,
bias_type
,
BLOCK_M
=
BLOCK
,
BLOCK_N
=
BLOCK
,
causal
,
BLOCK_HEADDIM
,
BLOCK_M
=
BLOCK
,
BLOCK_N
=
BLOCK
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
1
,
num_stages
=
1
,
)
)
return
o
,
lse
,
softmax_scale
# softmax_scale could have been updated
return
o
,
lse
,
softmax_scale
# softmax_scale could have been updated
def
_flash_attn_backward
(
do
,
q
,
k
,
v
,
o
,
lse
,
dq
,
dk
,
dv
,
bias
=
None
,
causal
=
False
,
softmax_scale
=
None
):
def
_flash_attn_backward
(
do
,
q
,
k
,
v
,
o
,
lse
,
dq
,
dk
,
dv
,
bias
=
None
,
causal
=
False
,
softmax_scale
=
None
):
# Make sure that the last dimension is contiguous
# Make sure that the last dimension is contiguous
if
do
.
stride
(
-
1
)
!=
1
:
if
do
.
stride
(
-
1
)
!=
1
:
do
=
do
.
contiguous
()
do
=
do
.
contiguous
()
...
@@ -662,53 +914,94 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=Fals
...
@@ -662,53 +914,94 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=Fals
BLOCK_HEADDIM
=
max
(
triton
.
next_power_of_2
(
d
),
16
)
BLOCK_HEADDIM
=
max
(
triton
.
next_power_of_2
(
d
),
16
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
seqlen_q
,
META
[
"BLOCK_M"
]),
batch
*
nheads
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
seqlen_q
,
META
[
"BLOCK_M"
]),
batch
*
nheads
)
_bwd_preprocess_do_o_dot
[
grid
](
_bwd_preprocess_do_o_dot
[
grid
](
o
,
do
,
delta
,
o
,
o
.
stride
(
0
),
o
.
stride
(
2
),
o
.
stride
(
1
),
do
,
do
.
stride
(
0
),
do
.
stride
(
2
),
do
.
stride
(
1
),
delta
,
nheads
,
seqlen_q
,
seqlen_q_rounded
,
d
,
o
.
stride
(
0
),
BLOCK_M
=
128
,
BLOCK_HEADDIM
=
BLOCK_HEADDIM
,
o
.
stride
(
2
),
o
.
stride
(
1
),
do
.
stride
(
0
),
do
.
stride
(
2
),
do
.
stride
(
1
),
nheads
,
seqlen_q
,
seqlen_q_rounded
,
d
,
BLOCK_M
=
128
,
BLOCK_HEADDIM
=
BLOCK_HEADDIM
,
)
)
has_bias
=
bias
is
not
None
has_bias
=
bias
is
not
None
bias_type
=
'
none
'
bias_type
=
"
none
"
if
has_bias
:
if
has_bias
:
assert
bias
.
dtype
in
[
q
.
dtype
,
torch
.
float
]
assert
bias
.
dtype
in
[
q
.
dtype
,
torch
.
float
]
assert
bias
.
is_cuda
assert
bias
.
is_cuda
assert
bias
.
dim
()
==
4
assert
bias
.
dim
()
==
4
assert
bias
.
stride
(
-
1
)
==
1
assert
bias
.
stride
(
-
1
)
==
1
if
bias
.
shape
[
2
:]
==
(
1
,
seqlen_k
):
if
bias
.
shape
[
2
:]
==
(
1
,
seqlen_k
):
bias_type
=
'
vector
'
bias_type
=
"
vector
"
elif
bias
.
shape
[
2
:]
==
(
seqlen_q
,
seqlen_k
):
elif
bias
.
shape
[
2
:]
==
(
seqlen_q
,
seqlen_k
):
bias_type
=
'
matrix
'
bias_type
=
"
matrix
"
else
:
else
:
raise
RuntimeError
(
'Last 2 dimensions of bias must be (1, seqlen_k)'
raise
RuntimeError
(
' or (seqlen_q, seqlen_k)'
)
"Last 2 dimensions of bias must be (1, seqlen_k)"
" or (seqlen_q, seqlen_k)"
)
bias
=
bias
.
expand
(
batch
,
nheads
,
seqlen_q
,
seqlen_k
)
bias
=
bias
.
expand
(
batch
,
nheads
,
seqlen_q
,
seqlen_k
)
bias_strides
=
(
bias
.
stride
(
0
),
bias
.
stride
(
1
),
bias
.
stride
(
2
))
if
has_bias
else
(
0
,
0
,
0
)
bias_strides
=
(
bias
.
stride
(
0
),
bias
.
stride
(
1
),
bias
.
stride
(
2
))
if
has_bias
else
(
0
,
0
,
0
)
# BLOCK_M = 128
# BLOCK_M = 128
# BLOCK_N = 64
# BLOCK_N = 64
# num_warps = 4
# num_warps = 4
grid
=
lambda
META
:
(
triton
.
cdiv
(
seqlen_k
,
META
[
"BLOCK_N"
])
if
META
[
"SEQUENCE_PARALLEL"
]
else
1
,
grid
=
lambda
META
:
(
batch
*
nheads
)
triton
.
cdiv
(
seqlen_k
,
META
[
"BLOCK_N"
])
if
META
[
"SEQUENCE_PARALLEL"
]
else
1
,
batch
*
nheads
,
)
_bwd_kernel
[
grid
](
_bwd_kernel
[
grid
](
q
,
k
,
v
,
bias
,
q
,
do
,
dq_accum
,
dk
,
dv
,
k
,
lse
,
delta
,
v
,
bias
,
do
,
dq_accum
,
dk
,
dv
,
lse
,
delta
,
softmax_scale
,
softmax_scale
,
q
.
stride
(
0
),
q
.
stride
(
2
),
q
.
stride
(
1
),
q
.
stride
(
0
),
k
.
stride
(
0
),
k
.
stride
(
2
),
k
.
stride
(
1
),
q
.
stride
(
2
),
v
.
stride
(
0
),
v
.
stride
(
2
),
v
.
stride
(
1
),
q
.
stride
(
1
),
k
.
stride
(
0
),
k
.
stride
(
2
),
k
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
2
),
v
.
stride
(
1
),
*
bias_strides
,
*
bias_strides
,
do
.
stride
(
0
),
do
.
stride
(
2
),
do
.
stride
(
1
),
do
.
stride
(
0
),
dq_accum
.
stride
(
0
),
dq_accum
.
stride
(
2
),
dq_accum
.
stride
(
1
),
do
.
stride
(
2
),
dk
.
stride
(
0
),
dk
.
stride
(
2
),
dk
.
stride
(
1
),
do
.
stride
(
1
),
dv
.
stride
(
0
),
dv
.
stride
(
2
),
dv
.
stride
(
1
),
dq_accum
.
stride
(
0
),
nheads
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
d
,
dq_accum
.
stride
(
2
),
seqlen_q
//
32
,
seqlen_k
//
32
,
# key for triton cache (limit number of compilations)
dq_accum
.
stride
(
1
),
dk
.
stride
(
0
),
dk
.
stride
(
2
),
dk
.
stride
(
1
),
dv
.
stride
(
0
),
dv
.
stride
(
2
),
dv
.
stride
(
1
),
nheads
,
seqlen_q
,
seqlen_k
,
seqlen_q_rounded
,
d
,
seqlen_q
//
32
,
seqlen_k
//
32
,
# key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
bias_type
,
causal
,
BLOCK_HEADDIM
,
bias_type
,
causal
,
BLOCK_HEADDIM
,
# SEQUENCE_PARALLEL=False,
# SEQUENCE_PARALLEL=False,
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
# BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N,
# num_warps=num_warps,
# num_warps=num_warps,
...
@@ -718,7 +1011,6 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=Fals
...
@@ -718,7 +1011,6 @@ def _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=Fals
class
FlashAttnQKVPackedFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnQKVPackedFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
qkv
,
bias
=
None
,
causal
=
False
,
softmax_scale
=
None
):
def
forward
(
ctx
,
qkv
,
bias
=
None
,
causal
=
False
,
softmax_scale
=
None
):
"""
"""
...
@@ -731,8 +1023,12 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
...
@@ -731,8 +1023,12 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
if
qkv
.
stride
(
-
1
)
!=
1
:
if
qkv
.
stride
(
-
1
)
!=
1
:
qkv
=
qkv
.
contiguous
()
qkv
=
qkv
.
contiguous
()
o
,
lse
,
ctx
.
softmax_scale
=
_flash_attn_forward
(
o
,
lse
,
ctx
.
softmax_scale
=
_flash_attn_forward
(
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
],
bias
=
bias
,
causal
=
causal
,
qkv
[:,
:,
0
],
softmax_scale
=
softmax_scale
qkv
[:,
:,
1
],
qkv
[:,
:,
2
],
bias
=
bias
,
causal
=
causal
,
softmax_scale
=
softmax_scale
,
)
)
ctx
.
save_for_backward
(
qkv
,
o
,
lse
,
bias
)
ctx
.
save_for_backward
(
qkv
,
o
,
lse
,
bias
)
ctx
.
causal
=
causal
ctx
.
causal
=
causal
...
@@ -741,14 +1037,25 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
...
@@ -741,14 +1037,25 @@ class FlashAttnQKVPackedFunc(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
do
):
def
backward
(
ctx
,
do
):
qkv
,
o
,
lse
,
bias
=
ctx
.
saved_tensors
qkv
,
o
,
lse
,
bias
=
ctx
.
saved_tensors
assert
not
ctx
.
needs_input_grad
[
1
],
'
FlashAttention does not support bias gradient yet
'
assert
not
ctx
.
needs_input_grad
[
1
],
"
FlashAttention does not support bias gradient yet
"
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
with
torch
.
inference_mode
():
with
torch
.
inference_mode
():
dqkv
=
torch
.
empty_like
(
qkv
)
dqkv
=
torch
.
empty_like
(
qkv
)
_flash_attn_backward
(
do
,
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
],
o
,
lse
,
_flash_attn_backward
(
dqkv
[:,
:,
0
],
dqkv
[:,
:,
1
],
dqkv
[:,
:,
2
],
do
,
bias
=
bias
,
causal
=
ctx
.
causal
,
softmax_scale
=
ctx
.
softmax_scale
)
qkv
[:,
:,
0
],
qkv
[:,
:,
1
],
qkv
[:,
:,
2
],
o
,
lse
,
dqkv
[:,
:,
0
],
dqkv
[:,
:,
1
],
dqkv
[:,
:,
2
],
bias
=
bias
,
causal
=
ctx
.
causal
,
softmax_scale
=
ctx
.
softmax_scale
,
)
return
dqkv
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
...
@@ -756,7 +1063,6 @@ flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
...
@@ -756,7 +1063,6 @@ flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply
class
FlashAttnKVPackedFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnKVPackedFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
q
,
kv
,
bias
=
None
,
causal
=
False
,
softmax_scale
=
None
):
def
forward
(
ctx
,
q
,
kv
,
bias
=
None
,
causal
=
False
,
softmax_scale
=
None
):
"""
"""
...
@@ -779,15 +1085,26 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
...
@@ -779,15 +1085,26 @@ class FlashAttnKVPackedFunc(torch.autograd.Function):
def
backward
(
ctx
,
do
):
def
backward
(
ctx
,
do
):
q
,
kv
,
o
,
lse
,
bias
=
ctx
.
saved_tensors
q
,
kv
,
o
,
lse
,
bias
=
ctx
.
saved_tensors
if
len
(
ctx
.
needs_input_grad
)
>=
3
:
if
len
(
ctx
.
needs_input_grad
)
>=
3
:
assert
not
ctx
.
needs_input_grad
[
2
],
'
FlashAttention does not support bias gradient yet
'
assert
not
ctx
.
needs_input_grad
[
2
],
"
FlashAttention does not support bias gradient yet
"
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
with
torch
.
inference_mode
():
with
torch
.
inference_mode
():
dq
=
torch
.
empty_like
(
q
)
dq
=
torch
.
empty_like
(
q
)
dkv
=
torch
.
empty_like
(
kv
)
dkv
=
torch
.
empty_like
(
kv
)
_flash_attn_backward
(
do
,
q
,
kv
[:,
:,
0
],
kv
[:,
:,
1
],
o
,
lse
,
_flash_attn_backward
(
dq
,
dkv
[:,
:,
0
],
dkv
[:,
:,
1
],
do
,
bias
=
bias
,
causal
=
ctx
.
causal
,
softmax_scale
=
ctx
.
softmax_scale
)
q
,
kv
[:,
:,
0
],
kv
[:,
:,
1
],
o
,
lse
,
dq
,
dkv
[:,
:,
0
],
dkv
[:,
:,
1
],
bias
=
bias
,
causal
=
ctx
.
causal
,
softmax_scale
=
ctx
.
softmax_scale
,
)
return
dq
,
dkv
,
None
,
None
,
None
return
dq
,
dkv
,
None
,
None
,
None
...
@@ -795,7 +1112,6 @@ flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
...
@@ -795,7 +1112,6 @@ flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply
class
FlashAttnFunc
(
torch
.
autograd
.
Function
):
class
FlashAttnFunc
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
bias
=
None
,
causal
=
False
,
softmax_scale
=
None
):
def
forward
(
ctx
,
q
,
k
,
v
,
bias
=
None
,
causal
=
False
,
softmax_scale
=
None
):
"""
"""
...
@@ -817,15 +1133,27 @@ class FlashAttnFunc(torch.autograd.Function):
...
@@ -817,15 +1133,27 @@ class FlashAttnFunc(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
do
):
def
backward
(
ctx
,
do
):
q
,
k
,
v
,
o
,
lse
,
bias
=
ctx
.
saved_tensors
q
,
k
,
v
,
o
,
lse
,
bias
=
ctx
.
saved_tensors
assert
not
ctx
.
needs_input_grad
[
3
],
'
FlashAttention does not support bias gradient yet
'
assert
not
ctx
.
needs_input_grad
[
3
],
"
FlashAttention does not support bias gradient yet
"
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# Triton's autotune causes the Tensor._version to change, and so Pytorch autograd
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
# does a memcpy. To avoid this we run in inference_mode, which doesn't track the version.
with
torch
.
inference_mode
():
with
torch
.
inference_mode
():
dq
=
torch
.
empty_like
(
q
)
dq
=
torch
.
empty_like
(
q
)
dk
=
torch
.
empty_like
(
k
)
dk
=
torch
.
empty_like
(
k
)
dv
=
torch
.
empty_like
(
v
)
dv
=
torch
.
empty_like
(
v
)
_flash_attn_backward
(
do
,
q
,
k
,
v
,
o
,
lse
,
dq
,
dk
,
dv
,
_flash_attn_backward
(
bias
=
bias
,
causal
=
ctx
.
causal
,
softmax_scale
=
ctx
.
softmax_scale
)
do
,
q
,
k
,
v
,
o
,
lse
,
dq
,
dk
,
dv
,
bias
=
bias
,
causal
=
ctx
.
causal
,
softmax_scale
=
ctx
.
softmax_scale
,
)
return
dq
,
dk
,
dv
,
None
,
None
,
None
return
dq
,
dk
,
dv
,
None
,
None
,
None
...
...
flash_attn/flash_attn_triton_og.py
View file @
f1a73d07
...
@@ -11,22 +11,41 @@ This is a Triton implementation of the Flash Attention algorithm
...
@@ -11,22 +11,41 @@ This is a Triton implementation of the Flash Attention algorithm
import
pytest
import
pytest
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
@
triton
.
jit
@
triton
.
jit
def
_fwd_kernel
(
def
_fwd_kernel
(
Q
,
K
,
V
,
sm_scale
,
Q
,
TMP
,
L
,
M
,
# NOTE: TMP is a scratchpad buffer to workaround a compiler bug
K
,
V
,
sm_scale
,
TMP
,
L
,
M
,
# NOTE: TMP is a scratchpad buffer to workaround a compiler bug
Out
,
Out
,
stride_qz
,
stride_qh
,
stride_qm
,
stride_qk
,
stride_qz
,
stride_kz
,
stride_kh
,
stride_kn
,
stride_kk
,
stride_qh
,
stride_vz
,
stride_vh
,
stride_vk
,
stride_vn
,
stride_qm
,
stride_oz
,
stride_oh
,
stride_om
,
stride_on
,
stride_qk
,
Z
,
H
,
N_CTX
,
stride_kz
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
stride_kh
,
stride_kn
,
stride_kk
,
stride_vz
,
stride_vh
,
stride_vk
,
stride_vn
,
stride_oz
,
stride_oh
,
stride_om
,
stride_on
,
Z
,
H
,
N_CTX
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
):
start_m
=
tl
.
program_id
(
0
)
start_m
=
tl
.
program_id
(
0
)
...
@@ -100,9 +119,13 @@ def _fwd_kernel(
...
@@ -100,9 +119,13 @@ def _fwd_kernel(
@
triton
.
jit
@
triton
.
jit
def
_bwd_preprocess
(
def
_bwd_preprocess
(
Out
,
DO
,
L
,
Out
,
NewDO
,
Delta
,
DO
,
BLOCK_M
:
tl
.
constexpr
,
D_HEAD
:
tl
.
constexpr
,
L
,
NewDO
,
Delta
,
BLOCK_M
:
tl
.
constexpr
,
D_HEAD
:
tl
.
constexpr
,
):
):
off_m
=
tl
.
program_id
(
0
)
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
off_m
=
tl
.
program_id
(
0
)
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
off_n
=
tl
.
arange
(
0
,
D_HEAD
)
off_n
=
tl
.
arange
(
0
,
D_HEAD
)
...
@@ -120,16 +143,36 @@ def _bwd_preprocess(
...
@@ -120,16 +143,36 @@ def _bwd_preprocess(
@
triton
.
jit
@
triton
.
jit
def
_bwd_kernel
(
def
_bwd_kernel
(
Q
,
K
,
V
,
sm_scale
,
Out
,
DO
,
Q
,
DQ
,
DK
,
DV
,
K
,
L
,
M
,
V
,
sm_scale
,
Out
,
DO
,
DQ
,
DK
,
DV
,
L
,
M
,
D
,
D
,
stride_qz
,
stride_qh
,
stride_qm
,
stride_qk
,
stride_qz
,
stride_kz
,
stride_kh
,
stride_kn
,
stride_kk
,
stride_qh
,
stride_vz
,
stride_vh
,
stride_vk
,
stride_vn
,
stride_qm
,
Z
,
H
,
N_CTX
,
stride_qk
,
stride_kz
,
stride_kh
,
stride_kn
,
stride_kk
,
stride_vz
,
stride_vh
,
stride_vk
,
stride_vn
,
Z
,
H
,
N_CTX
,
num_block
,
num_block
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
):
off_hz
=
tl
.
program_id
(
0
)
off_hz
=
tl
.
program_id
(
0
)
...
@@ -203,7 +246,6 @@ def _bwd_kernel(
...
@@ -203,7 +246,6 @@ def _bwd_kernel(
class
_attention
(
torch
.
autograd
.
Function
):
class
_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
sm_scale
):
def
forward
(
ctx
,
q
,
k
,
v
,
sm_scale
):
BLOCK
=
128
BLOCK
=
128
...
@@ -213,22 +255,45 @@ class _attention(torch.autograd.Function):
...
@@ -213,22 +255,45 @@ class _attention(torch.autograd.Function):
assert
Lk
in
{
16
,
32
,
64
,
128
}
assert
Lk
in
{
16
,
32
,
64
,
128
}
o
=
torch
.
empty_like
(
q
)
o
=
torch
.
empty_like
(
q
)
grid
=
(
triton
.
cdiv
(
q
.
shape
[
2
],
BLOCK
),
q
.
shape
[
0
]
*
q
.
shape
[
1
])
grid
=
(
triton
.
cdiv
(
q
.
shape
[
2
],
BLOCK
),
q
.
shape
[
0
]
*
q
.
shape
[
1
])
tmp
=
torch
.
empty
((
q
.
shape
[
0
]
*
q
.
shape
[
1
],
q
.
shape
[
2
]),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
tmp
=
torch
.
empty
(
(
q
.
shape
[
0
]
*
q
.
shape
[
1
],
q
.
shape
[
2
]),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
L
=
torch
.
empty
((
q
.
shape
[
0
]
*
q
.
shape
[
1
],
q
.
shape
[
2
]),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
L
=
torch
.
empty
((
q
.
shape
[
0
]
*
q
.
shape
[
1
],
q
.
shape
[
2
]),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
m
=
torch
.
empty
((
q
.
shape
[
0
]
*
q
.
shape
[
1
],
q
.
shape
[
2
]),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
m
=
torch
.
empty
((
q
.
shape
[
0
]
*
q
.
shape
[
1
],
q
.
shape
[
2
]),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
num_warps
=
4
if
Lk
<=
64
else
8
num_warps
=
4
if
Lk
<=
64
else
8
_fwd_kernel
[
grid
](
_fwd_kernel
[
grid
](
q
,
k
,
v
,
sm_scale
,
q
,
tmp
,
L
,
m
,
k
,
v
,
sm_scale
,
tmp
,
L
,
m
,
o
,
o
,
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
2
),
q
.
stride
(
3
),
q
.
stride
(
0
),
k
.
stride
(
0
),
k
.
stride
(
1
),
k
.
stride
(
2
),
k
.
stride
(
3
),
q
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
1
),
v
.
stride
(
2
),
v
.
stride
(
3
),
q
.
stride
(
2
),
o
.
stride
(
0
),
o
.
stride
(
1
),
o
.
stride
(
2
),
o
.
stride
(
3
),
q
.
stride
(
3
),
q
.
shape
[
0
],
q
.
shape
[
1
],
q
.
shape
[
2
],
k
.
stride
(
0
),
BLOCK_M
=
BLOCK
,
BLOCK_N
=
BLOCK
,
k
.
stride
(
1
),
BLOCK_DMODEL
=
Lk
,
num_warps
=
num_warps
,
k
.
stride
(
2
),
k
.
stride
(
3
),
v
.
stride
(
0
),
v
.
stride
(
1
),
v
.
stride
(
2
),
v
.
stride
(
3
),
o
.
stride
(
0
),
o
.
stride
(
1
),
o
.
stride
(
2
),
o
.
stride
(
3
),
q
.
shape
[
0
],
q
.
shape
[
1
],
q
.
shape
[
2
],
BLOCK_M
=
BLOCK
,
BLOCK_N
=
BLOCK
,
BLOCK_DMODEL
=
Lk
,
num_warps
=
num_warps
,
num_stages
=
1
,
num_stages
=
1
,
)
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
o
,
L
,
m
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
o
,
L
,
m
)
...
@@ -247,27 +312,51 @@ class _attention(torch.autograd.Function):
...
@@ -247,27 +312,51 @@ class _attention(torch.autograd.Function):
dv
=
torch
.
empty_like
(
v
)
dv
=
torch
.
empty_like
(
v
)
do_scaled
=
torch
.
empty_like
(
do
)
do_scaled
=
torch
.
empty_like
(
do
)
delta
=
torch
.
empty_like
(
l
)
delta
=
torch
.
empty_like
(
l
)
_bwd_preprocess
[(
ctx
.
grid
[
0
]
*
ctx
.
grid
[
1
],
)](
_bwd_preprocess
[(
ctx
.
grid
[
0
]
*
ctx
.
grid
[
1
],)](
o
,
do
,
l
,
o
,
do_scaled
,
delta
,
do
,
BLOCK_M
=
ctx
.
BLOCK
,
D_HEAD
=
ctx
.
BLOCK_DMODEL
,
l
,
do_scaled
,
delta
,
BLOCK_M
=
ctx
.
BLOCK
,
D_HEAD
=
ctx
.
BLOCK_DMODEL
,
)
)
# NOTE: kernel currently buggy for other values of `num_warps`
# NOTE: kernel currently buggy for other values of `num_warps`
num_warps
=
8
num_warps
=
8
_bwd_kernel
[(
ctx
.
grid
[
1
],)](
_bwd_kernel
[(
ctx
.
grid
[
1
],)](
q
,
k
,
v
,
ctx
.
sm_scale
,
q
,
o
,
do_scaled
,
k
,
dq
,
dk
,
dv
,
v
,
l
,
m
,
ctx
.
sm_scale
,
o
,
do_scaled
,
dq
,
dk
,
dv
,
l
,
m
,
delta
,
delta
,
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
2
),
q
.
stride
(
3
),
q
.
stride
(
0
),
k
.
stride
(
0
),
k
.
stride
(
1
),
k
.
stride
(
2
),
k
.
stride
(
3
),
q
.
stride
(
1
),
v
.
stride
(
0
),
v
.
stride
(
1
),
v
.
stride
(
2
),
v
.
stride
(
3
),
q
.
stride
(
2
),
q
.
shape
[
0
],
q
.
shape
[
1
],
q
.
shape
[
2
],
q
.
stride
(
3
),
k
.
stride
(
0
),
k
.
stride
(
1
),
k
.
stride
(
2
),
k
.
stride
(
3
),
v
.
stride
(
0
),
v
.
stride
(
1
),
v
.
stride
(
2
),
v
.
stride
(
3
),
q
.
shape
[
0
],
q
.
shape
[
1
],
q
.
shape
[
2
],
ctx
.
grid
[
0
],
ctx
.
grid
[
0
],
BLOCK_M
=
ctx
.
BLOCK
,
BLOCK_N
=
ctx
.
BLOCK
,
BLOCK_M
=
ctx
.
BLOCK
,
BLOCK_DMODEL
=
ctx
.
BLOCK_DMODEL
,
num_warps
=
num_warps
,
BLOCK_N
=
ctx
.
BLOCK
,
BLOCK_DMODEL
=
ctx
.
BLOCK_DMODEL
,
num_warps
=
num_warps
,
num_stages
=
1
,
num_stages
=
1
,
)
)
return
dq
.
to
(
q
.
dtype
),
dk
,
dv
,
None
return
dq
.
to
(
q
.
dtype
),
dk
,
dv
,
None
...
...
flash_attn/flash_blocksparse_attention.py
View file @
f1a73d07
import
math
import
math
import
hydra
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
einops
import
rearrange
from
einops
import
rearrange
import
hydra
from
flash_attn.bert_padding
import
index_first_axis
,
pad_input
,
unpad_input
from
flash_attn.flash_blocksparse_attn_interface
import
(
from
flash_attn.flash_blocksparse_attn_interface
import
flash_blocksparse_attn_func
convert_blockmask
,
from
flash_attn.
flash_blocksparse_attn_
interface
import
convert_blockmask
flash_blocksparse_attn_
func
,
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
,
index_first_axis
)
class
FlashBlocksparseAttention
(
nn
.
Module
):
class
FlashBlocksparseAttention
(
nn
.
Module
):
...
@@ -21,8 +22,16 @@ class FlashBlocksparseAttention(nn.Module):
...
@@ -21,8 +22,16 @@ class FlashBlocksparseAttention(nn.Module):
attention_dropout: The dropout rate to apply to the attention
attention_dropout: The dropout rate to apply to the attention
(default: 0.1)
(default: 0.1)
"""
"""
def
__init__
(
self
,
sparsity_config
,
softmax_temp
=
None
,
attention_dropout
=
0.0
,
max_seq_length
=
2048
,
device
=
None
,
dtype
=
None
):
def
__init__
(
self
,
sparsity_config
,
softmax_temp
=
None
,
attention_dropout
=
0.0
,
max_seq_length
=
2048
,
device
=
None
,
dtype
=
None
,
):
super
().
__init__
()
super
().
__init__
()
self
.
sparsity_config
=
hydra
.
utils
.
instantiate
(
sparsity_config
)
self
.
sparsity_config
=
hydra
.
utils
.
instantiate
(
sparsity_config
)
self
.
softmax_temp
=
softmax_temp
self
.
softmax_temp
=
softmax_temp
...
@@ -36,8 +45,17 @@ class FlashBlocksparseAttention(nn.Module):
...
@@ -36,8 +45,17 @@ class FlashBlocksparseAttention(nn.Module):
self
.
register_buffer
(
"blockmask_converted"
,
blockmask_converted
)
self
.
register_buffer
(
"blockmask_converted"
,
blockmask_converted
)
# logger.info(f'Attention class {self.__class__}: saving={self.layout.float().mean()}')
# logger.info(f'Attention class {self.__class__}: saving={self.layout.float().mean()}')
def
forward
(
self
,
qkv
,
attn_mask
=
None
,
key_padding_mask
=
None
,
causal
=
False
,
cu_seqlens
=
None
,
def
forward
(
max_s
=
None
,
need_weights
=
False
,
convert_mask
=
True
):
self
,
qkv
,
attn_mask
=
None
,
key_padding_mask
=
None
,
causal
=
False
,
cu_seqlens
=
None
,
max_s
=
None
,
need_weights
=
False
,
convert_mask
=
True
,
):
"""Implements the multihead softmax attention.
"""Implements the multihead softmax attention.
Arguments
Arguments
---------
---------
...
@@ -57,47 +75,76 @@ class FlashBlocksparseAttention(nn.Module):
...
@@ -57,47 +75,76 @@ class FlashBlocksparseAttention(nn.Module):
seqlen
=
qkv
.
shape
[
1
]
seqlen
=
qkv
.
shape
[
1
]
# Convert mask to take a subset
# Convert mask to take a subset
seqlen_rounded
=
((
seqlen
+
256
-
1
)
//
256
)
*
256
seqlen_rounded
=
((
seqlen
+
256
-
1
)
//
256
)
*
256
assert
seqlen_rounded
//
16
<=
self
.
layout
.
shape
[
0
],
seqlen_rounded
//
256
<=
self
.
layout
.
shape
[
1
]
assert
seqlen_rounded
//
16
<=
self
.
layout
.
shape
[
0
],
(
blockmask
=
self
.
layout
[:
seqlen_rounded
//
16
,
:
seqlen_rounded
//
256
]
seqlen_rounded
//
256
<=
self
.
layout
.
shape
[
1
]
)
blockmask
=
self
.
layout
[:
seqlen_rounded
//
16
,
:
seqlen_rounded
//
256
]
if
key_padding_mask
is
None
:
if
key_padding_mask
is
None
:
qkv
=
rearrange
(
qkv
,
'
b s ... -> (b s) ...
'
)
qkv
=
rearrange
(
qkv
,
"
b s ... -> (b s) ...
"
)
max_s
=
seqlen
max_s
=
seqlen
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen
,
step
=
seqlen
,
dtype
=
torch
.
int32
,
cu_seqlens
=
torch
.
arange
(
device
=
qkv
.
device
)
0
,
(
batch_size
+
1
)
*
seqlen
,
step
=
seqlen
,
dtype
=
torch
.
int32
,
device
=
qkv
.
device
)
output
=
flash_blocksparse_attn_func
(
output
=
flash_blocksparse_attn_func
(
qkv
,
cu_seqlens
,
blockmask
,
self
.
dropout_p
if
self
.
training
else
0.0
,
qkv
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
cu_seqlens
,
blockmask
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
,
)
)
output
=
rearrange
(
output
,
'
(b s) ... -> b s ...
'
,
b
=
batch_size
)
output
=
rearrange
(
output
,
"
(b s) ... -> b s ...
"
,
b
=
batch_size
)
else
:
else
:
key_padding_mask_bool
=
key_padding_mask
.
bool_matrix
key_padding_mask_bool
=
key_padding_mask
.
bool_matrix
nheads
=
qkv
.
shape
[
-
2
]
nheads
=
qkv
.
shape
[
-
2
]
x
=
rearrange
(
qkv
,
'
b s three h d -> b s (three h d)
'
)
x
=
rearrange
(
qkv
,
"
b s three h d -> b s (three h d)
"
)
x_unpad
,
indices
,
cu_seqlens
,
max_s
=
unpad_input
(
x
,
key_padding_mask_bool
)
x_unpad
,
indices
,
cu_seqlens
,
max_s
=
unpad_input
(
x
,
key_padding_mask_bool
)
x_unpad
=
rearrange
(
x_unpad
,
'
nnz (three h d) -> nnz three h d
'
,
three
=
3
,
h
=
nheads
)
x_unpad
=
rearrange
(
x_unpad
,
"
nnz (three h d) -> nnz three h d
"
,
three
=
3
,
h
=
nheads
)
output_unpad
=
flash_blocksparse_attn_func
(
output_unpad
=
flash_blocksparse_attn_func
(
x_unpad
,
cu_seqlens
,
blockmask
,
self
.
dropout_p
if
self
.
training
else
0.0
,
x_unpad
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
cu_seqlens
,
blockmask
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
,
)
output
=
rearrange
(
pad_input
(
rearrange
(
output_unpad
,
"nnz h d -> nnz (h d)"
),
indices
,
batch_size
,
seqlen
),
"b s (h d) -> b s h d"
,
h
=
nheads
,
)
)
output
=
rearrange
(
pad_input
(
rearrange
(
output_unpad
,
'nnz h d -> nnz (h d)'
),
indices
,
batch_size
,
seqlen
),
'b s (h d) -> b s h d'
,
h
=
nheads
)
else
:
else
:
assert
max_s
is
not
None
assert
max_s
is
not
None
seqlen
=
max_s
seqlen
=
max_s
# Convert mask to take a subset
# Convert mask to take a subset
seqlen_rounded
=
((
seqlen
+
256
-
1
)
//
256
)
*
256
seqlen_rounded
=
((
seqlen
+
256
-
1
)
//
256
)
*
256
assert
seqlen_rounded
//
16
<=
self
.
layout
.
shape
[
0
],
seqlen_rounded
//
256
<=
self
.
layout
.
shape
[
1
]
assert
seqlen_rounded
//
16
<=
self
.
layout
.
shape
[
0
],
(
blockmask
=
self
.
layout
[:
seqlen_rounded
//
16
,
:
seqlen_rounded
//
256
]
seqlen_rounded
//
256
<=
self
.
layout
.
shape
[
1
]
)
blockmask
=
self
.
layout
[:
seqlen_rounded
//
16
,
:
seqlen_rounded
//
256
]
if
convert_mask
:
if
convert_mask
:
output
=
flash_blocksparse_attn_func
(
output
=
flash_blocksparse_attn_func
(
qkv
,
cu_seqlens
,
blockmask
,
self
.
dropout_p
if
self
.
training
else
0.0
,
qkv
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
cu_seqlens
,
blockmask
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
,
)
)
else
:
else
:
output
=
flash_blocksparse_attn_func
(
output
=
flash_blocksparse_attn_func
(
qkv
,
cu_seqlens
,
self
.
blockmask_converted
,
self
.
dropout_p
if
self
.
training
else
0.0
,
qkv
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
,
cu_seqlens
,
self
.
blockmask_converted
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
,
convert_mask
=
False
,
convert_mask
=
False
,
)
)
...
@@ -105,12 +152,22 @@ class FlashBlocksparseAttention(nn.Module):
...
@@ -105,12 +152,22 @@ class FlashBlocksparseAttention(nn.Module):
class
FlashBlocksparseMHA
(
nn
.
Module
):
class
FlashBlocksparseMHA
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
embed_dim
,
num_heads
,
sparsity_config
,
bias
=
True
,
batch_first
=
True
,
self
,
attention_dropout
=
0.0
,
causal
=
False
,
max_seq_length
=
2048
,
embed_dim
,
device
=
None
,
dtype
=
None
,
**
kwargs
)
->
None
:
num_heads
,
sparsity_config
,
bias
=
True
,
batch_first
=
True
,
attention_dropout
=
0.0
,
causal
=
False
,
max_seq_length
=
2048
,
device
=
None
,
dtype
=
None
,
**
kwargs
,
)
->
None
:
assert
batch_first
assert
batch_first
factory_kwargs
=
{
'
device
'
:
device
,
'
dtype
'
:
dtype
}
factory_kwargs
=
{
"
device
"
:
device
,
"
dtype
"
:
dtype
}
super
().
__init__
()
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
embed_dim
=
embed_dim
self
.
causal
=
causal
self
.
causal
=
causal
...
@@ -122,15 +179,19 @@ class FlashBlocksparseMHA(nn.Module):
...
@@ -122,15 +179,19 @@ class FlashBlocksparseMHA(nn.Module):
self
.
Wqkv
=
nn
.
Linear
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
Wqkv
=
nn
.
Linear
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
inner_attn
=
FlashBlocksparseAttention
(
self
.
inner_attn
=
FlashBlocksparseAttention
(
sparsity_config
,
attention_dropout
=
attention_dropout
,
sparsity_config
,
max_seq_length
=
max_seq_length
,
**
factory_kwargs
attention_dropout
=
attention_dropout
,
max_seq_length
=
max_seq_length
,
**
factory_kwargs
,
)
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
def
forward
(
self
,
x
,
x_ignored_
,
x_ignored_1_
,
attn_mask
=
None
,
key_padding_mask
=
None
,
def
forward
(
need_weights
=
False
):
self
,
x
,
x_ignored_
,
x_ignored_1_
,
attn_mask
=
None
,
key_padding_mask
=
None
,
need_weights
=
False
):
qkv
=
self
.
Wqkv
(
x
)
qkv
=
self
.
Wqkv
(
x
)
qkv
=
rearrange
(
qkv
,
'b s (three h d) -> b s three h d'
,
three
=
3
,
h
=
self
.
num_heads
)
qkv
=
rearrange
(
qkv
,
"b s (three h d) -> b s three h d"
,
three
=
3
,
h
=
self
.
num_heads
)
context
,
attn_weights
=
self
.
inner_attn
(
qkv
,
key_padding_mask
=
key_padding_mask
,
context
,
attn_weights
=
self
.
inner_attn
(
need_weights
=
need_weights
,
causal
=
self
.
causal
)
qkv
,
key_padding_mask
=
key_padding_mask
,
need_weights
=
need_weights
,
causal
=
self
.
causal
return
self
.
out_proj
(
rearrange
(
context
,
'b s h d -> b s (h d)'
)),
attn_weights
)
return
self
.
out_proj
(
rearrange
(
context
,
"b s h d -> b s (h d)"
)),
attn_weights
flash_attn/flash_blocksparse_attn_interface.py
View file @
f1a73d07
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py
import
flash_attn_cuda
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
flash_attn_cuda
def
convert_blockmask
(
blockmask
,
causal
):
def
convert_blockmask
(
blockmask
,
causal
):
"""Convert from the 0-1 format to the format used by the CUDA code.
"""Convert from the 0-1 format to the format used by the CUDA code.
...
@@ -40,29 +39,51 @@ def convert_blockmask(blockmask, causal):
...
@@ -40,29 +39,51 @@ def convert_blockmask(blockmask, causal):
return
nonzero_idx
.
T
.
contiguous
().
to
(
dtype
=
torch
.
int32
)
return
nonzero_idx
.
T
.
contiguous
().
to
(
dtype
=
torch
.
int32
)
def
_flash_blocksparse_attn_forward
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
def
_flash_blocksparse_attn_forward
(
causal
,
return_softmax
):
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
,
return_softmax
context
,
softmax_lse
,
*
rest
=
flash_attn_cuda
.
fwd_block
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
):
max_s
,
softmax_scale
,
causal
,
context
,
softmax_lse
,
*
rest
=
flash_attn_cuda
.
fwd_block
(
return_softmax
,
None
)
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
,
return_softmax
,
None
)
# if context.isnan().any() or softmax_lse.isnan().any():
# if context.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
# breakpoint()
S_dmask
=
rest
[
0
]
if
return_softmax
else
None
S_dmask
=
rest
[
0
]
if
return_softmax
else
None
return
context
,
softmax_lse
,
S_dmask
return
context
,
softmax_lse
,
S_dmask
def
_flash_blocksparse_attn_backward
(
dout
,
qkv
,
out
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
def
_flash_blocksparse_attn_backward
(
dropout_p
,
max_s
,
softmax_scale
,
causal
):
dout
,
dqkv
,
dp
,
softmax_d
=
flash_attn_cuda
.
bwd_block
(
dout
,
qkv
,
out
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
qkv
,
blockmask
,
dropout_p
,
softmax_scale
,
max_s
,
out
,
causal
,
None
)
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
,
):
dqkv
,
dp
,
softmax_d
=
flash_attn_cuda
.
bwd_block
(
dout
,
qkv
,
out
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
dropout_p
,
softmax_scale
,
max_s
,
causal
,
None
,
)
# if dqkv.isnan().any() or softmax_d.isnan().any():
# if dqkv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
# breakpoint()
return
dqkv
return
dqkv
class
FlashBlocksparseAttnFun
(
torch
.
autograd
.
Function
):
class
FlashBlocksparseAttnFun
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
# Save rng_state because the backward pass will regenerate the dropout mask
# Save rng_state because the backward pass will regenerate the dropout mask
...
@@ -70,8 +91,14 @@ class FlashBlocksparseAttnFun(torch.autograd.Function):
...
@@ -70,8 +91,14 @@ class FlashBlocksparseAttnFun(torch.autograd.Function):
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
context
,
softmax_lse
,
S_dmask
=
_flash_blocksparse_attn_forward
(
context
,
softmax_lse
,
S_dmask
=
_flash_blocksparse_attn_forward
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
=
causal
,
qkv
,
return_softmax
=
False
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
False
,
)
)
ctx
.
save_for_backward
(
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
rng_state
)
ctx
.
save_for_backward
(
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
dropout_p
=
dropout_p
...
@@ -88,8 +115,17 @@ class FlashBlocksparseAttnFun(torch.autograd.Function):
...
@@ -88,8 +115,17 @@ class FlashBlocksparseAttnFun(torch.autograd.Function):
torch
.
cuda
.
set_rng_state
(
rng_state
)
torch
.
cuda
.
set_rng_state
(
rng_state
)
# S_dmask is None, temporarily use another tensor just to get it running
# S_dmask is None, temporarily use another tensor just to get it running
dqkv
=
_flash_blocksparse_attn_backward
(
dqkv
=
_flash_blocksparse_attn_backward
(
dout
,
qkv
,
context
,
context
,
softmax_lse
,
cu_seqlens
,
blockmask
,
ctx
.
dropout_p
,
dout
,
ctx
.
max_s
,
ctx
.
softmax_scale
,
ctx
.
causal
qkv
,
context
,
context
,
softmax_lse
,
cu_seqlens
,
blockmask
,
ctx
.
dropout_p
,
ctx
.
max_s
,
ctx
.
softmax_scale
,
ctx
.
causal
,
)
)
if
rng_state
is
not
None
:
if
rng_state
is
not
None
:
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
...
@@ -99,7 +135,6 @@ class FlashBlocksparseAttnFun(torch.autograd.Function):
...
@@ -99,7 +135,6 @@ class FlashBlocksparseAttnFun(torch.autograd.Function):
# We duplicate code to return both the output and the softmax for testing
# We duplicate code to return both the output and the softmax for testing
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
class
FlashBlocksparseAttnFunWithS
(
torch
.
autograd
.
Function
):
class
FlashBlocksparseAttnFunWithS
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
# Save rng_state because the backward pass is gonna regenerate the dropout mask
# Save rng_state because the backward pass is gonna regenerate the dropout mask
...
@@ -107,8 +142,14 @@ class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
...
@@ -107,8 +142,14 @@ class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
if
softmax_scale
is
None
:
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
context
,
softmax_lse
,
S_dmask
=
_flash_blocksparse_attn_forward
(
context
,
softmax_lse
,
S_dmask
=
_flash_blocksparse_attn_forward
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
=
causal
,
qkv
,
return_softmax
=
True
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
True
,
)
)
ctx
.
save_for_backward
(
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
rng_state
)
ctx
.
save_for_backward
(
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
ctx
.
dropout_p
=
dropout_p
...
@@ -124,18 +165,35 @@ class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
...
@@ -124,18 +165,35 @@ class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
cur_rng_state
=
torch
.
cuda
.
get_rng_state
()
cur_rng_state
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
rng_state
)
torch
.
cuda
.
set_rng_state
(
rng_state
)
dqkv
=
_flash_blocksparse_attn_backward
(
dqkv
=
_flash_blocksparse_attn_backward
(
dout
,
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
ctx
.
dropout_p
,
dout
,
ctx
.
max_s
,
ctx
.
softmax_scale
,
ctx
.
causal
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
ctx
.
dropout_p
,
ctx
.
max_s
,
ctx
.
softmax_scale
,
ctx
.
causal
,
)
)
if
rng_state
is
not
None
:
if
rng_state
is
not
None
:
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
def
flash_blocksparse_attn_func
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
=
None
,
def
flash_blocksparse_attn_func
(
causal
=
False
,
return_attn_probs
=
False
,
convert_mask
=
True
):
qkv
,
"""dropout_p should be set to 0.0 during evaluation
cu_seqlens
,
"""
blockmask
,
dropout_p
,
max_s
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
convert_mask
=
True
,
):
"""dropout_p should be set to 0.0 during evaluation"""
func
=
FlashBlocksparseAttnFun
if
not
return_attn_probs
else
FlashBlocksparseAttnFunWithS
func
=
FlashBlocksparseAttnFun
if
not
return_attn_probs
else
FlashBlocksparseAttnFunWithS
if
convert_mask
:
if
convert_mask
:
blockmask
=
convert_blockmask
(
blockmask
,
causal
=
causal
)
blockmask
=
convert_blockmask
(
blockmask
,
causal
=
causal
)
...
...
flash_attn/fused_softmax.py
View file @
f1a73d07
...
@@ -17,13 +17,15 @@
...
@@ -17,13 +17,15 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
torch
import
torch
from
apex._autocast_utils
import
_cast_if_autocast_enabled
from
apex._autocast_utils
import
_cast_if_autocast_enabled
from
apex.transformer.enums
import
AttnMaskType
from
apex.transformer.enums
import
AttnMaskType
from
fused_softmax_lib
import
(
from
fused_softmax_lib
import
scaled_masked_softmax_forward
,
scaled_masked_softmax_backward
scaled_masked_softmax_backward
,
from
fused_softmax_lib
import
scaled_masked_softmax_get_batch_per_block
scaled_masked_softmax_forward
,
from
fused_softmax_lib
import
scaled_upper_triang_masked_softmax_forward
,
scaled_upper_triang_masked_softmax_backward
scaled_masked_softmax_get_batch_per_block
,
scaled_upper_triang_masked_softmax_backward
,
scaled_upper_triang_masked_softmax_forward
,
)
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
):
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
):
...
@@ -37,9 +39,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
...
@@ -37,9 +39,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
inputs
,
scale
):
def
forward
(
ctx
,
inputs
,
scale
):
scale_t
=
torch
.
tensor
([
scale
])
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_upper_triang_masked_softmax_forward
(
softmax_results
=
scaled_upper_triang_masked_softmax_forward
(
inputs
,
scale_t
[
0
])
inputs
,
scale_t
[
0
]
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
return
softmax_results
...
@@ -81,9 +81,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
...
@@ -81,9 +81,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
output_grads
):
def
backward
(
ctx
,
output_grads
):
softmax_results
,
scale_t
=
ctx
.
saved_tensors
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_masked_softmax_backward
(
input_grads
=
scaled_masked_softmax_backward
(
output_grads
,
softmax_results
,
scale_t
[
0
])
output_grads
,
softmax_results
,
scale_t
[
0
]
)
return
input_grads
,
None
,
None
return
input_grads
,
None
,
None
...
@@ -122,9 +120,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -122,9 +120,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_bf16
=
input_in_bf16
self
.
input_in_bf16
=
input_in_bf16
if
self
.
input_in_fp16
and
self
.
input_in_bf16
:
if
self
.
input_in_fp16
and
self
.
input_in_bf16
:
raise
RuntimeError
(
raise
RuntimeError
(
"both fp16 and bf16 flags cannot be active at the same time."
)
"both fp16 and bf16 flags cannot be active at the same time."
)
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
attn_mask_type
=
attn_mask_type
self
.
attn_mask_type
=
attn_mask_type
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
...
...
flash_attn/layers/patch_embed.py
View file @
f1a73d07
...
@@ -4,11 +4,10 @@
...
@@ -4,11 +4,10 @@
from
functools
import
partial
from
functools
import
partial
import
torch.nn
as
nn
import
torch.nn
as
nn
from
einops
import
rearrange
from
torch
import
_assert
from
torch
import
_assert
from
torch.nn.modules.utils
import
_pair
from
torch.nn.modules.utils
import
_pair
from
einops
import
rearrange
try
:
try
:
from
flash_attn.ops.fused_dense
import
FusedDense
from
flash_attn.ops.fused_dense
import
FusedDense
except
ImportError
:
except
ImportError
:
...
@@ -16,8 +15,8 @@ except ImportError:
...
@@ -16,8 +15,8 @@ except ImportError:
class
PatchEmbed
(
nn
.
Module
):
class
PatchEmbed
(
nn
.
Module
):
"""
2D Image to Patch Embedding
"""2D Image to Patch Embedding
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
img_size
=
224
,
img_size
=
224
,
...
@@ -38,7 +37,7 @@ class PatchEmbed(nn.Module):
...
@@ -38,7 +37,7 @@ class PatchEmbed(nn.Module):
self
.
num_patches
=
self
.
grid_size
[
0
]
*
self
.
grid_size
[
1
]
self
.
num_patches
=
self
.
grid_size
[
0
]
*
self
.
grid_size
[
1
]
self
.
flatten
=
flatten
self
.
flatten
=
flatten
if
fused_bias_fc
and
FusedDense
is
None
:
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
'
fused_dense is not installed
'
)
raise
ImportError
(
"
fused_dense is not installed
"
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
or
not
bias
else
FusedDense
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
or
not
bias
else
FusedDense
self
.
proj
=
linear_cls
(
in_chans
*
patch_size
[
0
]
*
patch_size
[
1
],
embed_dim
,
bias
=
bias
)
self
.
proj
=
linear_cls
(
in_chans
*
patch_size
[
0
]
*
patch_size
[
1
],
embed_dim
,
bias
=
bias
)
...
@@ -46,11 +45,23 @@ class PatchEmbed(nn.Module):
...
@@ -46,11 +45,23 @@ class PatchEmbed(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
_
,
_
,
H
,
W
=
x
.
shape
_
,
_
,
H
,
W
=
x
.
shape
_assert
(
H
==
self
.
img_size
[
0
],
f
"Input image height (
{
H
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
)."
)
_assert
(
_assert
(
W
==
self
.
img_size
[
1
],
f
"Input image width (
{
W
}
) doesn't match model (
{
self
.
img_size
[
1
]
}
)."
)
H
==
self
.
img_size
[
0
],
x
=
self
.
proj
(
rearrange
(
x
,
'b c (h p1) (w p2) -> b h w (c p1 p2)'
,
f
"Input image height (
{
H
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
)."
,
p1
=
self
.
patch_size
[
0
],
p2
=
self
.
patch_size
[
1
]))
)
_assert
(
W
==
self
.
img_size
[
1
],
f
"Input image width (
{
W
}
) doesn't match model (
{
self
.
img_size
[
1
]
}
)."
,
)
x
=
self
.
proj
(
rearrange
(
x
,
"b c (h p1) (w p2) -> b h w (c p1 p2)"
,
p1
=
self
.
patch_size
[
0
],
p2
=
self
.
patch_size
[
1
],
)
)
if
self
.
flatten
:
if
self
.
flatten
:
x
=
rearrange
(
x
,
'
b h w c -> b (h w) c
'
)
x
=
rearrange
(
x
,
"
b h w c -> b (h w) c
"
)
x
=
self
.
norm
(
x
)
x
=
self
.
norm
(
x
)
return
x
return
x
flash_attn/layers/rotary.py
View file @
f1a73d07
# Copyright (c) 2023, Tri Dao.
# Copyright (c) 2023, Tri Dao.
from
typing
import
Tuple
,
Optional
import
math
import
math
from
typing
import
Optional
,
Tuple
import
rotary_emb
import
torch
import
torch
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
import
rotary_emb
def
rotate_half
(
x
,
interleaved
=
False
):
def
rotate_half
(
x
,
interleaved
=
False
):
if
not
interleaved
:
if
not
interleaved
:
...
@@ -16,7 +14,7 @@ def rotate_half(x, interleaved=False):
...
@@ -16,7 +14,7 @@ def rotate_half(x, interleaved=False):
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
else
:
else
:
x1
,
x2
=
x
[...,
::
2
],
x
[...,
1
::
2
]
x1
,
x2
=
x
[...,
::
2
],
x
[...,
1
::
2
]
return
rearrange
(
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
),
'
... d two -> ... (d two)
'
,
two
=
2
)
return
rearrange
(
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
),
"
... d two -> ... (d two)
"
,
two
=
2
)
def
apply_rotary_emb_torch
(
x
,
cos
,
sin
,
interleaved
=
False
):
def
apply_rotary_emb_torch
(
x
,
cos
,
sin
,
interleaved
=
False
):
...
@@ -26,14 +24,15 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
...
@@ -26,14 +24,15 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
"""
ro_dim
=
cos
.
shape
[
-
1
]
*
2
ro_dim
=
cos
.
shape
[
-
1
]
*
2
assert
ro_dim
<=
x
.
shape
[
-
1
]
assert
ro_dim
<=
x
.
shape
[
-
1
]
cos
=
repeat
(
cos
,
's d -> s 1 (2 d)'
)
cos
=
repeat
(
cos
,
"s d -> s 1 (2 d)"
)
sin
=
repeat
(
sin
,
's d -> s 1 (2 d)'
)
sin
=
repeat
(
sin
,
"s d -> s 1 (2 d)"
)
return
torch
.
cat
([
x
[...,
:
ro_dim
]
*
cos
+
rotate_half
(
x
[...,
:
ro_dim
],
interleaved
)
*
sin
,
return
torch
.
cat
(
x
[...,
ro_dim
:]],
dim
=-
1
)
[
x
[...,
:
ro_dim
]
*
cos
+
rotate_half
(
x
[...,
:
ro_dim
],
interleaved
)
*
sin
,
x
[...,
ro_dim
:]],
dim
=-
1
,
)
class
ApplyRotaryEmb
(
torch
.
autograd
.
Function
):
class
ApplyRotaryEmb
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
x
,
cos
,
sin
,
interleaved
=
False
,
inplace
=
False
):
def
forward
(
ctx
,
x
,
cos
,
sin
,
interleaved
=
False
,
inplace
=
False
):
"""
"""
...
@@ -57,10 +56,20 @@ class ApplyRotaryEmb(torch.autograd.Function):
...
@@ -57,10 +56,20 @@ class ApplyRotaryEmb(torch.autograd.Function):
if
inplace
:
if
inplace
:
o1
,
o2
=
x1
,
x2
o1
,
o2
=
x1
,
x2
else
:
else
:
o1
,
o2
=
(
out_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
o1
,
o2
=
(
else
(
out_ro
[...,
::
2
],
out_ro
[...,
1
::
2
]))
out_ro
.
chunk
(
2
,
dim
=-
1
)
rotary_emb
.
apply_rotary
(
x1
,
x2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
if
not
interleaved
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
o1
,
o2
,
False
)
else
(
out_ro
[...,
::
2
],
out_ro
[...,
1
::
2
])
)
rotary_emb
.
apply_rotary
(
x1
,
x2
,
rearrange
(
cos
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin
[:
seqlen
],
"s d -> s 1 d"
),
o1
,
o2
,
False
,
)
if
not
inplace
and
rotary_dim
<
headdim
:
if
not
inplace
and
rotary_dim
<
headdim
:
out
[...,
rotary_dim
:].
copy_
(
x
[...,
rotary_dim
:])
out
[...,
rotary_dim
:].
copy_
(
x
[...,
rotary_dim
:])
ctx
.
save_for_backward
(
cos
,
sin
)
ctx
.
save_for_backward
(
cos
,
sin
)
...
@@ -76,17 +85,28 @@ class ApplyRotaryEmb(torch.autograd.Function):
...
@@ -76,17 +85,28 @@ class ApplyRotaryEmb(torch.autograd.Function):
rotary_dim
*=
2
rotary_dim
*=
2
inplace
=
ctx
.
inplace
inplace
=
ctx
.
inplace
do_ro
=
do
[...,
:
rotary_dim
]
do_ro
=
do
[...,
:
rotary_dim
]
do1
,
do2
=
(
do_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
do1
,
do2
=
(
else
(
do_ro
[...,
::
2
],
do_ro
[...,
1
::
2
]))
do_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
do_ro
[...,
::
2
],
do_ro
[...,
1
::
2
])
)
dx
=
torch
.
empty_like
(
do
)
if
not
inplace
else
do
dx
=
torch
.
empty_like
(
do
)
if
not
inplace
else
do
if
inplace
:
if
inplace
:
dx1
,
dx2
=
do1
,
do2
dx1
,
dx2
=
do1
,
do2
else
:
else
:
dx_ro
=
dx
[...,
:
rotary_dim
]
dx_ro
=
dx
[...,
:
rotary_dim
]
dx1
,
dx2
=
(
dx_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
dx1
,
dx2
=
(
else
(
dx_ro
[...,
::
2
],
dx_ro
[...,
1
::
2
]))
dx_ro
.
chunk
(
2
,
dim
=-
1
)
rotary_emb
.
apply_rotary
(
do1
,
do2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
if
not
ctx
.
interleaved
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
dx1
,
dx2
,
True
)
else
(
dx_ro
[...,
::
2
],
dx_ro
[...,
1
::
2
])
)
rotary_emb
.
apply_rotary
(
do1
,
do2
,
rearrange
(
cos
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin
[:
seqlen
],
"s d -> s 1 d"
),
dx1
,
dx2
,
True
,
)
if
not
inplace
and
rotary_dim
<
headdim
:
if
not
inplace
and
rotary_dim
<
headdim
:
dx
[...,
rotary_dim
:].
copy_
(
do
[...,
rotary_dim
:])
dx
[...,
rotary_dim
:].
copy_
(
do
[...,
rotary_dim
:])
return
dx
,
None
,
None
,
None
,
None
return
dx
,
None
,
None
,
None
,
None
...
@@ -96,7 +116,6 @@ apply_rotary_emb_func = ApplyRotaryEmb.apply
...
@@ -96,7 +116,6 @@ apply_rotary_emb_func = ApplyRotaryEmb.apply
class
ApplyRotaryEmbQKV_
(
torch
.
autograd
.
Function
):
class
ApplyRotaryEmbQKV_
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
qkv
,
cos
,
sin
,
cos_k
=
None
,
sin_k
=
None
,
interleaved
=
False
):
def
forward
(
ctx
,
qkv
,
cos
,
sin
,
cos_k
=
None
,
sin_k
=
None
,
interleaved
=
False
):
"""
"""
...
@@ -119,12 +138,26 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
...
@@ -119,12 +138,26 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
assert
sin
.
shape
==
cos_k
.
shape
==
sin_k
.
shape
==
(
rotary_seqlen
,
rotary_dim
//
2
)
assert
sin
.
shape
==
cos_k
.
shape
==
sin_k
.
shape
==
(
rotary_seqlen
,
rotary_dim
//
2
)
q_ro
=
qkv
[:,
:,
0
,
:,
:
rotary_dim
]
q_ro
=
qkv
[:,
:,
0
,
:,
:
rotary_dim
]
q1
,
q2
=
q_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
q_ro
[...,
::
2
],
q_ro
[...,
1
::
2
])
q1
,
q2
=
q_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
q_ro
[...,
::
2
],
q_ro
[...,
1
::
2
])
rotary_emb
.
apply_rotary
(
q1
,
q2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rotary_emb
.
apply_rotary
(
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
q1
,
q2
,
False
)
q1
,
q2
,
rearrange
(
cos
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin
[:
seqlen
],
"s d -> s 1 d"
),
q1
,
q2
,
False
,
)
k_ro
=
qkv
[:,
:,
1
,
:,
:
rotary_dim
]
k_ro
=
qkv
[:,
:,
1
,
:,
:
rotary_dim
]
k1
,
k2
=
k_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
k_ro
[...,
::
2
],
k_ro
[...,
1
::
2
])
k1
,
k2
=
k_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
k_ro
[...,
::
2
],
k_ro
[...,
1
::
2
])
rotary_emb
.
apply_rotary
(
k1
,
k2
,
rearrange
(
cos_k
[:
seqlen
],
's d -> s 1 d'
),
rotary_emb
.
apply_rotary
(
rearrange
(
sin_k
[:
seqlen
],
's d -> s 1 d'
),
k1
,
k2
,
False
)
k1
,
k2
,
rearrange
(
cos_k
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin_k
[:
seqlen
],
"s d -> s 1 d"
),
k1
,
k2
,
False
,
)
ctx
.
save_for_backward
(
cos
,
sin
,
cos_k
,
sin_k
)
ctx
.
save_for_backward
(
cos
,
sin
,
cos_k
,
sin_k
)
ctx
.
interleaved
=
interleaved
ctx
.
interleaved
=
interleaved
return
qkv
return
qkv
...
@@ -136,15 +169,31 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
...
@@ -136,15 +169,31 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
rotary_dim
=
cos
.
shape
[
-
1
]
rotary_dim
=
cos
.
shape
[
-
1
]
rotary_dim
*=
2
rotary_dim
*=
2
dq_ro
=
dqkv
[:,
:,
0
,
:,
:
rotary_dim
]
dq_ro
=
dqkv
[:,
:,
0
,
:,
:
rotary_dim
]
dq1
,
dq2
=
(
dq_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
dq1
,
dq2
=
(
else
(
dq_ro
[...,
::
2
],
dq_ro
[...,
1
::
2
]))
dq_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
dq_ro
[...,
::
2
],
dq_ro
[...,
1
::
2
])
rotary_emb
.
apply_rotary
(
dq1
,
dq2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
)
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
dq1
,
dq2
,
True
)
rotary_emb
.
apply_rotary
(
dq1
,
dq2
,
rearrange
(
cos
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin
[:
seqlen
],
"s d -> s 1 d"
),
dq1
,
dq2
,
True
,
)
dk_ro
=
dqkv
[:,
:,
1
,
:,
:
rotary_dim
]
dk_ro
=
dqkv
[:,
:,
1
,
:,
:
rotary_dim
]
dk1
,
dk2
=
(
dk_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
dk1
,
dk2
=
(
else
(
dk_ro
[...,
::
2
],
dk_ro
[...,
1
::
2
]))
dk_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
dk_ro
[...,
::
2
],
dk_ro
[...,
1
::
2
])
rotary_emb
.
apply_rotary
(
dk1
,
dk2
,
rearrange
(
cos_k
[:
seqlen
],
's d -> s 1 d'
),
)
rearrange
(
sin_k
[:
seqlen
],
's d -> s 1 d'
),
dk1
,
dk2
,
True
)
rotary_emb
.
apply_rotary
(
dk1
,
dk2
,
rearrange
(
cos_k
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin_k
[:
seqlen
],
"s d -> s 1 d"
),
dk1
,
dk2
,
True
,
)
return
dqkv
,
None
,
None
,
None
,
None
,
None
return
dqkv
,
None
,
None
,
None
,
None
,
None
...
@@ -152,7 +201,6 @@ apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
...
@@ -152,7 +201,6 @@ apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
class
ApplyRotaryEmbKV_
(
torch
.
autograd
.
Function
):
class
ApplyRotaryEmbKV_
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
kv
,
cos
,
sin
,
interleaved
=
False
):
def
forward
(
ctx
,
kv
,
cos
,
sin
,
interleaved
=
False
):
"""
"""
...
@@ -171,9 +219,15 @@ class ApplyRotaryEmbKV_(torch.autograd.Function):
...
@@ -171,9 +219,15 @@ class ApplyRotaryEmbKV_(torch.autograd.Function):
assert
seqlen
<=
rotary_seqlen
assert
seqlen
<=
rotary_seqlen
k_ro
=
kv
[:,
:,
0
,
:,
:
rotary_dim
]
k_ro
=
kv
[:,
:,
0
,
:,
:
rotary_dim
]
k1
,
k2
=
k_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
k_ro
[...,
::
2
],
k_ro
[...,
1
::
2
])
k1
,
k2
=
k_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
k_ro
[...,
::
2
],
k_ro
[...,
1
::
2
])
rotary_emb
.
apply_rotary
(
k1
,
k2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rotary_emb
.
apply_rotary
(
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
k1
,
k2
,
k1
,
False
)
# conj=False since this is the forward pass
k2
,
rearrange
(
cos
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin
[:
seqlen
],
"s d -> s 1 d"
),
k1
,
k2
,
False
,
)
# conj=False since this is the forward pass
ctx
.
save_for_backward
(
cos
,
sin
)
ctx
.
save_for_backward
(
cos
,
sin
)
ctx
.
interleaved
=
interleaved
ctx
.
interleaved
=
interleaved
return
kv
return
kv
...
@@ -185,11 +239,18 @@ class ApplyRotaryEmbKV_(torch.autograd.Function):
...
@@ -185,11 +239,18 @@ class ApplyRotaryEmbKV_(torch.autograd.Function):
rotary_dim
=
cos
.
shape
[
-
1
]
rotary_dim
=
cos
.
shape
[
-
1
]
rotary_dim
*=
2
rotary_dim
*=
2
dk_ro
=
dkv
[:,
:,
0
,
:,
:
rotary_dim
]
dk_ro
=
dkv
[:,
:,
0
,
:,
:
rotary_dim
]
dk1
,
dk2
=
(
dk_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
dk1
,
dk2
=
(
else
(
dk_ro
[...,
::
2
],
dk_ro
[...,
1
::
2
]))
dk_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
dk_ro
[...,
::
2
],
dk_ro
[...,
1
::
2
])
rotary_emb
.
apply_rotary
(
dk1
,
dk2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
)
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
dk1
,
dk2
,
rotary_emb
.
apply_rotary
(
True
)
# conj=True since this is the backward pass
dk1
,
dk2
,
rearrange
(
cos
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin
[:
seqlen
],
"s d -> s 1 d"
),
dk1
,
dk2
,
True
,
)
# conj=True since this is the backward pass
return
dkv
,
None
,
None
,
None
return
dkv
,
None
,
None
,
None
...
@@ -214,8 +275,15 @@ class RotaryEmbedding(torch.nn.Module):
...
@@ -214,8 +275,15 @@ class RotaryEmbedding(torch.nn.Module):
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
"""
"""
def
__init__
(
self
,
dim
:
int
,
base
=
10000.0
,
interleaved
=
False
,
scale_base
=
None
,
def
__init__
(
pos_idx_in_fp32
=
True
,
device
=
None
):
self
,
dim
:
int
,
base
=
10000.0
,
interleaved
=
False
,
scale_base
=
None
,
pos_idx_in_fp32
=
True
,
device
=
None
,
):
"""
"""
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
of 1st half and 2nd half (GPT-NeoX style).
...
@@ -239,8 +307,11 @@ class RotaryEmbedding(torch.nn.Module):
...
@@ -239,8 +307,11 @@ class RotaryEmbedding(torch.nn.Module):
self
.
register_buffer
(
"inv_freq"
,
inv_freq
,
persistent
=
False
)
self
.
register_buffer
(
"inv_freq"
,
inv_freq
,
persistent
=
False
)
self
.
interleaved
=
interleaved
self
.
interleaved
=
interleaved
self
.
scale_base
=
scale_base
self
.
scale_base
=
scale_base
scale
=
((
torch
.
arange
(
0
,
dim
,
2
,
device
=
device
,
dtype
=
torch
.
float32
)
+
0.4
*
dim
)
scale
=
(
/
(
1.4
*
dim
)
if
scale_base
is
not
None
else
None
)
(
torch
.
arange
(
0
,
dim
,
2
,
device
=
device
,
dtype
=
torch
.
float32
)
+
0.4
*
dim
)
/
(
1.4
*
dim
)
if
scale_base
is
not
None
else
None
)
self
.
register_buffer
(
"scale"
,
scale
,
persistent
=
False
)
self
.
register_buffer
(
"scale"
,
scale
,
persistent
=
False
)
self
.
_seq_len_cached
=
0
self
.
_seq_len_cached
=
0
...
@@ -250,17 +321,21 @@ class RotaryEmbedding(torch.nn.Module):
...
@@ -250,17 +321,21 @@ class RotaryEmbedding(torch.nn.Module):
self
.
_sin_k_cached
=
None
self
.
_sin_k_cached
=
None
def
_compute_inv_freq
(
self
,
device
=
None
):
def
_compute_inv_freq
(
self
,
device
=
None
):
return
1.0
/
(
self
.
base
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
,
device
=
device
,
return
1.0
/
(
dtype
=
torch
.
float32
)
/
self
.
dim
))
self
.
base
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
,
device
=
device
,
dtype
=
torch
.
float32
)
/
self
.
dim
)
)
def
_update_cos_sin_cache
(
self
,
seqlen
,
device
=
None
,
dtype
=
None
):
def
_update_cos_sin_cache
(
self
,
seqlen
,
device
=
None
,
dtype
=
None
):
# Reset the tables if the sequence length has changed,
# Reset the tables if the sequence length has changed,
# if we're on a new device (possibly due to tracing for instance),
# if we're on a new device (possibly due to tracing for instance),
# or if we're switching from inference mode to training
# or if we're switching from inference mode to training
if
(
seqlen
>
self
.
_seq_len_cached
or
self
.
_cos_cached
.
device
!=
device
if
(
seqlen
>
self
.
_seq_len_cached
or
self
.
_cos_cached
.
device
!=
device
or
self
.
_cos_cached
.
dtype
!=
dtype
or
self
.
_cos_cached
.
dtype
!=
dtype
or
(
self
.
training
and
self
.
_cos_cached
.
is_inference
())):
or
(
self
.
training
and
self
.
_cos_cached
.
is_inference
())
):
self
.
_seq_len_cached
=
seqlen
self
.
_seq_len_cached
=
seqlen
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
...
@@ -285,17 +360,20 @@ class RotaryEmbedding(torch.nn.Module):
...
@@ -285,17 +360,20 @@ class RotaryEmbedding(torch.nn.Module):
self
.
_cos_cached
=
torch
.
cos
(
freqs
).
to
(
dtype
)
self
.
_cos_cached
=
torch
.
cos
(
freqs
).
to
(
dtype
)
self
.
_sin_cached
=
torch
.
sin
(
freqs
).
to
(
dtype
)
self
.
_sin_cached
=
torch
.
sin
(
freqs
).
to
(
dtype
)
else
:
else
:
power
=
((
torch
.
arange
(
seqlen
,
dtype
=
self
.
scale
.
dtype
,
device
=
self
.
scale
.
device
)
power
=
(
-
seqlen
//
2
)
/
self
.
scale_base
)
torch
.
arange
(
seqlen
,
dtype
=
self
.
scale
.
dtype
,
device
=
self
.
scale
.
device
)
scale
=
self
.
scale
.
to
(
device
=
power
.
device
)
**
rearrange
(
power
,
's -> s 1'
)
-
seqlen
//
2
)
/
self
.
scale_base
scale
=
self
.
scale
.
to
(
device
=
power
.
device
)
**
rearrange
(
power
,
"s -> s 1"
)
# We want the multiplication by scale to happen in fp32
# We want the multiplication by scale to happen in fp32
self
.
_cos_cached
=
(
torch
.
cos
(
freqs
)
*
scale
).
to
(
dtype
)
self
.
_cos_cached
=
(
torch
.
cos
(
freqs
)
*
scale
).
to
(
dtype
)
self
.
_sin_cached
=
(
torch
.
sin
(
freqs
)
*
scale
).
to
(
dtype
)
self
.
_sin_cached
=
(
torch
.
sin
(
freqs
)
*
scale
).
to
(
dtype
)
self
.
_cos_k_cached
=
(
torch
.
cos
(
freqs
)
/
scale
).
to
(
dtype
)
self
.
_cos_k_cached
=
(
torch
.
cos
(
freqs
)
/
scale
).
to
(
dtype
)
self
.
_sin_k_cached
=
(
torch
.
sin
(
freqs
)
/
scale
).
to
(
dtype
)
self
.
_sin_k_cached
=
(
torch
.
sin
(
freqs
)
/
scale
).
to
(
dtype
)
def
forward
(
self
,
qkv
:
torch
.
Tensor
,
kv
:
Optional
[
torch
.
Tensor
]
=
None
,
def
forward
(
seqlen_offset
:
int
=
0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
self
,
qkv
:
torch
.
Tensor
,
kv
:
Optional
[
torch
.
Tensor
]
=
None
,
seqlen_offset
:
int
=
0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
else it's just q of shape (batch, seqlen, nheads, headdim)
else it's just q of shape (batch, seqlen, nheads, headdim)
...
@@ -308,29 +386,43 @@ class RotaryEmbedding(torch.nn.Module):
...
@@ -308,29 +386,43 @@ class RotaryEmbedding(torch.nn.Module):
if
kv
is
None
:
if
kv
is
None
:
if
self
.
scale
is
None
:
if
self
.
scale
is
None
:
return
apply_rotary_emb_qkv_
(
return
apply_rotary_emb_qkv_
(
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
qkv
,
None
,
None
,
self
.
interleaved
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
None
,
None
,
self
.
interleaved
,
)
)
else
:
else
:
return
apply_rotary_emb_qkv_
(
return
apply_rotary_emb_qkv_
(
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
qkv
,
self
.
_cos_k_cached
[
seqlen_offset
:],
self
.
_sin_k_cached
[
seqlen_offset
:],
self
.
_cos_cached
[
seqlen_offset
:],
self
.
interleaved
self
.
_sin_cached
[
seqlen_offset
:],
self
.
_cos_k_cached
[
seqlen_offset
:],
self
.
_sin_k_cached
[
seqlen_offset
:],
self
.
interleaved
,
)
)
else
:
else
:
q
=
qkv
q
=
qkv
q
=
apply_rotary_emb_func
(
q
=
apply_rotary_emb_func
(
q
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
q
,
self
.
interleaved
,
True
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
self
.
interleaved
,
True
,
)
)
if
self
.
scale
is
None
:
if
self
.
scale
is
None
:
kv
=
apply_rotary_emb_kv_
(
kv
=
apply_rotary_emb_kv_
(
kv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
kv
,
self
.
interleaved
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
self
.
interleaved
,
)
)
else
:
else
:
kv
=
apply_rotary_emb_kv_
(
kv
=
apply_rotary_emb_kv_
(
kv
,
self
.
_cos_k_cached
[
seqlen_offset
:],
self
.
_sin_k_cached
[
seqlen_offset
:],
kv
,
self
.
interleaved
self
.
_cos_k_cached
[
seqlen_offset
:],
self
.
_sin_k_cached
[
seqlen_offset
:],
self
.
interleaved
,
)
)
return
q
,
kv
return
q
,
kv
flash_attn/losses/cross_entropy.py
View file @
f1a73d07
...
@@ -5,7 +5,6 @@
...
@@ -5,7 +5,6 @@
# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
xentropy_cuda_lib
import
xentropy_cuda_lib
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
...
@@ -17,10 +16,16 @@ if "all_gather_into_tensor" not in dir(torch.distributed):
...
@@ -17,10 +16,16 @@ if "all_gather_into_tensor" not in dir(torch.distributed):
class
SoftmaxCrossEntropyLossFn
(
torch
.
autograd
.
Function
):
class
SoftmaxCrossEntropyLossFn
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
logits
,
labels
,
smoothing
=
0.0
,
ignored_index
=-
100
,
inplace_backward
=
False
,
def
forward
(
process_group
=
None
):
ctx
,
logits
,
labels
,
smoothing
=
0.0
,
ignored_index
=-
100
,
inplace_backward
=
False
,
process_group
=
None
,
):
"""
"""
logits: (batch, vocab_size)
logits: (batch, vocab_size)
labels: (batch,)
labels: (batch,)
...
@@ -34,7 +39,7 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
...
@@ -34,7 +39,7 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
if
world_size
==
1
:
if
world_size
==
1
:
losses
,
lse
=
xentropy_cuda_lib
.
forward
(
logits
,
labels
,
smoothing
)
losses
,
lse
=
xentropy_cuda_lib
.
forward
(
logits
,
labels
,
smoothing
)
losses
.
masked_fill_
(
labels
==
ignored_index
,
0
)
losses
.
masked_fill_
(
labels
==
ignored_index
,
0
)
labels_local
=
labels
labels_local
=
labels
else
:
else
:
rank
=
torch
.
distributed
.
get_rank
(
process_group
)
rank
=
torch
.
distributed
.
get_rank
(
process_group
)
...
@@ -48,8 +53,9 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
...
@@ -48,8 +53,9 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
# For tensor parallel cross entropy with smoothing, we want to pass in the total number
# For tensor parallel cross entropy with smoothing, we want to pass in the total number
# of classes so that smoothing can be applied correctly. If total_classes=-1, use the
# of classes so that smoothing can be applied correctly. If total_classes=-1, use the
# last dimension of the input tensor.
# last dimension of the input tensor.
losses
,
lse_local
=
xentropy_cuda_lib
.
forward
(
logits
,
labels_local
,
smoothing
,
losses
,
lse_local
=
xentropy_cuda_lib
.
forward
(
world_size
*
vocab_size
)
logits
,
labels_local
,
smoothing
,
world_size
*
vocab_size
)
assert
lse_local
.
shape
==
(
batch
,)
assert
lse_local
.
shape
==
(
batch
,)
assert
losses
.
shape
==
(
batch
,)
assert
losses
.
shape
==
(
batch
,)
losses
.
masked_fill_
(
ignored_mask
,
0
)
losses
.
masked_fill_
(
ignored_mask
,
0
)
...
@@ -61,10 +67,12 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
...
@@ -61,10 +67,12 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
# For labels not in the vocab of this partition, losses contains
# For labels not in the vocab of this partition, losses contains
# 0.1 * (lse_local - sum logit / total_classes).
# 0.1 * (lse_local - sum logit / total_classes).
lse_allgather
=
torch
.
empty
(
world_size
,
batch
,
dtype
=
lse_local
.
dtype
,
lse_allgather
=
torch
.
empty
(
device
=
lse_local
.
device
)
world_size
,
batch
,
dtype
=
lse_local
.
dtype
,
device
=
lse_local
.
device
torch
.
distributed
.
all_gather_into_tensor
(
lse_allgather
,
lse_local
.
contiguous
(),
)
group
=
process_group
)
torch
.
distributed
.
all_gather_into_tensor
(
lse_allgather
,
lse_local
.
contiguous
(),
group
=
process_group
)
handle_losses
=
torch
.
distributed
.
all_reduce
(
handle_losses
=
torch
.
distributed
.
all_reduce
(
losses
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
process_group
,
async_op
=
True
losses
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
process_group
,
async_op
=
True
)
)
...
@@ -74,16 +82,18 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
...
@@ -74,16 +82,18 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
# If there's smoothing=0.1, the total losses are
# If there's smoothing=0.1, the total losses are
# 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
# 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
# We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
# We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
rank_per_sample
=
torch
.
div
(
labels
,
vocab_size
,
rounding_mode
=
'floor'
)
rank_per_sample
=
torch
.
div
(
labels
,
vocab_size
,
rounding_mode
=
"floor"
)
lse_local
=
lse_allgather
[
rank_per_sample
,
lse_local
=
lse_allgather
[
torch
.
arange
(
batch
,
device
=
lse_allgather
.
device
)]
rank_per_sample
,
torch
.
arange
(
batch
,
device
=
lse_allgather
.
device
)
]
handle_losses
.
wait
()
handle_losses
.
wait
()
if
smoothing
==
0.0
:
if
smoothing
==
0.0
:
losses
+=
lse
-
lse_local
losses
+=
lse
-
lse_local
else
:
else
:
losses
+=
((
1
-
smoothing
)
*
(
lse
-
lse_local
)
losses
+=
(
1
-
smoothing
)
*
(
lse
-
lse_local
)
+
smoothing
*
(
+
smoothing
*
(
lse
-
lse_allgather
.
sum
(
dim
=
0
)))
lse
-
lse_allgather
.
sum
(
dim
=
0
)
)
losses
.
masked_fill_
(
ignored_mask
,
0
)
losses
.
masked_fill_
(
ignored_mask
,
0
)
ctx
.
save_for_backward
(
logits
,
lse
,
labels_local
)
ctx
.
save_for_backward
(
logits
,
lse
,
labels_local
)
...
@@ -96,19 +106,24 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
...
@@ -96,19 +106,24 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
def
backward
(
ctx
,
grad_loss
):
def
backward
(
ctx
,
grad_loss
):
logits
,
lse
,
labels
=
ctx
.
saved_tensors
logits
,
lse
,
labels
=
ctx
.
saved_tensors
grad_loss
=
grad_loss
.
contiguous
()
grad_loss
=
grad_loss
.
contiguous
()
grad_loss
.
masked_fill_
(
labels
==
ctx
.
ignored_index
,
0
)
grad_loss
.
masked_fill_
(
labels
==
ctx
.
ignored_index
,
0
)
grad_logits
=
xentropy_cuda_lib
.
backward
(
grad_loss
,
logits
,
lse
,
labels
,
grad_logits
=
xentropy_cuda_lib
.
backward
(
ctx
.
smoothing
,
ctx
.
inplace_backward
,
grad_loss
,
logits
,
lse
,
labels
,
ctx
.
smoothing
,
ctx
.
inplace_backward
,
ctx
.
total_classes
ctx
.
total_classes
)
)
return
grad_logits
,
None
,
None
,
None
,
None
,
None
,
None
return
grad_logits
,
None
,
None
,
None
,
None
,
None
,
None
class
CrossEntropyLoss
(
nn
.
Module
):
class
CrossEntropyLoss
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
ignore_index
=-
100
,
reduction
=
'mean'
,
label_smoothing
=
0.0
,
self
,
inplace_backward
=
False
,
process_group
=
None
):
ignore_index
=-
100
,
reduction
=
"mean"
,
label_smoothing
=
0.0
,
inplace_backward
=
False
,
process_group
=
None
,
):
super
().
__init__
()
super
().
__init__
()
if
reduction
not
in
[
'
mean
'
,
'
none
'
]:
if
reduction
not
in
[
"
mean
"
,
"
none
"
]:
raise
NotImplementedError
(
"Only support reduction = 'mean' or 'none'"
)
raise
NotImplementedError
(
"Only support reduction = 'mean' or 'none'"
)
self
.
ignore_index
=
ignore_index
self
.
ignore_index
=
ignore_index
self
.
reduction
=
reduction
self
.
reduction
=
reduction
...
@@ -120,10 +135,14 @@ class CrossEntropyLoss(nn.Module):
...
@@ -120,10 +135,14 @@ class CrossEntropyLoss(nn.Module):
assert
input
.
is_cuda
and
target
.
is_cuda
assert
input
.
is_cuda
and
target
.
is_cuda
# SoftmaxCrossEntropyLoss implicitly casts to float
# SoftmaxCrossEntropyLoss implicitly casts to float
loss
=
SoftmaxCrossEntropyLossFn
.
apply
(
loss
=
SoftmaxCrossEntropyLossFn
.
apply
(
input
,
target
,
self
.
label_smoothing
,
self
.
ignore_index
,
self
.
inplace_backward
,
input
,
self
.
process_group
target
,
self
.
label_smoothing
,
self
.
ignore_index
,
self
.
inplace_backward
,
self
.
process_group
,
)
)
if
self
.
reduction
==
'
mean
'
:
if
self
.
reduction
==
"
mean
"
:
return
loss
.
sum
()
/
(
target
!=
self
.
ignore_index
).
sum
()
return
loss
.
sum
()
/
(
target
!=
self
.
ignore_index
).
sum
()
else
:
else
:
return
loss
return
loss
flash_attn/models/bert.py
View file @
f1a73d07
...
@@ -5,29 +5,32 @@
...
@@ -5,29 +5,32 @@
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
import
re
import
logging
import
logging
from
functools
import
partial
import
re
from
collections.abc
import
Sequence
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
collections.abc
import
Sequence
from
functools
import
partial
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
transformers
import
BertConfig
from
transformers.models.bert.modeling_bert
import
BaseModelOutputWithPoolingAndCrossAttentions
from
transformers.models.bert.modeling_bert
import
BertForPreTrainingOutput
from
einops
import
rearrange
from
einops
import
rearrange
from
transformers
import
BertConfig
from
flash_attn.modules.mha
import
MHA
from
transformers.models.bert.modeling_bert
import
(
from
flash_attn.modules.mlp
import
Mlp
,
FusedMLP
BaseModelOutputWithPoolingAndCrossAttentions
,
BertForPreTrainingOutput
,
)
from
flash_attn.bert_padding
import
(
index_first_axis
,
index_first_axis_residual
,
pad_input
,
unpad_input
,
)
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.embedding
import
BertEmbeddings
from
flash_attn.modules.embedding
import
BertEmbeddings
from
flash_attn.
bert_padding
import
unpad_input
,
pad_input
from
flash_attn.
modules.mha
import
MHA
from
flash_attn.
bert_padding
import
index_first_axis
,
index_first_axis_residual
from
flash_attn.
modules.mlp
import
FusedMLP
,
Mlp
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
try
:
try
:
...
@@ -50,48 +53,63 @@ logger = logging.getLogger(__name__)
...
@@ -50,48 +53,63 @@ logger = logging.getLogger(__name__)
def
create_mixer_cls
(
config
,
cross_attn
=
False
,
return_residual
=
False
):
def
create_mixer_cls
(
config
,
cross_attn
=
False
,
return_residual
=
False
):
use_flash_attn
=
getattr
(
config
,
'
use_flash_attn
'
,
False
)
use_flash_attn
=
getattr
(
config
,
"
use_flash_attn
"
,
False
)
fused_bias_fc
=
getattr
(
config
,
'
fused_bias_fc
'
,
False
)
fused_bias_fc
=
getattr
(
config
,
"
fused_bias_fc
"
,
False
)
rotary_kwargs
=
{}
rotary_kwargs
=
{}
if
config
.
position_embedding_type
==
"rotary"
:
if
config
.
position_embedding_type
==
"rotary"
:
rotary_kwargs
[
"rotary_emb_dim"
]
=
getattr
(
config
,
"rotary_emb_dim"
,
config
.
hidden_size
)
rotary_kwargs
[
"rotary_emb_dim"
]
=
getattr
(
config
,
"rotary_emb_dim"
,
config
.
hidden_size
)
rotary_kwargs
[
"rotary_emb_base"
]
=
getattr
(
config
,
"rotary_emb_base"
,
10000.0
)
rotary_kwargs
[
"rotary_emb_base"
]
=
getattr
(
config
,
"rotary_emb_base"
,
10000.0
)
rotary_kwargs
[
"rotary_emb_scale_base"
]
=
getattr
(
config
,
"rotary_emb_scale_base"
,
None
)
rotary_kwargs
[
"rotary_emb_scale_base"
]
=
getattr
(
config
,
"rotary_emb_scale_base"
,
None
)
rotary_kwargs
[
"rotary_emb_interleaved"
]
=
getattr
(
config
,
"rotary_emb_interleaved"
,
False
)
rotary_kwargs
[
"rotary_emb_interleaved"
]
=
getattr
(
config
,
"rotary_emb_interleaved"
,
False
)
mixer_cls
=
partial
(
MHA
,
num_heads
=
config
.
num_attention_heads
,
cross_attn
=
cross_attn
,
mixer_cls
=
partial
(
dropout
=
config
.
attention_probs_dropout_prob
,
causal
=
False
,
MHA
,
fused_bias_fc
=
fused_bias_fc
,
use_flash_attn
=
use_flash_attn
,
num_heads
=
config
.
num_attention_heads
,
return_residual
=
return_residual
,
**
rotary_kwargs
)
cross_attn
=
cross_attn
,
dropout
=
config
.
attention_probs_dropout_prob
,
causal
=
False
,
fused_bias_fc
=
fused_bias_fc
,
use_flash_attn
=
use_flash_attn
,
return_residual
=
return_residual
,
**
rotary_kwargs
,
)
return
mixer_cls
return
mixer_cls
def
create_mlp_cls
(
config
,
layer_idx
=
None
,
return_residual
=
False
):
def
create_mlp_cls
(
config
,
layer_idx
=
None
,
return_residual
=
False
):
inner_dim
=
config
.
intermediate_size
inner_dim
=
config
.
intermediate_size
fused_mlp
=
getattr
(
config
,
'
fused_mlp
'
,
False
)
fused_mlp
=
getattr
(
config
,
"
fused_mlp
"
,
False
)
if
fused_mlp
:
if
fused_mlp
:
assert
config
.
hidden_act
in
[
'gelu_new'
,
'gelu_fast'
],
(
'fused_mlp only '
assert
config
.
hidden_act
in
[
"gelu_new"
,
"gelu_fast"
],
(
'supports approximate gelu'
)
"fused_mlp only "
"supports approximate gelu"
)
if
not
fused_mlp
:
if
not
fused_mlp
:
approximate
=
'tanh'
if
config
.
hidden_act
in
[
'gelu_new'
,
'gelu_fast'
]
else
'none'
approximate
=
"tanh"
if
config
.
hidden_act
in
[
"gelu_new"
,
"gelu_fast"
]
else
"none"
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
inner_dim
,
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
inner_dim
,
activation
=
partial
(
F
.
gelu
,
approximate
=
approximate
),
activation
=
partial
(
F
.
gelu
,
approximate
=
approximate
),
return_residual
=
return_residual
)
return_residual
=
return_residual
,
)
else
:
else
:
if
FusedMLP
is
None
:
if
FusedMLP
is
None
:
raise
ImportError
(
'
fused_dense is not installed
'
)
raise
ImportError
(
"
fused_dense is not installed
"
)
mlp_checkpoint_lvl
=
getattr
(
config
,
'
mlp_checkpoint_lvl
'
,
0
)
mlp_checkpoint_lvl
=
getattr
(
config
,
"
mlp_checkpoint_lvl
"
,
0
)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
if
isinstance
(
mlp_checkpoint_lvl
,
Sequence
):
if
isinstance
(
mlp_checkpoint_lvl
,
Sequence
):
assert
layer_idx
is
not
None
assert
layer_idx
is
not
None
mlp_checkpoint_lvl
=
mlp_checkpoint_lvl
[
layer_idx
]
mlp_checkpoint_lvl
=
mlp_checkpoint_lvl
[
layer_idx
]
mlp_cls
=
partial
(
FusedMLP
,
hidden_features
=
inner_dim
,
mlp_cls
=
partial
(
checkpoint_lvl
=
mlp_checkpoint_lvl
,
return_residual
=
return_residual
)
FusedMLP
,
hidden_features
=
inner_dim
,
checkpoint_lvl
=
mlp_checkpoint_lvl
,
return_residual
=
return_residual
,
)
return
mlp_cls
return
mlp_cls
def
create_block
(
config
,
layer_idx
=
None
):
def
create_block
(
config
,
layer_idx
=
None
):
last_layer_subset
=
getattr
(
config
,
'
last_layer_subset
'
,
False
)
last_layer_subset
=
getattr
(
config
,
"
last_layer_subset
"
,
False
)
cross_attn
=
last_layer_subset
and
layer_idx
==
config
.
num_hidden_layers
-
1
cross_attn
=
last_layer_subset
and
layer_idx
==
config
.
num_hidden_layers
-
1
# TD [2022-12-19]: For cross attention (last layer), we actually want to return the
# TD [2022-12-19]: For cross attention (last layer), we actually want to return the
# residual x_kv, not residual x. But it's annoying to change the API (and it only affects
# residual x_kv, not residual x. But it's annoying to change the API (and it only affects
# one layer) so we just choose not to return residual in this case.
# one layer) so we just choose not to return residual in this case.
...
@@ -99,11 +117,17 @@ def create_block(config, layer_idx=None):
...
@@ -99,11 +117,17 @@ def create_block(config, layer_idx=None):
mixer_cls
=
create_mixer_cls
(
config
,
cross_attn
,
return_residual
=
return_residual
)
mixer_cls
=
create_mixer_cls
(
config
,
cross_attn
,
return_residual
=
return_residual
)
mlp_cls
=
create_mlp_cls
(
config
,
layer_idx
,
return_residual
=
return_residual
)
mlp_cls
=
create_mlp_cls
(
config
,
layer_idx
,
return_residual
=
return_residual
)
norm_cls
=
partial
(
nn
.
LayerNorm
,
eps
=
config
.
layer_norm_eps
)
norm_cls
=
partial
(
nn
.
LayerNorm
,
eps
=
config
.
layer_norm_eps
)
block
=
Block
(
config
.
hidden_size
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_cls
,
block
=
Block
(
prenorm
=
False
,
resid_dropout1
=
config
.
hidden_dropout_prob
,
config
.
hidden_size
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_cls
,
prenorm
=
False
,
resid_dropout1
=
config
.
hidden_dropout_prob
,
resid_dropout2
=
config
.
hidden_dropout_prob
,
resid_dropout2
=
config
.
hidden_dropout_prob
,
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
),
fused_dropout_add_ln
=
getattr
(
config
,
"fused_dropout_add_ln"
,
False
),
return_residual
=
return_residual
)
return_residual
=
return_residual
,
)
return
block
return
block
...
@@ -120,12 +144,12 @@ def _init_weights(module, initializer_range=0.02):
...
@@ -120,12 +144,12 @@ def _init_weights(module, initializer_range=0.02):
class
BertEncoder
(
nn
.
Module
):
class
BertEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BertConfig
):
def
__init__
(
self
,
config
:
BertConfig
):
super
().
__init__
()
super
().
__init__
()
self
.
use_flash_attn
=
getattr
(
config
,
'use_flash_attn'
,
False
)
self
.
use_flash_attn
=
getattr
(
config
,
"use_flash_attn"
,
False
)
self
.
layers
=
nn
.
ModuleList
([
create_block
(
config
,
layer_idx
=
i
)
self
.
layers
=
nn
.
ModuleList
(
for
i
in
range
(
config
.
num_hidden_layers
)])
[
create_block
(
config
,
layer_idx
=
i
)
for
i
in
range
(
config
.
num_hidden_layers
)]
)
def
forward
(
self
,
hidden_states
,
key_padding_mask
=
None
,
subset_mask
=
None
):
def
forward
(
self
,
hidden_states
,
key_padding_mask
=
None
,
subset_mask
=
None
):
"""If subset_mask is not None, we only want output for the subset of the sequence.
"""If subset_mask is not None, we only want output for the subset of the sequence.
...
@@ -133,8 +157,9 @@ class BertEncoder(nn.Module):
...
@@ -133,8 +157,9 @@ class BertEncoder(nn.Module):
subset_mask: (batch, seqlen), dtype=torch.bool
subset_mask: (batch, seqlen), dtype=torch.bool
"""
"""
if
key_padding_mask
is
None
or
not
self
.
use_flash_attn
:
if
key_padding_mask
is
None
or
not
self
.
use_flash_attn
:
mixer_kwargs
=
({
'key_padding_mask'
:
key_padding_mask
}
mixer_kwargs
=
(
if
key_padding_mask
is
not
None
else
None
)
{
"key_padding_mask"
:
key_padding_mask
}
if
key_padding_mask
is
not
None
else
None
)
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
hidden_states
=
layer
(
hidden_states
,
mixer_kwargs
=
mixer_kwargs
)
hidden_states
=
layer
(
hidden_states
,
mixer_kwargs
=
mixer_kwargs
)
if
subset_mask
is
not
None
:
if
subset_mask
is
not
None
:
...
@@ -144,7 +169,7 @@ class BertEncoder(nn.Module):
...
@@ -144,7 +169,7 @@ class BertEncoder(nn.Module):
hidden_states
,
indices
,
cu_seqlens
,
max_seqlen_in_batch
=
unpad_input
(
hidden_states
,
indices
,
cu_seqlens
,
max_seqlen_in_batch
=
unpad_input
(
hidden_states
,
key_padding_mask
hidden_states
,
key_padding_mask
)
)
mixer_kwargs
=
{
'
cu_seqlens
'
:
cu_seqlens
,
'
max_seqlen
'
:
max_seqlen_in_batch
}
mixer_kwargs
=
{
"
cu_seqlens
"
:
cu_seqlens
,
"
max_seqlen
"
:
max_seqlen_in_batch
}
if
subset_mask
is
None
:
if
subset_mask
is
None
:
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
hidden_states
=
layer
(
hidden_states
,
mixer_kwargs
=
mixer_kwargs
)
hidden_states
=
layer
(
hidden_states
,
mixer_kwargs
=
mixer_kwargs
)
...
@@ -153,33 +178,40 @@ class BertEncoder(nn.Module):
...
@@ -153,33 +178,40 @@ class BertEncoder(nn.Module):
for
layer
in
self
.
layers
[:
-
1
]:
for
layer
in
self
.
layers
[:
-
1
]:
hidden_states
=
layer
(
hidden_states
,
mixer_kwargs
=
mixer_kwargs
)
hidden_states
=
layer
(
hidden_states
,
mixer_kwargs
=
mixer_kwargs
)
if
key_padding_mask
is
not
None
:
if
key_padding_mask
is
not
None
:
subset_idx
=
torch
.
nonzero
(
subset_mask
[
key_padding_mask
],
as_tuple
=
False
).
flatten
()
subset_idx
=
torch
.
nonzero
(
subset_mask
[
key_padding_mask
],
as_tuple
=
False
).
flatten
()
subset_seqlens
=
(
subset_mask
&
key_padding_mask
).
sum
(
dim
=-
1
,
dtype
=
torch
.
int32
)
subset_seqlens
=
(
subset_mask
&
key_padding_mask
).
sum
(
dim
=-
1
,
dtype
=
torch
.
int32
)
subset_cu_seqlens
=
F
.
pad
(
torch
.
cumsum
(
subset_seqlens
,
dim
=
0
,
subset_cu_seqlens
=
F
.
pad
(
dtype
=
torch
.
torch
.
int32
),
(
1
,
0
))
torch
.
cumsum
(
subset_seqlens
,
dim
=
0
,
dtype
=
torch
.
torch
.
int32
),
(
1
,
0
)
)
else
:
else
:
subset_idx
=
torch
.
nonzero
(
subset_mask
,
as_tuple
=
False
).
flatten
()
subset_idx
=
torch
.
nonzero
(
subset_mask
,
as_tuple
=
False
).
flatten
()
subset_seqlens
=
subset_mask
.
sum
(
dim
=-
1
,
dtype
=
torch
.
int32
)
subset_seqlens
=
subset_mask
.
sum
(
dim
=-
1
,
dtype
=
torch
.
int32
)
subset_cu_seqlens
=
F
.
pad
(
torch
.
cumsum
(
subset_seqlens
,
dim
=
0
,
subset_cu_seqlens
=
F
.
pad
(
dtype
=
torch
.
torch
.
int32
),
(
1
,
0
))
torch
.
cumsum
(
subset_seqlens
,
dim
=
0
,
dtype
=
torch
.
torch
.
int32
),
(
1
,
0
)
)
hidden_states_subset
,
hidden_states
=
index_first_axis_residual
(
hidden_states_subset
,
hidden_states
=
index_first_axis_residual
(
hidden_states
,
subset_idx
hidden_states
,
subset_idx
)
)
# It's ok to set max_seqlen_q to be much larger
# It's ok to set max_seqlen_q to be much larger
mixer_kwargs
=
{
'x_kv'
:
hidden_states
,
mixer_kwargs
=
{
'cu_seqlens'
:
subset_cu_seqlens
,
'max_seqlen'
:
max_seqlen_in_batch
,
"x_kv"
:
hidden_states
,
'cu_seqlens_k'
:
cu_seqlens
,
'max_seqlen_k'
:
max_seqlen_in_batch
}
"cu_seqlens"
:
subset_cu_seqlens
,
"max_seqlen"
:
max_seqlen_in_batch
,
"cu_seqlens_k"
:
cu_seqlens
,
"max_seqlen_k"
:
max_seqlen_in_batch
,
}
hidden_states
=
self
.
layers
[
-
1
](
hidden_states_subset
,
mixer_kwargs
=
mixer_kwargs
)
hidden_states
=
self
.
layers
[
-
1
](
hidden_states_subset
,
mixer_kwargs
=
mixer_kwargs
)
return
hidden_states
return
hidden_states
class
BertPooler
(
nn
.
Module
):
class
BertPooler
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
super
().
__init__
()
fused_bias_fc
=
getattr
(
config
,
'
fused_bias_fc
'
,
False
)
fused_bias_fc
=
getattr
(
config
,
"
fused_bias_fc
"
,
False
)
if
fused_bias_fc
and
FusedDense
is
None
:
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
'
fused_dense is not installed
'
)
raise
ImportError
(
"
fused_dense is not installed
"
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
self
.
dense
=
linear_cls
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
dense
=
linear_cls
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
activation
=
nn
.
Tanh
()
self
.
activation
=
nn
.
Tanh
()
...
@@ -194,18 +226,17 @@ class BertPooler(nn.Module):
...
@@ -194,18 +226,17 @@ class BertPooler(nn.Module):
class
BertPredictionHeadTransform
(
nn
.
Module
):
class
BertPredictionHeadTransform
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
super
().
__init__
()
fused_bias_fc
=
getattr
(
config
,
'
fused_bias_fc
'
,
False
)
fused_bias_fc
=
getattr
(
config
,
"
fused_bias_fc
"
,
False
)
if
fused_bias_fc
and
FusedDense
is
None
:
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
'
fused_dense is not installed
'
)
raise
ImportError
(
"
fused_dense is not installed
"
)
self
.
fused_dropout_add_ln
=
getattr
(
config
,
'
fused_dropout_add_ln
'
,
False
)
self
.
fused_dropout_add_ln
=
getattr
(
config
,
"
fused_dropout_add_ln
"
,
False
)
if
self
.
fused_dropout_add_ln
and
layer_norm
is
None
:
if
self
.
fused_dropout_add_ln
and
layer_norm
is
None
:
raise
ImportError
(
'
dropout_add_layer_norm is not installed
'
)
raise
ImportError
(
"
dropout_add_layer_norm is not installed
"
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
self
.
dense
=
linear_cls
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
dense
=
linear_cls
(
config
.
hidden_size
,
config
.
hidden_size
)
approximate
=
'
tanh
'
if
config
.
hidden_act
in
[
'
gelu_new
'
,
'
gelu_fast
'
]
else
'
none
'
approximate
=
"
tanh
"
if
config
.
hidden_act
in
[
"
gelu_new
"
,
"
gelu_fast
"
]
else
"
none
"
self
.
transform_act_fn
=
nn
.
GELU
(
approximate
=
approximate
)
self
.
transform_act_fn
=
nn
.
GELU
(
approximate
=
approximate
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
...
@@ -215,18 +246,18 @@ class BertPredictionHeadTransform(nn.Module):
...
@@ -215,18 +246,18 @@ class BertPredictionHeadTransform(nn.Module):
if
not
self
.
fused_dropout_add_ln
:
if
not
self
.
fused_dropout_add_ln
:
hidden_states
=
self
.
layer_norm
(
hidden_states
)
hidden_states
=
self
.
layer_norm
(
hidden_states
)
else
:
else
:
hidden_states
=
layer_norm
(
hidden_states
,
self
.
layer_norm
.
weight
,
self
.
layer_norm
.
bias
,
hidden_states
=
layer_norm
(
self
.
layer_norm
.
eps
)
hidden_states
,
self
.
layer_norm
.
weight
,
self
.
layer_norm
.
bias
,
self
.
layer_norm
.
eps
)
return
hidden_states
return
hidden_states
class
BertLMPredictionHead
(
nn
.
Module
):
class
BertLMPredictionHead
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
()
super
().
__init__
()
fused_bias_fc
=
getattr
(
config
,
'
fused_bias_fc
'
,
False
)
fused_bias_fc
=
getattr
(
config
,
"
fused_bias_fc
"
,
False
)
if
fused_bias_fc
and
FusedDense
is
None
:
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
'
fused_dense is not installed
'
)
raise
ImportError
(
"
fused_dense is not installed
"
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDense
self
.
transform
=
BertPredictionHeadTransform
(
config
)
self
.
transform
=
BertPredictionHeadTransform
(
config
)
...
@@ -254,9 +285,10 @@ class BertPreTrainingHeads(nn.Module):
...
@@ -254,9 +285,10 @@ class BertPreTrainingHeads(nn.Module):
class
BertPreTrainedModel
(
nn
.
Module
):
class
BertPreTrainedModel
(
nn
.
Module
):
"""
An abstract class to handle weights initialization and
"""An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
a simple interface for dowloading and loading pretrained models.
"""
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
().
__init__
()
super
().
__init__
()
if
not
isinstance
(
config
,
BertConfig
):
if
not
isinstance
(
config
,
BertConfig
):
...
@@ -265,7 +297,8 @@ class BertPreTrainedModel(nn.Module):
...
@@ -265,7 +297,8 @@ class BertPreTrainedModel(nn.Module):
"To create a model from a Google pretrained model use "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
))
)
)
self
.
config
=
config
self
.
config
=
config
@
classmethod
@
classmethod
...
@@ -287,28 +320,33 @@ class BertPreTrainedModel(nn.Module):
...
@@ -287,28 +320,33 @@ class BertPreTrainedModel(nn.Module):
"""
"""
# Instantiate model.
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
load_return
=
model
.
load_state_dict
(
remap_state_dict
(
state_dict_from_pretrained
(
model_name
),
load_return
=
model
.
load_state_dict
(
config
),
strict
=
False
)
remap_state_dict
(
state_dict_from_pretrained
(
model_name
),
config
),
strict
=
False
)
logger
.
info
(
load_return
)
logger
.
info
(
load_return
)
return
model
return
model
class
BertModel
(
BertPreTrainedModel
):
class
BertModel
(
BertPreTrainedModel
):
def
__init__
(
self
,
config
:
BertConfig
,
add_pooling_layer
=
True
):
def
__init__
(
self
,
config
:
BertConfig
,
add_pooling_layer
=
True
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
pad_vocab_size_multiple
=
getattr
(
config
,
'
pad_vocab_size_multiple
'
,
1
)
self
.
pad_vocab_size_multiple
=
getattr
(
config
,
"
pad_vocab_size_multiple
"
,
1
)
if
config
.
vocab_size
%
self
.
pad_vocab_size_multiple
!=
0
:
if
config
.
vocab_size
%
self
.
pad_vocab_size_multiple
!=
0
:
config
.
vocab_size
+=
(
self
.
pad_vocab_size_multiple
config
.
vocab_size
+=
self
.
pad_vocab_size_multiple
-
(
-
(
config
.
vocab_size
%
self
.
pad_vocab_size_multiple
))
config
.
vocab_size
%
self
.
pad_vocab_size_multiple
self
.
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
)
)
self
.
fused_dropout_add_ln
=
getattr
(
config
,
"fused_dropout_add_ln"
,
False
)
if
self
.
fused_dropout_add_ln
and
layer_norm
is
None
:
if
self
.
fused_dropout_add_ln
and
layer_norm
is
None
:
raise
ImportError
(
'dropout_add_layer_norm is not installed'
)
raise
ImportError
(
"dropout_add_layer_norm is not installed"
)
assert
config
.
hidden_act
in
[
'gelu'
,
'gelu_new'
,
'gelu_fast'
]
assert
config
.
hidden_act
in
[
"gelu"
,
"gelu_new"
,
"gelu_fast"
]
self
.
embeddings
=
BertEmbeddings
(
config
.
hidden_size
,
config
.
vocab_size
,
self
.
embeddings
=
BertEmbeddings
(
config
.
max_position_embeddings
,
config
.
type_vocab_size
,
config
.
hidden_size
,
padding_idx
=
config
.
pad_token_id
)
config
.
vocab_size
,
config
.
max_position_embeddings
,
config
.
type_vocab_size
,
padding_idx
=
config
.
pad_token_id
,
)
self
.
emb_drop
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
emb_drop
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
emb_ln
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
emb_ln
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
encoder
=
BertEncoder
(
config
)
self
.
encoder
=
BertEncoder
(
config
)
...
@@ -316,36 +354,46 @@ class BertModel(BertPreTrainedModel):
...
@@ -316,36 +354,46 @@ class BertModel(BertPreTrainedModel):
self
.
apply
(
partial
(
_init_weights
,
initializer_range
=
config
.
initializer_range
))
self
.
apply
(
partial
(
_init_weights
,
initializer_range
=
config
.
initializer_range
))
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
def
forward
(
masked_tokens_mask
=
None
):
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_tokens_mask
=
None
,
):
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
we only want the output for the masked tokens. This means that we only compute the last
we only want the output for the masked tokens. This means that we only compute the last
layer output for these tokens.
layer output for these tokens.
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
"""
"""
hidden_states
=
self
.
embeddings
(
input_ids
,
position_ids
=
position_ids
,
hidden_states
=
self
.
embeddings
(
token_type_ids
=
token_type_ids
)
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
)
# TD [2022-12:18]: Don't need to force residual in fp32
# TD [2022-12:18]: Don't need to force residual in fp32
# BERT puts embedding LayerNorm before embedding dropout.
# BERT puts embedding LayerNorm before embedding dropout.
if
not
self
.
fused_dropout_add_ln
:
if
not
self
.
fused_dropout_add_ln
:
hidden_states
=
self
.
emb_ln
(
hidden_states
)
hidden_states
=
self
.
emb_ln
(
hidden_states
)
else
:
else
:
hidden_states
=
layer_norm
(
hidden_states
,
self
.
emb_ln
.
weight
,
self
.
emb_ln
.
bias
,
hidden_states
=
layer_norm
(
self
.
emb_ln
.
eps
)
hidden_states
,
self
.
emb_ln
.
weight
,
self
.
emb_ln
.
bias
,
self
.
emb_ln
.
eps
)
hidden_states
=
self
.
emb_drop
(
hidden_states
)
hidden_states
=
self
.
emb_drop
(
hidden_states
)
if
masked_tokens_mask
is
not
None
:
if
masked_tokens_mask
is
not
None
:
batch_size
,
seqlen
=
input_ids
.
shape
[:
2
]
batch_size
,
seqlen
=
input_ids
.
shape
[:
2
]
# We also need the first column for the CLS token
# We also need the first column for the CLS token
first_col_mask
=
torch
.
zeros
(
batch_size
,
seqlen
,
dtype
=
torch
.
bool
,
first_col_mask
=
torch
.
zeros
(
device
=
input_ids
.
device
)
batch_size
,
seqlen
,
dtype
=
torch
.
bool
,
device
=
input_ids
.
device
)
first_col_mask
[:,
0
]
=
True
first_col_mask
[:,
0
]
=
True
subset_mask
=
masked_tokens_mask
|
first_col_mask
subset_mask
=
masked_tokens_mask
|
first_col_mask
else
:
else
:
subset_mask
=
None
subset_mask
=
None
sequence_output
=
self
.
encoder
(
hidden_states
,
key_padding_mask
=
attention_mask
,
sequence_output
=
self
.
encoder
(
subset_mask
=
subset_mask
)
hidden_states
,
key_padding_mask
=
attention_mask
,
subset_mask
=
subset_mask
)
if
masked_tokens_mask
is
None
:
if
masked_tokens_mask
is
None
:
pooled_output
=
self
.
pooler
(
sequence_output
)
if
self
.
pooler
is
not
None
else
None
pooled_output
=
self
.
pooler
(
sequence_output
)
if
self
.
pooler
is
not
None
else
None
...
@@ -358,8 +406,7 @@ class BertModel(BertPreTrainedModel):
...
@@ -358,8 +406,7 @@ class BertModel(BertPreTrainedModel):
else
:
else
:
pool_input
=
sequence_output
[
first_col_mask
[
subset_mask
]]
pool_input
=
sequence_output
[
first_col_mask
[
subset_mask
]]
sequence_output
=
sequence_output
[
masked_tokens_mask
[
subset_mask
]]
sequence_output
=
sequence_output
[
masked_tokens_mask
[
subset_mask
]]
pooled_output
=
(
self
.
pooler
(
pool_input
,
pool
=
False
)
pooled_output
=
self
.
pooler
(
pool_input
,
pool
=
False
)
if
self
.
pooler
is
not
None
else
None
if
self
.
pooler
is
not
None
else
None
)
return
BaseModelOutputWithPoolingAndCrossAttentions
(
return
BaseModelOutputWithPoolingAndCrossAttentions
(
last_hidden_state
=
sequence_output
,
last_hidden_state
=
sequence_output
,
...
@@ -368,22 +415,24 @@ class BertModel(BertPreTrainedModel):
...
@@ -368,22 +415,24 @@ class BertModel(BertPreTrainedModel):
class
BertForPreTraining
(
BertPreTrainedModel
):
class
BertForPreTraining
(
BertPreTrainedModel
):
def
__init__
(
self
,
config
:
BertConfig
):
def
__init__
(
self
,
config
:
BertConfig
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
# If dense_seq_output, we only need to pass the hidden states for the masked out tokens
# If dense_seq_output, we only need to pass the hidden states for the masked out tokens
# (around 15%) to the classifier heads.
# (around 15%) to the classifier heads.
self
.
dense_seq_output
=
getattr
(
config
,
'
dense_seq_output
'
,
False
)
self
.
dense_seq_output
=
getattr
(
config
,
"
dense_seq_output
"
,
False
)
# If last_layer_subset, we only need the compute the last layer for a subset of tokens
# If last_layer_subset, we only need the compute the last layer for a subset of tokens
# (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
# (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
self
.
last_layer_subset
=
getattr
(
config
,
'
last_layer_subset
'
,
False
)
self
.
last_layer_subset
=
getattr
(
config
,
"
last_layer_subset
"
,
False
)
if
self
.
last_layer_subset
:
if
self
.
last_layer_subset
:
assert
self
.
dense_seq_output
,
'
last_layer_subset requires dense_seq_output
'
assert
self
.
dense_seq_output
,
"
last_layer_subset requires dense_seq_output
"
use_xentropy
=
getattr
(
config
,
'
use_xentropy
'
,
False
)
use_xentropy
=
getattr
(
config
,
"
use_xentropy
"
,
False
)
if
use_xentropy
and
CrossEntropyLoss
is
None
:
if
use_xentropy
and
CrossEntropyLoss
is
None
:
raise
ImportError
(
'xentropy_cuda is not installed'
)
raise
ImportError
(
"xentropy_cuda is not installed"
)
loss_cls
=
(
nn
.
CrossEntropyLoss
if
not
use_xentropy
loss_cls
=
(
else
partial
(
CrossEntropyLoss
,
inplace_backward
=
True
))
nn
.
CrossEntropyLoss
if
not
use_xentropy
else
partial
(
CrossEntropyLoss
,
inplace_backward
=
True
)
)
self
.
bert
=
BertModel
(
config
)
self
.
bert
=
BertModel
(
config
)
self
.
cls
=
BertPreTrainingHeads
(
config
)
self
.
cls
=
BertPreTrainingHeads
(
config
)
...
@@ -397,8 +446,15 @@ class BertForPreTraining(BertPreTrainedModel):
...
@@ -397,8 +446,15 @@ class BertForPreTraining(BertPreTrainedModel):
def
tie_weights
(
self
):
def
tie_weights
(
self
):
self
.
cls
.
predictions
.
decoder
.
weight
=
self
.
bert
.
embeddings
.
word_embeddings
.
weight
self
.
cls
.
predictions
.
decoder
.
weight
=
self
.
bert
.
embeddings
.
word_embeddings
.
weight
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
def
forward
(
labels
=
None
,
next_sentence_label
=
None
):
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
next_sentence_label
=
None
,
):
"""
"""
If labels are provided, they must be 0 for masked out tokens (as specified in the attention
If labels are provided, they must be 0 for masked out tokens (as specified in the attention
mask).
mask).
...
@@ -414,28 +470,38 @@ class BertForPreTraining(BertPreTrainedModel):
...
@@ -414,28 +470,38 @@ class BertForPreTraining(BertPreTrainedModel):
"""
"""
masked_tokens_mask
=
labels
>
0
if
(
self
.
last_layer_subset
and
labels
is
not
None
)
else
None
masked_tokens_mask
=
labels
>
0
if
(
self
.
last_layer_subset
and
labels
is
not
None
)
else
None
outputs
=
self
.
bert
(
outputs
=
self
.
bert
(
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
attention_mask
=
attention_mask
.
bool
()
if
attention_mask
is
not
None
else
None
,
attention_mask
=
attention_mask
.
bool
()
if
attention_mask
is
not
None
else
None
,
masked_tokens_mask
=
masked_tokens_mask
masked_tokens_mask
=
masked_tokens_mask
,
)
)
sequence_output
,
pooled_output
=
outputs
.
last_hidden_state
,
outputs
.
pooler_output
sequence_output
,
pooled_output
=
outputs
.
last_hidden_state
,
outputs
.
pooler_output
if
self
.
dense_seq_output
and
labels
is
not
None
:
if
self
.
dense_seq_output
and
labels
is
not
None
:
masked_token_idx
=
torch
.
nonzero
(
labels
.
flatten
()
>
0
,
as_tuple
=
False
).
flatten
()
masked_token_idx
=
torch
.
nonzero
(
labels
.
flatten
()
>
0
,
as_tuple
=
False
).
flatten
()
if
not
self
.
last_layer_subset
:
if
not
self
.
last_layer_subset
:
sequence_output
=
index_first_axis
(
rearrange
(
sequence_output
,
'b s d -> (b s) d'
),
sequence_output
=
index_first_axis
(
masked_token_idx
)
rearrange
(
sequence_output
,
"b s d -> (b s) d"
),
masked_token_idx
)
prediction_scores
,
seq_relationship_score
=
self
.
cls
(
sequence_output
,
pooled_output
)
prediction_scores
,
seq_relationship_score
=
self
.
cls
(
sequence_output
,
pooled_output
)
total_loss
=
None
total_loss
=
None
if
labels
is
not
None
and
next_sentence_label
is
not
None
:
if
labels
is
not
None
and
next_sentence_label
is
not
None
:
if
self
.
dense_seq_output
and
labels
is
not
None
:
# prediction_scores are already flattened
if
(
masked_lm_loss
=
self
.
mlm_loss
(
prediction_scores
,
self
.
dense_seq_output
and
labels
is
not
None
labels
.
flatten
()[
masked_token_idx
])
):
# prediction_scores are already flattened
masked_lm_loss
=
self
.
mlm_loss
(
prediction_scores
,
labels
.
flatten
()[
masked_token_idx
]
)
else
:
else
:
masked_lm_loss
=
self
.
mlm_loss
(
rearrange
(
prediction_scores
,
'... v -> (...) v'
),
masked_lm_loss
=
self
.
mlm_loss
(
rearrange
(
labels
,
'... -> (...)'
))
rearrange
(
prediction_scores
,
"... v -> (...) v"
),
next_sentence_loss
=
self
.
nsp_loss
(
rearrange
(
seq_relationship_score
,
'... t -> (...) t'
),
rearrange
(
labels
,
"... -> (...)"
),
rearrange
(
next_sentence_label
,
'... -> (...)'
))
)
next_sentence_loss
=
self
.
nsp_loss
(
rearrange
(
seq_relationship_score
,
"... t -> (...) t"
),
rearrange
(
next_sentence_label
,
"... -> (...)"
),
)
total_loss
=
masked_lm_loss
.
float
()
+
next_sentence_loss
.
float
()
total_loss
=
masked_lm_loss
.
float
()
+
next_sentence_loss
.
float
()
return
BertForPreTrainingOutput
(
return
BertForPreTrainingOutput
(
...
@@ -448,83 +514,106 @@ class BertForPreTraining(BertPreTrainedModel):
...
@@ -448,83 +514,106 @@ class BertForPreTraining(BertPreTrainedModel):
def
remap_state_dict
(
state_dict
,
config
):
def
remap_state_dict
(
state_dict
,
config
):
# LayerNorm
# LayerNorm
def
key_mapping_ln_gamma_beta
(
key
):
def
key_mapping_ln_gamma_beta
(
key
):
key
=
re
.
sub
(
r
'
LayerNorm.gamma$
'
,
'
LayerNorm.weight
'
,
key
)
key
=
re
.
sub
(
r
"
LayerNorm.gamma$
"
,
"
LayerNorm.weight
"
,
key
)
key
=
re
.
sub
(
r
'
LayerNorm.beta$
'
,
'
LayerNorm.bias
'
,
key
)
key
=
re
.
sub
(
r
"
LayerNorm.beta$
"
,
"
LayerNorm.bias
"
,
key
)
return
key
return
key
state_dict
=
OrderedDict
((
key_mapping_ln_gamma_beta
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_ln_gamma_beta
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Layers
# Layers
def
key_mapping_layers
(
key
):
def
key_mapping_layers
(
key
):
return
re
.
sub
(
r
'^bert.encoder.layer.'
,
'bert.encoder.layers.'
,
key
)
return
re
.
sub
(
r
"^bert.encoder.layer."
,
"bert.encoder.layers."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# LayerNorm
# LayerNorm
def
key_mapping_ln
(
key
):
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
'^bert.embeddings.LayerNorm.'
,
'bert.emb_ln.'
,
key
)
key
=
re
.
sub
(
r
"^bert.embeddings.LayerNorm."
,
"bert.emb_ln."
,
key
)
key
=
re
.
sub
(
r
'^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)'
,
key
=
re
.
sub
(
r
'bert.encoder.layers.\1.norm1.\2'
,
key
)
r
"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)"
,
key
=
re
.
sub
(
r
'^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)'
,
r
"bert.encoder.layers.\1.norm1.\2"
,
r
'bert.encoder.layers.\1.norm2.\2'
,
key
)
key
,
key
=
re
.
sub
(
r
'^cls.predictions.transform.LayerNorm.(weight|bias)'
,
)
r
'cls.predictions.transform.layer_norm.\1'
,
key
)
key
=
re
.
sub
(
r
"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)"
,
r
"bert.encoder.layers.\1.norm2.\2"
,
key
,
)
key
=
re
.
sub
(
r
"^cls.predictions.transform.LayerNorm.(weight|bias)"
,
r
"cls.predictions.transform.layer_norm.\1"
,
key
,
)
return
key
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
# MLP
def
key_mapping_mlp
(
key
):
def
key_mapping_mlp
(
key
):
key
=
re
.
sub
(
r
'^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)'
,
key
=
re
.
sub
(
r
'bert.encoder.layers.\1.mlp.fc1.\2'
,
key
)
r
"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)"
,
key
=
re
.
sub
(
r
'^bert.encoder.layers.(\d+).output.dense.(weight|bias)'
,
r
"bert.encoder.layers.\1.mlp.fc1.\2"
,
r
'bert.encoder.layers.\1.mlp.fc2.\2'
,
key
)
key
,
)
key
=
re
.
sub
(
r
"^bert.encoder.layers.(\d+).output.dense.(weight|bias)"
,
r
"bert.encoder.layers.\1.mlp.fc2.\2"
,
key
,
)
return
key
return
key
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
# Attention
last_layer_subset
=
getattr
(
config
,
'
last_layer_subset
'
,
False
)
last_layer_subset
=
getattr
(
config
,
"
last_layer_subset
"
,
False
)
for
d
in
range
(
config
.
num_hidden_layers
):
for
d
in
range
(
config
.
num_hidden_layers
):
Wq
=
state_dict
.
pop
(
f
'
bert.encoder.layers.
{
d
}
.attention.self.query.weight
'
)
Wq
=
state_dict
.
pop
(
f
"
bert.encoder.layers.
{
d
}
.attention.self.query.weight
"
)
Wk
=
state_dict
.
pop
(
f
'
bert.encoder.layers.
{
d
}
.attention.self.key.weight
'
)
Wk
=
state_dict
.
pop
(
f
"
bert.encoder.layers.
{
d
}
.attention.self.key.weight
"
)
Wv
=
state_dict
.
pop
(
f
'
bert.encoder.layers.
{
d
}
.attention.self.value.weight
'
)
Wv
=
state_dict
.
pop
(
f
"
bert.encoder.layers.
{
d
}
.attention.self.value.weight
"
)
bq
=
state_dict
.
pop
(
f
'
bert.encoder.layers.
{
d
}
.attention.self.query.bias
'
)
bq
=
state_dict
.
pop
(
f
"
bert.encoder.layers.
{
d
}
.attention.self.query.bias
"
)
bk
=
state_dict
.
pop
(
f
'
bert.encoder.layers.
{
d
}
.attention.self.key.bias
'
)
bk
=
state_dict
.
pop
(
f
"
bert.encoder.layers.
{
d
}
.attention.self.key.bias
"
)
bv
=
state_dict
.
pop
(
f
'
bert.encoder.layers.
{
d
}
.attention.self.value.bias
'
)
bv
=
state_dict
.
pop
(
f
"
bert.encoder.layers.
{
d
}
.attention.self.value.bias
"
)
if
not
(
last_layer_subset
and
d
==
config
.
num_hidden_layers
-
1
):
if
not
(
last_layer_subset
and
d
==
config
.
num_hidden_layers
-
1
):
state_dict
[
f
'
bert.encoder.layers.
{
d
}
.mixer.Wqkv.weight
'
]
=
torch
.
cat
(
state_dict
[
f
"
bert.encoder.layers.
{
d
}
.mixer.Wqkv.weight
"
]
=
torch
.
cat
(
[
Wq
,
Wk
,
Wv
],
dim
=
0
[
Wq
,
Wk
,
Wv
],
dim
=
0
)
)
state_dict
[
f
'
bert.encoder.layers.
{
d
}
.mixer.Wqkv.bias
'
]
=
torch
.
cat
([
bq
,
bk
,
bv
],
dim
=
0
)
state_dict
[
f
"
bert.encoder.layers.
{
d
}
.mixer.Wqkv.bias
"
]
=
torch
.
cat
([
bq
,
bk
,
bv
],
dim
=
0
)
else
:
else
:
state_dict
[
f
'bert.encoder.layers.
{
d
}
.mixer.Wq.weight'
]
=
Wq
state_dict
[
f
"bert.encoder.layers.
{
d
}
.mixer.Wq.weight"
]
=
Wq
state_dict
[
f
'bert.encoder.layers.
{
d
}
.mixer.Wkv.weight'
]
=
torch
.
cat
(
state_dict
[
f
"bert.encoder.layers.
{
d
}
.mixer.Wkv.weight"
]
=
torch
.
cat
([
Wk
,
Wv
],
dim
=
0
)
[
Wk
,
Wv
],
dim
=
0
state_dict
[
f
"bert.encoder.layers.
{
d
}
.mixer.Wq.bias"
]
=
bq
)
state_dict
[
f
"bert.encoder.layers.
{
d
}
.mixer.Wkv.bias"
]
=
torch
.
cat
([
bk
,
bv
],
dim
=
0
)
state_dict
[
f
'bert.encoder.layers.
{
d
}
.mixer.Wq.bias'
]
=
bq
state_dict
[
f
'bert.encoder.layers.
{
d
}
.mixer.Wkv.bias'
]
=
torch
.
cat
([
bk
,
bv
],
dim
=
0
)
def
key_mapping_attn
(
key
):
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
'^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)'
,
return
re
.
sub
(
r
'bert.encoder.layers.\1.mixer.out_proj.\2'
,
key
)
r
"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)"
,
r
"bert.encoder.layers.\1.mixer.out_proj.\2"
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
def
key_mapping_decoder_bias
(
key
):
def
key_mapping_decoder_bias
(
key
):
return
re
.
sub
(
r
'^cls.predictions.bias'
,
'cls.predictions.decoder.bias'
,
key
)
return
re
.
sub
(
r
"^cls.predictions.bias"
,
"cls.predictions.decoder.bias"
,
key
)
state_dict
=
OrderedDict
((
key_mapping_decoder_bias
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_decoder_bias
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding
# Word embedding
pad_vocab_size_multiple
=
getattr
(
config
,
'
pad_vocab_size_multiple
'
,
1
)
pad_vocab_size_multiple
=
getattr
(
config
,
"
pad_vocab_size_multiple
"
,
1
)
if
pad_vocab_size_multiple
>
1
:
if
pad_vocab_size_multiple
>
1
:
word_embeddings
=
state_dict
[
'
bert.embeddings.word_embeddings.weight
'
]
word_embeddings
=
state_dict
[
"
bert.embeddings.word_embeddings.weight
"
]
state_dict
[
'
bert.embeddings.word_embeddings.weight
'
]
=
F
.
pad
(
state_dict
[
"
bert.embeddings.word_embeddings.weight
"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
config
.
vocab_size
-
word_embeddings
.
shape
[
0
])
word_embeddings
,
(
0
,
0
,
0
,
config
.
vocab_size
-
word_embeddings
.
shape
[
0
])
)
)
decoder_weight
=
state_dict
[
'
cls.predictions.decoder.weight
'
]
decoder_weight
=
state_dict
[
"
cls.predictions.decoder.weight
"
]
state_dict
[
'
cls.predictions.decoder.weight
'
]
=
F
.
pad
(
state_dict
[
"
cls.predictions.decoder.weight
"
]
=
F
.
pad
(
decoder_weight
,
(
0
,
0
,
0
,
config
.
vocab_size
-
decoder_weight
.
shape
[
0
])
decoder_weight
,
(
0
,
0
,
0
,
config
.
vocab_size
-
decoder_weight
.
shape
[
0
])
)
)
# If the vocab was padded, we want to set the decoder bias for those padded indices to be
# If the vocab was padded, we want to set the decoder bias for those padded indices to be
# strongly negative (i.e. the decoder shouldn't predict those indices).
# strongly negative (i.e. the decoder shouldn't predict those indices).
# TD [2022-05-09]: I don't think it affects the MLPerf training.
# TD [2022-05-09]: I don't think it affects the MLPerf training.
decoder_bias
=
state_dict
[
'
cls.predictions.decoder.bias
'
]
decoder_bias
=
state_dict
[
"
cls.predictions.decoder.bias
"
]
state_dict
[
'
cls.predictions.decoder.bias
'
]
=
F
.
pad
(
state_dict
[
"
cls.predictions.decoder.bias
"
]
=
F
.
pad
(
decoder_bias
,
(
0
,
config
.
vocab_size
-
decoder_bias
.
shape
[
0
]),
value
=-
100.0
decoder_bias
,
(
0
,
config
.
vocab_size
-
decoder_bias
.
shape
[
0
]),
value
=-
100.0
)
)
...
...
flash_attn/models/falcon.py
View file @
f1a73d07
...
@@ -2,93 +2,114 @@
...
@@ -2,93 +2,114 @@
import
math
import
math
import
re
import
re
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
einops
import
rearrange
from
transformers
import
FalconConfig
,
GPT2Config
from
transformers
import
GPT2Config
,
FalconConfig
def
remap_state_dict_hf_falcon
(
state_dict
,
config
):
def
remap_state_dict_hf_falcon
(
state_dict
,
config
):
def
key_mapping_layers
(
key
):
def
key_mapping_layers
(
key
):
return
re
.
sub
(
r
'^transformer.h.'
,
'transformer.layers.'
,
key
)
return
re
.
sub
(
r
"^transformer.h."
,
"transformer.layers."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding
# Word embedding
def
key_mapping_emb
(
key
):
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
'^transformer.word_embeddings.'
,
'transformer.embeddings.word_embeddings.'
,
key
)
return
re
.
sub
(
r
"^transformer.word_embeddings."
,
"transformer.embeddings.word_embeddings."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
'
transformer.embeddings.word_embeddings.weight
'
)
word_embeddings
=
state_dict
.
pop
(
"
transformer.embeddings.word_embeddings.weight
"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
'
pad_vocab_size_multiple
'
,
1
)
pad_vocab_size_multiple
=
getattr
(
config
,
"
pad_vocab_size_multiple
"
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
vocab_size
=
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
state_dict
[
'
transformer.embeddings.word_embeddings.weight
'
]
=
F
.
pad
(
state_dict
[
"
transformer.embeddings.word_embeddings.weight
"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
)
if
getattr
(
config
,
'
tie_word_embeddings
'
):
if
getattr
(
config
,
"
tie_word_embeddings
"
):
state_dict
[
'
lm_head.weight
'
]
=
state_dict
[
'
transformer.embeddings.word_embeddings.weight
'
]
state_dict
[
"
lm_head.weight
"
]
=
state_dict
[
"
transformer.embeddings.word_embeddings.weight
"
]
else
:
else
:
output_embeddings
=
state_dict
.
pop
(
'
lm_head.weight
'
)
output_embeddings
=
state_dict
.
pop
(
"
lm_head.weight
"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict
[
'
lm_head.weight
'
]
=
F
.
pad
(
state_dict
[
"
lm_head.weight
"
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
)
output_embeddings_bias
=
state_dict
.
pop
(
'
lm_head.bias
'
)
output_embeddings_bias
=
state_dict
.
pop
(
"
lm_head.bias
"
)
state_dict
[
'
lm_head.bias
'
]
=
F
.
pad
(
state_dict
[
"
lm_head.bias
"
]
=
F
.
pad
(
output_embeddings_bias
,
(
0
,
vocab_size
-
output_embeddings_bias
.
shape
[
0
])
output_embeddings_bias
,
(
0
,
vocab_size
-
output_embeddings_bias
.
shape
[
0
])
)
)
# LayerNorm
# LayerNorm
def
key_mapping_ln
(
key
):
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).input_layernorm.'
,
key
=
re
.
sub
(
r
'transformer.layers.\1.norm1.'
,
key
)
r
"^transformer.layers.(\d+).input_layernorm."
,
r
"transformer.layers.\1.norm1."
,
key
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).post_attention_layernorm.'
,
)
r
'transformer.layers.\1.norm2.'
,
key
)
key
=
re
.
sub
(
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).ln_attn.'
,
r
'transformer.layers.\1.norm1.'
,
key
)
r
"^transformer.layers.(\d+).post_attention_layernorm."
,
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).ln_mlp.'
,
r
'transformer.layers.\1.norm2.'
,
key
)
r
"transformer.layers.\1.norm2."
,
key
,
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).ln_attn."
,
r
"transformer.layers.\1.norm1."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).ln_mlp."
,
r
"transformer.layers.\1.norm2."
,
key
)
return
key
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
# MLP
def
key_mapping_mlp
(
key
):
def
key_mapping_mlp
(
key
):
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).mlp.dense_h_to_4h.'
,
key
=
re
.
sub
(
r
'transformer.layers.\1.mlp.fc1.'
,
key
)
r
"^transformer.layers.(\d+).mlp.dense_h_to_4h."
,
r
"transformer.layers.\1.mlp.fc1."
,
key
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).mlp.dense_4h_to_h.'
,
)
r
'transformer.layers.\1.mlp.fc2.'
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).mlp.dense_4h_to_h."
,
r
"transformer.layers.\1.mlp.fc2."
,
key
)
return
key
return
key
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
def
key_mapping_attn
(
key
):
def
key_mapping_attn
(
key
):
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).self_attention.query_key_value.'
,
key
=
re
.
sub
(
r
'transformer.layers.\1.mixer.Wqkv.'
,
key
)
r
"^transformer.layers.(\d+).self_attention.query_key_value."
,
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).self_attention.dense.'
,
r
"transformer.layers.\1.mixer.Wqkv."
,
r
'transformer.layers.\1.mixer.out_proj.'
,
key
)
key
,
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).self_attention.dense."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
,
)
return
key
return
key
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
n_head
=
config
.
n_head
n_head
=
config
.
n_head
n_head_kv
=
getattr
(
config
,
"n_head_kv"
,
1
)
n_head_kv
=
getattr
(
config
,
"n_head_kv"
,
1
)
headdim
=
config
.
hidden_size
//
n_head
headdim
=
config
.
hidden_size
//
n_head
for
l
in
range
(
config
.
n_layer
):
for
l
in
range
(
config
.
n_layer
):
# The weights are stored in a different layout compared to our implementation
# The weights are stored in a different layout compared to our implementation
Wqkv
=
rearrange
(
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.mixer.Wqkv.weight'
),
Wqkv
=
rearrange
(
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.mixer.Wqkv.weight"
),
"(group ratio headdim) ... -> group ratio headdim ..."
,
"(group ratio headdim) ... -> group ratio headdim ..."
,
ratio
=
n_head
//
n_head_kv
+
2
,
headdim
=
headdim
)
ratio
=
n_head
//
n_head_kv
+
2
,
headdim
=
headdim
,
)
Wq
=
rearrange
(
Wqkv
[:,
:
-
2
],
"group ratio headdim ... -> (group ratio headdim) ..."
)
Wq
=
rearrange
(
Wqkv
[:,
:
-
2
],
"group ratio headdim ... -> (group ratio headdim) ..."
)
Wk
=
rearrange
(
Wqkv
[:,
[
-
2
]],
"group ratio headdim ... -> (group ratio headdim) ..."
)
Wk
=
rearrange
(
Wqkv
[:,
[
-
2
]],
"group ratio headdim ... -> (group ratio headdim) ..."
)
Wv
=
rearrange
(
Wqkv
[:,
[
-
1
]],
"group ratio headdim ... -> (group ratio headdim) ..."
)
Wv
=
rearrange
(
Wqkv
[:,
[
-
1
]],
"group ratio headdim ... -> (group ratio headdim) ..."
)
state_dict
[
f
'
transformer.layers.
{
l
}
.mixer.Wqkv.weight
'
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
state_dict
[
f
"
transformer.layers.
{
l
}
.mixer.Wqkv.weight
"
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
return
state_dict
return
state_dict
def
falcon_config_to_gpt2_config
(
falcon_config
:
FalconConfig
)
->
GPT2Config
:
def
falcon_config_to_gpt2_config
(
falcon_config
:
FalconConfig
)
->
GPT2Config
:
# The 40b config uses "n_head_kv" instead of "num_kv_heads"
# The 40b config uses "n_head_kv" instead of "num_kv_heads"
n_head_kv
=
getattr
(
falcon_config
,
"n_head_kv"
,
n_head_kv
=
getattr
(
1
if
getattr
(
falcon_config
,
"multi_query"
,
False
)
falcon_config
,
else
falcon_config
.
n_head
)
"n_head_kv"
,
1
if
getattr
(
falcon_config
,
"multi_query"
,
False
)
else
falcon_config
.
n_head
,
)
# HACK: the 40b config has 2 LN per layer instead of 1, but that's not reflected in the config.
# HACK: the 40b config has 2 LN per layer instead of 1, but that's not reflected in the config.
# So we have to infer it from the number of heads in the key/value block
# So we have to infer it from the number of heads in the key/value block
parallel_block_tied_norm
=
n_head_kv
==
1
parallel_block_tied_norm
=
n_head_kv
==
1
...
...
flash_attn/models/gpt.py
View file @
f1a73d07
...
@@ -11,6 +11,8 @@ import torch
...
@@ -11,6 +11,8 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
einops
import
rearrange
from
transformers
import
GPT2Config
from
flash_attn.models.falcon
import
remap_state_dict_hf_falcon
from
flash_attn.models.falcon
import
remap_state_dict_hf_falcon
from
flash_attn.models.gpt_neox
import
remap_state_dict_hf_gpt_neox
from
flash_attn.models.gpt_neox
import
remap_state_dict_hf_gpt_neox
from
flash_attn.models.gptj
import
remap_state_dict_hf_gptj
from
flash_attn.models.gptj
import
remap_state_dict_hf_gptj
...
@@ -27,10 +29,9 @@ from flash_attn.modules.mlp import (
...
@@ -27,10 +29,9 @@ from flash_attn.modules.mlp import (
ParallelMLP
,
ParallelMLP
,
)
)
from
flash_attn.ops.activations
import
sqrelu_fwd
from
flash_attn.ops.activations
import
sqrelu_fwd
from
flash_attn.utils.distributed
import
all_gather_raw
,
sync_shared_params
,
get_dim_for_local_rank
from
flash_attn.utils.distributed
import
all_gather_raw
,
get_dim_for_local_rank
,
sync_shared_params
from
flash_attn.utils.generation
import
GenerationMixin
from
flash_attn.utils.generation
import
GenerationMixin
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
transformers
import
GPT2Config
try
:
try
:
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
...
@@ -690,7 +691,7 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
...
@@ -690,7 +691,7 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
if
key
in
state_dict
:
if
key
in
state_dict
:
x
=
state_dict
[
key
]
x
=
state_dict
[
key
]
dim
=
x
.
shape
[
0
]
//
world_size
dim
=
x
.
shape
[
0
]
//
world_size
state_dict
[
key
]
=
x
[
rank
*
dim
:
(
rank
+
1
)
*
dim
]
state_dict
[
key
]
=
x
[
rank
*
dim
:
(
rank
+
1
)
*
dim
]
def
shard_last_dim
(
state_dict
,
key
,
multiple_of
=
1
):
def
shard_last_dim
(
state_dict
,
key
,
multiple_of
=
1
):
if
key
in
state_dict
:
if
key
in
state_dict
:
...
@@ -707,17 +708,19 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
...
@@ -707,17 +708,19 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
x
=
state_dict
[
key
]
x
=
state_dict
[
key
]
dim
=
x
.
shape
[
0
]
//
world_size
//
2
dim
=
x
.
shape
[
0
]
//
world_size
//
2
state_dict
[
key
]
=
rearrange
(
state_dict
[
key
]
=
rearrange
(
rearrange
(
x
,
"(two o) ... -> two o ..."
,
two
=
2
)[:,
rank
*
dim
:
(
rank
+
1
)
*
dim
],
rearrange
(
x
,
"(two o) ... -> two o ..."
,
two
=
2
)[:,
rank
*
dim
:
(
rank
+
1
)
*
dim
],
"two o ... -> (two o) ..."
,
"two o ... -> (two o) ..."
,
)
)
def
shard_qkv_headdim
(
state_dict
,
key
):
def
shard_qkv_headdim
(
state_dict
,
key
):
if
key
in
state_dict
:
if
key
in
state_dict
:
n_head_each_rank
=
[
n_head_each_rank
=
[
get_dim_for_local_rank
(
n_head
,
world_size
,
local_rank
)
for
local_rank
in
range
(
world_size
)
get_dim_for_local_rank
(
n_head
,
world_size
,
local_rank
)
for
local_rank
in
range
(
world_size
)
]
]
n_head_kv_each_rank
=
[
n_head_kv_each_rank
=
[
get_dim_for_local_rank
(
n_head_kv
,
world_size
,
local_rank
)
for
local_rank
in
range
(
world_size
)
get_dim_for_local_rank
(
n_head_kv
,
world_size
,
local_rank
)
for
local_rank
in
range
(
world_size
)
]
]
beg_n_head
=
sum
(
n_head_each_rank
[:
rank
])
beg_n_head
=
sum
(
n_head_each_rank
[:
rank
])
...
@@ -729,7 +732,8 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
...
@@ -729,7 +732,8 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
if
n_head_kv
==
n_head
:
if
n_head_kv
==
n_head
:
x
=
rearrange
(
state_dict
[
key
],
"(three d) ... -> three d ..."
,
three
=
3
)
x
=
rearrange
(
state_dict
[
key
],
"(three d) ... -> three d ..."
,
three
=
3
)
state_dict
[
key
]
=
rearrange
(
state_dict
[
key
]
=
rearrange
(
x
[:,
beg_n_head
*
head_dim
:
end_n_head
*
head_dim
],
"three d ... -> (three d) ..."
x
[:,
beg_n_head
*
head_dim
:
end_n_head
*
head_dim
],
"three d ... -> (three d) ..."
,
)
)
else
:
else
:
x
=
rearrange
(
x
=
rearrange
(
...
@@ -741,8 +745,14 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
...
@@ -741,8 +745,14 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
torch
.
cat
(
torch
.
cat
(
[
[
x
[
beg_n_head
:
end_n_head
],
x
[
beg_n_head
:
end_n_head
],
x
[
n_head
+
beg_n_head_kv
:
n_head
+
end_n_head_kv
],
x
[
n_head
+
beg_n_head_kv
:
n_head
+
end_n_head_kv
],
x
[
n_head
+
n_head_kv
+
beg_n_head_kv
:
n_head
+
n_head_kv
+
end_n_head_kv
],
x
[
n_head
+
n_head_kv
+
beg_n_head_kv
:
n_head
+
n_head_kv
+
end_n_head_kv
],
],
],
dim
=
0
,
dim
=
0
,
),
),
...
@@ -824,7 +834,7 @@ def combine_state_dicts_tp(state_dicts, config):
...
@@ -824,7 +834,7 @@ def combine_state_dicts_tp(state_dicts, config):
torch
.
cat
([
x
[:
n_head_per_rank
]
for
x
in
xs
],
dim
=
0
),
torch
.
cat
([
x
[:
n_head_per_rank
]
for
x
in
xs
],
dim
=
0
),
torch
.
cat
(
torch
.
cat
(
[
[
x
[
n_head_per_rank
:
n_head_per_rank
+
n_head_kv_per_rank
]
x
[
n_head_per_rank
:
n_head_per_rank
+
n_head_kv_per_rank
]
for
x
in
xs
for
x
in
xs
],
],
dim
=
0
,
dim
=
0
,
...
...
flash_attn/models/gpt_neox.py
View file @
f1a73d07
...
@@ -2,80 +2,100 @@
...
@@ -2,80 +2,100 @@
import
math
import
math
import
re
import
re
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
einops
import
rearrange
from
transformers
import
GPT2Config
,
GPTNeoXConfig
from
transformers
import
GPT2Config
,
GPTNeoXConfig
def
remap_state_dict_hf_gpt_neox
(
state_dict
,
config
):
def
remap_state_dict_hf_gpt_neox
(
state_dict
,
config
):
def
key_mapping_layers
(
key
):
def
key_mapping_layers
(
key
):
return
re
.
sub
(
r
'^gpt_neox.'
,
'transformer.'
,
key
)
return
re
.
sub
(
r
"^gpt_neox."
,
"transformer."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding
# Word embedding
def
key_mapping_emb
(
key
):
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
'^transformer.embed_in.'
,
'transformer.embeddings.word_embeddings.'
,
key
)
return
re
.
sub
(
r
"^transformer.embed_in."
,
"transformer.embeddings.word_embeddings."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
'
transformer.embeddings.word_embeddings.weight
'
)
word_embeddings
=
state_dict
.
pop
(
"
transformer.embeddings.word_embeddings.weight
"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
'
pad_vocab_size_multiple
'
,
1
)
pad_vocab_size_multiple
=
getattr
(
config
,
"
pad_vocab_size_multiple
"
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
vocab_size
=
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
state_dict
[
'
transformer.embeddings.word_embeddings.weight
'
]
=
F
.
pad
(
state_dict
[
"
transformer.embeddings.word_embeddings.weight
"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
)
if
getattr
(
config
,
'
tie_word_embeddings
'
):
if
getattr
(
config
,
"
tie_word_embeddings
"
):
state_dict
[
'
lm_head.weight
'
]
=
state_dict
[
'
transformer.embeddings.word_embeddings.weight
'
]
state_dict
[
"
lm_head.weight
"
]
=
state_dict
[
"
transformer.embeddings.word_embeddings.weight
"
]
else
:
else
:
output_embeddings
=
state_dict
.
pop
(
'
embed_out.weight
'
)
output_embeddings
=
state_dict
.
pop
(
"
embed_out.weight
"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict
[
'
lm_head.weight
'
]
=
F
.
pad
(
state_dict
[
"
lm_head.weight
"
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
)
# LayerNorm
# LayerNorm
def
key_mapping_ln
(
key
):
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
'^transformer.final_layer_norm.'
,
r
'transformer.ln_f.'
,
key
)
key
=
re
.
sub
(
r
"^transformer.final_layer_norm."
,
r
"transformer.ln_f."
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).input_layernorm.'
,
r
'transformer.layers.\1.norm1.'
,
key
)
key
=
re
.
sub
(
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).post_attention_layernorm.'
,
r
'transformer.layers.\1.norm2.'
,
key
)
r
"^transformer.layers.(\d+).input_layernorm."
,
r
"transformer.layers.\1.norm1."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).post_attention_layernorm."
,
r
"transformer.layers.\1.norm2."
,
key
,
)
return
key
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
# MLP
def
key_mapping_mlp
(
key
):
def
key_mapping_mlp
(
key
):
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).mlp.dense_h_to_4h.'
,
r
'transformer.layers.\1.mlp.fc1.'
,
key
)
key
=
re
.
sub
(
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).mlp.dense_4h_to_h.'
,
r
'transformer.layers.\1.mlp.fc2.'
,
key
)
r
"^transformer.layers.(\d+).mlp.dense_h_to_4h."
,
r
"transformer.layers.\1.mlp.fc1."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).mlp.dense_4h_to_h."
,
r
"transformer.layers.\1.mlp.fc2."
,
key
)
return
key
return
key
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
# Attention
for
l
in
range
(
config
.
n_layer
):
for
l
in
range
(
config
.
n_layer
):
# We don't store these biases
# We don't store these biases
state_dict
.
pop
(
f
'
transformer.layers.
{
l
}
.attention.bias
'
)
state_dict
.
pop
(
f
"
transformer.layers.
{
l
}
.attention.bias
"
)
state_dict
.
pop
(
f
'
transformer.layers.
{
l
}
.attention.masked_bias
'
)
state_dict
.
pop
(
f
"
transformer.layers.
{
l
}
.attention.masked_bias
"
)
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
headdim
=
config
.
hidden_size
//
config
.
num_attention_heads
headdim
=
config
.
hidden_size
//
config
.
num_attention_heads
Wqkv
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attention.query_key_value.weight'
)
Wqkv
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.query_key_value.weight"
)
state_dict
[
f
'transformer.layers.
{
l
}
.mixer.Wqkv.weight'
]
=
rearrange
(
state_dict
[
f
"transformer.layers.
{
l
}
.mixer.Wqkv.weight"
]
=
rearrange
(
Wqkv
,
'(nheads three headdim) ... -> (three nheads headdim) ...'
,
Wqkv
,
three
=
3
,
headdim
=
headdim
"(nheads three headdim) ... -> (three nheads headdim) ..."
,
three
=
3
,
headdim
=
headdim
,
)
)
bqkv
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attention.query_key_value.bias'
)
bqkv
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.query_key_value.bias"
)
state_dict
[
f
'transformer.layers.
{
l
}
.mixer.Wqkv.bias'
]
=
rearrange
(
state_dict
[
f
"transformer.layers.
{
l
}
.mixer.Wqkv.bias"
]
=
rearrange
(
bqkv
,
'(nheads three headdim) -> (three nheads headdim)'
,
bqkv
,
"(nheads three headdim) -> (three nheads headdim)"
,
three
=
3
,
headdim
=
headdim
three
=
3
,
headdim
=
headdim
)
)
def
key_mapping_attn
(
key
):
def
key_mapping_attn
(
key
):
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).attention.dense.'
,
key
=
re
.
sub
(
r
'transformer.layers.\1.mixer.out_proj.'
,
key
)
r
"^transformer.layers.(\d+).attention.dense."
,
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).attention.rotary_emb.'
,
r
"transformer.layers.\1.mixer.out_proj."
,
r
'transformer.layers.\1.mixer.rotary_emb.'
,
key
)
key
,
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).attention.rotary_emb."
,
r
"transformer.layers.\1.mixer.rotary_emb."
,
key
,
)
return
key
return
key
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
return
state_dict
...
...
flash_attn/models/gptj.py
View file @
f1a73d07
...
@@ -2,67 +2,78 @@
...
@@ -2,67 +2,78 @@
import
math
import
math
import
re
import
re
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
transformers
import
GPT2Config
,
GPTJConfig
from
transformers
import
GPT2Config
,
GPTJConfig
def
remap_state_dict_hf_gptj
(
state_dict
,
config
):
def
remap_state_dict_hf_gptj
(
state_dict
,
config
):
def
key_mapping_layers
(
key
):
def
key_mapping_layers
(
key
):
return
re
.
sub
(
r
'^transformer.h.'
,
'transformer.layers.'
,
key
)
return
re
.
sub
(
r
"^transformer.h."
,
"transformer.layers."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding
# Word embedding
def
key_mapping_emb
(
key
):
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
'^transformer.wte.'
,
'transformer.embeddings.word_embeddings.'
,
key
)
return
re
.
sub
(
r
"^transformer.wte."
,
"transformer.embeddings.word_embeddings."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
'
transformer.embeddings.word_embeddings.weight
'
)
word_embeddings
=
state_dict
.
pop
(
"
transformer.embeddings.word_embeddings.weight
"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
'
pad_vocab_size_multiple
'
,
1
)
pad_vocab_size_multiple
=
getattr
(
config
,
"
pad_vocab_size_multiple
"
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
vocab_size
=
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
state_dict
[
'
transformer.embeddings.word_embeddings.weight
'
]
=
F
.
pad
(
state_dict
[
"
transformer.embeddings.word_embeddings.weight
"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
)
if
getattr
(
config
,
'
tie_word_embeddings
'
):
if
getattr
(
config
,
"
tie_word_embeddings
"
):
state_dict
[
'
lm_head.weight
'
]
=
state_dict
[
'
transformer.embeddings.word_embeddings.weight
'
]
state_dict
[
"
lm_head.weight
"
]
=
state_dict
[
"
transformer.embeddings.word_embeddings.weight
"
]
else
:
else
:
output_embeddings
=
state_dict
.
pop
(
'
lm_head.weight
'
)
output_embeddings
=
state_dict
.
pop
(
"
lm_head.weight
"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict
[
'
lm_head.weight
'
]
=
F
.
pad
(
state_dict
[
"
lm_head.weight
"
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
)
output_embeddings_bias
=
state_dict
.
pop
(
'
lm_head.bias
'
)
output_embeddings_bias
=
state_dict
.
pop
(
"
lm_head.bias
"
)
state_dict
[
'
lm_head.bias
'
]
=
F
.
pad
(
state_dict
[
"
lm_head.bias
"
]
=
F
.
pad
(
output_embeddings_bias
,
(
0
,
vocab_size
-
output_embeddings_bias
.
shape
[
0
])
output_embeddings_bias
,
(
0
,
vocab_size
-
output_embeddings_bias
.
shape
[
0
])
)
)
# LayerNorm
# LayerNorm
def
key_mapping_ln
(
key
):
def
key_mapping_ln
(
key
):
return
re
.
sub
(
r
'^transformer.layers.(\d+).ln_1.'
,
r
'transformer.layers.\1.norm1.'
,
key
)
return
re
.
sub
(
r
"^transformer.layers.(\d+).ln_1."
,
r
"transformer.layers.\1.norm1."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
# MLP
def
key_mapping_mlp
(
key
):
def
key_mapping_mlp
(
key
):
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).mlp.fc_in.'
,
r
'transformer.layers.\1.mlp.fc1.'
,
key
)
key
=
re
.
sub
(
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).mlp.fc_out.'
,
r
'transformer.layers.\1.mlp.fc2.'
,
key
)
r
"^transformer.layers.(\d+).mlp.fc_in."
,
r
"transformer.layers.\1.mlp.fc1."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).mlp.fc_out."
,
r
"transformer.layers.\1.mlp.fc2."
,
key
)
return
key
return
key
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
# Attention
for
l
in
range
(
config
.
n_layer
):
for
l
in
range
(
config
.
n_layer
):
Wq
=
state_dict
.
pop
(
f
'
transformer.layers.
{
l
}
.attn.q_proj.weight
'
)
Wq
=
state_dict
.
pop
(
f
"
transformer.layers.
{
l
}
.attn.q_proj.weight
"
)
Wk
=
state_dict
.
pop
(
f
'
transformer.layers.
{
l
}
.attn.k_proj.weight
'
)
Wk
=
state_dict
.
pop
(
f
"
transformer.layers.
{
l
}
.attn.k_proj.weight
"
)
Wv
=
state_dict
.
pop
(
f
'
transformer.layers.
{
l
}
.attn.v_proj.weight
'
)
Wv
=
state_dict
.
pop
(
f
"
transformer.layers.
{
l
}
.attn.v_proj.weight
"
)
state_dict
[
f
'
transformer.layers.
{
l
}
.mixer.Wqkv.weight
'
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
state_dict
[
f
"
transformer.layers.
{
l
}
.mixer.Wqkv.weight
"
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
# We don't store these biases
# We don't store these biases
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attn.bias'
)
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attn.bias"
)
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attn.masked_bias'
)
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attn.masked_bias"
)
def
key_mapping_attn
(
key
):
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
'^transformer.layers.(\d+).attn.out_proj.'
,
return
re
.
sub
(
r
'transformer.layers.\1.mixer.out_proj.'
,
key
)
r
"^transformer.layers.(\d+).attn.out_proj."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
return
state_dict
...
...
flash_attn/models/llama.py
View file @
f1a73d07
...
@@ -15,63 +15,81 @@ from transformers import GPT2Config, LlamaConfig
...
@@ -15,63 +15,81 @@ from transformers import GPT2Config, LlamaConfig
def
remap_state_dict_meta_llama
(
state_dict
,
config
):
def
remap_state_dict_meta_llama
(
state_dict
,
config
):
def
key_mapping_layers
(
key
):
def
key_mapping_layers
(
key
):
return
f
'transformer.
{
key
}
'
if
not
key
.
startswith
(
'output.'
)
else
key
return
f
"transformer.
{
key
}
"
if
not
key
.
startswith
(
"output."
)
else
key
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding
# Word embedding
def
key_mapping_emb
(
key
):
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
'^transformer.tok_embeddings.'
,
'transformer.embeddings.word_embeddings.'
,
key
)
return
re
.
sub
(
r
"^transformer.tok_embeddings."
,
"transformer.embeddings.word_embeddings."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
'
transformer.embeddings.word_embeddings.weight
'
)
word_embeddings
=
state_dict
.
pop
(
"
transformer.embeddings.word_embeddings.weight
"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
(
math
.
ceil
(
word_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
vocab_size
=
(
*
pad_vocab_size_multiple
)
math
.
ceil
(
word_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
state_dict
[
'transformer.embeddings.word_embeddings.weight'
]
=
F
.
pad
(
)
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
)
if
getattr
(
config
,
'
tie_word_embeddings
'
):
if
getattr
(
config
,
"
tie_word_embeddings
"
):
state_dict
[
'
lm_head.weight
'
]
=
state_dict
[
'
transformer.embeddings.word_embeddings.weight
'
]
state_dict
[
"
lm_head.weight
"
]
=
state_dict
[
"
transformer.embeddings.word_embeddings.weight
"
]
else
:
else
:
output_embeddings
=
state_dict
.
pop
(
'
output.weight
'
)
output_embeddings
=
state_dict
.
pop
(
"
output.weight
"
)
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
# differently.
# differently.
vocab_size
=
(
math
.
ceil
(
output_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
vocab_size
=
(
*
pad_vocab_size_multiple
)
math
.
ceil
(
output_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict
[
'
lm_head.weight
'
]
=
F
.
pad
(
state_dict
[
"
lm_head.weight
"
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
)
# LayerNorm
# LayerNorm
def
key_mapping_ln
(
key
):
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
'^transformer.norm.'
,
r
'transformer.ln_f.'
,
key
)
key
=
re
.
sub
(
r
"^transformer.norm."
,
r
"transformer.ln_f."
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).attention_norm.'
,
r
'transformer.layers.\1.norm1.'
,
key
)
key
=
re
.
sub
(
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).ffn_norm.'
,
r
'transformer.layers.\1.norm2.'
,
key
)
r
"^transformer.layers.(\d+).attention_norm."
,
r
"transformer.layers.\1.norm1."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).ffn_norm."
,
r
"transformer.layers.\1.norm2."
,
key
)
return
key
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
# MLP
for
l
in
range
(
config
.
n_layer
):
for
l
in
range
(
config
.
n_layer
):
w1
=
state_dict
.
pop
(
f
'
transformer.layers.
{
l
}
.feed_forward.w1.weight
'
)
w1
=
state_dict
.
pop
(
f
"
transformer.layers.
{
l
}
.feed_forward.w1.weight
"
)
w3
=
state_dict
.
pop
(
f
'
transformer.layers.
{
l
}
.feed_forward.w3.weight
'
)
w3
=
state_dict
.
pop
(
f
"
transformer.layers.
{
l
}
.feed_forward.w3.weight
"
)
# Our ordering is different
# Our ordering is different
state_dict
[
f
'transformer.layers.
{
l
}
.mlp.fc1.weight'
]
=
torch
.
cat
([
w3
,
w1
],
dim
=
0
)
state_dict
[
f
"transformer.layers.
{
l
}
.mlp.fc1.weight"
]
=
torch
.
cat
([
w3
,
w1
],
dim
=
0
)
def
key_mapping_mlp
(
key
):
def
key_mapping_mlp
(
key
):
return
re
.
sub
(
r
'^transformer.layers.(\d+).feed_forward.w2.'
,
return
re
.
sub
(
r
'transformer.layers.\1.mlp.fc2.'
,
key
)
r
"^transformer.layers.(\d+).feed_forward.w2."
,
r
"transformer.layers.\1.mlp.fc2."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
# Attention
for
l
in
range
(
config
.
n_layer
):
for
l
in
range
(
config
.
n_layer
):
Wq
=
state_dict
.
pop
(
f
'
transformer.layers.
{
l
}
.attention.wq.weight
'
)
Wq
=
state_dict
.
pop
(
f
"
transformer.layers.
{
l
}
.attention.wq.weight
"
)
Wk
=
state_dict
.
pop
(
f
'
transformer.layers.
{
l
}
.attention.wk.weight
'
)
Wk
=
state_dict
.
pop
(
f
"
transformer.layers.
{
l
}
.attention.wk.weight
"
)
Wv
=
state_dict
.
pop
(
f
'
transformer.layers.
{
l
}
.attention.wv.weight
'
)
Wv
=
state_dict
.
pop
(
f
"
transformer.layers.
{
l
}
.attention.wv.weight
"
)
state_dict
[
f
'
transformer.layers.
{
l
}
.mixer.Wqkv.weight
'
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
state_dict
[
f
"
transformer.layers.
{
l
}
.mixer.Wqkv.weight
"
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
# We don't store these
# We don't store these
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attention.inner_attention.rope.freqs'
,
None
)
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.inner_attention.rope.freqs"
,
None
)
def
key_mapping_attn
(
key
):
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
'^transformer.layers.(\d+).attention.wo.'
,
return
re
.
sub
(
r
'transformer.layers.\1.mixer.out_proj.'
,
key
)
r
"^transformer.layers.(\d+).attention.wo."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
.
pop
(
"transformer.rope.freqs"
,
None
)
state_dict
.
pop
(
"transformer.rope.freqs"
,
None
)
...
@@ -82,29 +100,32 @@ def remap_state_dict_meta_llama(state_dict, config):
...
@@ -82,29 +100,32 @@ def remap_state_dict_meta_llama(state_dict, config):
def
remap_state_dict_hf_llama
(
state_dict
,
config
):
def
remap_state_dict_hf_llama
(
state_dict
,
config
):
# Embedding
# Embedding
def
key_mapping_emb
(
key
):
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
'
^model.embed_tokens.
'
,
'
transformer.embeddings.word_embeddings.
'
,
key
)
return
re
.
sub
(
r
"
^model.embed_tokens.
"
,
"
transformer.embeddings.word_embeddings.
"
,
key
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
'
transformer.embeddings.word_embeddings.weight
'
)
word_embeddings
=
state_dict
.
pop
(
"
transformer.embeddings.word_embeddings.weight
"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
(
math
.
ceil
(
word_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
vocab_size
=
(
*
pad_vocab_size_multiple
)
math
.
ceil
(
word_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
state_dict
[
'transformer.embeddings.word_embeddings.weight'
]
=
F
.
pad
(
)
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
)
# LM head
# LM head
if
getattr
(
config
,
'
tie_word_embeddings
'
):
if
getattr
(
config
,
"
tie_word_embeddings
"
):
state_dict
[
'
lm_head.weight
'
]
=
state_dict
[
'
transformer.embeddings.word_embeddings.weight
'
]
state_dict
[
"
lm_head.weight
"
]
=
state_dict
[
"
transformer.embeddings.word_embeddings.weight
"
]
else
:
else
:
output_embeddings
=
state_dict
.
pop
(
'
lm_head.weight
'
)
output_embeddings
=
state_dict
.
pop
(
"
lm_head.weight
"
)
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
# differently.
# differently.
vocab_size
=
(
math
.
ceil
(
output_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
vocab_size
=
(
*
pad_vocab_size_multiple
)
math
.
ceil
(
output_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict
[
'
lm_head.weight
'
]
=
F
.
pad
(
state_dict
[
"
lm_head.weight
"
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
)
...
@@ -113,21 +134,22 @@ def remap_state_dict_hf_llama(state_dict, config):
...
@@ -113,21 +134,22 @@ def remap_state_dict_hf_llama(state_dict, config):
# Fusing weights this way based on difference in the following:
# Fusing weights this way based on difference in the following:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220
# https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115
# https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115
w1
=
state_dict
.
pop
(
f
'
model.layers.
{
l
}
.mlp.gate_proj.weight
'
)
w1
=
state_dict
.
pop
(
f
"
model.layers.
{
l
}
.mlp.gate_proj.weight
"
)
w3
=
state_dict
.
pop
(
f
'
model.layers.
{
l
}
.mlp.up_proj.weight
'
)
w3
=
state_dict
.
pop
(
f
"
model.layers.
{
l
}
.mlp.up_proj.weight
"
)
state_dict
[
f
'
transformer.layers.
{
l
}
.mlp.fc1.weight
'
]
=
torch
.
cat
([
w3
,
w1
],
dim
=
0
)
state_dict
[
f
"
transformer.layers.
{
l
}
.mlp.fc1.weight
"
]
=
torch
.
cat
([
w3
,
w1
],
dim
=
0
)
def
key_mapping_mlp
(
key
):
def
key_mapping_mlp
(
key
):
return
re
.
sub
(
r
'^model.layers.(\d+).mlp.down_proj.'
,
return
re
.
sub
(
r
"^model.layers.(\d+).mlp.down_proj."
,
r
"transformer.layers.\1.mlp.fc2."
,
key
)
r
'transformer.layers.\1.mlp.fc2.'
,
key
)
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# LayerNorm
# LayerNorm
def
key_mapping_ln
(
key
):
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
'^model.norm.'
,
r
'transformer.ln_f.'
,
key
)
key
=
re
.
sub
(
r
"^model.norm."
,
r
"transformer.ln_f."
,
key
)
key
=
re
.
sub
(
r
'^model.layers.(\d+).input_layernorm.'
,
r
'transformer.layers.\1.norm1.'
,
key
)
key
=
re
.
sub
(
r
"^model.layers.(\d+).input_layernorm."
,
r
"transformer.layers.\1.norm1."
,
key
)
key
=
re
.
sub
(
r
'^model.layers.(\d+).post_attention_layernorm.'
,
r
'transformer.layers.\1.norm2.'
,
key
)
key
=
re
.
sub
(
r
"^model.layers.(\d+).post_attention_layernorm."
,
r
"transformer.layers.\1.norm2."
,
key
)
return
key
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
...
@@ -135,42 +157,52 @@ def remap_state_dict_hf_llama(state_dict, config):
...
@@ -135,42 +157,52 @@ def remap_state_dict_hf_llama(state_dict, config):
def
inv_permute
(
w
):
def
inv_permute
(
w
):
# Inverse of permute implemented in:
# Inverse of permute implemented in:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
return
w
.
reshape
(
return
(
config
.
n_head
,
2
,
config
.
n_embd
//
config
.
n_head
//
2
,
config
.
n_embd
w
.
reshape
(
config
.
n_head
,
2
,
config
.
n_embd
//
config
.
n_head
//
2
,
config
.
n_embd
)
).
transpose
(
1
,
2
).
reshape
(
config
.
n_embd
,
config
.
n_embd
)
.
transpose
(
1
,
2
)
.
reshape
(
config
.
n_embd
,
config
.
n_embd
)
)
# Attention
# Attention
for
l
in
range
(
config
.
n_layer
):
for
l
in
range
(
config
.
n_layer
):
Wq
=
state_dict
.
pop
(
f
'
model.layers.
{
l
}
.self_attn.q_proj.weight
'
)
Wq
=
state_dict
.
pop
(
f
"
model.layers.
{
l
}
.self_attn.q_proj.weight
"
)
Wk
=
state_dict
.
pop
(
f
'
model.layers.
{
l
}
.self_attn.k_proj.weight
'
)
Wk
=
state_dict
.
pop
(
f
"
model.layers.
{
l
}
.self_attn.k_proj.weight
"
)
Wv
=
state_dict
.
pop
(
f
'
model.layers.
{
l
}
.self_attn.v_proj.weight
'
)
Wv
=
state_dict
.
pop
(
f
"
model.layers.
{
l
}
.self_attn.v_proj.weight
"
)
state_dict
[
f
'
transformer.layers.
{
l
}
.mixer.Wqkv.weight
'
]
=
torch
.
cat
(
state_dict
[
f
"
transformer.layers.
{
l
}
.mixer.Wqkv.weight
"
]
=
torch
.
cat
(
[
inv_permute
(
Wq
),
inv_permute
(
Wk
),
Wv
],
dim
=
0
[
inv_permute
(
Wq
),
inv_permute
(
Wk
),
Wv
],
dim
=
0
)
)
# We don't store these
# We don't store these
state_dict
.
pop
(
f
'
model.layers.
{
l
}
.self_attn.rotary_emb.inv_freq
'
,
None
)
state_dict
.
pop
(
f
"
model.layers.
{
l
}
.self_attn.rotary_emb.inv_freq
"
,
None
)
def
key_mapping_attn
(
key
):
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
'^model.layers.(\d+).self_attn.o_proj.'
,
return
re
.
sub
(
r
'transformer.layers.\1.mixer.out_proj.'
,
key
)
r
"^model.layers.(\d+).self_attn.o_proj."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
return
state_dict
def
config_from_meta_checkpoint
(
checkpoint_path
:
Union
[
str
,
os
.
PathLike
],
model_name
:
str
)
->
LlamaConfig
:
def
config_from_meta_checkpoint
(
checkpoint_path
:
Union
[
str
,
os
.
PathLike
],
model_name
:
str
)
->
LlamaConfig
:
"""Load a LlamaConfig from a checkpoint path."""
"""Load a LlamaConfig from a checkpoint path."""
with
open
(
Path
(
checkpoint_path
)
/
model_name
/
'
params.json
'
)
as
f
:
with
open
(
Path
(
checkpoint_path
)
/
model_name
/
"
params.json
"
)
as
f
:
params
=
json
.
load
(
f
)
params
=
json
.
load
(
f
)
config
=
LlamaConfig
(
hidden_size
=
params
[
'dim'
],
intermediate_size
=
None
,
config
=
LlamaConfig
(
num_attention_heads
=
params
[
'n_heads'
],
hidden_size
=
params
[
"dim"
],
num_hidden_layers
=
params
[
'n_layers'
],
intermediate_size
=
None
,
rms_norm_eps
=
params
[
'norm_eps'
])
num_attention_heads
=
params
[
"n_heads"
],
num_hidden_layers
=
params
[
"n_layers"
],
rms_norm_eps
=
params
[
"norm_eps"
],
)
return
config
return
config
def
config_from_hf_checkpoint
(
checkpoint_path
:
Union
[
str
,
os
.
PathLike
],
model_name
:
str
)
->
LlamaConfig
:
def
config_from_hf_checkpoint
(
return
LlamaConfig
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
'
{
model_name
}
-hf'
/
"config.json"
)
checkpoint_path
:
Union
[
str
,
os
.
PathLike
],
model_name
:
str
)
->
LlamaConfig
:
return
LlamaConfig
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
/
"config.json"
)
def
config_from_checkpoint
(
def
config_from_checkpoint
(
...
@@ -182,10 +214,14 @@ def config_from_checkpoint(
...
@@ -182,10 +214,14 @@ def config_from_checkpoint(
return
config_from_hf_checkpoint
(
checkpoint_path
,
model_name
)
return
config_from_hf_checkpoint
(
checkpoint_path
,
model_name
)
def
state_dicts_from_checkpoint
(
checkpoint_path
:
Union
[
str
,
os
.
PathLike
],
model_name
:
str
)
->
list
[
dict
]:
def
state_dicts_from_checkpoint
(
checkpoint_path
:
Union
[
str
,
os
.
PathLike
],
model_name
:
str
)
->
list
[
dict
]:
# Need to sort, otherwise we mess up the ordering and the weights are wrong
# Need to sort, otherwise we mess up the ordering and the weights are wrong
return
[
torch
.
load
(
path
,
map_location
=
'cpu'
)
return
[
for
path
in
sorted
((
Path
(
checkpoint_path
)
/
model_name
).
glob
(
'consolidated.*.pth'
))]
torch
.
load
(
path
,
map_location
=
"cpu"
)
for
path
in
sorted
((
Path
(
checkpoint_path
)
/
model_name
).
glob
(
"consolidated.*.pth"
))
]
def
llama_config_to_gpt2_config
(
llama_config
:
LlamaConfig
)
->
GPT2Config
:
def
llama_config_to_gpt2_config
(
llama_config
:
LlamaConfig
)
->
GPT2Config
:
...
@@ -196,7 +232,7 @@ def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
...
@@ -196,7 +232,7 @@ def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
n_layer
=
llama_config
.
num_hidden_layers
,
n_layer
=
llama_config
.
num_hidden_layers
,
n_head
=
llama_config
.
num_attention_heads
,
n_head
=
llama_config
.
num_attention_heads
,
n_inner
=
llama_config
.
intermediate_size
,
n_inner
=
llama_config
.
intermediate_size
,
activation_function
=
'
swiglu
'
,
# Hardcode since HF calls it 'silu'
activation_function
=
"
swiglu
"
,
# Hardcode since HF calls it 'silu'
# Llama doesn't have dropout, idk if it's because they only release the inference code
# Llama doesn't have dropout, idk if it's because they only release the inference code
resid_pdrop
=
0.0
,
resid_pdrop
=
0.0
,
embd_pdrop
=
0.0
,
embd_pdrop
=
0.0
,
...
...
flash_attn/models/opt.py
View file @
f1a73d07
...
@@ -2,75 +2,86 @@
...
@@ -2,75 +2,86 @@
import
math
import
math
import
re
import
re
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
transformers
import
GPT2Config
,
OPTConfig
from
transformers
import
GPT2Config
,
OPTConfig
def
remap_state_dict_hf_opt
(
state_dict
,
config
):
def
remap_state_dict_hf_opt
(
state_dict
,
config
):
def
key_mapping_model
(
key
):
def
key_mapping_model
(
key
):
key
=
re
.
sub
(
r
'
^model.decoder.
'
,
'
transformer.
'
,
key
)
key
=
re
.
sub
(
r
"
^model.decoder.
"
,
"
transformer.
"
,
key
)
# The OPT-350m model uses '^decoder' instead of '^model.decoder'
# The OPT-350m model uses '^decoder' instead of '^model.decoder'
key
=
re
.
sub
(
r
'
^decoder.
'
,
'
transformer.
'
,
key
)
key
=
re
.
sub
(
r
"
^decoder.
"
,
"
transformer.
"
,
key
)
return
key
return
key
state_dict
=
OrderedDict
((
key_mapping_model
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_model
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding and position embedding
# Word embedding and position embedding
def
key_mapping_emb
(
key
):
def
key_mapping_emb
(
key
):
key
=
re
.
sub
(
r
'
^transformer.embed_tokens.
'
,
'
transformer.embeddings.word_embeddings.
'
,
key
)
key
=
re
.
sub
(
r
"
^transformer.embed_tokens.
"
,
"
transformer.embeddings.word_embeddings.
"
,
key
)
# The OPT-350m model uses has project_in and project_out
# The OPT-350m model uses has project_in and project_out
key
=
re
.
sub
(
r
'^transformer.project_in.'
,
'transformer.embeddings.project_in.'
,
key
)
key
=
re
.
sub
(
r
"^transformer.project_in."
,
"transformer.embeddings.project_in."
,
key
)
key
=
re
.
sub
(
r
'^transformer.project_out.'
,
'project_out.'
,
key
)
key
=
re
.
sub
(
r
"^transformer.project_out."
,
"project_out."
,
key
)
key
=
re
.
sub
(
r
'^transformer.embed_positions.'
,
key
=
re
.
sub
(
'transformer.embeddings.position_embeddings.'
,
key
)
r
"^transformer.embed_positions."
,
"transformer.embeddings.position_embeddings."
,
key
)
return
key
return
key
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# OPT uses the first 2 indices of pos_emb for padding tokens
# OPT uses the first 2 indices of pos_emb for padding tokens
pos_embeddings
=
state_dict
.
pop
(
'
transformer.embeddings.position_embeddings.weight
'
)
pos_embeddings
=
state_dict
.
pop
(
"
transformer.embeddings.position_embeddings.weight
"
)
state_dict
[
'
transformer.embeddings.position_embeddings.weight
'
]
=
pos_embeddings
[
2
:]
state_dict
[
"
transformer.embeddings.position_embeddings.weight
"
]
=
pos_embeddings
[
2
:]
word_embeddings
=
state_dict
.
pop
(
'
transformer.embeddings.word_embeddings.weight
'
)
word_embeddings
=
state_dict
.
pop
(
"
transformer.embeddings.word_embeddings.weight
"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
'
pad_vocab_size_multiple
'
,
1
)
pad_vocab_size_multiple
=
getattr
(
config
,
"
pad_vocab_size_multiple
"
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
vocab_size
=
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
state_dict
[
'
transformer.embeddings.word_embeddings.weight
'
]
=
F
.
pad
(
state_dict
[
"
transformer.embeddings.word_embeddings.weight
"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
)
state_dict
[
'
lm_head.weight
'
]
=
state_dict
[
'
transformer.embeddings.word_embeddings.weight
'
]
state_dict
[
"
lm_head.weight
"
]
=
state_dict
[
"
transformer.embeddings.word_embeddings.weight
"
]
# LayerNorm
# LayerNorm
def
key_mapping_ln
(
key
):
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
'
^transformer.final_layer_norm.
'
,
r
'
transformer.ln_f.
'
,
key
)
key
=
re
.
sub
(
r
"
^transformer.final_layer_norm.
"
,
r
"
transformer.ln_f.
"
,
key
)
# The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm'
# The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm'
key
=
re
.
sub
(
r
'^transformer.layer_norm.'
,
r
'transformer.ln_f.'
,
key
)
key
=
re
.
sub
(
r
"^transformer.layer_norm."
,
r
"transformer.ln_f."
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).self_attn_layer_norm.'
,
key
=
re
.
sub
(
r
'transformer.layers.\1.norm1.'
,
key
)
r
"^transformer.layers.(\d+).self_attn_layer_norm."
,
r
"transformer.layers.\1.norm1."
,
key
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).final_layer_norm.'
,
)
r
'transformer.layers.\1.norm2.'
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).final_layer_norm."
,
r
"transformer.layers.\1.norm2."
,
key
)
return
key
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
# MLP
def
key_mapping_mlp
(
key
):
def
key_mapping_mlp
(
key
):
return
re
.
sub
(
r
'^transformer.layers.(\d+).fc(1|2).'
,
return
re
.
sub
(
r
'transformer.layers.\1.mlp.fc\2.'
,
key
)
r
"^transformer.layers.(\d+).fc(1|2)."
,
r
"transformer.layers.\1.mlp.fc\2."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
# Attention
for
l
in
range
(
config
.
n_layer
):
for
l
in
range
(
config
.
n_layer
):
Wq
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.self_attn.q_proj.weight'
)
Wq
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.q_proj.weight"
)
Wk
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.self_attn.k_proj.weight'
)
Wk
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.k_proj.weight"
)
Wv
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.self_attn.v_proj.weight'
)
Wv
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.v_proj.weight"
)
bq
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.self_attn.q_proj.bias'
)
bq
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.q_proj.bias"
)
bk
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.self_attn.k_proj.bias'
)
bk
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.k_proj.bias"
)
bv
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.self_attn.v_proj.bias'
)
bv
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.v_proj.bias"
)
state_dict
[
f
'transformer.layers.
{
l
}
.mixer.Wqkv.weight'
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
state_dict
[
f
"transformer.layers.
{
l
}
.mixer.Wqkv.weight"
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
state_dict
[
f
'transformer.layers.
{
l
}
.mixer.Wqkv.bias'
]
=
torch
.
cat
([
bq
,
bk
,
bv
],
dim
=
0
)
state_dict
[
f
"transformer.layers.
{
l
}
.mixer.Wqkv.bias"
]
=
torch
.
cat
([
bq
,
bk
,
bv
],
dim
=
0
)
def
key_mapping_attn
(
key
):
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
'^transformer.layers.(\d+).self_attn.out_proj.'
,
return
re
.
sub
(
r
'transformer.layers.\1.mixer.out_proj.'
,
key
)
r
"^transformer.layers.(\d+).self_attn.out_proj."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
return
state_dict
...
@@ -79,8 +90,11 @@ def remap_state_dict_hf_opt(state_dict, config):
...
@@ -79,8 +90,11 @@ def remap_state_dict_hf_opt(state_dict, config):
def
opt_config_to_gpt2_config
(
opt_config
:
OPTConfig
)
->
GPT2Config
:
def
opt_config_to_gpt2_config
(
opt_config
:
OPTConfig
)
->
GPT2Config
:
assert
opt_config
.
layerdrop
==
0.0
assert
opt_config
.
layerdrop
==
0.0
assert
opt_config
.
layer_norm_elementwise_affine
assert
opt_config
.
layer_norm_elementwise_affine
word_embed_proj_dim
=
(
None
if
opt_config
.
word_embed_proj_dim
==
opt_config
.
hidden_size
word_embed_proj_dim
=
(
else
opt_config
.
word_embed_proj_dim
)
None
if
opt_config
.
word_embed_proj_dim
==
opt_config
.
hidden_size
else
opt_config
.
word_embed_proj_dim
)
return
GPT2Config
(
return
GPT2Config
(
vocab_size
=
opt_config
.
vocab_size
,
vocab_size
=
opt_config
.
vocab_size
,
n_positions
=
opt_config
.
max_position_embeddings
,
n_positions
=
opt_config
.
max_position_embeddings
,
...
@@ -98,5 +112,5 @@ def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config:
...
@@ -98,5 +112,5 @@ def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config:
eos_token_id
=
opt_config
.
eos_token_id
,
eos_token_id
=
opt_config
.
eos_token_id
,
# These are new arguments not in the original GPT2Config
# These are new arguments not in the original GPT2Config
prenorm
=
opt_config
.
do_layer_norm_before
,
prenorm
=
opt_config
.
do_layer_norm_before
,
word_embed_proj_dim
=
word_embed_proj_dim
word_embed_proj_dim
=
word_embed_proj_dim
,
)
)
flash_attn/models/vit.py
View file @
f1a73d07
...
@@ -10,13 +10,14 @@ import torch
...
@@ -10,13 +10,14 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
einops
import
rearrange
from
timm.models.helpers
import
named_apply
from
torch.nn.init
import
trunc_normal_
from
torchvision.ops
import
StochasticDepth
from
flash_attn.layers.patch_embed
import
PatchEmbed
from
flash_attn.layers.patch_embed
import
PatchEmbed
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mlp
import
FusedMLP
,
Mlp
from
flash_attn.modules.mlp
import
FusedMLP
,
Mlp
from
timm.models.helpers
import
named_apply
from
torch.nn.init
import
trunc_normal_
from
torchvision.ops
import
StochasticDepth
try
:
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm
...
...
flash_attn/modules/block.py
View file @
f1a73d07
# Copyright (c) 2022, Tri Dao.
# Copyright (c) 2022, Tri Dao.
from
typing
import
Optional
from
functools
import
partial
from
functools
import
partial
from
typing
import
Optional
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
torch
import
Tensor
from
torchvision.ops
import
StochasticDepth
from
torchvision.ops
import
StochasticDepth
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mha
import
MHA
...
@@ -35,11 +34,24 @@ except ImportError:
...
@@ -35,11 +34,24 @@ except ImportError:
class
Block
(
nn
.
Module
):
class
Block
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
dim
,
mixer_cls
=
None
,
mlp_cls
=
None
,
norm_cls
=
nn
.
LayerNorm
,
self
,
dropout_cls
=
nn
.
Dropout
,
prenorm
=
True
,
resid_dropout1
=
0.
,
resid_dropout2
=
0.
,
dim
,
drop_path1
=
0.
,
drop_path2
=
0.
,
fused_dropout_add_ln
=
False
,
return_residual
=
False
,
mixer_cls
=
None
,
residual_in_fp32
=
False
,
sequence_parallel
=
False
,
mark_shared_params
=
False
):
mlp_cls
=
None
,
norm_cls
=
nn
.
LayerNorm
,
dropout_cls
=
nn
.
Dropout
,
prenorm
=
True
,
resid_dropout1
=
0.0
,
resid_dropout2
=
0.0
,
drop_path1
=
0.0
,
drop_path2
=
0.0
,
fused_dropout_add_ln
=
False
,
return_residual
=
False
,
residual_in_fp32
=
False
,
sequence_parallel
=
False
,
mark_shared_params
=
False
,
):
"""
"""
For prenorm=True, this Block has a slightly different structure compared to a regular
For prenorm=True, this Block has a slightly different structure compared to a regular
prenorm Transformer block.
prenorm Transformer block.
...
@@ -63,26 +75,27 @@ class Block(nn.Module):
...
@@ -63,26 +75,27 @@ class Block(nn.Module):
self
.
return_residual
=
return_residual
self
.
return_residual
=
return_residual
self
.
residual_in_fp32
=
residual_in_fp32
self
.
residual_in_fp32
=
residual_in_fp32
if
self
.
residual_in_fp32
:
if
self
.
residual_in_fp32
:
assert
self
.
prenorm
,
'
residual_in_fp32 is only compatible with prenorm=True
'
assert
self
.
prenorm
,
"
residual_in_fp32 is only compatible with prenorm=True
"
if
mixer_cls
is
None
:
if
mixer_cls
is
None
:
mixer_cls
=
partial
(
MHA
,
num_heads
=
dim
//
64
)
mixer_cls
=
partial
(
MHA
,
num_heads
=
dim
//
64
)
if
mlp_cls
is
None
:
if
mlp_cls
is
None
:
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
4
*
dim
)
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
4
*
dim
)
self
.
mixer
=
mixer_cls
(
dim
)
self
.
mixer
=
mixer_cls
(
dim
)
self
.
dropout1
=
dropout_cls
(
resid_dropout1
)
self
.
dropout1
=
dropout_cls
(
resid_dropout1
)
self
.
drop_path1
=
StochasticDepth
(
drop_path1
,
mode
=
'
row
'
)
self
.
drop_path1
=
StochasticDepth
(
drop_path1
,
mode
=
"
row
"
)
self
.
norm1
=
norm_cls
(
dim
)
self
.
norm1
=
norm_cls
(
dim
)
self
.
mlp
=
mlp_cls
(
dim
)
self
.
mlp
=
mlp_cls
(
dim
)
if
not
isinstance
(
self
.
mlp
,
nn
.
Identity
):
if
not
isinstance
(
self
.
mlp
,
nn
.
Identity
):
self
.
dropout2
=
dropout_cls
(
resid_dropout2
)
self
.
dropout2
=
dropout_cls
(
resid_dropout2
)
self
.
drop_path2
=
StochasticDepth
(
drop_path2
,
mode
=
'
row
'
)
self
.
drop_path2
=
StochasticDepth
(
drop_path2
,
mode
=
"
row
"
)
self
.
norm2
=
norm_cls
(
dim
)
self
.
norm2
=
norm_cls
(
dim
)
if
self
.
fused_dropout_add_ln
:
if
self
.
fused_dropout_add_ln
:
assert
dropout_add_layer_norm
is
not
None
,
'dropout_layer_norm is not installed'
assert
dropout_add_layer_norm
is
not
None
,
"dropout_layer_norm is not installed"
assert
dropout_add_rms_norm
is
not
None
,
'dropout_layer_norm is not installed'
assert
dropout_add_rms_norm
is
not
None
,
"dropout_layer_norm is not installed"
assert
(
isinstance
(
self
.
norm1
,
(
nn
.
LayerNorm
,
RMSNorm
))
assert
isinstance
(
self
.
norm1
,
(
nn
.
LayerNorm
,
RMSNorm
))
and
isinstance
(
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
))
self
.
dropout1
,
nn
.
Dropout
)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# then the input to each worker in the tensor parallel group will be different.
# then the input to each worker in the tensor parallel group will be different.
...
@@ -94,22 +107,27 @@ class Block(nn.Module):
...
@@ -94,22 +107,27 @@ class Block(nn.Module):
if
sequence_parallel
:
if
sequence_parallel
:
for
p
in
self
.
norm1
.
parameters
():
for
p
in
self
.
norm1
.
parameters
():
p
.
_sequence_parallel
=
True
p
.
_sequence_parallel
=
True
if
hasattr
(
self
,
'
norm2
'
):
if
hasattr
(
self
,
"
norm2
"
):
for
p
in
self
.
norm2
.
parameters
():
for
p
in
self
.
norm2
.
parameters
():
p
.
_sequence_parallel
=
True
p
.
_sequence_parallel
=
True
# Mark the norm parameters as "shared_params" so that we sync their values at init.
# Mark the norm parameters as "shared_params" so that we sync their values at init.
if
mark_shared_params
:
if
mark_shared_params
:
for
p
in
self
.
norm1
.
parameters
():
for
p
in
self
.
norm1
.
parameters
():
p
.
_shared_params
=
True
p
.
_shared_params
=
True
if
hasattr
(
self
,
'
norm2
'
):
if
hasattr
(
self
,
"
norm2
"
):
for
p
in
self
.
norm2
.
parameters
():
for
p
in
self
.
norm2
.
parameters
():
p
.
_shared_params
=
True
p
.
_shared_params
=
True
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
return
self
.
mixer
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
=
dtype
,
**
kwargs
)
return
self
.
mixer
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
=
dtype
,
**
kwargs
)
def
forward
(
self
,
hidden_states
:
Tensor
,
residual
:
Optional
[
Tensor
]
=
None
,
def
forward
(
mixer_subset
=
None
,
mixer_kwargs
=
None
):
self
,
hidden_states
:
Tensor
,
residual
:
Optional
[
Tensor
]
=
None
,
mixer_subset
=
None
,
mixer_kwargs
=
None
,
):
r
"""Pass the input through the encoder layer.
r
"""Pass the input through the encoder layer.
Args:
Args:
...
@@ -119,8 +137,11 @@ class Block(nn.Module):
...
@@ -119,8 +137,11 @@ class Block(nn.Module):
before applying the query projection. Useful for e.g., ViT where we only care
before applying the query projection. Useful for e.g., ViT where we only care
about the CLS token in the last layer.
about the CLS token in the last layer.
"""
"""
fused_add_norm_fn
=
(
dropout_add_rms_norm
if
RMSNorm
and
isinstance
(
self
.
norm1
,
RMSNorm
)
fused_add_norm_fn
=
(
else
dropout_add_layer_norm
)
dropout_add_rms_norm
if
RMSNorm
and
isinstance
(
self
.
norm1
,
RMSNorm
)
else
dropout_add_layer_norm
)
if
self
.
prenorm
:
if
self
.
prenorm
:
if
not
self
.
fused_dropout_add_ln
:
if
not
self
.
fused_dropout_add_ln
:
dropped
=
self
.
drop_path1
(
self
.
dropout1
(
hidden_states
))
dropped
=
self
.
drop_path1
(
self
.
dropout1
(
hidden_states
))
...
@@ -132,19 +153,28 @@ class Block(nn.Module):
...
@@ -132,19 +153,28 @@ class Block(nn.Module):
if
self
.
drop_path1
.
p
==
0
or
not
self
.
training
:
if
self
.
drop_path1
.
p
==
0
or
not
self
.
training
:
rowscale1
=
None
rowscale1
=
None
else
:
else
:
rowscale1
=
self
.
drop_path1
(
torch
.
ones
(
rowscale1
=
self
.
drop_path1
(
hidden_states
.
shape
[:
-
1
],
device
=
hidden_states
.
device
,
torch
.
ones
(
dtype
=
hidden_states
.
dtype
)
hidden_states
.
shape
[:
-
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
)
)
hidden_states
,
residual
=
fused_add_norm_fn
(
hidden_states
,
residual
=
fused_add_norm_fn
(
hidden_states
,
residual
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
hidden_states
,
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
self
.
norm1
.
eps
,
residual
,
rowscale
=
rowscale1
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
self
.
norm1
.
eps
,
rowscale
=
rowscale1
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
,
)
)
if
mixer_kwargs
is
None
:
if
mixer_kwargs
is
None
:
mixer_kwargs
=
{}
mixer_kwargs
=
{}
if
mixer_subset
is
not
None
:
if
mixer_subset
is
not
None
:
mixer_kwargs
[
'
mixer_subset
'
]
=
mixer_subset
mixer_kwargs
[
"
mixer_subset
"
]
=
mixer_subset
hidden_states
=
self
.
mixer
(
hidden_states
,
**
mixer_kwargs
)
hidden_states
=
self
.
mixer
(
hidden_states
,
**
mixer_kwargs
)
if
mixer_subset
is
not
None
:
if
mixer_subset
is
not
None
:
residual
=
residual
[:,
mixer_subset
]
residual
=
residual
[:,
mixer_subset
]
...
@@ -159,14 +189,23 @@ class Block(nn.Module):
...
@@ -159,14 +189,23 @@ class Block(nn.Module):
if
self
.
drop_path2
.
p
==
0
or
not
self
.
training
:
if
self
.
drop_path2
.
p
==
0
or
not
self
.
training
:
rowscale2
=
None
rowscale2
=
None
else
:
else
:
rowscale2
=
self
.
drop_path2
(
torch
.
ones
(
rowscale2
=
self
.
drop_path2
(
hidden_states
.
shape
[:
-
1
],
device
=
hidden_states
.
device
,
torch
.
ones
(
dtype
=
hidden_states
.
dtype
)
hidden_states
.
shape
[:
-
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
)
)
hidden_states
,
residual
=
fused_add_norm_fn
(
hidden_states
,
residual
=
fused_add_norm_fn
(
hidden_states
,
residual
,
self
.
norm2
.
weight
,
self
.
norm2
.
bias
,
hidden_states
,
self
.
dropout2
.
p
if
self
.
training
else
0.0
,
self
.
norm2
.
eps
,
residual
,
rowscale
=
rowscale2
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
self
.
norm2
.
weight
,
self
.
norm2
.
bias
,
self
.
dropout2
.
p
if
self
.
training
else
0.0
,
self
.
norm2
.
eps
,
rowscale
=
rowscale2
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
,
)
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
return
hidden_states
,
residual
...
@@ -178,38 +217,58 @@ class Block(nn.Module):
...
@@ -178,38 +217,58 @@ class Block(nn.Module):
if
self
.
return_residual
:
# mixer out is actually a pair here
if
self
.
return_residual
:
# mixer out is actually a pair here
mixer_out
,
hidden_states
=
mixer_out
mixer_out
,
hidden_states
=
mixer_out
if
not
self
.
fused_dropout_add_ln
:
if
not
self
.
fused_dropout_add_ln
:
hidden_states
=
self
.
norm1
((
self
.
drop_path1
(
self
.
dropout1
(
mixer_out
))
hidden_states
=
self
.
norm1
(
+
hidden_states
).
to
(
dtype
=
self
.
norm1
.
weight
.
dtype
))
(
self
.
drop_path1
(
self
.
dropout1
(
mixer_out
))
+
hidden_states
).
to
(
dtype
=
self
.
norm1
.
weight
.
dtype
)
)
else
:
else
:
if
self
.
drop_path1
.
p
==
0
or
not
self
.
training
:
if
self
.
drop_path1
.
p
==
0
or
not
self
.
training
:
rowscale1
=
None
rowscale1
=
None
else
:
else
:
rowscale1
=
self
.
drop_path1
(
torch
.
ones
(
rowscale1
=
self
.
drop_path1
(
mixer_out
.
shape
[:
-
1
],
device
=
mixer_out
.
device
,
dtype
=
mixer_out
.
dtype
)
torch
.
ones
(
mixer_out
.
shape
[:
-
1
],
device
=
mixer_out
.
device
,
dtype
=
mixer_out
.
dtype
)
)
)
hidden_states
=
fused_add_norm_fn
(
hidden_states
=
fused_add_norm_fn
(
mixer_out
,
hidden_states
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
mixer_out
,
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
self
.
norm1
.
eps
,
hidden_states
,
rowscale
=
rowscale1
,
prenorm
=
False
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
self
.
norm1
.
eps
,
rowscale
=
rowscale1
,
prenorm
=
False
,
)
)
if
not
isinstance
(
self
.
mlp
,
nn
.
Identity
):
if
not
isinstance
(
self
.
mlp
,
nn
.
Identity
):
mlp_out
=
self
.
mlp
(
hidden_states
)
mlp_out
=
self
.
mlp
(
hidden_states
)
if
self
.
return_residual
:
# mlp out is actually a pair here
if
self
.
return_residual
:
# mlp out is actually a pair here
mlp_out
,
hidden_states
=
mlp_out
mlp_out
,
hidden_states
=
mlp_out
if
not
self
.
fused_dropout_add_ln
:
if
not
self
.
fused_dropout_add_ln
:
hidden_states
=
self
.
norm2
((
self
.
drop_path2
(
self
.
dropout2
(
mlp_out
))
hidden_states
=
self
.
norm2
(
+
hidden_states
).
to
(
dtype
=
self
.
norm2
.
weight
.
dtype
))
(
self
.
drop_path2
(
self
.
dropout2
(
mlp_out
))
+
hidden_states
).
to
(
dtype
=
self
.
norm2
.
weight
.
dtype
)
)
else
:
else
:
if
self
.
drop_path2
.
p
==
0
or
not
self
.
training
:
if
self
.
drop_path2
.
p
==
0
or
not
self
.
training
:
rowscale2
=
None
rowscale2
=
None
else
:
else
:
rowscale2
=
self
.
drop_path2
(
torch
.
ones
(
rowscale2
=
self
.
drop_path2
(
mlp_out
.
shape
[:
-
1
],
device
=
mlp_out
.
device
,
dtype
=
mlp_out
.
dtype
)
torch
.
ones
(
mlp_out
.
shape
[:
-
1
],
device
=
mlp_out
.
device
,
dtype
=
mlp_out
.
dtype
)
)
)
hidden_states
=
fused_add_norm_fn
(
hidden_states
=
fused_add_norm_fn
(
mlp_out
,
hidden_states
,
self
.
norm2
.
weight
,
self
.
norm2
.
bias
,
mlp_out
,
self
.
dropout2
.
p
if
self
.
training
else
0.0
,
self
.
norm2
.
eps
,
hidden_states
,
rowscale
=
rowscale2
,
prenorm
=
False
self
.
norm2
.
weight
,
self
.
norm2
.
bias
,
self
.
dropout2
.
p
if
self
.
training
else
0.0
,
self
.
norm2
.
eps
,
rowscale
=
rowscale2
,
prenorm
=
False
,
)
)
return
hidden_states
return
hidden_states
...
@@ -219,10 +278,21 @@ class ParallelBlock(nn.Module):
...
@@ -219,10 +278,21 @@ class ParallelBlock(nn.Module):
and PaLM.
and PaLM.
"""
"""
def
__init__
(
self
,
dim
,
mixer_cls
=
None
,
mlp_cls
=
None
,
norm_cls
=
nn
.
LayerNorm
,
def
__init__
(
dropout_cls
=
nn
.
Dropout
,
resid_dropout1
=
0.
,
resid_dropout2
=
0.
,
self
,
tied_norm
=
False
,
fused_dropout_add_ln
=
False
,
residual_in_fp32
=
False
,
dim
,
sequence_parallel
=
False
,
mark_shared_params
=
False
):
mixer_cls
=
None
,
mlp_cls
=
None
,
norm_cls
=
nn
.
LayerNorm
,
dropout_cls
=
nn
.
Dropout
,
resid_dropout1
=
0.0
,
resid_dropout2
=
0.0
,
tied_norm
=
False
,
fused_dropout_add_ln
=
False
,
residual_in_fp32
=
False
,
sequence_parallel
=
False
,
mark_shared_params
=
False
,
):
"""
"""
This Block has a slightly different structure compared to a regular
This Block has a slightly different structure compared to a regular
prenorm Transformer block.
prenorm Transformer block.
...
@@ -250,10 +320,15 @@ class ParallelBlock(nn.Module):
...
@@ -250,10 +320,15 @@ class ParallelBlock(nn.Module):
self
.
norm2
=
norm_cls
(
dim
)
self
.
norm2
=
norm_cls
(
dim
)
if
self
.
fused_dropout_add_ln
:
if
self
.
fused_dropout_add_ln
:
assert
dropout_add_layer_norm_parallel_residual
is
not
None
,
'dropout_layer_norm is not installed'
assert
(
assert
dropout_add_rms_norm_parallel_residual
is
not
None
,
'dropout_layer_norm is not installed'
dropout_add_layer_norm_parallel_residual
is
not
None
assert
(
isinstance
(
self
.
norm1
,
(
nn
.
LayerNorm
,
RMSNorm
))
),
"dropout_layer_norm is not installed"
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
))
assert
(
dropout_add_rms_norm_parallel_residual
is
not
None
),
"dropout_layer_norm is not installed"
assert
isinstance
(
self
.
norm1
,
(
nn
.
LayerNorm
,
RMSNorm
))
and
isinstance
(
self
.
dropout1
,
nn
.
Dropout
)
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0,
# then the input to each worker in the tensor parallel group will be different.
# then the input to each worker in the tensor parallel group will be different.
...
@@ -265,22 +340,27 @@ class ParallelBlock(nn.Module):
...
@@ -265,22 +340,27 @@ class ParallelBlock(nn.Module):
if
sequence_parallel
:
if
sequence_parallel
:
for
p
in
self
.
norm1
.
parameters
():
for
p
in
self
.
norm1
.
parameters
():
p
.
_sequence_parallel
=
True
p
.
_sequence_parallel
=
True
if
hasattr
(
self
,
'
norm2
'
):
if
hasattr
(
self
,
"
norm2
"
):
for
p
in
self
.
norm2
.
parameters
():
for
p
in
self
.
norm2
.
parameters
():
p
.
_sequence_parallel
=
True
p
.
_sequence_parallel
=
True
# Mark the norm parameters as "shared_params" so that we sync their values at init.
# Mark the norm parameters as "shared_params" so that we sync their values at init.
if
mark_shared_params
:
if
mark_shared_params
:
for
p
in
self
.
norm1
.
parameters
():
for
p
in
self
.
norm1
.
parameters
():
p
.
_shared_params
=
True
p
.
_shared_params
=
True
if
hasattr
(
self
,
'
norm2
'
):
if
hasattr
(
self
,
"
norm2
"
):
for
p
in
self
.
norm2
.
parameters
():
for
p
in
self
.
norm2
.
parameters
():
p
.
_shared_params
=
True
p
.
_shared_params
=
True
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
def
allocate_inference_cache
(
self
,
batch_size
,
max_seqlen
,
dtype
=
None
,
**
kwargs
):
return
self
.
mixer
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
=
dtype
,
**
kwargs
)
return
self
.
mixer
.
allocate_inference_cache
(
batch_size
,
max_seqlen
,
dtype
=
dtype
,
**
kwargs
)
def
forward
(
self
,
hidden_states1
:
Tensor
,
hidden_states2
:
Optional
[
Tensor
]
=
None
,
def
forward
(
residual
:
Optional
[
Tensor
]
=
None
,
mixer_kwargs
=
None
):
self
,
hidden_states1
:
Tensor
,
hidden_states2
:
Optional
[
Tensor
]
=
None
,
residual
:
Optional
[
Tensor
]
=
None
,
mixer_kwargs
=
None
,
):
r
"""Pass the input through the encoder layer.
r
"""Pass the input through the encoder layer.
Args:
Args:
...
@@ -290,30 +370,47 @@ class ParallelBlock(nn.Module):
...
@@ -290,30 +370,47 @@ class ParallelBlock(nn.Module):
"""
"""
# TODO: Ideally we should only do the allgather / allreduce once for
# TODO: Ideally we should only do the allgather / allreduce once for
# the Linear to MLP & Attention
# the Linear to MLP & Attention
fused_add_norm_fn
=
(
dropout_add_rms_norm_parallel_residual
fused_add_norm_fn
=
(
dropout_add_rms_norm_parallel_residual
if
isinstance
(
self
.
norm1
,
RMSNorm
)
if
isinstance
(
self
.
norm1
,
RMSNorm
)
else
dropout_add_layer_norm_parallel_residual
)
else
dropout_add_layer_norm_parallel_residual
)
if
not
self
.
fused_dropout_add_ln
:
if
not
self
.
fused_dropout_add_ln
:
dropped1
=
self
.
dropout1
(
hidden_states1
)
dropped1
=
self
.
dropout1
(
hidden_states1
)
# For the very 1st block, we only want 1 dropout, not two different dropouts
# For the very 1st block, we only want 1 dropout, not two different dropouts
if
hidden_states2
is
not
None
:
if
hidden_states2
is
not
None
:
dropped2
=
self
.
dropout2
(
hidden_states2
)
dropped2
=
self
.
dropout2
(
hidden_states2
)
residual
=
((
residual
+
dropped1
+
dropped2
)
residual
=
(
if
residual
is
not
None
else
dropped1
+
dropped2
)
(
residual
+
dropped1
+
dropped2
)
if
residual
is
not
None
else
dropped1
+
dropped2
)
else
:
else
:
residual
=
(
residual
+
dropped1
)
if
residual
is
not
None
else
dropped1
residual
=
(
residual
+
dropped1
)
if
residual
is
not
None
else
dropped1
hidden_states1
=
self
.
norm1
(
residual
.
to
(
dtype
=
self
.
norm1
.
weight
.
dtype
))
hidden_states1
=
self
.
norm1
(
residual
.
to
(
dtype
=
self
.
norm1
.
weight
.
dtype
))
hidden_states2
=
(
self
.
norm2
(
residual
.
to
(
dtype
=
self
.
norm2
.
weight
.
dtype
))
hidden_states2
=
(
if
not
self
.
tied_norm
else
hidden_states1
)
self
.
norm2
(
residual
.
to
(
dtype
=
self
.
norm2
.
weight
.
dtype
))
if
not
self
.
tied_norm
else
hidden_states1
)
if
self
.
residual_in_fp32
:
if
self
.
residual_in_fp32
:
residual
=
residual
.
to
(
torch
.
float32
)
residual
=
residual
.
to
(
torch
.
float32
)
else
:
else
:
weight2
,
bias2
=
((
self
.
norm2
.
weight
,
self
.
norm2
.
bias
)
weight2
,
bias2
=
(
if
not
self
.
tied_norm
else
(
None
,
None
))
(
self
.
norm2
.
weight
,
self
.
norm2
.
bias
)
if
not
self
.
tied_norm
else
(
None
,
None
)
)
hidden_states1
,
hidden_states2
,
residual
=
fused_add_norm_fn
(
hidden_states1
,
hidden_states2
,
residual
=
fused_add_norm_fn
(
hidden_states1
,
hidden_states2
,
residual
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
hidden_states1
,
weight2
,
bias2
,
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
self
.
norm1
.
eps
,
hidden_states2
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
residual
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
weight2
,
bias2
,
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
self
.
norm1
.
eps
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
,
)
)
if
self
.
tied_norm
:
if
self
.
tied_norm
:
hidden_states2
=
hidden_states1
hidden_states2
=
hidden_states1
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment