Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
4c91621a
Unverified
Commit
4c91621a
authored
Sep 09, 2023
by
Kevin Hu
Committed by
GitHub
Sep 09, 2023
Browse files
Inverse state dict for BERT (#527)
parent
a86442f0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
171 additions
and
16 deletions
+171
-16
.gitignore
.gitignore
+3
-0
flash_attn/models/bert.py
flash_attn/models/bert.py
+142
-12
tests/models/test_bert.py
tests/models/test_bert.py
+26
-4
No files found.
.gitignore
View file @
4c91621a
...
@@ -22,3 +22,6 @@ var/
...
@@ -22,3 +22,6 @@ var/
# IDE-related
# IDE-related
.idea/
.idea/
# Dev
venv
\ No newline at end of file
flash_attn/models/bert.py
View file @
4c91621a
...
@@ -10,23 +10,19 @@ import re
...
@@ -10,23 +10,19 @@ import re
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
functools
import
partial
from
functools
import
partial
from
typing
import
Any
,
Mapping
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
einops
import
rearrange
from
transformers
import
BertConfig
from
transformers
import
BertConfig
,
PretrainedConfig
from
transformers.models.bert.modeling_bert
import
(
from
transformers.models.bert.modeling_bert
import
(
BaseModelOutputWithPoolingAndCrossAttentions
,
BaseModelOutputWithPoolingAndCrossAttentions
,
BertForPreTrainingOutput
)
BertForPreTrainingOutput
,
)
from
flash_attn.bert_padding
import
(
index_first_axis
,
index_first_axis_residual
,
pad_input
,
from
flash_attn.bert_padding
import
(
unpad_input
)
index_first_axis
,
index_first_axis_residual
,
pad_input
,
unpad_input
,
)
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.embedding
import
BertEmbeddings
from
flash_attn.modules.embedding
import
BertEmbeddings
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mha
import
MHA
...
@@ -511,7 +507,11 @@ class BertForPreTraining(BertPreTrainedModel):
...
@@ -511,7 +507,11 @@ class BertForPreTraining(BertPreTrainedModel):
)
)
def
remap_state_dict
(
state_dict
,
config
):
def
remap_state_dict
(
state_dict
,
config
:
PretrainedConfig
):
"""
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
"""
# LayerNorm
# LayerNorm
def
key_mapping_ln_gamma_beta
(
key
):
def
key_mapping_ln_gamma_beta
(
key
):
key
=
re
.
sub
(
r
"LayerNorm.gamma$"
,
"LayerNorm.weight"
,
key
)
key
=
re
.
sub
(
r
"LayerNorm.gamma$"
,
"LayerNorm.weight"
,
key
)
...
@@ -618,3 +618,133 @@ def remap_state_dict(state_dict, config):
...
@@ -618,3 +618,133 @@ def remap_state_dict(state_dict, config):
)
)
return
state_dict
return
state_dict
def
inv_remap_state_dict
(
state_dict
,
config
:
PretrainedConfig
):
"""
Map the state_dict of a flash_attn model to be Huggingface BERT compatible.
This function is meant to be the inverse of remap_state_dict.
"""
# Word embedding
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
if
pad_vocab_size_multiple
>
1
:
word_embeddings
=
state_dict
[
"bert.embeddings.word_embeddings.weight"
]
decoder_weight
=
state_dict
[
"cls.predictions.decoder.weight"
]
decoder_bias
=
state_dict
[
"cls.predictions.decoder.bias"
]
# unpad embeddings
state_dict
[
"bert.embeddings.word_embeddings.weight"
]
=
word_embeddings
[
:
config
.
orig_vocab_size
,
:
]
state_dict
[
"cls.predictions.decoder.weight"
]
=
decoder_weight
[:
config
.
orig_vocab_size
,
:]
state_dict
[
"cls.predictions.decoder.bias"
]
=
decoder_bias
[:
config
.
orig_vocab_size
]
for
d
in
range
(
config
.
num_hidden_layers
):
last_layer_subset
=
getattr
(
config
,
"last_layer_subset"
,
False
)
if
not
last_layer_subset
or
d
!=
(
config
.
num_hidden_layers
-
1
):
Wqkv_weights
=
state_dict
.
pop
(
f
"bert.encoder.layers.
{
d
}
.mixer.Wqkv.weight"
)
Wqkv_biases
=
state_dict
.
pop
(
f
"bert.encoder.layers.
{
d
}
.mixer.Wqkv.bias"
)
state_dict
[
f
"bert.encoder.layers.
{
d
}
.attention.self.query.weight"
]
=
Wqkv_weights
[
:
Wqkv_weights
.
shape
[
0
]
//
3
,
:
]
state_dict
[
f
"bert.encoder.layers.
{
d
}
.attention.self.key.weight"
]
=
Wqkv_weights
[
Wqkv_weights
.
shape
[
0
]
//
3
:
2
*
Wqkv_weights
.
shape
[
0
]
//
3
,
:
]
state_dict
[
f
"bert.encoder.layers.
{
d
}
.attention.self.value.weight"
]
=
Wqkv_weights
[
2
*
Wqkv_weights
.
shape
[
0
]
//
3
:,
:
]
state_dict
[
f
"bert.encoder.layers.
{
d
}
.attention.self.query.bias"
]
=
Wqkv_biases
[
:
Wqkv_biases
.
shape
[
0
]
//
3
]
state_dict
[
f
"bert.encoder.layers.
{
d
}
.attention.self.key.bias"
]
=
Wqkv_biases
[
Wqkv_biases
.
shape
[
0
]
//
3
:
2
*
Wqkv_biases
.
shape
[
0
]
//
3
]
state_dict
[
f
"bert.encoder.layers.
{
d
}
.attention.self.value.bias"
]
=
Wqkv_biases
[
2
*
Wqkv_biases
.
shape
[
0
]
//
3
:
]
else
:
Wq_weight
=
state_dict
.
pop
(
f
"bert.encoder.layers.
{
d
}
.mixer.Wq.weight"
)
Wkv_weights
=
state_dict
.
pop
(
f
"bert.encoder.layers.
{
d
}
.mixer.Wkv.weight"
)
Wq_bias
=
state_dict
.
pop
(
f
"bert.encoder.layers.
{
d
}
.mixer.Wq.bias"
)
Wkv_biases
=
state_dict
.
pop
(
f
"bert.encoder.layers.
{
d
}
.mixer.Wkv.bias"
)
state_dict
[
f
"bert.encoder.layers.
{
d
}
.attention.self.query.weight"
]
=
Wq_weight
state_dict
[
f
"bert.encoder.layers.
{
d
}
.attention.self.key.weight"
]
=
Wkv_weights
[
:
Wkv_weights
.
shape
[
0
]
//
2
,
:
]
state_dict
[
f
"bert.encoder.layers.
{
d
}
.attention.self.value.weight"
]
=
Wkv_weights
[
Wkv_weights
.
shape
[
0
]
//
2
:,
:
]
state_dict
[
f
"bert.encoder.layers.
{
d
}
.attention.self.query.bias"
]
=
Wq_bias
state_dict
[
f
"bert.encoder.layers.
{
d
}
.attention.self.key.bias"
]
=
Wkv_biases
[
:
Wkv_biases
.
shape
[
0
]
//
2
]
state_dict
[
f
"bert.encoder.layers.
{
d
}
.attention.self.value.bias"
]
=
Wkv_biases
[
Wkv_biases
.
shape
[
0
]
//
2
:
]
def
inv_key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
"bert.emb_ln."
,
"bert.embeddings.LayerNorm."
,
key
)
key
=
re
.
sub
(
r
"bert.encoder.layers.(\d+).norm1.(weight|bias)"
,
r
"bert.encoder.layers.\1.attention.output.LayerNorm.\2"
,
key
,
)
key
=
re
.
sub
(
r
"bert.encoder.layers.(\d+).norm2.(weight|bias)"
,
r
"bert.encoder.layers.\1.output.LayerNorm.\2"
,
key
,
)
key
=
re
.
sub
(
r
"cls.predictions.transform.layer_norm.(weight|bias)"
,
r
"cls.predictions.transform.LayerNorm.\1"
,
key
,
)
return
key
def
inv_key_mapping_ln_gamma_beta
(
key
):
key
=
re
.
sub
(
r
"LayerNorm.weight$"
,
"LayerNorm.gamma"
,
key
)
key
=
re
.
sub
(
r
"LayerNorm.bias$"
,
"LayerNorm.beta"
,
key
)
return
key
def
inv_key_mapping_layers
(
key
):
return
re
.
sub
(
r
"bert.encoder.layers."
,
"bert.encoder.layer."
,
key
)
def
inv_key_mapping_mlp
(
key
):
key
=
re
.
sub
(
r
"bert.encoder.layer.(\d+).mlp.fc1.(weight|bias)"
,
r
"bert.encoder.layer.\1.intermediate.dense.\2"
,
key
,
)
key
=
re
.
sub
(
r
"bert.encoder.layer.(\d+).mlp.fc2.(weight|bias)"
,
r
"bert.encoder.layer.\1.output.dense.\2"
,
key
,
)
return
key
def
inv_key_mapping_attn
(
key
):
return
re
.
sub
(
r
"bert.encoder.layer.(\d+).mixer.out_proj.(weight|bias)"
,
r
"bert.encoder.layer.\1.attention.output.dense.\2"
,
key
,
)
def
inv_key_mapping_decoder_bias
(
key
):
return
re
.
sub
(
r
"cls.predictions.decoder.bias"
,
"cls.predictions.bias"
,
key
)
state_dict
=
OrderedDict
((
inv_key_mapping_ln
(
key
),
value
)
for
key
,
value
in
state_dict
.
items
())
state_dict
=
OrderedDict
(
(
inv_key_mapping_ln_gamma_beta
(
key
),
value
)
for
key
,
value
in
state_dict
.
items
()
)
state_dict
=
OrderedDict
(
(
inv_key_mapping_layers
(
key
),
value
)
for
key
,
value
in
state_dict
.
items
()
)
state_dict
=
OrderedDict
((
inv_key_mapping_mlp
(
key
),
value
)
for
key
,
value
in
state_dict
.
items
())
state_dict
=
OrderedDict
(
(
inv_key_mapping_attn
(
key
),
value
)
for
key
,
value
in
state_dict
.
items
()
)
state_dict
=
OrderedDict
(
(
inv_key_mapping_decoder_bias
(
key
),
value
)
for
key
,
value
in
state_dict
.
items
()
)
return
state_dict
tests/models/test_bert.py
View file @
4c91621a
...
@@ -5,12 +5,15 @@ import pytest
...
@@ -5,12 +5,15 @@ import pytest
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
einops
import
rearrange
from
flash_attn.models.bert
import
BertForPreTraining
,
BertModel
,
remap_state_dict
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
transformers
import
BertConfig
from
transformers
import
BertConfig
from
transformers.models.bert.modeling_bert
import
BertForPreTraining
as
BertForPreTrainingHF
from
transformers.models.bert.modeling_bert
import
\
BertForPreTraining
as
BertForPreTrainingHF
from
transformers.models.bert.modeling_bert
import
BertModel
as
BertModelHF
from
transformers.models.bert.modeling_bert
import
BertModel
as
BertModelHF
from
flash_attn.models.bert
import
(
BertForPreTraining
,
BertModel
,
inv_remap_state_dict
,
remap_state_dict
)
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"bert-base-uncased"
,
"bert-large-uncased"
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"bert-base-uncased"
,
"bert-large-uncased"
])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
...
@@ -43,7 +46,7 @@ def get_hf_models(model_name, config, dtype):
...
@@ -43,7 +46,7 @@ def get_hf_models(model_name, config, dtype):
return
model_hf
return
model_hf
@
pytest
.
mark
.
parametrize
(
'
model_name
'
,
[
"bert-base-uncased"
])
@
pytest
.
mark
.
parametrize
(
"
model_name
"
,
[
"bert-base-uncased"
])
def
test_bert_non_optimized
(
model_name
):
def
test_bert_non_optimized
(
model_name
):
"""Check that our implementation of BERT (without any optimizations enabled) matches the
"""Check that our implementation of BERT (without any optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
...
@@ -297,3 +300,22 @@ def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subs
...
@@ -297,3 +300,22 @@ def test_bert_dense_seq_output(model_name, has_key_padding_mask, last_layer_subs
).
abs
().
max
().
item
()
).
abs
().
max
().
item
()
# The loss calculation from HF is wrong: it doesn't ignore the labels that are 0.
# The loss calculation from HF is wrong: it doesn't ignore the labels that are 0.
# assert (out.loss - out_ref.loss).abs().max().item() < 2 * (out_hf.loss - out_ref.loss).abs().max().item()
# assert (out.loss - out_ref.loss).abs().max().item() < 2 * (out_hf.loss - out_ref.loss).abs().max().item()
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"bert-base-uncased"
,
"bert-large-uncased"
])
def
test_inv_remap_state_dict
(
model_name
:
str
):
"""
Verify that we can convert a HF BERT model to flash_attn and back.
"""
state_dict
=
state_dict_from_pretrained
(
model_name
)
config
=
BertConfig
.
from_pretrained
(
model_name
)
flash_state_dict
=
remap_state_dict
(
state_dict
,
config
)
recovered_state_dict
=
inv_remap_state_dict
(
flash_state_dict
,
config
)
assert
set
(
state_dict
.
keys
())
==
set
(
recovered_state_dict
.
keys
())
for
k
in
state_dict
.
keys
():
assert
state_dict
[
k
].
shape
==
recovered_state_dict
[
k
].
shape
torch
.
testing
.
assert_close
(
state_dict
[
k
],
recovered_state_dict
[
k
],
rtol
=
1e-6
,
atol
=
1e-6
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment