Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
f1a73d07
Commit
f1a73d07
authored
Aug 18, 2023
by
Tri Dao
Browse files
Run isort and black on python files
parent
cbb4cf5f
Changes
34
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2219 additions
and
1011 deletions
+2219
-1011
flash_attn/__init__.py
flash_attn/__init__.py
+8
-6
flash_attn/bert_padding.py
flash_attn/bert_padding.py
+23
-18
flash_attn/flash_attn_interface.py
flash_attn/flash_attn_interface.py
+336
-88
flash_attn/flash_attn_triton.py
flash_attn/flash_attn_triton.py
+533
-205
flash_attn/flash_attn_triton_og.py
flash_attn/flash_attn_triton_og.py
+134
-45
flash_attn/flash_blocksparse_attention.py
flash_attn/flash_blocksparse_attention.py
+105
-44
flash_attn/flash_blocksparse_attn_interface.py
flash_attn/flash_blocksparse_attn_interface.py
+84
-26
flash_attn/fused_softmax.py
flash_attn/fused_softmax.py
+10
-14
flash_attn/layers/patch_embed.py
flash_attn/layers/patch_embed.py
+30
-19
flash_attn/layers/rotary.py
flash_attn/layers/rotary.py
+171
-79
flash_attn/losses/cross_entropy.py
flash_attn/losses/cross_entropy.py
+46
-27
flash_attn/models/bert.py
flash_attn/models/bert.py
+244
-155
flash_attn/models/falcon.py
flash_attn/models/falcon.py
+58
-37
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+20
-10
flash_attn/models/gpt_neox.py
flash_attn/models/gpt_neox.py
+52
-32
flash_attn/models/gptj.py
flash_attn/models/gptj.py
+36
-25
flash_attn/models/llama.py
flash_attn/models/llama.py
+106
-70
flash_attn/models/opt.py
flash_attn/models/opt.py
+51
-37
flash_attn/models/vit.py
flash_attn/models/vit.py
+4
-3
flash_attn/modules/block.py
flash_attn/modules/block.py
+168
-71
No files found.
flash_attn/__init__.py
View file @
f1a73d07
__version__
=
"2.0.8"
from
flash_attn.flash_attn_interface
import
flash_attn_func
from
flash_attn.flash_attn_interface
import
flash_attn_kvpacked_func
from
flash_attn.flash_attn_interface
import
flash_attn_qkvpacked_func
from
flash_attn.flash_attn_interface
import
flash_attn_varlen_qkvpacked_func
from
flash_attn.flash_attn_interface
import
flash_attn_varlen_kvpacked_func
from
flash_attn.flash_attn_interface
import
flash_attn_varlen_func
from
flash_attn.flash_attn_interface
import
(
flash_attn_func
,
flash_attn_kvpacked_func
,
flash_attn_qkvpacked_func
,
flash_attn_varlen_func
,
flash_attn_varlen_kvpacked_func
,
flash_attn_varlen_qkvpacked_func
,
)
flash_attn/bert_padding.py
View file @
f1a73d07
...
...
@@ -2,12 +2,10 @@
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
class
IndexFirstAxis
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
indices
):
ctx
.
save_for_backward
(
indices
)
...
...
@@ -16,20 +14,24 @@ class IndexFirstAxis(torch.autograd.Function):
second_dim
=
other_shape
.
numel
()
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
# return input[indices]
return
torch
.
gather
(
rearrange
(
input
,
'b ... -> b (...)'
),
0
,
repeat
(
indices
,
'z -> z d'
,
d
=
second_dim
)).
reshape
(
-
1
,
*
other_shape
)
return
torch
.
gather
(
rearrange
(
input
,
"b ... -> b (...)"
),
0
,
repeat
(
indices
,
"z -> z d"
,
d
=
second_dim
)
).
reshape
(
-
1
,
*
other_shape
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
indices
,
=
ctx
.
saved_tensors
(
indices
,
)
=
ctx
.
saved_tensors
assert
grad_output
.
ndim
>=
2
other_shape
=
grad_output
.
shape
[
1
:]
grad_output
=
rearrange
(
grad_output
,
'b ... -> b (...)'
)
grad_input
=
torch
.
zeros
([
ctx
.
first_axis_dim
,
grad_output
.
shape
[
1
]],
device
=
grad_output
.
device
,
dtype
=
grad_output
.
dtype
)
grad_output
=
rearrange
(
grad_output
,
"b ... -> b (...)"
)
grad_input
=
torch
.
zeros
(
[
ctx
.
first_axis_dim
,
grad_output
.
shape
[
1
]],
device
=
grad_output
.
device
,
dtype
=
grad_output
.
dtype
,
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
# grad_input[indices] = grad_output
grad_input
.
scatter_
(
0
,
repeat
(
indices
,
'
z -> z d
'
,
d
=
grad_output
.
shape
[
1
]),
grad_output
)
grad_input
.
scatter_
(
0
,
repeat
(
indices
,
"
z -> z d
"
,
d
=
grad_output
.
shape
[
1
]),
grad_output
)
return
grad_input
.
reshape
(
ctx
.
first_axis_dim
,
*
other_shape
),
None
...
...
@@ -37,14 +39,14 @@ index_first_axis = IndexFirstAxis.apply
class
IndexPutFirstAxis
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
values
,
indices
,
first_axis_dim
):
ctx
.
save_for_backward
(
indices
)
assert
indices
.
ndim
==
1
assert
values
.
ndim
>=
2
output
=
torch
.
zeros
(
first_axis_dim
,
*
values
.
shape
[
1
:],
device
=
values
.
device
,
dtype
=
values
.
dtype
)
output
=
torch
.
zeros
(
first_axis_dim
,
*
values
.
shape
[
1
:],
device
=
values
.
device
,
dtype
=
values
.
dtype
)
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
output
[
indices
]
=
values
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
...
...
@@ -52,7 +54,7 @@ class IndexPutFirstAxis(torch.autograd.Function):
@
staticmethod
def
backward
(
ctx
,
grad_output
):
indices
,
=
ctx
.
saved_tensors
(
indices
,
)
=
ctx
.
saved_tensors
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
grad_values
=
grad_output
[
indices
]
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
...
...
@@ -63,7 +65,6 @@ index_put_first_axis = IndexPutFirstAxis.apply
class
IndexFirstAxisResidual
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
input
,
indices
):
ctx
.
save_for_backward
(
indices
)
...
...
@@ -79,7 +80,7 @@ class IndexFirstAxisResidual(torch.autograd.Function):
@
staticmethod
def
backward
(
ctx
,
grad_output
,
grad_residual
):
indices
,
=
ctx
.
saved_tensors
(
indices
,
)
=
ctx
.
saved_tensors
assert
grad_output
.
ndim
>=
2
other_shape
=
grad_output
.
shape
[
1
:]
assert
grad_residual
.
shape
[
1
:]
==
other_shape
...
...
@@ -113,8 +114,12 @@ def unpad_input(hidden_states, attention_mask):
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
# so we write custom forward and backward to make it a bit faster.
return
(
index_first_axis
(
rearrange
(
hidden_states
,
'b s ... -> (b s) ...'
),
indices
),
indices
,
cu_seqlens
,
max_seqlen_in_batch
)
return
(
index_first_axis
(
rearrange
(
hidden_states
,
"b s ... -> (b s) ..."
),
indices
),
indices
,
cu_seqlens
,
max_seqlen_in_batch
,
)
def
pad_input
(
hidden_states
,
indices
,
batch
,
seqlen
):
...
...
@@ -129,4 +134,4 @@ def pad_input(hidden_states, indices, batch, seqlen):
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
# output[indices] = hidden_states
output
=
index_put_first_axis
(
hidden_states
,
indices
,
batch
*
seqlen
)
return
rearrange
(
output
,
'
(b s) ... -> b s ...
'
,
b
=
batch
)
return
rearrange
(
output
,
"
(b s) ... -> b s ...
"
,
b
=
batch
)
flash_attn/flash_attn_interface.py
View file @
f1a73d07
This diff is collapsed.
Click to expand it.
flash_attn/flash_attn_triton.py
View file @
f1a73d07
This diff is collapsed.
Click to expand it.
flash_attn/flash_attn_triton_og.py
View file @
f1a73d07
...
...
@@ -11,22 +11,41 @@ This is a Triton implementation of the Flash Attention algorithm
import
pytest
import
torch
import
triton
import
triton.language
as
tl
@
triton
.
jit
def
_fwd_kernel
(
Q
,
K
,
V
,
sm_scale
,
TMP
,
L
,
M
,
# NOTE: TMP is a scratchpad buffer to workaround a compiler bug
Q
,
K
,
V
,
sm_scale
,
TMP
,
L
,
M
,
# NOTE: TMP is a scratchpad buffer to workaround a compiler bug
Out
,
stride_qz
,
stride_qh
,
stride_qm
,
stride_qk
,
stride_kz
,
stride_kh
,
stride_kn
,
stride_kk
,
stride_vz
,
stride_vh
,
stride_vk
,
stride_vn
,
stride_oz
,
stride_oh
,
stride_om
,
stride_on
,
Z
,
H
,
N_CTX
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
stride_qz
,
stride_qh
,
stride_qm
,
stride_qk
,
stride_kz
,
stride_kh
,
stride_kn
,
stride_kk
,
stride_vz
,
stride_vh
,
stride_vk
,
stride_vn
,
stride_oz
,
stride_oh
,
stride_om
,
stride_on
,
Z
,
H
,
N_CTX
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
start_m
=
tl
.
program_id
(
0
)
...
...
@@ -100,9 +119,13 @@ def _fwd_kernel(
@
triton
.
jit
def
_bwd_preprocess
(
Out
,
DO
,
L
,
NewDO
,
Delta
,
BLOCK_M
:
tl
.
constexpr
,
D_HEAD
:
tl
.
constexpr
,
Out
,
DO
,
L
,
NewDO
,
Delta
,
BLOCK_M
:
tl
.
constexpr
,
D_HEAD
:
tl
.
constexpr
,
):
off_m
=
tl
.
program_id
(
0
)
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
off_n
=
tl
.
arange
(
0
,
D_HEAD
)
...
...
@@ -120,16 +143,36 @@ def _bwd_preprocess(
@
triton
.
jit
def
_bwd_kernel
(
Q
,
K
,
V
,
sm_scale
,
Out
,
DO
,
DQ
,
DK
,
DV
,
L
,
M
,
Q
,
K
,
V
,
sm_scale
,
Out
,
DO
,
DQ
,
DK
,
DV
,
L
,
M
,
D
,
stride_qz
,
stride_qh
,
stride_qm
,
stride_qk
,
stride_kz
,
stride_kh
,
stride_kn
,
stride_kk
,
stride_vz
,
stride_vh
,
stride_vk
,
stride_vn
,
Z
,
H
,
N_CTX
,
stride_qz
,
stride_qh
,
stride_qm
,
stride_qk
,
stride_kz
,
stride_kh
,
stride_kn
,
stride_kk
,
stride_vz
,
stride_vh
,
stride_vk
,
stride_vn
,
Z
,
H
,
N_CTX
,
num_block
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
off_hz
=
tl
.
program_id
(
0
)
...
...
@@ -203,7 +246,6 @@ def _bwd_kernel(
class
_attention
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
q
,
k
,
v
,
sm_scale
):
BLOCK
=
128
...
...
@@ -213,22 +255,45 @@ class _attention(torch.autograd.Function):
assert
Lk
in
{
16
,
32
,
64
,
128
}
o
=
torch
.
empty_like
(
q
)
grid
=
(
triton
.
cdiv
(
q
.
shape
[
2
],
BLOCK
),
q
.
shape
[
0
]
*
q
.
shape
[
1
])
tmp
=
torch
.
empty
((
q
.
shape
[
0
]
*
q
.
shape
[
1
],
q
.
shape
[
2
]),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
tmp
=
torch
.
empty
(
(
q
.
shape
[
0
]
*
q
.
shape
[
1
],
q
.
shape
[
2
]),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
L
=
torch
.
empty
((
q
.
shape
[
0
]
*
q
.
shape
[
1
],
q
.
shape
[
2
]),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
m
=
torch
.
empty
((
q
.
shape
[
0
]
*
q
.
shape
[
1
],
q
.
shape
[
2
]),
device
=
q
.
device
,
dtype
=
torch
.
float32
)
num_warps
=
4
if
Lk
<=
64
else
8
_fwd_kernel
[
grid
](
q
,
k
,
v
,
sm_scale
,
tmp
,
L
,
m
,
q
,
k
,
v
,
sm_scale
,
tmp
,
L
,
m
,
o
,
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
2
),
q
.
stride
(
3
),
k
.
stride
(
0
),
k
.
stride
(
1
),
k
.
stride
(
2
),
k
.
stride
(
3
),
v
.
stride
(
0
),
v
.
stride
(
1
),
v
.
stride
(
2
),
v
.
stride
(
3
),
o
.
stride
(
0
),
o
.
stride
(
1
),
o
.
stride
(
2
),
o
.
stride
(
3
),
q
.
shape
[
0
],
q
.
shape
[
1
],
q
.
shape
[
2
],
BLOCK_M
=
BLOCK
,
BLOCK_N
=
BLOCK
,
BLOCK_DMODEL
=
Lk
,
num_warps
=
num_warps
,
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
2
),
q
.
stride
(
3
),
k
.
stride
(
0
),
k
.
stride
(
1
),
k
.
stride
(
2
),
k
.
stride
(
3
),
v
.
stride
(
0
),
v
.
stride
(
1
),
v
.
stride
(
2
),
v
.
stride
(
3
),
o
.
stride
(
0
),
o
.
stride
(
1
),
o
.
stride
(
2
),
o
.
stride
(
3
),
q
.
shape
[
0
],
q
.
shape
[
1
],
q
.
shape
[
2
],
BLOCK_M
=
BLOCK
,
BLOCK_N
=
BLOCK
,
BLOCK_DMODEL
=
Lk
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
ctx
.
save_for_backward
(
q
,
k
,
v
,
o
,
L
,
m
)
...
...
@@ -247,27 +312,51 @@ class _attention(torch.autograd.Function):
dv
=
torch
.
empty_like
(
v
)
do_scaled
=
torch
.
empty_like
(
do
)
delta
=
torch
.
empty_like
(
l
)
_bwd_preprocess
[(
ctx
.
grid
[
0
]
*
ctx
.
grid
[
1
],
)](
o
,
do
,
l
,
do_scaled
,
delta
,
BLOCK_M
=
ctx
.
BLOCK
,
D_HEAD
=
ctx
.
BLOCK_DMODEL
,
_bwd_preprocess
[(
ctx
.
grid
[
0
]
*
ctx
.
grid
[
1
],)](
o
,
do
,
l
,
do_scaled
,
delta
,
BLOCK_M
=
ctx
.
BLOCK
,
D_HEAD
=
ctx
.
BLOCK_DMODEL
,
)
# NOTE: kernel currently buggy for other values of `num_warps`
num_warps
=
8
_bwd_kernel
[(
ctx
.
grid
[
1
],)](
q
,
k
,
v
,
ctx
.
sm_scale
,
o
,
do_scaled
,
dq
,
dk
,
dv
,
l
,
m
,
q
,
k
,
v
,
ctx
.
sm_scale
,
o
,
do_scaled
,
dq
,
dk
,
dv
,
l
,
m
,
delta
,
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
2
),
q
.
stride
(
3
),
k
.
stride
(
0
),
k
.
stride
(
1
),
k
.
stride
(
2
),
k
.
stride
(
3
),
v
.
stride
(
0
),
v
.
stride
(
1
),
v
.
stride
(
2
),
v
.
stride
(
3
),
q
.
shape
[
0
],
q
.
shape
[
1
],
q
.
shape
[
2
],
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
2
),
q
.
stride
(
3
),
k
.
stride
(
0
),
k
.
stride
(
1
),
k
.
stride
(
2
),
k
.
stride
(
3
),
v
.
stride
(
0
),
v
.
stride
(
1
),
v
.
stride
(
2
),
v
.
stride
(
3
),
q
.
shape
[
0
],
q
.
shape
[
1
],
q
.
shape
[
2
],
ctx
.
grid
[
0
],
BLOCK_M
=
ctx
.
BLOCK
,
BLOCK_N
=
ctx
.
BLOCK
,
BLOCK_DMODEL
=
ctx
.
BLOCK_DMODEL
,
num_warps
=
num_warps
,
BLOCK_M
=
ctx
.
BLOCK
,
BLOCK_N
=
ctx
.
BLOCK
,
BLOCK_DMODEL
=
ctx
.
BLOCK_DMODEL
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
return
dq
.
to
(
q
.
dtype
),
dk
,
dv
,
None
...
...
flash_attn/flash_blocksparse_attention.py
View file @
f1a73d07
import
math
import
hydra
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
import
hydra
from
flash_attn.flash_blocksparse_attn_interface
import
flash_blocksparse_attn_func
from
flash_attn.
flash_blocksparse_attn_
interface
import
convert_blockmask
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
,
index_first_axis
from
flash_attn.bert_padding
import
index_first_axis
,
pad_input
,
unpad_input
from
flash_attn.flash_blocksparse_attn_interface
import
(
convert_blockmask
,
flash_blocksparse_attn_
func
,
)
class
FlashBlocksparseAttention
(
nn
.
Module
):
...
...
@@ -21,8 +22,16 @@ class FlashBlocksparseAttention(nn.Module):
attention_dropout: The dropout rate to apply to the attention
(default: 0.1)
"""
def
__init__
(
self
,
sparsity_config
,
softmax_temp
=
None
,
attention_dropout
=
0.0
,
max_seq_length
=
2048
,
device
=
None
,
dtype
=
None
):
def
__init__
(
self
,
sparsity_config
,
softmax_temp
=
None
,
attention_dropout
=
0.0
,
max_seq_length
=
2048
,
device
=
None
,
dtype
=
None
,
):
super
().
__init__
()
self
.
sparsity_config
=
hydra
.
utils
.
instantiate
(
sparsity_config
)
self
.
softmax_temp
=
softmax_temp
...
...
@@ -36,8 +45,17 @@ class FlashBlocksparseAttention(nn.Module):
self
.
register_buffer
(
"blockmask_converted"
,
blockmask_converted
)
# logger.info(f'Attention class {self.__class__}: saving={self.layout.float().mean()}')
def
forward
(
self
,
qkv
,
attn_mask
=
None
,
key_padding_mask
=
None
,
causal
=
False
,
cu_seqlens
=
None
,
max_s
=
None
,
need_weights
=
False
,
convert_mask
=
True
):
def
forward
(
self
,
qkv
,
attn_mask
=
None
,
key_padding_mask
=
None
,
causal
=
False
,
cu_seqlens
=
None
,
max_s
=
None
,
need_weights
=
False
,
convert_mask
=
True
,
):
"""Implements the multihead softmax attention.
Arguments
---------
...
...
@@ -57,47 +75,76 @@ class FlashBlocksparseAttention(nn.Module):
seqlen
=
qkv
.
shape
[
1
]
# Convert mask to take a subset
seqlen_rounded
=
((
seqlen
+
256
-
1
)
//
256
)
*
256
assert
seqlen_rounded
//
16
<=
self
.
layout
.
shape
[
0
],
seqlen_rounded
//
256
<=
self
.
layout
.
shape
[
1
]
blockmask
=
self
.
layout
[:
seqlen_rounded
//
16
,
:
seqlen_rounded
//
256
]
assert
seqlen_rounded
//
16
<=
self
.
layout
.
shape
[
0
],
(
seqlen_rounded
//
256
<=
self
.
layout
.
shape
[
1
]
)
blockmask
=
self
.
layout
[:
seqlen_rounded
//
16
,
:
seqlen_rounded
//
256
]
if
key_padding_mask
is
None
:
qkv
=
rearrange
(
qkv
,
'
b s ... -> (b s) ...
'
)
qkv
=
rearrange
(
qkv
,
"
b s ... -> (b s) ...
"
)
max_s
=
seqlen
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen
,
step
=
seqlen
,
dtype
=
torch
.
int32
,
device
=
qkv
.
device
)
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen
,
step
=
seqlen
,
dtype
=
torch
.
int32
,
device
=
qkv
.
device
)
output
=
flash_blocksparse_attn_func
(
qkv
,
cu_seqlens
,
blockmask
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
qkv
,
cu_seqlens
,
blockmask
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
,
)
output
=
rearrange
(
output
,
'
(b s) ... -> b s ...
'
,
b
=
batch_size
)
output
=
rearrange
(
output
,
"
(b s) ... -> b s ...
"
,
b
=
batch_size
)
else
:
key_padding_mask_bool
=
key_padding_mask
.
bool_matrix
nheads
=
qkv
.
shape
[
-
2
]
x
=
rearrange
(
qkv
,
'
b s three h d -> b s (three h d)
'
)
x
=
rearrange
(
qkv
,
"
b s three h d -> b s (three h d)
"
)
x_unpad
,
indices
,
cu_seqlens
,
max_s
=
unpad_input
(
x
,
key_padding_mask_bool
)
x_unpad
=
rearrange
(
x_unpad
,
'
nnz (three h d) -> nnz three h d
'
,
three
=
3
,
h
=
nheads
)
x_unpad
=
rearrange
(
x_unpad
,
"
nnz (three h d) -> nnz three h d
"
,
three
=
3
,
h
=
nheads
)
output_unpad
=
flash_blocksparse_attn_func
(
x_unpad
,
cu_seqlens
,
blockmask
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
x_unpad
,
cu_seqlens
,
blockmask
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
,
)
output
=
rearrange
(
pad_input
(
rearrange
(
output_unpad
,
"nnz h d -> nnz (h d)"
),
indices
,
batch_size
,
seqlen
),
"b s (h d) -> b s h d"
,
h
=
nheads
,
)
output
=
rearrange
(
pad_input
(
rearrange
(
output_unpad
,
'nnz h d -> nnz (h d)'
),
indices
,
batch_size
,
seqlen
),
'b s (h d) -> b s h d'
,
h
=
nheads
)
else
:
assert
max_s
is
not
None
seqlen
=
max_s
# Convert mask to take a subset
seqlen_rounded
=
((
seqlen
+
256
-
1
)
//
256
)
*
256
assert
seqlen_rounded
//
16
<=
self
.
layout
.
shape
[
0
],
seqlen_rounded
//
256
<=
self
.
layout
.
shape
[
1
]
blockmask
=
self
.
layout
[:
seqlen_rounded
//
16
,
:
seqlen_rounded
//
256
]
assert
seqlen_rounded
//
16
<=
self
.
layout
.
shape
[
0
],
(
seqlen_rounded
//
256
<=
self
.
layout
.
shape
[
1
]
)
blockmask
=
self
.
layout
[:
seqlen_rounded
//
16
,
:
seqlen_rounded
//
256
]
if
convert_mask
:
output
=
flash_blocksparse_attn_func
(
qkv
,
cu_seqlens
,
blockmask
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
qkv
,
cu_seqlens
,
blockmask
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
,
)
else
:
output
=
flash_blocksparse_attn_func
(
qkv
,
cu_seqlens
,
self
.
blockmask_converted
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
,
qkv
,
cu_seqlens
,
self
.
blockmask_converted
,
self
.
dropout_p
if
self
.
training
else
0.0
,
max_s
,
softmax_scale
=
self
.
softmax_temp
,
causal
=
causal
,
convert_mask
=
False
,
)
...
...
@@ -105,12 +152,22 @@ class FlashBlocksparseAttention(nn.Module):
class
FlashBlocksparseMHA
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
num_heads
,
sparsity_config
,
bias
=
True
,
batch_first
=
True
,
attention_dropout
=
0.0
,
causal
=
False
,
max_seq_length
=
2048
,
device
=
None
,
dtype
=
None
,
**
kwargs
)
->
None
:
def
__init__
(
self
,
embed_dim
,
num_heads
,
sparsity_config
,
bias
=
True
,
batch_first
=
True
,
attention_dropout
=
0.0
,
causal
=
False
,
max_seq_length
=
2048
,
device
=
None
,
dtype
=
None
,
**
kwargs
,
)
->
None
:
assert
batch_first
factory_kwargs
=
{
'
device
'
:
device
,
'
dtype
'
:
dtype
}
factory_kwargs
=
{
"
device
"
:
device
,
"
dtype
"
:
dtype
}
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
causal
=
causal
...
...
@@ -122,15 +179,19 @@ class FlashBlocksparseMHA(nn.Module):
self
.
Wqkv
=
nn
.
Linear
(
embed_dim
,
3
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
inner_attn
=
FlashBlocksparseAttention
(
sparsity_config
,
attention_dropout
=
attention_dropout
,
max_seq_length
=
max_seq_length
,
**
factory_kwargs
sparsity_config
,
attention_dropout
=
attention_dropout
,
max_seq_length
=
max_seq_length
,
**
factory_kwargs
,
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
def
forward
(
self
,
x
,
x_ignored_
,
x_ignored_1_
,
attn_mask
=
None
,
key_padding_mask
=
None
,
need_weights
=
False
):
def
forward
(
self
,
x
,
x_ignored_
,
x_ignored_1_
,
attn_mask
=
None
,
key_padding_mask
=
None
,
need_weights
=
False
):
qkv
=
self
.
Wqkv
(
x
)
qkv
=
rearrange
(
qkv
,
'b s (three h d) -> b s three h d'
,
three
=
3
,
h
=
self
.
num_heads
)
context
,
attn_weights
=
self
.
inner_attn
(
qkv
,
key_padding_mask
=
key_padding_mask
,
need_weights
=
need_weights
,
causal
=
self
.
causal
)
return
self
.
out_proj
(
rearrange
(
context
,
'b s h d -> b s (h d)'
)),
attn_weights
qkv
=
rearrange
(
qkv
,
"b s (three h d) -> b s three h d"
,
three
=
3
,
h
=
self
.
num_heads
)
context
,
attn_weights
=
self
.
inner_attn
(
qkv
,
key_padding_mask
=
key_padding_mask
,
need_weights
=
need_weights
,
causal
=
self
.
causal
)
return
self
.
out_proj
(
rearrange
(
context
,
"b s h d -> b s (h d)"
)),
attn_weights
flash_attn/flash_blocksparse_attn_interface.py
View file @
f1a73d07
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py
import
flash_attn_cuda
import
torch
import
torch.nn
as
nn
import
flash_attn_cuda
def
convert_blockmask
(
blockmask
,
causal
):
"""Convert from the 0-1 format to the format used by the CUDA code.
...
...
@@ -40,29 +39,51 @@ def convert_blockmask(blockmask, causal):
return
nonzero_idx
.
T
.
contiguous
().
to
(
dtype
=
torch
.
int32
)
def
_flash_blocksparse_attn_forward
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
,
return_softmax
):
context
,
softmax_lse
,
*
rest
=
flash_attn_cuda
.
fwd_block
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
,
return_softmax
,
None
)
def
_flash_blocksparse_attn_forward
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
,
return_softmax
):
context
,
softmax_lse
,
*
rest
=
flash_attn_cuda
.
fwd_block
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
,
return_softmax
,
None
)
# if context.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
S_dmask
=
rest
[
0
]
if
return_softmax
else
None
return
context
,
softmax_lse
,
S_dmask
def
_flash_blocksparse_attn_backward
(
dout
,
qkv
,
out
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
dqkv
,
dp
,
softmax_d
=
flash_attn_cuda
.
bwd_block
(
dout
,
qkv
,
out
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
dropout_p
,
softmax_scale
,
max_s
,
causal
,
None
)
def
_flash_blocksparse_attn_backward
(
dout
,
qkv
,
out
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
,
):
dqkv
,
dp
,
softmax_d
=
flash_attn_cuda
.
bwd_block
(
dout
,
qkv
,
out
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
dropout_p
,
softmax_scale
,
max_s
,
causal
,
None
,
)
# if dqkv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return
dqkv
class
FlashBlocksparseAttnFun
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
# Save rng_state because the backward pass will regenerate the dropout mask
...
...
@@ -70,8 +91,14 @@ class FlashBlocksparseAttnFun(torch.autograd.Function):
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
context
,
softmax_lse
,
S_dmask
=
_flash_blocksparse_attn_forward
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
False
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
False
,
)
ctx
.
save_for_backward
(
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
...
...
@@ -88,8 +115,17 @@ class FlashBlocksparseAttnFun(torch.autograd.Function):
torch
.
cuda
.
set_rng_state
(
rng_state
)
# S_dmask is None, temporarily use another tensor just to get it running
dqkv
=
_flash_blocksparse_attn_backward
(
dout
,
qkv
,
context
,
context
,
softmax_lse
,
cu_seqlens
,
blockmask
,
ctx
.
dropout_p
,
ctx
.
max_s
,
ctx
.
softmax_scale
,
ctx
.
causal
dout
,
qkv
,
context
,
context
,
softmax_lse
,
cu_seqlens
,
blockmask
,
ctx
.
dropout_p
,
ctx
.
max_s
,
ctx
.
softmax_scale
,
ctx
.
causal
,
)
if
rng_state
is
not
None
:
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
...
...
@@ -99,7 +135,6 @@ class FlashBlocksparseAttnFun(torch.autograd.Function):
# We duplicate code to return both the output and the softmax for testing
# Returning both makes backward a bit slower, so we want to keep using the other version for speed.
class
FlashBlocksparseAttnFunWithS
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
):
# Save rng_state because the backward pass is gonna regenerate the dropout mask
...
...
@@ -107,8 +142,14 @@ class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
if
softmax_scale
is
None
:
softmax_scale
=
qkv
.
shape
[
-
1
]
**
(
-
0.5
)
context
,
softmax_lse
,
S_dmask
=
_flash_blocksparse_attn_forward
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
True
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
,
causal
=
causal
,
return_softmax
=
True
,
)
ctx
.
save_for_backward
(
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
rng_state
)
ctx
.
dropout_p
=
dropout_p
...
...
@@ -124,18 +165,35 @@ class FlashBlocksparseAttnFunWithS(torch.autograd.Function):
cur_rng_state
=
torch
.
cuda
.
get_rng_state
()
torch
.
cuda
.
set_rng_state
(
rng_state
)
dqkv
=
_flash_blocksparse_attn_backward
(
dout
,
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
ctx
.
dropout_p
,
ctx
.
max_s
,
ctx
.
softmax_scale
,
ctx
.
causal
dout
,
qkv
,
context
,
S_dmask
,
softmax_lse
,
cu_seqlens
,
blockmask
,
ctx
.
dropout_p
,
ctx
.
max_s
,
ctx
.
softmax_scale
,
ctx
.
causal
,
)
if
rng_state
is
not
None
:
torch
.
cuda
.
set_rng_state
(
cur_rng_state
)
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
def
flash_blocksparse_attn_func
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
convert_mask
=
True
):
"""dropout_p should be set to 0.0 during evaluation
"""
def
flash_blocksparse_attn_func
(
qkv
,
cu_seqlens
,
blockmask
,
dropout_p
,
max_s
,
softmax_scale
=
None
,
causal
=
False
,
return_attn_probs
=
False
,
convert_mask
=
True
,
):
"""dropout_p should be set to 0.0 during evaluation"""
func
=
FlashBlocksparseAttnFun
if
not
return_attn_probs
else
FlashBlocksparseAttnFunWithS
if
convert_mask
:
blockmask
=
convert_blockmask
(
blockmask
,
causal
=
causal
)
...
...
flash_attn/fused_softmax.py
View file @
f1a73d07
...
...
@@ -17,13 +17,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
apex._autocast_utils
import
_cast_if_autocast_enabled
from
apex.transformer.enums
import
AttnMaskType
from
fused_softmax_lib
import
scaled_masked_softmax_forward
,
scaled_masked_softmax_backward
from
fused_softmax_lib
import
scaled_masked_softmax_get_batch_per_block
from
fused_softmax_lib
import
scaled_upper_triang_masked_softmax_forward
,
scaled_upper_triang_masked_softmax_backward
from
fused_softmax_lib
import
(
scaled_masked_softmax_backward
,
scaled_masked_softmax_forward
,
scaled_masked_softmax_get_batch_per_block
,
scaled_upper_triang_masked_softmax_backward
,
scaled_upper_triang_masked_softmax_forward
,
)
class
ScaledUpperTriangMaskedSoftmax
(
torch
.
autograd
.
Function
):
...
...
@@ -37,9 +39,7 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
@
staticmethod
def
forward
(
ctx
,
inputs
,
scale
):
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_upper_triang_masked_softmax_forward
(
inputs
,
scale_t
[
0
]
)
softmax_results
=
scaled_upper_triang_masked_softmax_forward
(
inputs
,
scale_t
[
0
])
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
...
...
@@ -81,9 +81,7 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
@
staticmethod
def
backward
(
ctx
,
output_grads
):
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_masked_softmax_backward
(
output_grads
,
softmax_results
,
scale_t
[
0
]
)
input_grads
=
scaled_masked_softmax_backward
(
output_grads
,
softmax_results
,
scale_t
[
0
])
return
input_grads
,
None
,
None
...
...
@@ -122,9 +120,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_bf16
=
input_in_bf16
if
self
.
input_in_fp16
and
self
.
input_in_bf16
:
raise
RuntimeError
(
"both fp16 and bf16 flags cannot be active at the same time."
)
raise
RuntimeError
(
"both fp16 and bf16 flags cannot be active at the same time."
)
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
attn_mask_type
=
attn_mask_type
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
...
...
flash_attn/layers/patch_embed.py
View file @
f1a73d07
...
...
@@ -4,11 +4,10 @@
from
functools
import
partial
import
torch.nn
as
nn
from
einops
import
rearrange
from
torch
import
_assert
from
torch.nn.modules.utils
import
_pair
from
einops
import
rearrange
try
:
from
flash_attn.ops.fused_dense
import
FusedDense
except
ImportError
:
...
...
@@ -16,18 +15,18 @@ except ImportError:
class
PatchEmbed
(
nn
.
Module
):
"""
2D Image to Patch Embedding
"""
"""2D Image to Patch Embedding
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
embed_dim
=
768
,
norm_layer
=
None
,
flatten
=
True
,
bias
=
True
,
fused_bias_fc
=
False
,
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
embed_dim
=
768
,
norm_layer
=
None
,
flatten
=
True
,
bias
=
True
,
fused_bias_fc
=
False
,
):
super
().
__init__
()
img_size
=
_pair
(
img_size
)
...
...
@@ -38,7 +37,7 @@ class PatchEmbed(nn.Module):
self
.
num_patches
=
self
.
grid_size
[
0
]
*
self
.
grid_size
[
1
]
self
.
flatten
=
flatten
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
'
fused_dense is not installed
'
)
raise
ImportError
(
"
fused_dense is not installed
"
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
or
not
bias
else
FusedDense
self
.
proj
=
linear_cls
(
in_chans
*
patch_size
[
0
]
*
patch_size
[
1
],
embed_dim
,
bias
=
bias
)
...
...
@@ -46,11 +45,23 @@ class PatchEmbed(nn.Module):
def
forward
(
self
,
x
):
_
,
_
,
H
,
W
=
x
.
shape
_assert
(
H
==
self
.
img_size
[
0
],
f
"Input image height (
{
H
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
)."
)
_assert
(
W
==
self
.
img_size
[
1
],
f
"Input image width (
{
W
}
) doesn't match model (
{
self
.
img_size
[
1
]
}
)."
)
x
=
self
.
proj
(
rearrange
(
x
,
'b c (h p1) (w p2) -> b h w (c p1 p2)'
,
p1
=
self
.
patch_size
[
0
],
p2
=
self
.
patch_size
[
1
]))
_assert
(
H
==
self
.
img_size
[
0
],
f
"Input image height (
{
H
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
)."
,
)
_assert
(
W
==
self
.
img_size
[
1
],
f
"Input image width (
{
W
}
) doesn't match model (
{
self
.
img_size
[
1
]
}
)."
,
)
x
=
self
.
proj
(
rearrange
(
x
,
"b c (h p1) (w p2) -> b h w (c p1 p2)"
,
p1
=
self
.
patch_size
[
0
],
p2
=
self
.
patch_size
[
1
],
)
)
if
self
.
flatten
:
x
=
rearrange
(
x
,
'
b h w c -> b (h w) c
'
)
x
=
rearrange
(
x
,
"
b h w c -> b (h w) c
"
)
x
=
self
.
norm
(
x
)
return
x
flash_attn/layers/rotary.py
View file @
f1a73d07
# Copyright (c) 2023, Tri Dao.
from
typing
import
Tuple
,
Optional
import
math
from
typing
import
Optional
,
Tuple
import
rotary_emb
import
torch
from
einops
import
rearrange
,
repeat
import
rotary_emb
def
rotate_half
(
x
,
interleaved
=
False
):
if
not
interleaved
:
...
...
@@ -16,7 +14,7 @@ def rotate_half(x, interleaved=False):
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
else
:
x1
,
x2
=
x
[...,
::
2
],
x
[...,
1
::
2
]
return
rearrange
(
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
),
'
... d two -> ... (d two)
'
,
two
=
2
)
return
rearrange
(
torch
.
stack
((
-
x2
,
x1
),
dim
=-
1
),
"
... d two -> ... (d two)
"
,
two
=
2
)
def
apply_rotary_emb_torch
(
x
,
cos
,
sin
,
interleaved
=
False
):
...
...
@@ -26,14 +24,15 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
"""
ro_dim
=
cos
.
shape
[
-
1
]
*
2
assert
ro_dim
<=
x
.
shape
[
-
1
]
cos
=
repeat
(
cos
,
's d -> s 1 (2 d)'
)
sin
=
repeat
(
sin
,
's d -> s 1 (2 d)'
)
return
torch
.
cat
([
x
[...,
:
ro_dim
]
*
cos
+
rotate_half
(
x
[...,
:
ro_dim
],
interleaved
)
*
sin
,
x
[...,
ro_dim
:]],
dim
=-
1
)
cos
=
repeat
(
cos
,
"s d -> s 1 (2 d)"
)
sin
=
repeat
(
sin
,
"s d -> s 1 (2 d)"
)
return
torch
.
cat
(
[
x
[...,
:
ro_dim
]
*
cos
+
rotate_half
(
x
[...,
:
ro_dim
],
interleaved
)
*
sin
,
x
[...,
ro_dim
:]],
dim
=-
1
,
)
class
ApplyRotaryEmb
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
cos
,
sin
,
interleaved
=
False
,
inplace
=
False
):
"""
...
...
@@ -57,10 +56,20 @@ class ApplyRotaryEmb(torch.autograd.Function):
if
inplace
:
o1
,
o2
=
x1
,
x2
else
:
o1
,
o2
=
(
out_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
out_ro
[...,
::
2
],
out_ro
[...,
1
::
2
]))
rotary_emb
.
apply_rotary
(
x1
,
x2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
o1
,
o2
,
False
)
o1
,
o2
=
(
out_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
out_ro
[...,
::
2
],
out_ro
[...,
1
::
2
])
)
rotary_emb
.
apply_rotary
(
x1
,
x2
,
rearrange
(
cos
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin
[:
seqlen
],
"s d -> s 1 d"
),
o1
,
o2
,
False
,
)
if
not
inplace
and
rotary_dim
<
headdim
:
out
[...,
rotary_dim
:].
copy_
(
x
[...,
rotary_dim
:])
ctx
.
save_for_backward
(
cos
,
sin
)
...
...
@@ -76,17 +85,28 @@ class ApplyRotaryEmb(torch.autograd.Function):
rotary_dim
*=
2
inplace
=
ctx
.
inplace
do_ro
=
do
[...,
:
rotary_dim
]
do1
,
do2
=
(
do_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
do_ro
[...,
::
2
],
do_ro
[...,
1
::
2
]))
do1
,
do2
=
(
do_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
do_ro
[...,
::
2
],
do_ro
[...,
1
::
2
])
)
dx
=
torch
.
empty_like
(
do
)
if
not
inplace
else
do
if
inplace
:
dx1
,
dx2
=
do1
,
do2
else
:
dx_ro
=
dx
[...,
:
rotary_dim
]
dx1
,
dx2
=
(
dx_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
dx_ro
[...,
::
2
],
dx_ro
[...,
1
::
2
]))
rotary_emb
.
apply_rotary
(
do1
,
do2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
dx1
,
dx2
,
True
)
dx1
,
dx2
=
(
dx_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
dx_ro
[...,
::
2
],
dx_ro
[...,
1
::
2
])
)
rotary_emb
.
apply_rotary
(
do1
,
do2
,
rearrange
(
cos
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin
[:
seqlen
],
"s d -> s 1 d"
),
dx1
,
dx2
,
True
,
)
if
not
inplace
and
rotary_dim
<
headdim
:
dx
[...,
rotary_dim
:].
copy_
(
do
[...,
rotary_dim
:])
return
dx
,
None
,
None
,
None
,
None
...
...
@@ -96,7 +116,6 @@ apply_rotary_emb_func = ApplyRotaryEmb.apply
class
ApplyRotaryEmbQKV_
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
qkv
,
cos
,
sin
,
cos_k
=
None
,
sin_k
=
None
,
interleaved
=
False
):
"""
...
...
@@ -119,12 +138,26 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
assert
sin
.
shape
==
cos_k
.
shape
==
sin_k
.
shape
==
(
rotary_seqlen
,
rotary_dim
//
2
)
q_ro
=
qkv
[:,
:,
0
,
:,
:
rotary_dim
]
q1
,
q2
=
q_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
q_ro
[...,
::
2
],
q_ro
[...,
1
::
2
])
rotary_emb
.
apply_rotary
(
q1
,
q2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
q1
,
q2
,
False
)
rotary_emb
.
apply_rotary
(
q1
,
q2
,
rearrange
(
cos
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin
[:
seqlen
],
"s d -> s 1 d"
),
q1
,
q2
,
False
,
)
k_ro
=
qkv
[:,
:,
1
,
:,
:
rotary_dim
]
k1
,
k2
=
k_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
k_ro
[...,
::
2
],
k_ro
[...,
1
::
2
])
rotary_emb
.
apply_rotary
(
k1
,
k2
,
rearrange
(
cos_k
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin_k
[:
seqlen
],
's d -> s 1 d'
),
k1
,
k2
,
False
)
rotary_emb
.
apply_rotary
(
k1
,
k2
,
rearrange
(
cos_k
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin_k
[:
seqlen
],
"s d -> s 1 d"
),
k1
,
k2
,
False
,
)
ctx
.
save_for_backward
(
cos
,
sin
,
cos_k
,
sin_k
)
ctx
.
interleaved
=
interleaved
return
qkv
...
...
@@ -136,15 +169,31 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
rotary_dim
=
cos
.
shape
[
-
1
]
rotary_dim
*=
2
dq_ro
=
dqkv
[:,
:,
0
,
:,
:
rotary_dim
]
dq1
,
dq2
=
(
dq_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
dq_ro
[...,
::
2
],
dq_ro
[...,
1
::
2
]))
rotary_emb
.
apply_rotary
(
dq1
,
dq2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
dq1
,
dq2
,
True
)
dq1
,
dq2
=
(
dq_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
dq_ro
[...,
::
2
],
dq_ro
[...,
1
::
2
])
)
rotary_emb
.
apply_rotary
(
dq1
,
dq2
,
rearrange
(
cos
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin
[:
seqlen
],
"s d -> s 1 d"
),
dq1
,
dq2
,
True
,
)
dk_ro
=
dqkv
[:,
:,
1
,
:,
:
rotary_dim
]
dk1
,
dk2
=
(
dk_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
dk_ro
[...,
::
2
],
dk_ro
[...,
1
::
2
]))
rotary_emb
.
apply_rotary
(
dk1
,
dk2
,
rearrange
(
cos_k
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin_k
[:
seqlen
],
's d -> s 1 d'
),
dk1
,
dk2
,
True
)
dk1
,
dk2
=
(
dk_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
dk_ro
[...,
::
2
],
dk_ro
[...,
1
::
2
])
)
rotary_emb
.
apply_rotary
(
dk1
,
dk2
,
rearrange
(
cos_k
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin_k
[:
seqlen
],
"s d -> s 1 d"
),
dk1
,
dk2
,
True
,
)
return
dqkv
,
None
,
None
,
None
,
None
,
None
...
...
@@ -152,7 +201,6 @@ apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
class
ApplyRotaryEmbKV_
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
kv
,
cos
,
sin
,
interleaved
=
False
):
"""
...
...
@@ -171,9 +219,15 @@ class ApplyRotaryEmbKV_(torch.autograd.Function):
assert
seqlen
<=
rotary_seqlen
k_ro
=
kv
[:,
:,
0
,
:,
:
rotary_dim
]
k1
,
k2
=
k_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
k_ro
[...,
::
2
],
k_ro
[...,
1
::
2
])
rotary_emb
.
apply_rotary
(
k1
,
k2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
k1
,
k2
,
False
)
# conj=False since this is the forward pass
rotary_emb
.
apply_rotary
(
k1
,
k2
,
rearrange
(
cos
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin
[:
seqlen
],
"s d -> s 1 d"
),
k1
,
k2
,
False
,
)
# conj=False since this is the forward pass
ctx
.
save_for_backward
(
cos
,
sin
)
ctx
.
interleaved
=
interleaved
return
kv
...
...
@@ -185,11 +239,18 @@ class ApplyRotaryEmbKV_(torch.autograd.Function):
rotary_dim
=
cos
.
shape
[
-
1
]
rotary_dim
*=
2
dk_ro
=
dkv
[:,
:,
0
,
:,
:
rotary_dim
]
dk1
,
dk2
=
(
dk_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
dk_ro
[...,
::
2
],
dk_ro
[...,
1
::
2
]))
rotary_emb
.
apply_rotary
(
dk1
,
dk2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
dk1
,
dk2
,
True
)
# conj=True since this is the backward pass
dk1
,
dk2
=
(
dk_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
dk_ro
[...,
::
2
],
dk_ro
[...,
1
::
2
])
)
rotary_emb
.
apply_rotary
(
dk1
,
dk2
,
rearrange
(
cos
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin
[:
seqlen
],
"s d -> s 1 d"
),
dk1
,
dk2
,
True
,
)
# conj=True since this is the backward pass
return
dkv
,
None
,
None
,
None
...
...
@@ -214,21 +275,28 @@ class RotaryEmbedding(torch.nn.Module):
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
"""
def
__init__
(
self
,
dim
:
int
,
base
=
10000.0
,
interleaved
=
False
,
scale_base
=
None
,
pos_idx_in_fp32
=
True
,
device
=
None
):
def
__init__
(
self
,
dim
:
int
,
base
=
10000.0
,
interleaved
=
False
,
scale_base
=
None
,
pos_idx_in_fp32
=
True
,
device
=
None
,
):
"""
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
otherwise they might be in lower precision.
This option was added because previously (before 2023-07-02), when we construct
the position indices, we use the dtype of self.inv_freq. In most cases this would
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
self.inv_freq would be bf16, and the position indices are also in bf16.
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
embeddings for some positions will coincide.
To maintain compatibility with models previously trained in pure bf16,
we add this option.
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
otherwise they might be in lower precision.
This option was added because previously (before 2023-07-02), when we construct
the position indices, we use the dtype of self.inv_freq. In most cases this would
be fp32, but if the model is trained in pure bf16 (not mixed precision), then
self.inv_freq would be bf16, and the position indices are also in bf16.
Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
embeddings for some positions will coincide.
To maintain compatibility with models previously trained in pure bf16,
we add this option.
"""
super
().
__init__
()
self
.
dim
=
dim
...
...
@@ -239,8 +307,11 @@ class RotaryEmbedding(torch.nn.Module):
self
.
register_buffer
(
"inv_freq"
,
inv_freq
,
persistent
=
False
)
self
.
interleaved
=
interleaved
self
.
scale_base
=
scale_base
scale
=
((
torch
.
arange
(
0
,
dim
,
2
,
device
=
device
,
dtype
=
torch
.
float32
)
+
0.4
*
dim
)
/
(
1.4
*
dim
)
if
scale_base
is
not
None
else
None
)
scale
=
(
(
torch
.
arange
(
0
,
dim
,
2
,
device
=
device
,
dtype
=
torch
.
float32
)
+
0.4
*
dim
)
/
(
1.4
*
dim
)
if
scale_base
is
not
None
else
None
)
self
.
register_buffer
(
"scale"
,
scale
,
persistent
=
False
)
self
.
_seq_len_cached
=
0
...
...
@@ -250,17 +321,21 @@ class RotaryEmbedding(torch.nn.Module):
self
.
_sin_k_cached
=
None
def
_compute_inv_freq
(
self
,
device
=
None
):
return
1.0
/
(
self
.
base
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
,
device
=
device
,
dtype
=
torch
.
float32
)
/
self
.
dim
))
return
1.0
/
(
self
.
base
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
,
device
=
device
,
dtype
=
torch
.
float32
)
/
self
.
dim
)
)
def
_update_cos_sin_cache
(
self
,
seqlen
,
device
=
None
,
dtype
=
None
):
# Reset the tables if the sequence length has changed,
# if we're on a new device (possibly due to tracing for instance),
# or if we're switching from inference mode to training
if
(
seqlen
>
self
.
_seq_len_cached
or
self
.
_cos_cached
.
device
!=
device
if
(
seqlen
>
self
.
_seq_len_cached
or
self
.
_cos_cached
.
device
!=
device
or
self
.
_cos_cached
.
dtype
!=
dtype
or
(
self
.
training
and
self
.
_cos_cached
.
is_inference
())):
or
(
self
.
training
and
self
.
_cos_cached
.
is_inference
())
):
self
.
_seq_len_cached
=
seqlen
# We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
# And the output of arange can be quite large, so bf16 would lose a lot of precision.
...
...
@@ -285,17 +360,20 @@ class RotaryEmbedding(torch.nn.Module):
self
.
_cos_cached
=
torch
.
cos
(
freqs
).
to
(
dtype
)
self
.
_sin_cached
=
torch
.
sin
(
freqs
).
to
(
dtype
)
else
:
power
=
((
torch
.
arange
(
seqlen
,
dtype
=
self
.
scale
.
dtype
,
device
=
self
.
scale
.
device
)
-
seqlen
//
2
)
/
self
.
scale_base
)
scale
=
self
.
scale
.
to
(
device
=
power
.
device
)
**
rearrange
(
power
,
's -> s 1'
)
power
=
(
torch
.
arange
(
seqlen
,
dtype
=
self
.
scale
.
dtype
,
device
=
self
.
scale
.
device
)
-
seqlen
//
2
)
/
self
.
scale_base
scale
=
self
.
scale
.
to
(
device
=
power
.
device
)
**
rearrange
(
power
,
"s -> s 1"
)
# We want the multiplication by scale to happen in fp32
self
.
_cos_cached
=
(
torch
.
cos
(
freqs
)
*
scale
).
to
(
dtype
)
self
.
_sin_cached
=
(
torch
.
sin
(
freqs
)
*
scale
).
to
(
dtype
)
self
.
_cos_k_cached
=
(
torch
.
cos
(
freqs
)
/
scale
).
to
(
dtype
)
self
.
_sin_k_cached
=
(
torch
.
sin
(
freqs
)
/
scale
).
to
(
dtype
)
def
forward
(
self
,
qkv
:
torch
.
Tensor
,
kv
:
Optional
[
torch
.
Tensor
]
=
None
,
seqlen_offset
:
int
=
0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
forward
(
self
,
qkv
:
torch
.
Tensor
,
kv
:
Optional
[
torch
.
Tensor
]
=
None
,
seqlen_offset
:
int
=
0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
else it's just q of shape (batch, seqlen, nheads, headdim)
...
...
@@ -308,29 +386,43 @@ class RotaryEmbedding(torch.nn.Module):
if
kv
is
None
:
if
self
.
scale
is
None
:
return
apply_rotary_emb_qkv_
(
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
None
,
None
,
self
.
interleaved
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
None
,
None
,
self
.
interleaved
,
)
else
:
return
apply_rotary_emb_qkv_
(
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
self
.
_cos_k_cached
[
seqlen_offset
:],
self
.
_sin_k_cached
[
seqlen_offset
:],
self
.
interleaved
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
self
.
_cos_k_cached
[
seqlen_offset
:],
self
.
_sin_k_cached
[
seqlen_offset
:],
self
.
interleaved
,
)
else
:
q
=
qkv
q
=
apply_rotary_emb_func
(
q
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
self
.
interleaved
,
True
q
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
self
.
interleaved
,
True
,
)
if
self
.
scale
is
None
:
kv
=
apply_rotary_emb_kv_
(
kv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
self
.
interleaved
kv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
self
.
interleaved
,
)
else
:
kv
=
apply_rotary_emb_kv_
(
kv
,
self
.
_cos_k_cached
[
seqlen_offset
:],
self
.
_sin_k_cached
[
seqlen_offset
:],
self
.
interleaved
kv
,
self
.
_cos_k_cached
[
seqlen_offset
:],
self
.
_sin_k_cached
[
seqlen_offset
:],
self
.
interleaved
,
)
return
q
,
kv
flash_attn/losses/cross_entropy.py
View file @
f1a73d07
...
...
@@ -5,7 +5,6 @@
# The original xentropy interface is here: https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
import
torch
import
torch.nn
as
nn
import
xentropy_cuda_lib
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for
...
...
@@ -17,10 +16,16 @@ if "all_gather_into_tensor" not in dir(torch.distributed):
class
SoftmaxCrossEntropyLossFn
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
logits
,
labels
,
smoothing
=
0.0
,
ignored_index
=-
100
,
inplace_backward
=
False
,
process_group
=
None
):
def
forward
(
ctx
,
logits
,
labels
,
smoothing
=
0.0
,
ignored_index
=-
100
,
inplace_backward
=
False
,
process_group
=
None
,
):
"""
logits: (batch, vocab_size)
labels: (batch,)
...
...
@@ -34,7 +39,7 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
if
world_size
==
1
:
losses
,
lse
=
xentropy_cuda_lib
.
forward
(
logits
,
labels
,
smoothing
)
losses
.
masked_fill_
(
labels
==
ignored_index
,
0
)
losses
.
masked_fill_
(
labels
==
ignored_index
,
0
)
labels_local
=
labels
else
:
rank
=
torch
.
distributed
.
get_rank
(
process_group
)
...
...
@@ -48,8 +53,9 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
# For tensor parallel cross entropy with smoothing, we want to pass in the total number
# of classes so that smoothing can be applied correctly. If total_classes=-1, use the
# last dimension of the input tensor.
losses
,
lse_local
=
xentropy_cuda_lib
.
forward
(
logits
,
labels_local
,
smoothing
,
world_size
*
vocab_size
)
losses
,
lse_local
=
xentropy_cuda_lib
.
forward
(
logits
,
labels_local
,
smoothing
,
world_size
*
vocab_size
)
assert
lse_local
.
shape
==
(
batch
,)
assert
losses
.
shape
==
(
batch
,)
losses
.
masked_fill_
(
ignored_mask
,
0
)
...
...
@@ -61,10 +67,12 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
# For labels not in the vocab of this partition, losses contains
# 0.1 * (lse_local - sum logit / total_classes).
lse_allgather
=
torch
.
empty
(
world_size
,
batch
,
dtype
=
lse_local
.
dtype
,
device
=
lse_local
.
device
)
torch
.
distributed
.
all_gather_into_tensor
(
lse_allgather
,
lse_local
.
contiguous
(),
group
=
process_group
)
lse_allgather
=
torch
.
empty
(
world_size
,
batch
,
dtype
=
lse_local
.
dtype
,
device
=
lse_local
.
device
)
torch
.
distributed
.
all_gather_into_tensor
(
lse_allgather
,
lse_local
.
contiguous
(),
group
=
process_group
)
handle_losses
=
torch
.
distributed
.
all_reduce
(
losses
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
process_group
,
async_op
=
True
)
...
...
@@ -74,16 +82,18 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
# If there's smoothing=0.1, the total losses are
# 0.9 * (lse_local - predicted_logit) + 0.1 * (sum of all lse_local - sum logit / total_classes)
# We want 0.9 * (lse - predicted_logit) + 0.1 * (lse - sum logit / total_classes).
rank_per_sample
=
torch
.
div
(
labels
,
vocab_size
,
rounding_mode
=
'floor'
)
lse_local
=
lse_allgather
[
rank_per_sample
,
torch
.
arange
(
batch
,
device
=
lse_allgather
.
device
)]
rank_per_sample
=
torch
.
div
(
labels
,
vocab_size
,
rounding_mode
=
"floor"
)
lse_local
=
lse_allgather
[
rank_per_sample
,
torch
.
arange
(
batch
,
device
=
lse_allgather
.
device
)
]
handle_losses
.
wait
()
if
smoothing
==
0.0
:
losses
+=
lse
-
lse_local
else
:
losses
+=
((
1
-
smoothing
)
*
(
lse
-
lse_local
)
+
smoothing
*
(
lse
-
lse_allgather
.
sum
(
dim
=
0
)))
losses
+=
(
1
-
smoothing
)
*
(
lse
-
lse_local
)
+
smoothing
*
(
lse
-
lse_allgather
.
sum
(
dim
=
0
)
)
losses
.
masked_fill_
(
ignored_mask
,
0
)
ctx
.
save_for_backward
(
logits
,
lse
,
labels_local
)
...
...
@@ -96,19 +106,24 @@ class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
def
backward
(
ctx
,
grad_loss
):
logits
,
lse
,
labels
=
ctx
.
saved_tensors
grad_loss
=
grad_loss
.
contiguous
()
grad_loss
.
masked_fill_
(
labels
==
ctx
.
ignored_index
,
0
)
grad_logits
=
xentropy_cuda_lib
.
backward
(
grad_loss
,
logits
,
lse
,
labels
,
ctx
.
smoothing
,
ctx
.
inplace_backward
,
ctx
.
total_classes
)
grad_loss
.
masked_fill_
(
labels
==
ctx
.
ignored_index
,
0
)
grad_logits
=
xentropy_cuda_lib
.
backward
(
grad_loss
,
logits
,
lse
,
labels
,
ctx
.
smoothing
,
ctx
.
inplace_backward
,
ctx
.
total_classes
)
return
grad_logits
,
None
,
None
,
None
,
None
,
None
,
None
class
CrossEntropyLoss
(
nn
.
Module
):
def
__init__
(
self
,
ignore_index
=-
100
,
reduction
=
'mean'
,
label_smoothing
=
0.0
,
inplace_backward
=
False
,
process_group
=
None
):
def
__init__
(
self
,
ignore_index
=-
100
,
reduction
=
"mean"
,
label_smoothing
=
0.0
,
inplace_backward
=
False
,
process_group
=
None
,
):
super
().
__init__
()
if
reduction
not
in
[
'
mean
'
,
'
none
'
]:
if
reduction
not
in
[
"
mean
"
,
"
none
"
]:
raise
NotImplementedError
(
"Only support reduction = 'mean' or 'none'"
)
self
.
ignore_index
=
ignore_index
self
.
reduction
=
reduction
...
...
@@ -120,10 +135,14 @@ class CrossEntropyLoss(nn.Module):
assert
input
.
is_cuda
and
target
.
is_cuda
# SoftmaxCrossEntropyLoss implicitly casts to float
loss
=
SoftmaxCrossEntropyLossFn
.
apply
(
input
,
target
,
self
.
label_smoothing
,
self
.
ignore_index
,
self
.
inplace_backward
,
self
.
process_group
input
,
target
,
self
.
label_smoothing
,
self
.
ignore_index
,
self
.
inplace_backward
,
self
.
process_group
,
)
if
self
.
reduction
==
'
mean
'
:
if
self
.
reduction
==
"
mean
"
:
return
loss
.
sum
()
/
(
target
!=
self
.
ignore_index
).
sum
()
else
:
return
loss
flash_attn/models/bert.py
View file @
f1a73d07
This diff is collapsed.
Click to expand it.
flash_attn/models/falcon.py
View file @
f1a73d07
...
...
@@ -2,93 +2,114 @@
import
math
import
re
from
collections
import
OrderedDict
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
transformers
import
GPT2Config
,
FalconConfig
from
transformers
import
FalconConfig
,
GPT2Config
def
remap_state_dict_hf_falcon
(
state_dict
,
config
):
def
key_mapping_layers
(
key
):
return
re
.
sub
(
r
'^transformer.h.'
,
'transformer.layers.'
,
key
)
return
re
.
sub
(
r
"^transformer.h."
,
"transformer.layers."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
'^transformer.word_embeddings.'
,
'transformer.embeddings.word_embeddings.'
,
key
)
return
re
.
sub
(
r
"^transformer.word_embeddings."
,
"transformer.embeddings.word_embeddings."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
'
transformer.embeddings.word_embeddings.weight
'
)
word_embeddings
=
state_dict
.
pop
(
"
transformer.embeddings.word_embeddings.weight
"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
'
pad_vocab_size_multiple
'
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
'
transformer.embeddings.word_embeddings.weight
'
]
=
F
.
pad
(
pad_vocab_size_multiple
=
getattr
(
config
,
"
pad_vocab_size_multiple
"
,
1
)
vocab_size
=
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
state_dict
[
"
transformer.embeddings.word_embeddings.weight
"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
if
getattr
(
config
,
'
tie_word_embeddings
'
):
state_dict
[
'
lm_head.weight
'
]
=
state_dict
[
'
transformer.embeddings.word_embeddings.weight
'
]
if
getattr
(
config
,
"
tie_word_embeddings
"
):
state_dict
[
"
lm_head.weight
"
]
=
state_dict
[
"
transformer.embeddings.word_embeddings.weight
"
]
else
:
output_embeddings
=
state_dict
.
pop
(
'
lm_head.weight
'
)
output_embeddings
=
state_dict
.
pop
(
"
lm_head.weight
"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict
[
'
lm_head.weight
'
]
=
F
.
pad
(
state_dict
[
"
lm_head.weight
"
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
output_embeddings_bias
=
state_dict
.
pop
(
'
lm_head.bias
'
)
state_dict
[
'
lm_head.bias
'
]
=
F
.
pad
(
output_embeddings_bias
=
state_dict
.
pop
(
"
lm_head.bias
"
)
state_dict
[
"
lm_head.bias
"
]
=
F
.
pad
(
output_embeddings_bias
,
(
0
,
vocab_size
-
output_embeddings_bias
.
shape
[
0
])
)
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).input_layernorm.'
,
r
'transformer.layers.\1.norm1.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).post_attention_layernorm.'
,
r
'transformer.layers.\1.norm2.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).ln_attn.'
,
r
'transformer.layers.\1.norm1.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).ln_mlp.'
,
r
'transformer.layers.\1.norm2.'
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).input_layernorm."
,
r
"transformer.layers.\1.norm1."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).post_attention_layernorm."
,
r
"transformer.layers.\1.norm2."
,
key
,
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).ln_attn."
,
r
"transformer.layers.\1.norm1."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).ln_mlp."
,
r
"transformer.layers.\1.norm2."
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
def
key_mapping_mlp
(
key
):
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).mlp.dense_h_to_4h.'
,
r
'transformer.layers.\1.mlp.fc1.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).mlp.dense_4h_to_h.'
,
r
'transformer.layers.\1.mlp.fc2.'
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).mlp.dense_h_to_4h."
,
r
"transformer.layers.\1.mlp.fc1."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).mlp.dense_4h_to_h."
,
r
"transformer.layers.\1.mlp.fc2."
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
def
key_mapping_attn
(
key
):
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).self_attention.query_key_value.'
,
r
'transformer.layers.\1.mixer.Wqkv.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).self_attention.dense.'
,
r
'transformer.layers.\1.mixer.out_proj.'
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).self_attention.query_key_value."
,
r
"transformer.layers.\1.mixer.Wqkv."
,
key
,
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).self_attention.dense."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
,
)
return
key
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
n_head
=
config
.
n_head
n_head_kv
=
getattr
(
config
,
"n_head_kv"
,
1
)
headdim
=
config
.
hidden_size
//
n_head
for
l
in
range
(
config
.
n_layer
):
# The weights are stored in a different layout compared to our implementation
Wqkv
=
rearrange
(
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.mixer.Wqkv.weight'
),
"(group ratio headdim) ... -> group ratio headdim ..."
,
ratio
=
n_head
//
n_head_kv
+
2
,
headdim
=
headdim
)
Wqkv
=
rearrange
(
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.mixer.Wqkv.weight"
),
"(group ratio headdim) ... -> group ratio headdim ..."
,
ratio
=
n_head
//
n_head_kv
+
2
,
headdim
=
headdim
,
)
Wq
=
rearrange
(
Wqkv
[:,
:
-
2
],
"group ratio headdim ... -> (group ratio headdim) ..."
)
Wk
=
rearrange
(
Wqkv
[:,
[
-
2
]],
"group ratio headdim ... -> (group ratio headdim) ..."
)
Wv
=
rearrange
(
Wqkv
[:,
[
-
1
]],
"group ratio headdim ... -> (group ratio headdim) ..."
)
state_dict
[
f
'
transformer.layers.
{
l
}
.mixer.Wqkv.weight
'
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
state_dict
[
f
"
transformer.layers.
{
l
}
.mixer.Wqkv.weight
"
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
return
state_dict
def
falcon_config_to_gpt2_config
(
falcon_config
:
FalconConfig
)
->
GPT2Config
:
# The 40b config uses "n_head_kv" instead of "num_kv_heads"
n_head_kv
=
getattr
(
falcon_config
,
"n_head_kv"
,
1
if
getattr
(
falcon_config
,
"multi_query"
,
False
)
else
falcon_config
.
n_head
)
n_head_kv
=
getattr
(
falcon_config
,
"n_head_kv"
,
1
if
getattr
(
falcon_config
,
"multi_query"
,
False
)
else
falcon_config
.
n_head
,
)
# HACK: the 40b config has 2 LN per layer instead of 1, but that's not reflected in the config.
# So we have to infer it from the number of heads in the key/value block
parallel_block_tied_norm
=
n_head_kv
==
1
...
...
flash_attn/models/gpt.py
View file @
f1a73d07
...
...
@@ -11,6 +11,8 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
transformers
import
GPT2Config
from
flash_attn.models.falcon
import
remap_state_dict_hf_falcon
from
flash_attn.models.gpt_neox
import
remap_state_dict_hf_gpt_neox
from
flash_attn.models.gptj
import
remap_state_dict_hf_gptj
...
...
@@ -27,10 +29,9 @@ from flash_attn.modules.mlp import (
ParallelMLP
,
)
from
flash_attn.ops.activations
import
sqrelu_fwd
from
flash_attn.utils.distributed
import
all_gather_raw
,
sync_shared_params
,
get_dim_for_local_rank
from
flash_attn.utils.distributed
import
all_gather_raw
,
get_dim_for_local_rank
,
sync_shared_params
from
flash_attn.utils.generation
import
GenerationMixin
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
transformers
import
GPT2Config
try
:
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
...
...
@@ -690,7 +691,7 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
if
key
in
state_dict
:
x
=
state_dict
[
key
]
dim
=
x
.
shape
[
0
]
//
world_size
state_dict
[
key
]
=
x
[
rank
*
dim
:
(
rank
+
1
)
*
dim
]
state_dict
[
key
]
=
x
[
rank
*
dim
:
(
rank
+
1
)
*
dim
]
def
shard_last_dim
(
state_dict
,
key
,
multiple_of
=
1
):
if
key
in
state_dict
:
...
...
@@ -707,17 +708,19 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
x
=
state_dict
[
key
]
dim
=
x
.
shape
[
0
]
//
world_size
//
2
state_dict
[
key
]
=
rearrange
(
rearrange
(
x
,
"(two o) ... -> two o ..."
,
two
=
2
)[:,
rank
*
dim
:
(
rank
+
1
)
*
dim
],
rearrange
(
x
,
"(two o) ... -> two o ..."
,
two
=
2
)[:,
rank
*
dim
:
(
rank
+
1
)
*
dim
],
"two o ... -> (two o) ..."
,
)
def
shard_qkv_headdim
(
state_dict
,
key
):
if
key
in
state_dict
:
n_head_each_rank
=
[
get_dim_for_local_rank
(
n_head
,
world_size
,
local_rank
)
for
local_rank
in
range
(
world_size
)
get_dim_for_local_rank
(
n_head
,
world_size
,
local_rank
)
for
local_rank
in
range
(
world_size
)
]
n_head_kv_each_rank
=
[
get_dim_for_local_rank
(
n_head_kv
,
world_size
,
local_rank
)
for
local_rank
in
range
(
world_size
)
get_dim_for_local_rank
(
n_head_kv
,
world_size
,
local_rank
)
for
local_rank
in
range
(
world_size
)
]
beg_n_head
=
sum
(
n_head_each_rank
[:
rank
])
...
...
@@ -729,7 +732,8 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
if
n_head_kv
==
n_head
:
x
=
rearrange
(
state_dict
[
key
],
"(three d) ... -> three d ..."
,
three
=
3
)
state_dict
[
key
]
=
rearrange
(
x
[:,
beg_n_head
*
head_dim
:
end_n_head
*
head_dim
],
"three d ... -> (three d) ..."
x
[:,
beg_n_head
*
head_dim
:
end_n_head
*
head_dim
],
"three d ... -> (three d) ..."
,
)
else
:
x
=
rearrange
(
...
...
@@ -741,8 +745,14 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
torch
.
cat
(
[
x
[
beg_n_head
:
end_n_head
],
x
[
n_head
+
beg_n_head_kv
:
n_head
+
end_n_head_kv
],
x
[
n_head
+
n_head_kv
+
beg_n_head_kv
:
n_head
+
n_head_kv
+
end_n_head_kv
],
x
[
n_head
+
beg_n_head_kv
:
n_head
+
end_n_head_kv
],
x
[
n_head
+
n_head_kv
+
beg_n_head_kv
:
n_head
+
n_head_kv
+
end_n_head_kv
],
],
dim
=
0
,
),
...
...
@@ -824,7 +834,7 @@ def combine_state_dicts_tp(state_dicts, config):
torch
.
cat
([
x
[:
n_head_per_rank
]
for
x
in
xs
],
dim
=
0
),
torch
.
cat
(
[
x
[
n_head_per_rank
:
n_head_per_rank
+
n_head_kv_per_rank
]
x
[
n_head_per_rank
:
n_head_per_rank
+
n_head_kv_per_rank
]
for
x
in
xs
],
dim
=
0
,
...
...
flash_attn/models/gpt_neox.py
View file @
f1a73d07
...
...
@@ -2,80 +2,100 @@
import
math
import
re
from
collections
import
OrderedDict
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
transformers
import
GPT2Config
,
GPTNeoXConfig
def
remap_state_dict_hf_gpt_neox
(
state_dict
,
config
):
def
key_mapping_layers
(
key
):
return
re
.
sub
(
r
'^gpt_neox.'
,
'transformer.'
,
key
)
return
re
.
sub
(
r
"^gpt_neox."
,
"transformer."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
'^transformer.embed_in.'
,
'transformer.embeddings.word_embeddings.'
,
key
)
return
re
.
sub
(
r
"^transformer.embed_in."
,
"transformer.embeddings.word_embeddings."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
'
transformer.embeddings.word_embeddings.weight
'
)
word_embeddings
=
state_dict
.
pop
(
"
transformer.embeddings.word_embeddings.weight
"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
'
pad_vocab_size_multiple
'
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
'
transformer.embeddings.word_embeddings.weight
'
]
=
F
.
pad
(
pad_vocab_size_multiple
=
getattr
(
config
,
"
pad_vocab_size_multiple
"
,
1
)
vocab_size
=
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
state_dict
[
"
transformer.embeddings.word_embeddings.weight
"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
if
getattr
(
config
,
'
tie_word_embeddings
'
):
state_dict
[
'
lm_head.weight
'
]
=
state_dict
[
'
transformer.embeddings.word_embeddings.weight
'
]
if
getattr
(
config
,
"
tie_word_embeddings
"
):
state_dict
[
"
lm_head.weight
"
]
=
state_dict
[
"
transformer.embeddings.word_embeddings.weight
"
]
else
:
output_embeddings
=
state_dict
.
pop
(
'
embed_out.weight
'
)
output_embeddings
=
state_dict
.
pop
(
"
embed_out.weight
"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict
[
'
lm_head.weight
'
]
=
F
.
pad
(
state_dict
[
"
lm_head.weight
"
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
'^transformer.final_layer_norm.'
,
r
'transformer.ln_f.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).input_layernorm.'
,
r
'transformer.layers.\1.norm1.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).post_attention_layernorm.'
,
r
'transformer.layers.\1.norm2.'
,
key
)
key
=
re
.
sub
(
r
"^transformer.final_layer_norm."
,
r
"transformer.ln_f."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).input_layernorm."
,
r
"transformer.layers.\1.norm1."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).post_attention_layernorm."
,
r
"transformer.layers.\1.norm2."
,
key
,
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
def
key_mapping_mlp
(
key
):
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).mlp.dense_h_to_4h.'
,
r
'transformer.layers.\1.mlp.fc1.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).mlp.dense_4h_to_h.'
,
r
'transformer.layers.\1.mlp.fc2.'
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).mlp.dense_h_to_4h."
,
r
"transformer.layers.\1.mlp.fc1."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).mlp.dense_4h_to_h."
,
r
"transformer.layers.\1.mlp.fc2."
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
for
l
in
range
(
config
.
n_layer
):
# We don't store these biases
state_dict
.
pop
(
f
'
transformer.layers.
{
l
}
.attention.bias
'
)
state_dict
.
pop
(
f
'
transformer.layers.
{
l
}
.attention.masked_bias
'
)
state_dict
.
pop
(
f
"
transformer.layers.
{
l
}
.attention.bias
"
)
state_dict
.
pop
(
f
"
transformer.layers.
{
l
}
.attention.masked_bias
"
)
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
headdim
=
config
.
hidden_size
//
config
.
num_attention_heads
Wqkv
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attention.query_key_value.weight'
)
state_dict
[
f
'transformer.layers.
{
l
}
.mixer.Wqkv.weight'
]
=
rearrange
(
Wqkv
,
'(nheads three headdim) ... -> (three nheads headdim) ...'
,
three
=
3
,
headdim
=
headdim
Wqkv
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.query_key_value.weight"
)
state_dict
[
f
"transformer.layers.
{
l
}
.mixer.Wqkv.weight"
]
=
rearrange
(
Wqkv
,
"(nheads three headdim) ... -> (three nheads headdim) ..."
,
three
=
3
,
headdim
=
headdim
,
)
bqkv
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attention.query_key_value.bias'
)
state_dict
[
f
'transformer.layers.
{
l
}
.mixer.Wqkv.bias'
]
=
rearrange
(
bqkv
,
'(nheads three headdim) -> (three nheads headdim)'
,
three
=
3
,
headdim
=
headdim
bqkv
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.query_key_value.bias"
)
state_dict
[
f
"transformer.layers.
{
l
}
.mixer.Wqkv.bias"
]
=
rearrange
(
bqkv
,
"(nheads three headdim) -> (three nheads headdim)"
,
three
=
3
,
headdim
=
headdim
)
def
key_mapping_attn
(
key
):
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).attention.dense.'
,
r
'transformer.layers.\1.mixer.out_proj.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).attention.rotary_emb.'
,
r
'transformer.layers.\1.mixer.rotary_emb.'
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).attention.dense."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
,
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).attention.rotary_emb."
,
r
"transformer.layers.\1.mixer.rotary_emb."
,
key
,
)
return
key
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
...
...
flash_attn/models/gptj.py
View file @
f1a73d07
...
...
@@ -2,67 +2,78 @@
import
math
import
re
from
collections
import
OrderedDict
import
torch
import
torch.nn.functional
as
F
from
transformers
import
GPT2Config
,
GPTJConfig
def
remap_state_dict_hf_gptj
(
state_dict
,
config
):
def
key_mapping_layers
(
key
):
return
re
.
sub
(
r
'^transformer.h.'
,
'transformer.layers.'
,
key
)
return
re
.
sub
(
r
"^transformer.h."
,
"transformer.layers."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
'^transformer.wte.'
,
'transformer.embeddings.word_embeddings.'
,
key
)
return
re
.
sub
(
r
"^transformer.wte."
,
"transformer.embeddings.word_embeddings."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
'
transformer.embeddings.word_embeddings.weight
'
)
word_embeddings
=
state_dict
.
pop
(
"
transformer.embeddings.word_embeddings.weight
"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
'
pad_vocab_size_multiple
'
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
'
transformer.embeddings.word_embeddings.weight
'
]
=
F
.
pad
(
pad_vocab_size_multiple
=
getattr
(
config
,
"
pad_vocab_size_multiple
"
,
1
)
vocab_size
=
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
state_dict
[
"
transformer.embeddings.word_embeddings.weight
"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
if
getattr
(
config
,
'
tie_word_embeddings
'
):
state_dict
[
'
lm_head.weight
'
]
=
state_dict
[
'
transformer.embeddings.word_embeddings.weight
'
]
if
getattr
(
config
,
"
tie_word_embeddings
"
):
state_dict
[
"
lm_head.weight
"
]
=
state_dict
[
"
transformer.embeddings.word_embeddings.weight
"
]
else
:
output_embeddings
=
state_dict
.
pop
(
'
lm_head.weight
'
)
output_embeddings
=
state_dict
.
pop
(
"
lm_head.weight
"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict
[
'
lm_head.weight
'
]
=
F
.
pad
(
state_dict
[
"
lm_head.weight
"
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
output_embeddings_bias
=
state_dict
.
pop
(
'
lm_head.bias
'
)
state_dict
[
'
lm_head.bias
'
]
=
F
.
pad
(
output_embeddings_bias
=
state_dict
.
pop
(
"
lm_head.bias
"
)
state_dict
[
"
lm_head.bias
"
]
=
F
.
pad
(
output_embeddings_bias
,
(
0
,
vocab_size
-
output_embeddings_bias
.
shape
[
0
])
)
# LayerNorm
def
key_mapping_ln
(
key
):
return
re
.
sub
(
r
'^transformer.layers.(\d+).ln_1.'
,
r
'transformer.layers.\1.norm1.'
,
key
)
return
re
.
sub
(
r
"^transformer.layers.(\d+).ln_1."
,
r
"transformer.layers.\1.norm1."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
def
key_mapping_mlp
(
key
):
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).mlp.fc_in.'
,
r
'transformer.layers.\1.mlp.fc1.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).mlp.fc_out.'
,
r
'transformer.layers.\1.mlp.fc2.'
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).mlp.fc_in."
,
r
"transformer.layers.\1.mlp.fc1."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).mlp.fc_out."
,
r
"transformer.layers.\1.mlp.fc2."
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
for
l
in
range
(
config
.
n_layer
):
Wq
=
state_dict
.
pop
(
f
'
transformer.layers.
{
l
}
.attn.q_proj.weight
'
)
Wk
=
state_dict
.
pop
(
f
'
transformer.layers.
{
l
}
.attn.k_proj.weight
'
)
Wv
=
state_dict
.
pop
(
f
'
transformer.layers.
{
l
}
.attn.v_proj.weight
'
)
state_dict
[
f
'
transformer.layers.
{
l
}
.mixer.Wqkv.weight
'
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
Wq
=
state_dict
.
pop
(
f
"
transformer.layers.
{
l
}
.attn.q_proj.weight
"
)
Wk
=
state_dict
.
pop
(
f
"
transformer.layers.
{
l
}
.attn.k_proj.weight
"
)
Wv
=
state_dict
.
pop
(
f
"
transformer.layers.
{
l
}
.attn.v_proj.weight
"
)
state_dict
[
f
"
transformer.layers.
{
l
}
.mixer.Wqkv.weight
"
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
# We don't store these biases
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attn.bias'
)
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attn.masked_bias'
)
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attn.bias"
)
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attn.masked_bias"
)
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
'^transformer.layers.(\d+).attn.out_proj.'
,
r
'transformer.layers.\1.mixer.out_proj.'
,
key
)
return
re
.
sub
(
r
"^transformer.layers.(\d+).attn.out_proj."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
...
...
flash_attn/models/llama.py
View file @
f1a73d07
...
...
@@ -15,63 +15,81 @@ from transformers import GPT2Config, LlamaConfig
def
remap_state_dict_meta_llama
(
state_dict
,
config
):
def
key_mapping_layers
(
key
):
return
f
'transformer.
{
key
}
'
if
not
key
.
startswith
(
'output.'
)
else
key
return
f
"transformer.
{
key
}
"
if
not
key
.
startswith
(
"output."
)
else
key
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
'^transformer.tok_embeddings.'
,
'transformer.embeddings.word_embeddings.'
,
key
)
return
re
.
sub
(
r
"^transformer.tok_embeddings."
,
"transformer.embeddings.word_embeddings."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
'
transformer.embeddings.word_embeddings.weight
'
)
word_embeddings
=
state_dict
.
pop
(
"
transformer.embeddings.word_embeddings.weight
"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
vocab_size
=
(
math
.
ceil
(
word_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
'transformer.embeddings.word_embeddings.weight'
]
=
F
.
pad
(
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
(
math
.
ceil
(
word_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
if
getattr
(
config
,
'
tie_word_embeddings
'
):
state_dict
[
'
lm_head.weight
'
]
=
state_dict
[
'
transformer.embeddings.word_embeddings.weight
'
]
if
getattr
(
config
,
"
tie_word_embeddings
"
):
state_dict
[
"
lm_head.weight
"
]
=
state_dict
[
"
transformer.embeddings.word_embeddings.weight
"
]
else
:
output_embeddings
=
state_dict
.
pop
(
'
output.weight
'
)
output_embeddings
=
state_dict
.
pop
(
"
output.weight
"
)
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
# differently.
vocab_size
=
(
math
.
ceil
(
output_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
vocab_size
=
(
math
.
ceil
(
output_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict
[
'
lm_head.weight
'
]
=
F
.
pad
(
state_dict
[
"
lm_head.weight
"
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
'^transformer.norm.'
,
r
'transformer.ln_f.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).attention_norm.'
,
r
'transformer.layers.\1.norm1.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).ffn_norm.'
,
r
'transformer.layers.\1.norm2.'
,
key
)
key
=
re
.
sub
(
r
"^transformer.norm."
,
r
"transformer.ln_f."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).attention_norm."
,
r
"transformer.layers.\1.norm1."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).ffn_norm."
,
r
"transformer.layers.\1.norm2."
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
for
l
in
range
(
config
.
n_layer
):
w1
=
state_dict
.
pop
(
f
'
transformer.layers.
{
l
}
.feed_forward.w1.weight
'
)
w3
=
state_dict
.
pop
(
f
'
transformer.layers.
{
l
}
.feed_forward.w3.weight
'
)
w1
=
state_dict
.
pop
(
f
"
transformer.layers.
{
l
}
.feed_forward.w1.weight
"
)
w3
=
state_dict
.
pop
(
f
"
transformer.layers.
{
l
}
.feed_forward.w3.weight
"
)
# Our ordering is different
state_dict
[
f
'transformer.layers.
{
l
}
.mlp.fc1.weight'
]
=
torch
.
cat
([
w3
,
w1
],
dim
=
0
)
state_dict
[
f
"transformer.layers.
{
l
}
.mlp.fc1.weight"
]
=
torch
.
cat
([
w3
,
w1
],
dim
=
0
)
def
key_mapping_mlp
(
key
):
return
re
.
sub
(
r
'^transformer.layers.(\d+).feed_forward.w2.'
,
r
'transformer.layers.\1.mlp.fc2.'
,
key
)
return
re
.
sub
(
r
"^transformer.layers.(\d+).feed_forward.w2."
,
r
"transformer.layers.\1.mlp.fc2."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
for
l
in
range
(
config
.
n_layer
):
Wq
=
state_dict
.
pop
(
f
'
transformer.layers.
{
l
}
.attention.wq.weight
'
)
Wk
=
state_dict
.
pop
(
f
'
transformer.layers.
{
l
}
.attention.wk.weight
'
)
Wv
=
state_dict
.
pop
(
f
'
transformer.layers.
{
l
}
.attention.wv.weight
'
)
state_dict
[
f
'
transformer.layers.
{
l
}
.mixer.Wqkv.weight
'
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
Wq
=
state_dict
.
pop
(
f
"
transformer.layers.
{
l
}
.attention.wq.weight
"
)
Wk
=
state_dict
.
pop
(
f
"
transformer.layers.
{
l
}
.attention.wk.weight
"
)
Wv
=
state_dict
.
pop
(
f
"
transformer.layers.
{
l
}
.attention.wv.weight
"
)
state_dict
[
f
"
transformer.layers.
{
l
}
.mixer.Wqkv.weight
"
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
# We don't store these
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.attention.inner_attention.rope.freqs'
,
None
)
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.inner_attention.rope.freqs"
,
None
)
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
'^transformer.layers.(\d+).attention.wo.'
,
r
'transformer.layers.\1.mixer.out_proj.'
,
key
)
return
re
.
sub
(
r
"^transformer.layers.(\d+).attention.wo."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
.
pop
(
"transformer.rope.freqs"
,
None
)
...
...
@@ -82,29 +100,32 @@ def remap_state_dict_meta_llama(state_dict, config):
def
remap_state_dict_hf_llama
(
state_dict
,
config
):
# Embedding
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
'
^model.embed_tokens.
'
,
'
transformer.embeddings.word_embeddings.
'
,
key
)
return
re
.
sub
(
r
"
^model.embed_tokens.
"
,
"
transformer.embeddings.word_embeddings.
"
,
key
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
'
transformer.embeddings.word_embeddings.weight
'
)
word_embeddings
=
state_dict
.
pop
(
"
transformer.embeddings.word_embeddings.weight
"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
vocab_size
=
(
math
.
ceil
(
word_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
'transformer.embeddings.word_embeddings.weight'
]
=
F
.
pad
(
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
(
math
.
ceil
(
word_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
# LM head
if
getattr
(
config
,
'
tie_word_embeddings
'
):
state_dict
[
'
lm_head.weight
'
]
=
state_dict
[
'
transformer.embeddings.word_embeddings.weight
'
]
if
getattr
(
config
,
"
tie_word_embeddings
"
):
state_dict
[
"
lm_head.weight
"
]
=
state_dict
[
"
transformer.embeddings.word_embeddings.weight
"
]
else
:
output_embeddings
=
state_dict
.
pop
(
'
lm_head.weight
'
)
output_embeddings
=
state_dict
.
pop
(
"
lm_head.weight
"
)
# Need to recompute vocab_size since LLaMa shards the word embeddings and output embeddings
# differently.
vocab_size
=
(
math
.
ceil
(
output_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
vocab_size
=
(
math
.
ceil
(
output_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict
[
'
lm_head.weight
'
]
=
F
.
pad
(
state_dict
[
"
lm_head.weight
"
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
...
...
@@ -113,21 +134,22 @@ def remap_state_dict_hf_llama(state_dict, config):
# Fusing weights this way based on difference in the following:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/modeling_llama.py#L220
# https://github.com/Dao-AILab/flash-attention/blob/c60851a8253257eb970e06a022c82517a8033e8c/flash_attn/modules/mlp.py#L115
w1
=
state_dict
.
pop
(
f
'
model.layers.
{
l
}
.mlp.gate_proj.weight
'
)
w3
=
state_dict
.
pop
(
f
'
model.layers.
{
l
}
.mlp.up_proj.weight
'
)
state_dict
[
f
'
transformer.layers.
{
l
}
.mlp.fc1.weight
'
]
=
torch
.
cat
([
w3
,
w1
],
dim
=
0
)
w1
=
state_dict
.
pop
(
f
"
model.layers.
{
l
}
.mlp.gate_proj.weight
"
)
w3
=
state_dict
.
pop
(
f
"
model.layers.
{
l
}
.mlp.up_proj.weight
"
)
state_dict
[
f
"
transformer.layers.
{
l
}
.mlp.fc1.weight
"
]
=
torch
.
cat
([
w3
,
w1
],
dim
=
0
)
def
key_mapping_mlp
(
key
):
return
re
.
sub
(
r
'^model.layers.(\d+).mlp.down_proj.'
,
r
'transformer.layers.\1.mlp.fc2.'
,
key
)
return
re
.
sub
(
r
"^model.layers.(\d+).mlp.down_proj."
,
r
"transformer.layers.\1.mlp.fc2."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
'^model.norm.'
,
r
'transformer.ln_f.'
,
key
)
key
=
re
.
sub
(
r
'^model.layers.(\d+).input_layernorm.'
,
r
'transformer.layers.\1.norm1.'
,
key
)
key
=
re
.
sub
(
r
'^model.layers.(\d+).post_attention_layernorm.'
,
r
'transformer.layers.\1.norm2.'
,
key
)
key
=
re
.
sub
(
r
"^model.norm."
,
r
"transformer.ln_f."
,
key
)
key
=
re
.
sub
(
r
"^model.layers.(\d+).input_layernorm."
,
r
"transformer.layers.\1.norm1."
,
key
)
key
=
re
.
sub
(
r
"^model.layers.(\d+).post_attention_layernorm."
,
r
"transformer.layers.\1.norm2."
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
...
...
@@ -135,42 +157,52 @@ def remap_state_dict_hf_llama(state_dict, config):
def
inv_permute
(
w
):
# Inverse of permute implemented in:
# https://github.com/huggingface/transformers/blob/b42010bb1d3cbf262d27e0a328661885be46dfdb/src/transformers/models/llama/convert_llama_weights_to_hf.py#L114
return
w
.
reshape
(
config
.
n_head
,
2
,
config
.
n_embd
//
config
.
n_head
//
2
,
config
.
n_embd
).
transpose
(
1
,
2
).
reshape
(
config
.
n_embd
,
config
.
n_embd
)
return
(
w
.
reshape
(
config
.
n_head
,
2
,
config
.
n_embd
//
config
.
n_head
//
2
,
config
.
n_embd
)
.
transpose
(
1
,
2
)
.
reshape
(
config
.
n_embd
,
config
.
n_embd
)
)
# Attention
for
l
in
range
(
config
.
n_layer
):
Wq
=
state_dict
.
pop
(
f
'
model.layers.
{
l
}
.self_attn.q_proj.weight
'
)
Wk
=
state_dict
.
pop
(
f
'
model.layers.
{
l
}
.self_attn.k_proj.weight
'
)
Wv
=
state_dict
.
pop
(
f
'
model.layers.
{
l
}
.self_attn.v_proj.weight
'
)
state_dict
[
f
'
transformer.layers.
{
l
}
.mixer.Wqkv.weight
'
]
=
torch
.
cat
(
Wq
=
state_dict
.
pop
(
f
"
model.layers.
{
l
}
.self_attn.q_proj.weight
"
)
Wk
=
state_dict
.
pop
(
f
"
model.layers.
{
l
}
.self_attn.k_proj.weight
"
)
Wv
=
state_dict
.
pop
(
f
"
model.layers.
{
l
}
.self_attn.v_proj.weight
"
)
state_dict
[
f
"
transformer.layers.
{
l
}
.mixer.Wqkv.weight
"
]
=
torch
.
cat
(
[
inv_permute
(
Wq
),
inv_permute
(
Wk
),
Wv
],
dim
=
0
)
# We don't store these
state_dict
.
pop
(
f
'
model.layers.
{
l
}
.self_attn.rotary_emb.inv_freq
'
,
None
)
state_dict
.
pop
(
f
"
model.layers.
{
l
}
.self_attn.rotary_emb.inv_freq
"
,
None
)
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
'^model.layers.(\d+).self_attn.o_proj.'
,
r
'transformer.layers.\1.mixer.out_proj.'
,
key
)
return
re
.
sub
(
r
"^model.layers.(\d+).self_attn.o_proj."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
def
config_from_meta_checkpoint
(
checkpoint_path
:
Union
[
str
,
os
.
PathLike
],
model_name
:
str
)
->
LlamaConfig
:
def
config_from_meta_checkpoint
(
checkpoint_path
:
Union
[
str
,
os
.
PathLike
],
model_name
:
str
)
->
LlamaConfig
:
"""Load a LlamaConfig from a checkpoint path."""
with
open
(
Path
(
checkpoint_path
)
/
model_name
/
'
params.json
'
)
as
f
:
with
open
(
Path
(
checkpoint_path
)
/
model_name
/
"
params.json
"
)
as
f
:
params
=
json
.
load
(
f
)
config
=
LlamaConfig
(
hidden_size
=
params
[
'dim'
],
intermediate_size
=
None
,
num_attention_heads
=
params
[
'n_heads'
],
num_hidden_layers
=
params
[
'n_layers'
],
rms_norm_eps
=
params
[
'norm_eps'
])
config
=
LlamaConfig
(
hidden_size
=
params
[
"dim"
],
intermediate_size
=
None
,
num_attention_heads
=
params
[
"n_heads"
],
num_hidden_layers
=
params
[
"n_layers"
],
rms_norm_eps
=
params
[
"norm_eps"
],
)
return
config
def
config_from_hf_checkpoint
(
checkpoint_path
:
Union
[
str
,
os
.
PathLike
],
model_name
:
str
)
->
LlamaConfig
:
return
LlamaConfig
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
'
{
model_name
}
-hf'
/
"config.json"
)
def
config_from_hf_checkpoint
(
checkpoint_path
:
Union
[
str
,
os
.
PathLike
],
model_name
:
str
)
->
LlamaConfig
:
return
LlamaConfig
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
"
{
model_name
}
-hf"
/
"config.json"
)
def
config_from_checkpoint
(
...
...
@@ -182,10 +214,14 @@ def config_from_checkpoint(
return
config_from_hf_checkpoint
(
checkpoint_path
,
model_name
)
def
state_dicts_from_checkpoint
(
checkpoint_path
:
Union
[
str
,
os
.
PathLike
],
model_name
:
str
)
->
list
[
dict
]:
def
state_dicts_from_checkpoint
(
checkpoint_path
:
Union
[
str
,
os
.
PathLike
],
model_name
:
str
)
->
list
[
dict
]:
# Need to sort, otherwise we mess up the ordering and the weights are wrong
return
[
torch
.
load
(
path
,
map_location
=
'cpu'
)
for
path
in
sorted
((
Path
(
checkpoint_path
)
/
model_name
).
glob
(
'consolidated.*.pth'
))]
return
[
torch
.
load
(
path
,
map_location
=
"cpu"
)
for
path
in
sorted
((
Path
(
checkpoint_path
)
/
model_name
).
glob
(
"consolidated.*.pth"
))
]
def
llama_config_to_gpt2_config
(
llama_config
:
LlamaConfig
)
->
GPT2Config
:
...
...
@@ -196,7 +232,7 @@ def llama_config_to_gpt2_config(llama_config: LlamaConfig) -> GPT2Config:
n_layer
=
llama_config
.
num_hidden_layers
,
n_head
=
llama_config
.
num_attention_heads
,
n_inner
=
llama_config
.
intermediate_size
,
activation_function
=
'
swiglu
'
,
# Hardcode since HF calls it 'silu'
activation_function
=
"
swiglu
"
,
# Hardcode since HF calls it 'silu'
# Llama doesn't have dropout, idk if it's because they only release the inference code
resid_pdrop
=
0.0
,
embd_pdrop
=
0.0
,
...
...
flash_attn/models/opt.py
View file @
f1a73d07
...
...
@@ -2,75 +2,86 @@
import
math
import
re
from
collections
import
OrderedDict
import
torch
import
torch.nn.functional
as
F
from
transformers
import
GPT2Config
,
OPTConfig
def
remap_state_dict_hf_opt
(
state_dict
,
config
):
def
key_mapping_model
(
key
):
key
=
re
.
sub
(
r
'
^model.decoder.
'
,
'
transformer.
'
,
key
)
key
=
re
.
sub
(
r
"
^model.decoder.
"
,
"
transformer.
"
,
key
)
# The OPT-350m model uses '^decoder' instead of '^model.decoder'
key
=
re
.
sub
(
r
'
^decoder.
'
,
'
transformer.
'
,
key
)
key
=
re
.
sub
(
r
"
^decoder.
"
,
"
transformer.
"
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_model
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding and position embedding
def
key_mapping_emb
(
key
):
key
=
re
.
sub
(
r
'
^transformer.embed_tokens.
'
,
'
transformer.embeddings.word_embeddings.
'
,
key
)
key
=
re
.
sub
(
r
"
^transformer.embed_tokens.
"
,
"
transformer.embeddings.word_embeddings.
"
,
key
)
# The OPT-350m model uses has project_in and project_out
key
=
re
.
sub
(
r
'^transformer.project_in.'
,
'transformer.embeddings.project_in.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.project_out.'
,
'project_out.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.embed_positions.'
,
'transformer.embeddings.position_embeddings.'
,
key
)
key
=
re
.
sub
(
r
"^transformer.project_in."
,
"transformer.embeddings.project_in."
,
key
)
key
=
re
.
sub
(
r
"^transformer.project_out."
,
"project_out."
,
key
)
key
=
re
.
sub
(
r
"^transformer.embed_positions."
,
"transformer.embeddings.position_embeddings."
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# OPT uses the first 2 indices of pos_emb for padding tokens
pos_embeddings
=
state_dict
.
pop
(
'
transformer.embeddings.position_embeddings.weight
'
)
state_dict
[
'
transformer.embeddings.position_embeddings.weight
'
]
=
pos_embeddings
[
2
:]
word_embeddings
=
state_dict
.
pop
(
'
transformer.embeddings.word_embeddings.weight
'
)
pos_embeddings
=
state_dict
.
pop
(
"
transformer.embeddings.position_embeddings.weight
"
)
state_dict
[
"
transformer.embeddings.position_embeddings.weight
"
]
=
pos_embeddings
[
2
:]
word_embeddings
=
state_dict
.
pop
(
"
transformer.embeddings.word_embeddings.weight
"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
'
pad_vocab_size_multiple
'
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
'
transformer.embeddings.word_embeddings.weight
'
]
=
F
.
pad
(
pad_vocab_size_multiple
=
getattr
(
config
,
"
pad_vocab_size_multiple
"
,
1
)
vocab_size
=
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
state_dict
[
"
transformer.embeddings.word_embeddings.weight
"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
state_dict
[
'
lm_head.weight
'
]
=
state_dict
[
'
transformer.embeddings.word_embeddings.weight
'
]
state_dict
[
"
lm_head.weight
"
]
=
state_dict
[
"
transformer.embeddings.word_embeddings.weight
"
]
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
'
^transformer.final_layer_norm.
'
,
r
'
transformer.ln_f.
'
,
key
)
key
=
re
.
sub
(
r
"
^transformer.final_layer_norm.
"
,
r
"
transformer.ln_f.
"
,
key
)
# The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm'
key
=
re
.
sub
(
r
'^transformer.layer_norm.'
,
r
'transformer.ln_f.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).self_attn_layer_norm.'
,
r
'transformer.layers.\1.norm1.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).final_layer_norm.'
,
r
'transformer.layers.\1.norm2.'
,
key
)
key
=
re
.
sub
(
r
"^transformer.layer_norm."
,
r
"transformer.ln_f."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).self_attn_layer_norm."
,
r
"transformer.layers.\1.norm1."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).final_layer_norm."
,
r
"transformer.layers.\1.norm2."
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
def
key_mapping_mlp
(
key
):
return
re
.
sub
(
r
'^transformer.layers.(\d+).fc(1|2).'
,
r
'transformer.layers.\1.mlp.fc\2.'
,
key
)
return
re
.
sub
(
r
"^transformer.layers.(\d+).fc(1|2)."
,
r
"transformer.layers.\1.mlp.fc\2."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
for
l
in
range
(
config
.
n_layer
):
Wq
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.self_attn.q_proj.weight'
)
Wk
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.self_attn.k_proj.weight'
)
Wv
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.self_attn.v_proj.weight'
)
bq
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.self_attn.q_proj.bias'
)
bk
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.self_attn.k_proj.bias'
)
bv
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.self_attn.v_proj.bias'
)
state_dict
[
f
'transformer.layers.
{
l
}
.mixer.Wqkv.weight'
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
state_dict
[
f
'transformer.layers.
{
l
}
.mixer.Wqkv.bias'
]
=
torch
.
cat
([
bq
,
bk
,
bv
],
dim
=
0
)
Wq
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.q_proj.weight"
)
Wk
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.k_proj.weight"
)
Wv
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.v_proj.weight"
)
bq
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.q_proj.bias"
)
bk
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.k_proj.bias"
)
bv
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.v_proj.bias"
)
state_dict
[
f
"transformer.layers.
{
l
}
.mixer.Wqkv.weight"
]
=
torch
.
cat
([
Wq
,
Wk
,
Wv
],
dim
=
0
)
state_dict
[
f
"transformer.layers.
{
l
}
.mixer.Wqkv.bias"
]
=
torch
.
cat
([
bq
,
bk
,
bv
],
dim
=
0
)
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
'^transformer.layers.(\d+).self_attn.out_proj.'
,
r
'transformer.layers.\1.mixer.out_proj.'
,
key
)
return
re
.
sub
(
r
"^transformer.layers.(\d+).self_attn.out_proj."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
...
...
@@ -79,8 +90,11 @@ def remap_state_dict_hf_opt(state_dict, config):
def
opt_config_to_gpt2_config
(
opt_config
:
OPTConfig
)
->
GPT2Config
:
assert
opt_config
.
layerdrop
==
0.0
assert
opt_config
.
layer_norm_elementwise_affine
word_embed_proj_dim
=
(
None
if
opt_config
.
word_embed_proj_dim
==
opt_config
.
hidden_size
else
opt_config
.
word_embed_proj_dim
)
word_embed_proj_dim
=
(
None
if
opt_config
.
word_embed_proj_dim
==
opt_config
.
hidden_size
else
opt_config
.
word_embed_proj_dim
)
return
GPT2Config
(
vocab_size
=
opt_config
.
vocab_size
,
n_positions
=
opt_config
.
max_position_embeddings
,
...
...
@@ -98,5 +112,5 @@ def opt_config_to_gpt2_config(opt_config: OPTConfig) -> GPT2Config:
eos_token_id
=
opt_config
.
eos_token_id
,
# These are new arguments not in the original GPT2Config
prenorm
=
opt_config
.
do_layer_norm_before
,
word_embed_proj_dim
=
word_embed_proj_dim
word_embed_proj_dim
=
word_embed_proj_dim
,
)
flash_attn/models/vit.py
View file @
f1a73d07
...
...
@@ -10,13 +10,14 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
timm.models.helpers
import
named_apply
from
torch.nn.init
import
trunc_normal_
from
torchvision.ops
import
StochasticDepth
from
flash_attn.layers.patch_embed
import
PatchEmbed
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mlp
import
FusedMLP
,
Mlp
from
timm.models.helpers
import
named_apply
from
torch.nn.init
import
trunc_normal_
from
torchvision.ops
import
StochasticDepth
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm
...
...
flash_attn/modules/block.py
View file @
f1a73d07
This diff is collapsed.
Click to expand it.
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment