Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ff34123b
Commit
ff34123b
authored
Jan 15, 2023
by
Tri Dao
Browse files
Reorder LN in Block, support OPT
parent
f1e01c27
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
345 additions
and
79 deletions
+345
-79
flash_attn/models/bert.py
flash_attn/models/bert.py
+2
-1
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+106
-56
flash_attn/models/opt.py
flash_attn/models/opt.py
+104
-0
flash_attn/modules/block.py
flash_attn/modules/block.py
+41
-19
flash_attn/modules/embedding.py
flash_attn/modules/embedding.py
+14
-3
tests/models/test_gpt.py
tests/models/test_gpt.py
+1
-0
tests/models/test_opt.py
tests/models/test_opt.py
+77
-0
No files found.
flash_attn/models/bert.py
View file @
ff34123b
...
...
@@ -94,7 +94,8 @@ def create_block(config, layer_idx=None):
mlp_cls
=
create_mlp_cls
(
config
,
layer_idx
,
return_residual
=
return_residual
)
norm_cls
=
partial
(
nn
.
LayerNorm
,
eps
=
config
.
layer_norm_eps
)
block
=
Block
(
config
.
hidden_size
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_cls
,
prenorm
=
False
,
resid_dropout
=
config
.
hidden_dropout_prob
,
prenorm
=
False
,
resid_dropout1
=
config
.
hidden_dropout_prob
,
resid_dropout2
=
config
.
hidden_dropout_prob
,
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
),
return_residual
=
return_residual
)
return
block
...
...
flash_attn/models/gpt.py
View file @
ff34123b
# Copyright (c) 202
2
, Tri Dao.
# Copyright (c) 202
3
, Tri Dao.
import
logging
import
math
...
...
@@ -23,6 +23,7 @@ from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings
from
flash_attn.utils.distributed
import
sync_shared_params
,
all_gather_raw
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.generation
import
GenerationMixin
from
flash_attn.models.opt
import
remap_state_dict_opt
try
:
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
...
...
@@ -88,9 +89,12 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
if
process_group
is
not
None
:
assert
fused_dense_gelu_dense
,
'Tensor Parallel is only implemented for FusedDenseGeluDense'
if
not
fused_dense_gelu_dense
and
not
fused_dense_sqrelu_dense
:
if
config
.
activation_function
==
'relu'
:
activation
=
partial
(
F
.
relu
,
inplace
=
True
)
else
:
approximate
=
'tanh'
if
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
]
else
'none'
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
inner_dim
,
activation
=
partial
(
F
.
gelu
,
approximate
=
approximate
)
,
**
factory_kwargs
)
activation
=
partial
(
F
.
gelu
,
approximate
=
approximate
)
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
inner_dim
,
activation
=
activation
,
**
factory_kwargs
)
else
:
mlp_checkpoint_lvl
=
getattr
(
config
,
'mlp_checkpoint_lvl'
,
0
)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
...
...
@@ -121,9 +125,14 @@ def create_block(config, layer_idx=None, process_group=None, device=None, dtype=
mixer_cls
=
create_mixer_cls
(
config
,
layer_idx
,
process_group
=
process_group
,
**
factory_kwargs
)
mlp_cls
=
create_mlp_cls
(
config
,
layer_idx
,
process_group
=
process_group
,
**
factory_kwargs
)
norm_cls
=
partial
(
nn
.
LayerNorm
,
eps
=
config
.
layer_norm_epsilon
,
**
factory_kwargs
)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
residual_in_fp32
=
getattr
(
config
,
'residual_in_fp32'
,
False
)
resid_dropout1
=
config
.
resid_pdrop
if
layer_idx
is
None
or
layer_idx
>
0
else
config
.
embd_pdrop
prenorm
=
getattr
(
config
,
'prenorm'
,
True
)
block
=
Block
(
config
.
hidden_size
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_cls
,
prenorm
=
True
,
resid_dropout
=
config
.
resid_pdrop
,
prenorm
=
prenorm
,
resid_dropout1
=
resid_dropout1
,
resid_dropout
2
=
config
.
resid_pdrop
,
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
),
residual_in_fp32
=
residual_in_fp32
,
sequence_parallel
=
sequence_parallel
and
process_group
is
not
None
,
mark_shared_params
=
process_group
is
not
None
)
block
.
layer_idx
=
layer_idx
...
...
@@ -154,11 +163,16 @@ class GPTPreTrainedModel(nn.Module):
"""
# Instantiate model.
model
=
cls
(
config
,
*
args
,
device
=
device
,
dtype
=
dtype
,
**
kwargs
)
state_dict
=
remap_state_dict_gpt2
(
# If we're going to shard the model, then don't load fp32 weights to GPU.
state_dict_from_pretrained
(
model_name
,
device
=
device
if
world_size
==
1
else
None
,
dtype
=
dtype
),
config
state_dict
=
state_dict_from_pretrained
(
model_name
,
device
=
device
if
world_size
==
1
else
None
,
dtype
=
dtype
)
if
model_name
.
startswith
(
'gpt2'
):
state_dict
=
remap_state_dict_gpt2
(
state_dict
,
config
)
elif
model_name
.
startswith
(
'facebook/opt'
):
state_dict
=
remap_state_dict_opt
(
state_dict
,
config
)
else
:
raise
NotImplementedError
(
f
'Model
{
model_name
}
not supported'
)
if
world_size
>
1
:
state_dict
=
shard_state_dict_tp
(
state_dict
,
config
,
world_size
,
rank
)
state_dict
=
{
k
:
v
.
to
(
device
=
device
)
for
k
,
v
in
state_dict
.
items
()}
...
...
@@ -166,6 +180,7 @@ class GPTPreTrainedModel(nn.Module):
logger
.
info
(
load_return
)
return
model
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def
_init_weights
(
module
,
n_layer
,
initializer_range
=
0.02
,
rescale_prenorm_residual
=
True
):
if
isinstance
(
module
,
nn
.
Linear
):
...
...
@@ -195,47 +210,53 @@ class GPTModel(GPTPreTrainedModel):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
self
.
process_group
=
process_group
self
.
sequence_parallel
=
getattr
(
config
,
'sequence_parallel'
,
True
)
assert
config
.
activation_function
in
[
'gelu'
,
'gelu_new'
,
'gelu_fast'
,
'sqrelu'
]
assert
config
.
activation_function
in
[
'gelu'
,
'gelu_new'
,
'gelu_fast'
,
'relu'
,
'sqrelu'
]
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
self
.
residual_in_fp32
=
getattr
(
config
,
'residual_in_fp32'
,
False
)
# These 2 options are for OPT-350m
self
.
prenorm
=
getattr
(
config
,
'prenorm'
,
True
)
word_embed_proj_dim
=
getattr
(
config
,
'word_embed_proj_dim'
,
None
)
if
process_group
is
None
:
self
.
embeddings
=
GPT2Embeddings
(
config
.
hidden_size
,
vocab_size
,
config
.
max_position_embeddings
,
**
factory_kwargs
)
self
.
embeddings
=
GPT2Embeddings
(
config
.
hidden_size
,
vocab_size
,
config
.
max_position_embeddings
,
word_embed_proj_dim
=
word_embed_proj_dim
,
**
factory_kwargs
)
else
:
self
.
embeddings
=
ParallelGPT2Embeddings
(
config
.
hidden_size
,
vocab_size
,
config
.
max_position_embeddings
,
process_group
=
process_group
,
sequence_parallel
=
self
.
sequence_parallel
,
**
factory_kwargs
)
self
.
emb_drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
# We change the order of residual and layer norm:
# We change the order of
dropout,
residual and layer norm:
# Instead of LN -> Attn / MLP -> Dropout -> Add, we do:
#
Attn / MLP ->
Dropout -> Add -> LN, returning both the residual branch (output of Add) and
# the main branch (output of
LN
). The model definition is unchanged, but the mapping of the
# nn.
LayerNorm weight
s are changed.
# Dropout -> Add -> LN
-> Attn / MLP
, returning both the residual branch (output of Add) and
# the main branch (output of
MLP
). The model definition is unchanged, but the mapping of the
# nn.
Dropout probabilitie
s are changed.
# This is for performance reason: we can fuse dropout + add + layer_norm.
self
.
layers
=
nn
.
ModuleList
([
create_block
(
config
,
layer_idx
=
i
,
process_group
=
process_group
,
**
factory_kwargs
)
for
i
in
range
(
config
.
num_hidden_layers
)])
self
.
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
)
if
self
.
fused_dropout_add_ln
and
dropout_add_layer_norm
is
None
:
raise
ImportError
(
'dropout_add_layer_norm is not installed'
)
# self.ln_0 is the first layer norm in the model, while self.ln_f (in the pretrained weight)
# is the final layer norm.
self
.
ln_
0
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
,
if
self
.
prenorm
:
self
.
drop_f
=
nn
.
Dropout
(
config
.
resid_pdrop
)
self
.
ln_
f
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
,
**
factory_kwargs
)
if
process_group
is
not
None
:
for
p
in
self
.
ln_
0
.
parameters
():
for
p
in
self
.
ln_
f
.
parameters
():
# Mark the norm parameters as "shared_params" so that we sync their values at init.
p
.
_shared_params
=
True
# Mark the norm params as "sequence_parallel" so we run all-reduce on their grads.
if
self
.
sequence_parallel
:
p
.
_sequence_parallel
=
True
self
.
layers
=
nn
.
ModuleList
([
create_block
(
config
,
layer_idx
=
i
,
process_group
=
process_group
,
**
factory_kwargs
)
for
i
in
range
(
config
.
num_hidden_layers
)])
self
.
apply
(
partial
(
_init_weights
,
n_layer
=
config
.
num_hidden_layers
,
initializer_range
=
config
.
initializer_range
))
self
.
tie_weights
()
...
...
@@ -251,23 +272,28 @@ class GPTModel(GPTPreTrainedModel):
embedding_kwargs
=
({
'combine_batch_seqlen_dim'
:
True
}
if
self
.
process_group
is
not
None
and
self
.
sequence_parallel
else
{})
hidden_states
=
self
.
embeddings
(
input_ids
,
position_ids
=
position_ids
,
**
embedding_kwargs
)
# TD [2022-07-30]: Force residual in fp32, seems to make fp16 training more stable
if
not
self
.
fused_dropout_add_ln
:
residual
=
self
.
emb_drop
(
hidden_states
)
hidden_states
=
self
.
ln_0
(
residual
.
to
(
dtype
=
self
.
ln_0
.
weight
.
dtype
))
residual
=
residual
.
float
()
else
:
hidden_states
,
residual
=
dropout_add_layer_norm
(
hidden_states
,
None
,
self
.
ln_0
.
weight
,
self
.
ln_0
.
bias
,
self
.
emb_drop
.
p
if
self
.
training
else
0.0
,
self
.
ln_0
.
eps
,
prenorm
=
True
,
residual_in_fp32
=
True
)
residual
=
None
mixer_kwargs
=
({
'seqlen'
:
input_ids
.
shape
[
1
]}
if
self
.
process_group
is
not
None
and
self
.
sequence_parallel
else
{})
if
inference_params
is
not
None
:
mixer_kwargs
[
'inference_params'
]
=
inference_params
for
layer
in
self
.
layers
:
if
self
.
prenorm
:
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
,
mixer_kwargs
=
mixer_kwargs
)
else
:
hidden_states
=
layer
(
hidden_states
,
mixer_kwargs
=
mixer_kwargs
)
if
self
.
prenorm
:
if
not
self
.
fused_dropout_add_ln
:
dropped
=
self
.
drop_f
(
hidden_states
)
residual
=
(
dropped
+
residual
)
if
residual
is
not
None
else
dropped
hidden_states
=
self
.
ln_f
(
residual
.
to
(
dtype
=
self
.
ln_f
.
weight
.
dtype
))
else
:
# Set prenorm=False here since we don't need to the residual
hidden_states
=
dropout_add_layer_norm
(
hidden_states
,
residual
,
self
.
ln_f
.
weight
,
self
.
ln_f
.
bias
,
self
.
drop_f
.
p
if
self
.
training
else
0.0
,
self
.
ln_f
.
eps
,
prenorm
=
False
,
residual_in_fp32
=
self
.
residual_in_fp32
)
return
hidden_states
...
...
@@ -281,13 +307,20 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
# This option is for OPT-350m
word_embed_proj_dim
=
getattr
(
config
,
'word_embed_proj_dim'
,
None
)
embed_dim
=
config
.
n_embd
if
word_embed_proj_dim
is
None
else
word_embed_proj_dim
if
word_embed_proj_dim
is
not
None
:
self
.
project_out
=
nn
.
Linear
(
config
.
n_embd
,
embed_dim
,
bias
=
False
,
**
factory_kwargs
)
else
:
self
.
project_out
=
None
if
process_group
is
None
:
self
.
lm_head
=
nn
.
Linear
(
config
.
n_embd
,
vocab_size
,
bias
=
False
,
**
factory_kwargs
)
self
.
lm_head
=
nn
.
Linear
(
embed_dim
,
vocab_size
,
bias
=
False
,
**
factory_kwargs
)
else
:
if
ColumnParallelLinear
is
None
:
raise
ImportError
(
'fused_dense_lib is not installed'
)
self
.
lm_head
=
ColumnParallelLinear
(
config
.
n_embd
,
vocab_size
,
process_group
,
bias
=
False
,
embed_dim
,
vocab_size
,
process_group
,
bias
=
False
,
sequence_parallel
=
getattr
(
config
,
'sequence_parallel'
,
True
),
**
factory_kwargs
)
# Initialize weights and apply final processing
...
...
@@ -307,6 +340,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
"""
hidden_states
=
self
.
transformer
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
)
if
self
.
project_out
is
not
None
:
hidden_states
=
self
.
project_out
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
# During inference, we want the full logit for sampling
if
isinstance
(
self
.
lm_head
,
ColumnParallelLinear
)
and
inference_params
is
not
None
:
...
...
@@ -315,6 +350,32 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
CausalLMOutput
=
namedtuple
(
'CausalLMOutput'
,
[
'logits'
])
return
CausalLMOutput
(
logits
=
lm_logits
)
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
# Remapping from our checkpoints that used a different ordering of layers in the block
# Previous: Attn / MLP -> Dropout -> Add -> LN
# Current: Dropout -> Add -> LN -> Attn / MLP
if
'transformer.ln_0.weight'
in
state_dict
:
n_layers
=
self
.
config
.
num_hidden_layers
ln_weight
=
state_dict
.
pop
(
f
'transformer.layers.
{
n_layers
-
1
}
.norm2.weight'
)
ln_bias
=
state_dict
.
pop
(
f
'transformer.layers.
{
n_layers
-
1
}
.norm2.bias'
)
state_dict
[
'transformer.ln_f.weight'
]
=
ln_weight
state_dict
[
'transformer.ln_f.bias'
]
=
ln_bias
for
l
in
reversed
(
range
(
n_layers
)):
ln_weight
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.norm1.weight'
)
ln_bias
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.norm1.bias'
)
state_dict
[
f
'transformer.layers.
{
l
}
.norm2.weight'
]
=
ln_weight
state_dict
[
f
'transformer.layers.
{
l
}
.norm2.bias'
]
=
ln_bias
if
l
>
0
:
ln_weight
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
-
1
}
.norm2.weight'
)
ln_bias
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
-
1
}
.norm2.bias'
)
state_dict
[
f
'transformer.layers.
{
l
}
.norm1.weight'
]
=
ln_weight
state_dict
[
f
'transformer.layers.
{
l
}
.norm1.bias'
]
=
ln_bias
ln_weight
=
state_dict
.
pop
(
'transformer.ln_0.weight'
)
ln_bias
=
state_dict
.
pop
(
'transformer.ln_0.bias'
)
state_dict
[
f
'transformer.layers.0.norm1.weight'
]
=
ln_weight
state_dict
[
f
'transformer.layers.0.norm1.bias'
]
=
ln_bias
return
super
().
load_state_dict
(
state_dict
,
strict
=
strict
)
def
remap_state_dict_gpt2
(
state_dict
,
config
):
# Word embedding and position embedding
...
...
@@ -331,22 +392,11 @@ def remap_state_dict_gpt2(state_dict, config):
state_dict
[
'lm_head.weight'
]
=
state_dict
[
'transformer.embeddings.word_embeddings.weight'
]
# LayerNorm
ln_weight
,
ln_bias
=
state_dict
.
pop
(
'ln_f.weight'
),
state_dict
.
pop
(
'ln_f.bias'
)
state_dict
[
f
'transformer.layers.
{
config
.
num_hidden_layers
-
1
}
.norm2.weight'
]
=
ln_weight
state_dict
[
f
'transformer.layers.
{
config
.
num_hidden_layers
-
1
}
.norm2.bias'
]
=
ln_bias
ln_weight
,
ln_bias
=
state_dict
.
pop
(
'h.0.ln_1.weight'
),
state_dict
.
pop
(
'h.0.ln_1.bias'
)
state_dict
[
'transformer.ln_0.weight'
]
=
ln_weight
state_dict
[
'transformer.ln_0.bias'
]
=
ln_bias
for
d
in
range
(
config
.
num_hidden_layers
):
ln_weight
=
state_dict
.
pop
(
f
'h.
{
d
}
.ln_2.weight'
)
ln_bias
=
state_dict
.
pop
(
f
'h.
{
d
}
.ln_2.bias'
)
state_dict
[
f
'transformer.layers.
{
d
}
.norm1.weight'
]
=
ln_weight
state_dict
[
f
'transformer.layers.
{
d
}
.norm1.bias'
]
=
ln_bias
if
d
>
0
:
ln_weight
=
state_dict
.
pop
(
f
'h.
{
d
}
.ln_1.weight'
)
ln_bias
=
state_dict
.
pop
(
f
'h.
{
d
}
.ln_1.bias'
)
state_dict
[
f
'transformer.layers.
{
d
-
1
}
.norm2.weight'
]
=
ln_weight
state_dict
[
f
'transformer.layers.
{
d
-
1
}
.norm2.bias'
]
=
ln_bias
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
'^ln_f.(weight|bias)'
,
r
'transformer.ln_f.\1'
,
key
)
key
=
re
.
sub
(
r
'^h.(\d+).ln_(1|2).(weight|bias)'
,
r
'transformer.layers.\1.norm\2.\3'
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
for
d
in
range
(
config
.
num_hidden_layers
):
...
...
flash_attn/models/opt.py
0 → 100644
View file @
ff34123b
# Copyright (c) 2023, Tri Dao.
import
math
import
re
from
collections
import
OrderedDict
import
torch
import
torch.nn.functional
as
F
from
transformers
import
GPT2Config
,
OPTConfig
def
remap_state_dict_opt
(
state_dict
,
config
):
def
key_mapping_model
(
key
):
key
=
re
.
sub
(
r
'^model.decoder.'
,
'transformer.'
,
key
)
# The OPT-350m model uses '^decoder' instead of '^model.decoder'
key
=
re
.
sub
(
r
'^decoder.'
,
'transformer.'
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_model
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding and position embedding
def
key_mapping_emb
(
key
):
key
=
re
.
sub
(
r
'^transformer.embed_tokens.'
,
'transformer.embeddings.word_embeddings.'
,
key
)
# The OPT-350m model uses has project_in and project_out
key
=
re
.
sub
(
r
'^transformer.project_in.'
,
'transformer.embeddings.project_in.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.project_out.'
,
'project_out.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.embed_positions.'
,
'transformer.embeddings.position_embeddings.'
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# OPT uses the first 2 indices of pos_emb for padding tokens
pos_embeddings
=
state_dict
.
pop
(
'transformer.embeddings.position_embeddings.weight'
)
state_dict
[
'transformer.embeddings.position_embeddings.weight'
]
=
pos_embeddings
[
2
:]
word_embeddings
=
state_dict
.
pop
(
'transformer.embeddings.word_embeddings.weight'
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
vocab_size
=
(
math
.
ceil
(
config
.
vocab_size
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
'transformer.embeddings.word_embeddings.weight'
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
state_dict
[
'lm_head.weight'
]
=
state_dict
[
'transformer.embeddings.word_embeddings.weight'
]
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
'^transformer.final_layer_norm.'
,
r
'transformer.ln_f.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).self_attn_layer_norm.'
,
r
'transformer.layers.\1.norm1.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).final_layer_norm.'
,
r
'transformer.layers.\1.norm2.'
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
def
key_mapping_mlp
(
key
):
return
re
.
sub
(
r
'^transformer.layers.(\d+).fc(1|2).'
,
r
'transformer.layers.\1.mlp.fc\2.'
,
key
)
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
for
l
in
range
(
config
.
n_layer
):
Wq
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.self_attn.q_proj.weight'
)
Wk
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.self_attn.k_proj.weight'
)
Wv
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.self_attn.v_proj.weight'
)
bq
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.self_attn.q_proj.bias'
)
bk
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.self_attn.k_proj.bias'
)
bv
=
state_dict
.
pop
(
f
'transformer.layers.
{
l
}
.self_attn.v_proj.bias'
)
state_dict
[
f
'transformer.layers.
{
l
}
.mixer.Wqkv.weight'
]
=
torch
.
cat
(
[
Wq
,
Wk
,
Wv
],
dim
=
0
)
state_dict
[
f
'transformer.layers.
{
l
}
.mixer.Wqkv.bias'
]
=
torch
.
cat
(
[
bq
,
bk
,
bv
],
dim
=
0
)
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
'^transformer.layers.(\d+).self_attn.out_proj.'
,
r
'transformer.layers.\1.mixer.out_proj.'
,
key
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
def
opt_config_to_gpt2_config
(
opt_config
:
OPTConfig
)
->
GPT2Config
:
assert
opt_config
.
layerdrop
==
0.0
assert
opt_config
.
layer_norm_elementwise_affine
word_embed_proj_dim
=
(
None
if
opt_config
.
word_embed_proj_dim
==
opt_config
.
hidden_size
else
opt_config
.
word_embed_proj_dim
)
return
GPT2Config
(
vocab_size
=
opt_config
.
vocab_size
,
n_positions
=
opt_config
.
max_position_embeddings
,
n_embd
=
opt_config
.
hidden_size
,
n_layer
=
opt_config
.
num_hidden_layers
,
n_head
=
opt_config
.
num_attention_heads
,
n_inner
=
opt_config
.
ffn_dim
,
activation_function
=
opt_config
.
activation_function
,
resid_pdrop
=
opt_config
.
dropout
,
# HF's implementation of OPT doesn't seem to have embedding dropout
embd_pdrop
=
opt_config
.
dropout
,
attn_pdrop
=
opt_config
.
attention_dropout
,
initializer_range
=
opt_config
.
init_std
,
bos_token_id
=
opt_config
.
bos_token_id
,
eos_token_id
=
opt_config
.
eos_token_id
,
# These are new arguments not in the original GPT2Config
prenorm
=
opt_config
.
do_layer_norm_before
,
word_embed_proj_dim
=
word_embed_proj_dim
)
flash_attn/modules/block.py
View file @
ff34123b
...
...
@@ -22,10 +22,22 @@ except ImportError:
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
mixer_cls
=
None
,
mlp_cls
=
None
,
norm_cls
=
nn
.
LayerNorm
,
dropout_cls
=
nn
.
Dropout
,
prenorm
=
True
,
resid_dropout
=
0.
,
drop_path
=
0.
,
fused_dropout_add_ln
=
False
,
return_residual
=
False
,
sequence_parallel
=
False
,
mark_shared_params
=
False
):
dropout_cls
=
nn
.
Dropout
,
prenorm
=
True
,
resid_dropout
1
=
0.
,
resid_dropout2
=
0.
,
drop_path1
=
0.
,
drop_path2
=
0.
,
fused_dropout_add_ln
=
False
,
return_residual
=
False
,
residual_in_fp32
=
False
,
sequence_parallel
=
False
,
mark_shared_params
=
False
):
"""
For prenorm=True, this Block has a slightly different structure compared to a regular
prenorm Transformer block.
The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add.
[Ref: https://arxiv.org/abs/2002.04745]
Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both
the hidden_states (output of the MLP) and the residual.
This is for performance reasons, as we can fuse the dropout, add and LayerNorm.
The residual needs to be provided (except for the very first block).
For prenorm=False, this Block has the same structure as a regular postnorm Transformer
block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN.
return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
This is for performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
...
...
@@ -34,18 +46,21 @@ class Block(nn.Module):
self
.
prenorm
=
prenorm
self
.
fused_dropout_add_ln
=
fused_dropout_add_ln
self
.
return_residual
=
return_residual
self
.
residual_in_fp32
=
residual_in_fp32
if
self
.
residual_in_fp32
:
assert
self
.
prenorm
,
'residual_in_fp32 is only compatible with prenorm=True'
if
mixer_cls
is
None
:
mixer_cls
=
partial
(
MHA
,
num_heads
=
dim
//
64
)
if
mlp_cls
is
None
:
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
4
*
dim
)
self
.
mixer
=
mixer_cls
(
dim
)
self
.
dropout1
=
dropout_cls
(
resid_dropout
)
self
.
drop_path1
=
StochasticDepth
(
drop_path
,
mode
=
'row'
)
self
.
dropout1
=
dropout_cls
(
resid_dropout
1
)
self
.
drop_path1
=
StochasticDepth
(
drop_path
1
,
mode
=
'row'
)
self
.
norm1
=
norm_cls
(
dim
)
self
.
mlp
=
mlp_cls
(
dim
)
if
not
isinstance
(
self
.
mlp
,
nn
.
Identity
):
self
.
dropout2
=
dropout_cls
(
resid_dropout
)
self
.
drop_path2
=
StochasticDepth
(
drop_path
,
mode
=
'row'
)
self
.
dropout2
=
dropout_cls
(
resid_dropout
2
)
self
.
drop_path2
=
StochasticDepth
(
drop_path
2
,
mode
=
'row'
)
self
.
norm2
=
norm_cls
(
dim
)
if
self
.
fused_dropout_add_ln
:
...
...
@@ -82,41 +97,48 @@ class Block(nn.Module):
residual: if postnorm, residual=None, If prenorm, hidden_states = LayerNorm(residual)
"""
if
self
.
prenorm
:
assert
residual
is
not
None
mixer_out
=
self
.
mixer
(
hidden_states
,
**
(
mixer_kwargs
if
mixer_kwargs
is
not
None
else
{}))
if
not
self
.
fused_dropout_add_ln
:
residual
=
self
.
drop_path1
(
self
.
dropout1
(
mixer_out
))
+
residual
dropped
=
self
.
drop_path1
(
self
.
dropout1
(
hidden_states
))
residual
=
(
dropped
+
residual
)
if
residual
is
not
None
else
dropped
hidden_states
=
self
.
norm1
(
residual
.
to
(
dtype
=
self
.
norm1
.
weight
.
dtype
))
if
self
.
residual_in_fp32
:
residual
=
residual
.
to
(
torch
.
float32
)
else
:
if
self
.
drop_path1
.
p
==
0
or
not
self
.
training
:
rowscale1
=
None
else
:
rowscale1
=
self
.
drop_path1
(
torch
.
ones
(
mixer_out
.
shape
[:
-
1
],
device
=
mixer_out
.
device
,
dtype
=
mixer_out
.
dtype
)
hidden_states
.
shape
[:
-
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
)
hidden_states
,
residual
=
dropout_add_layer_norm
(
mixer_out
,
residual
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
hidden_states
,
residual
,
self
.
norm1
.
weight
,
self
.
norm1
.
bias
,
self
.
dropout1
.
p
if
self
.
training
else
0.0
,
self
.
norm1
.
eps
,
rowscale
=
rowscale1
,
prenorm
=
True
rowscale
=
rowscale1
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
)
hidden_states
=
self
.
mixer
(
hidden_states
,
**
(
mixer_kwargs
if
mixer_kwargs
is
not
None
else
{}))
if
not
isinstance
(
self
.
mlp
,
nn
.
Identity
):
mlp_out
=
self
.
mlp
(
hidden_states
)
if
not
self
.
fused_dropout_add_ln
:
residual
=
self
.
drop_path2
(
self
.
dropout2
(
mlp_out
))
+
residual
dropped
=
self
.
drop_path2
(
self
.
dropout2
(
hidden_states
))
residual
=
(
dropped
+
residual
)
if
residual
is
not
None
else
dropped
hidden_states
=
self
.
norm2
(
residual
.
to
(
dtype
=
self
.
norm2
.
weight
.
dtype
))
if
self
.
residual_in_fp32
:
residual
=
residual
.
to
(
torch
.
float32
)
else
:
if
self
.
drop_path2
.
p
==
0
or
not
self
.
training
:
rowscale2
=
None
else
:
rowscale2
=
self
.
drop_path2
(
torch
.
ones
(
mlp_out
.
shape
[:
-
1
],
device
=
mlp_out
.
device
,
dtype
=
mlp_out
.
dtype
)
hidden_states
.
shape
[:
-
1
],
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
)
hidden_states
,
residual
=
dropout_add_layer_norm
(
mlp_out
,
residual
,
self
.
norm2
.
weight
,
self
.
norm2
.
bias
,
hidden_states
,
residual
,
self
.
norm2
.
weight
,
self
.
norm2
.
bias
,
self
.
dropout2
.
p
if
self
.
training
else
0.0
,
self
.
norm2
.
eps
,
rowscale
=
rowscale2
,
prenorm
=
True
rowscale
=
rowscale2
,
prenorm
=
True
,
residual_in_fp32
=
self
.
residual_in_fp32
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
else
:
assert
residual
is
None
...
...
flash_attn/modules/embedding.py
View file @
ff34123b
...
...
@@ -12,14 +12,23 @@ from flash_attn.utils.distributed import reduce_scatter, all_reduce
class
GPT2Embeddings
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
padding_idx
=
None
,
device
=
None
,
dtype
=
None
):
word_embed_proj_dim
=
None
,
device
=
None
,
dtype
=
None
):
"""
If max_position_embeddings <= 0, there's no position embeddings
If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension
the project up to embed_dim
"""
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
if
word_embed_proj_dim
is
None
:
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
,
**
factory_kwargs
)
self
.
project_in
=
None
else
:
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
word_embed_proj_dim
,
padding_idx
=
padding_idx
,
**
factory_kwargs
)
self
.
project_in
=
nn
.
Linear
(
word_embed_proj_dim
,
embed_dim
,
bias
=
False
,
**
factory_kwargs
)
self
.
max_position_embeddings
=
max_position_embeddings
if
self
.
max_position_embeddings
>
0
:
self
.
position_embeddings
=
nn
.
Embedding
(
max_position_embeddings
,
embed_dim
,
...
...
@@ -32,6 +41,8 @@ class GPT2Embeddings(nn.Module):
"""
batch_size
,
seqlen
=
input_ids
.
shape
embeddings
=
self
.
word_embeddings
(
input_ids
)
if
self
.
project_in
is
not
None
:
embeddings
=
self
.
project_in
(
embeddings
)
if
self
.
max_position_embeddings
>
0
:
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
...
...
tests/models/test_gpt.py
View file @
ff34123b
...
...
@@ -84,6 +84,7 @@ def test_gpt2_optimized(model_name):
config
.
fused_bias_fc
=
True
config
.
fused_dense_gelu_dense
=
True
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
config
.
pad_vocab_size_multiple
=
8
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
)
...
...
tests/models/test_opt.py
0 → 100644
View file @
ff34123b
import
re
import
torch
import
pytest
from
transformers
import
OPTConfig
from
transformers.models.opt.modeling_opt
import
OPTForCausalLM
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.opt
import
remap_state_dict_opt
,
opt_config_to_gpt2_config
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"facebook/opt-125m"
,
"facebook/opt-350m"
,
"facebook/opt-1.3b"
])
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def
test_opt_state_dict
(
model_name
):
config
=
opt_config_to_gpt2_config
(
OPTConfig
.
from_pretrained
(
model_name
))
pretrained_state_dict
=
remap_state_dict_opt
(
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
)
state_dict
=
model
.
state_dict
()
assert
state_dict
.
keys
()
==
pretrained_state_dict
.
keys
()
for
k
in
state_dict
.
keys
():
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"facebook/opt-125m"
,
"facebook/opt-350m"
,
"facebook/opt-1.3b"
])
# @pytest.mark.parametrize('model_name', ["facebook/opt-350m"])
def
test_opt_optimized
(
model_name
):
"""Check that our implementation of OPT (without any optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype
=
torch
.
float16
device
=
'cuda'
config
=
opt_config_to_gpt2_config
(
OPTConfig
.
from_pretrained
(
model_name
))
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_dropout_add_ln
=
True
# Only prenorm supports residual_in_fp32
config
.
residual_in_fp32
=
getattr
(
config
,
'prenorm'
,
True
)
config
.
pad_vocab_size_multiple
=
8
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
,
device
=
device
,
dtype
=
dtype
)
model_ref
=
OPTForCausalLM
.
from_pretrained
(
model_name
).
to
(
device
=
device
)
model_hf
=
OPTForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
dtype
).
to
(
device
=
device
)
model
.
eval
()
model_ref
.
eval
()
model_hf
.
eval
()
torch
.
manual_seed
(
0
)
batch_size
=
2
max_seqlen
=
256
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
'cuda'
)
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
'cuda'
)
if
model_name
!=
'facebook/opt-350m'
:
# The OPT-350m projects the embeddings to dimension 512
out
=
model
.
transformer
(
input_ids
)
out_hf
=
model_hf
.
model
(
input_ids
).
last_hidden_state
out_ref
=
model_ref
.
model
(
input_ids
).
last_hidden_state
print
(
f
'Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
out_hf
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
3
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
logits
=
model
(
input_ids
).
logits
logits_hf
=
model_hf
(
input_ids
).
logits
logits_ref
=
model_ref
(
input_ids
).
logits
print
(
f
'Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
logits_hf
-
logits_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
3
*
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment