Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
5fb6df0e
Commit
5fb6df0e
authored
Dec 18, 2022
by
Tri Dao
Browse files
Implement BERT
parent
dc24c226
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
784 additions
and
37 deletions
+784
-37
flash_attn/flash_attention.py
flash_attn/flash_attention.py
+1
-1
flash_attn/models/bert.py
flash_attn/models/bert.py
+426
-0
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+1
-1
flash_attn/modules/block.py
flash_attn/modules/block.py
+14
-3
flash_attn/modules/embedding.py
flash_attn/modules/embedding.py
+43
-9
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+76
-23
flash_attn/ops/layer_norm.py
flash_attn/ops/layer_norm.py
+4
-0
tests/models/test_bert.py
tests/models/test_bert.py
+219
-0
No files found.
flash_attn/flash_attention.py
View file @
5fb6df0e
...
...
@@ -5,7 +5,7 @@ import torch.nn as nn
from
einops
import
rearrange
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_qkvpacked_func
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
,
index_first_axis
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
class
FlashAttention
(
nn
.
Module
):
...
...
flash_attn/models/bert.py
0 → 100644
View file @
5fb6df0e
# Copyright (c) 2022, Tri Dao.
# This BERT implementation is based on our MLPerf 2.0 and MLPerf 2.1 BERT implementation.
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
import
re
import
logging
from
functools
import
partial
from
collections.abc
import
Sequence
from
collections
import
OrderedDict
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
transformers
import
BertConfig
from
einops
import
rearrange
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mlp
import
Mlp
,
FusedDenseGeluDense
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.embedding
import
BertEmbeddings
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
,
index_first_axis
try
:
from
flash_attn.ops.fused_dense
import
FusedDenseTD
except
ImportError
:
FusedDenseTD
=
None
try
:
from
flash_attn.ops.layer_norm
import
dropout_add_layer_norm
,
layer_norm
except
ImportError
:
dropout_add_layer_norm
,
layer_norm
=
None
,
None
try
:
from
flash_attn.losses.cross_entropy_apex
import
CrossEntropyLossApex
except
ImportError
:
CrossEntropyLossApex
=
None
logger
=
logging
.
getLogger
(
__name__
)
def
create_mixer_cls
(
config
):
use_flash_attn
=
getattr
(
config
,
'use_flash_attn'
,
False
)
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
mixer_cls
=
partial
(
MHA
,
num_heads
=
config
.
num_attention_heads
,
dropout
=
config
.
attention_probs_dropout_prob
,
causal
=
False
,
fused_bias_fc
=
fused_bias_fc
,
use_flash_attn
=
use_flash_attn
)
return
mixer_cls
def
create_mlp_cls
(
config
,
layer_idx
=
None
):
inner_dim
=
config
.
intermediate_size
fused_dense_gelu_dense
=
getattr
(
config
,
'fused_dense_gelu_dense'
,
False
)
if
not
fused_dense_gelu_dense
:
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
inner_dim
,
activation
=
partial
(
F
.
gelu
,
approximate
=
'tanh'
))
else
:
mlp_checkpoint_lvl
=
getattr
(
config
,
'mlp_checkpoint_lvl'
,
0
)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
if
isinstance
(
mlp_checkpoint_lvl
,
Sequence
):
assert
layer_idx
is
not
None
mlp_checkpoint_lvl
=
mlp_checkpoint_lvl
[
layer_idx
]
mlp_cls
=
partial
(
FusedDenseGeluDense
,
hidden_features
=
inner_dim
,
checkpoint_lvl
=
mlp_checkpoint_lvl
)
return
mlp_cls
def
create_block
(
config
,
layer_idx
=
None
):
mixer_cls
=
create_mixer_cls
(
config
)
mlp_cls
=
create_mlp_cls
(
config
,
layer_idx
)
norm_cls
=
partial
(
nn
.
LayerNorm
,
eps
=
config
.
layer_norm_eps
)
block
=
Block
(
config
.
hidden_size
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_cls
,
prenorm
=
False
,
resid_dropout
=
config
.
hidden_dropout_prob
,
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
))
return
block
# https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
def
_init_weights
(
module
,
initializer_range
=
0.02
):
if
isinstance
(
module
,
nn
.
Linear
):
nn
.
init
.
normal_
(
module
.
weight
,
std
=
initializer_range
)
if
module
.
bias
is
not
None
:
nn
.
init
.
zeros_
(
module
.
bias
)
elif
isinstance
(
module
,
nn
.
Embedding
):
nn
.
init
.
normal_
(
module
.
weight
,
std
=
initializer_range
)
if
module
.
padding_idx
is
not
None
:
nn
.
init
.
zeros_
(
module
.
weight
[
module
.
padding_idx
])
class
BertEncoder
(
nn
.
Module
):
def
__init__
(
self
,
config
:
BertConfig
):
super
().
__init__
()
self
.
use_flash_attn
=
getattr
(
config
,
'use_flash_attn'
,
False
)
self
.
layers
=
nn
.
ModuleList
([
create_block
(
config
,
layer_idx
=
i
)
for
i
in
range
(
config
.
num_hidden_layers
)])
def
forward
(
self
,
hidden_states
,
key_padding_mask
=
None
):
if
key_padding_mask
is
None
or
not
self
.
use_flash_attn
:
mixer_kwargs
=
({
'key_padding_mask'
:
key_padding_mask
}
if
key_padding_mask
is
not
None
else
None
)
for
layer
in
self
.
layers
:
hidden_states
=
layer
(
hidden_states
,
mixer_kwargs
=
mixer_kwargs
)
else
:
batch
,
seqlen
=
hidden_states
.
shape
[:
2
]
hidden_states
,
indices
,
cu_seqlens
,
max_seqlen_in_batch
=
unpad_input
(
hidden_states
,
key_padding_mask
)
mixer_kwargs
=
{
'cu_seqlens'
:
cu_seqlens
,
'max_seqlen'
:
max_seqlen_in_batch
}
for
layer
in
self
.
layers
:
hidden_states
=
layer
(
hidden_states
,
mixer_kwargs
=
mixer_kwargs
)
hidden_states
=
pad_input
(
hidden_states
,
indices
,
batch
,
seqlen
)
return
hidden_states
class
BertPooler
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
if
fused_bias_fc
and
FusedDenseTD
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDenseTD
self
.
dense
=
linear_cls
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
activation
=
nn
.
Tanh
()
def
forward
(
self
,
hidden_states
,
pool
=
True
):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor
=
hidden_states
[:,
0
]
if
pool
else
hidden_states
pooled_output
=
self
.
dense
(
first_token_tensor
)
pooled_output
=
self
.
activation
(
pooled_output
)
return
pooled_output
class
BertPredictionHeadTransform
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
if
fused_bias_fc
and
FusedDenseTD
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
self
.
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
)
if
self
.
fused_dropout_add_ln
and
layer_norm
is
None
:
raise
ImportError
(
'dropout_add_layer_norm is not installed'
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDenseTD
self
.
dense
=
linear_cls
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
transform_act_fn
=
nn
.
GELU
(
approximate
=
'tanh'
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
dense
(
hidden_states
)
hidden_states
=
self
.
transform_act_fn
(
hidden_states
)
if
not
self
.
fused_dropout_add_ln
:
hidden_states
=
self
.
layer_norm
(
hidden_states
)
else
:
hidden_states
=
layer_norm
(
hidden_states
,
self
.
layer_norm
.
weight
,
self
.
layer_norm
.
bias
,
self
.
layer_norm
.
eps
)
return
hidden_states
class
BertLMPredictionHead
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
if
fused_bias_fc
and
FusedDenseTD
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDenseTD
self
.
transform
=
BertPredictionHeadTransform
(
config
)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self
.
decoder
=
linear_cls
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
True
)
def
forward
(
self
,
hidden_states
):
hidden_states
=
self
.
transform
(
hidden_states
)
hidden_states
=
self
.
decoder
(
hidden_states
)
return
hidden_states
class
BertPreTrainingHeads
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
predictions
=
BertLMPredictionHead
(
config
)
self
.
seq_relationship
=
nn
.
Linear
(
config
.
hidden_size
,
2
)
def
forward
(
self
,
sequence_output
,
pooled_output
):
prediction_scores
=
self
.
predictions
(
sequence_output
)
seq_relationship_score
=
self
.
seq_relationship
(
pooled_output
)
return
prediction_scores
,
seq_relationship_score
class
BertPreTrainedModel
(
nn
.
Module
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
().
__init__
()
if
not
isinstance
(
config
,
BertConfig
):
raise
ValueError
(
"Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
))
self
.
config
=
config
@
classmethod
def
from_pretrained
(
cls
,
model_name
,
config
,
*
inputs
,
**
kwargs
):
"""
Instantiate a BertPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
Params:
pretrained_model_name_or_path: either:
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `pytorch_model.bin` a PyTorch dump of a BertForPretraining instance
- a path or url to a pretrained model archive containing:
. `bert_config.json` a configuration file for the model
. `model.chkpt` a TensorFlow checkpoint
*inputs, **kwargs: additional input for the specific Bert class
(ex: num_labels for BertForSequenceClassification)
"""
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
load_return
=
model
.
load_state_dict
(
remap_state_dict
(
state_dict_from_pretrained
(
model_name
),
config
),
strict
=
False
)
logger
.
info
(
load_return
)
return
model
class
BertModel
(
BertPreTrainedModel
):
def
__init__
(
self
,
config
:
BertConfig
,
add_pooling_layer
=
True
):
super
().
__init__
(
config
)
self
.
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
if
config
.
vocab_size
%
self
.
pad_vocab_size_multiple
!=
0
:
config
.
vocab_size
+=
(
self
.
pad_vocab_size_multiple
-
(
config
.
vocab_size
%
self
.
pad_vocab_size_multiple
))
self
.
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
)
if
self
.
fused_dropout_add_ln
and
dropout_add_layer_norm
is
None
:
raise
ImportError
(
'dropout_add_layer_norm is not installed'
)
assert
config
.
position_embedding_type
==
'absolute'
assert
config
.
hidden_act
in
[
'gelu'
,
'gelu_new'
,
'gelu_fast'
]
self
.
embeddings
=
BertEmbeddings
(
config
.
hidden_size
,
config
.
vocab_size
,
config
.
max_position_embeddings
,
config
.
type_vocab_size
,
padding_idx
=
config
.
pad_token_id
)
self
.
emb_drop
=
nn
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
emb_ln
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
encoder
=
BertEncoder
(
config
)
self
.
pooler
=
BertPooler
(
config
)
if
add_pooling_layer
else
None
self
.
apply
(
partial
(
_init_weights
,
initializer_range
=
config
.
initializer_range
))
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_tokens_mask
=
None
):
hidden_states
=
self
.
embeddings
(
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
)
# TD [2022-12:18]: Don't need to force residual in fp32
if
not
self
.
fused_dropout_add_ln
:
hidden_states
=
self
.
emb_drop
(
hidden_states
)
hidden_states
=
self
.
emb_ln
(
hidden_states
)
else
:
hidden_states
=
dropout_add_layer_norm
(
hidden_states
,
None
,
self
.
emb_ln
.
weight
,
self
.
emb_ln
.
bias
,
self
.
emb_drop
.
p
if
self
.
training
else
0.0
,
self
.
emb_ln
.
eps
,
prenorm
=
False
,
)
sequence_output
=
self
.
encoder
(
hidden_states
,
key_padding_mask
=
attention_mask
)
pooled_output
=
self
.
pooler
(
sequence_output
)
if
self
.
pooler
is
not
None
else
None
return
sequence_output
,
pooled_output
class
BertForPreTraining
(
BertPreTrainedModel
):
def
__init__
(
self
,
config
:
BertConfig
):
super
().
__init__
(
config
)
# If dense_seq_output, we only need to pass the hidden states for the masked out tokens
# (around 15%) to the classifier heads.
self
.
dense_seq_output
=
getattr
(
config
,
'dense_seq_output'
,
False
)
# If last_layer_subset, we only need the compute the last layer for a subset of tokens
# (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
self
.
last_layer_subset
=
getattr
(
config
,
'last_layer_subset'
,
False
)
assert
not
self
.
last_layer_subset
,
'last_layer_subset is not implemented yet'
use_xentropy
=
getattr
(
config
,
'use_xentropy'
,
False
)
if
use_xentropy
and
CrossEntropyLossApex
is
None
:
raise
ImportError
(
'xentropy_cuda is not installed'
)
loss_cls
=
nn
.
CrossEntropyLoss
if
not
use_xentropy
else
CrossEntropyLossApex
self
.
bert
=
BertModel
(
config
)
self
.
cls
=
BertPreTrainingHeads
(
config
)
self
.
mlm_loss
=
loss_cls
(
ignore_index
=
0
)
self
.
nsp_loss
=
loss_cls
(
ignore_index
=-
1
)
# Initialize weights and apply final processing
self
.
apply
(
partial
(
_init_weights
,
initializer_range
=
config
.
initializer_range
))
self
.
tie_weights
()
def
tie_weights
(
self
):
self
.
cls
.
predictions
.
decoder
.
weight
=
self
.
bert
.
embeddings
.
word_embeddings
.
weight
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
next_sentence_label
=
None
):
"""
Outputs:
if `labels` and `next_sentence_label` are not `None`:
Outputs the total_loss which is the sum of the masked language modeling loss and the next
sentence classification loss.
if `labels` or `next_sentence_label` is `None`:
Outputs a tuple comprising
- the masked language modeling logits of shape [batch_size, sequence_length, vocab_size], and
- the next sentence classification logits of shape [batch_size, 2].
"""
masked_tokens_mask
=
labels
>
0
if
(
self
.
last_layer_subset
and
labels
is
not
None
)
else
None
sequence_output
,
pooled_output
=
self
.
bert
(
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
attention_mask
=
attention_mask
.
bool
(),
masked_tokens_mask
=
masked_tokens_mask
)
if
self
.
dense_seq_output
and
labels
is
not
None
:
masked_token_idx
=
torch
.
nonzero
(
labels
.
flatten
()
>
0
,
as_tuple
=
False
).
flatten
()
if
not
self
.
last_layer_subset
:
sequence_output
=
index_first_axis
(
rearrange
(
sequence_output
,
'b s d -> (b s) d'
),
masked_token_idx
)
prediction_scores
,
seq_relationship_score
=
self
.
cls
(
sequence_output
,
pooled_output
)
if
labels
is
not
None
and
next_sentence_label
is
not
None
:
if
masked_token_idx
is
not
None
:
# prediction_scores are already flattened
masked_lm_loss
=
self
.
mlm_loss
(
prediction_scores
,
labels
.
flatten
()[
masked_token_idx
])
else
:
masked_lm_loss
=
self
.
mlm_loss
(
rearrange
(
prediction_scores
,
'... v -> (...) v'
),
rearrange
(
labels
,
'... -> (...)'
))
next_sentence_loss
=
self
.
nsp_loss
(
rearrange
(
seq_relationship_score
,
'... t -> (...) t'
),
rearrange
(
next_sentence_label
,
'... -> (...)'
))
total_loss
=
(
masked_lm_loss
+
next_sentence_loss
).
float
()
# Masked Language Model Accuracy
masked_lm_labels_flat
=
labels
.
view
(
-
1
)
mlm_labels
=
masked_lm_labels_flat
[
masked_lm_labels_flat
!=
0
]
if
not
self
.
dense_seq_output
:
prediction_scores_flat
=
rearrange
(
prediction_scores
,
'... v -> (...) v'
)
mlm_predictions_scores
=
prediction_scores_flat
[
masked_lm_labels_flat
!=
0
]
mlm_predictions
=
mlm_predictions_scores
.
argmax
(
dim
=-
1
)
else
:
mlm_predictions
=
prediction_scores
.
argmax
(
dim
=-
1
)
mlm_acc
=
(
mlm_predictions
==
mlm_labels
).
sum
(
dtype
=
torch
.
float
)
/
mlm_labels
.
numel
()
return
total_loss
,
prediction_scores
,
seq_relationship_score
,
mlm_acc
,
mlm_labels
.
numel
()
else
:
return
prediction_scores
,
seq_relationship_score
def
state_dict_from_pretrained
(
model_name
):
from
transformers.utils
import
WEIGHTS_NAME
from
transformers.utils.hub
import
cached_file
return
torch
.
load
(
cached_file
(
model_name
,
WEIGHTS_NAME
))
def
remap_state_dict
(
state_dict
,
config
):
# LayerNorm
def
key_mapping_ln_gamma_beta
(
key
):
key
=
re
.
sub
(
r
'LayerNorm.gamma$'
,
'LayerNorm.weight'
,
key
)
key
=
re
.
sub
(
r
'LayerNorm.beta$'
,
'LayerNorm.bias'
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln_gamma_beta
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Layers
def
key_mapping_layers
(
key
):
return
re
.
sub
(
r
'^bert.encoder.layer.'
,
'bert.encoder.layers.'
,
key
)
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
'^bert.embeddings.LayerNorm.'
,
'bert.emb_ln.'
,
key
)
key
=
re
.
sub
(
r
'^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)'
,
r
'bert.encoder.layers.\1.norm1.\2'
,
key
)
key
=
re
.
sub
(
r
'^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)'
,
r
'bert.encoder.layers.\1.norm2.\2'
,
key
)
key
=
re
.
sub
(
r
'^cls.predictions.transform.LayerNorm.(weight|bias)'
,
r
'cls.predictions.transform.layer_norm.\1'
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
def
key_mapping_mlp
(
key
):
key
=
re
.
sub
(
r
'^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)'
,
r
'bert.encoder.layers.\1.mlp.fc1.\2'
,
key
)
key
=
re
.
sub
(
r
'^bert.encoder.layers.(\d+).output.dense.(weight|bias)'
,
r
'bert.encoder.layers.\1.mlp.fc2.\2'
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
for
d
in
range
(
config
.
num_hidden_layers
):
Wq
=
state_dict
.
pop
(
f
'bert.encoder.layers.
{
d
}
.attention.self.query.weight'
)
Wk
=
state_dict
.
pop
(
f
'bert.encoder.layers.
{
d
}
.attention.self.key.weight'
)
Wv
=
state_dict
.
pop
(
f
'bert.encoder.layers.
{
d
}
.attention.self.value.weight'
)
bq
=
state_dict
.
pop
(
f
'bert.encoder.layers.
{
d
}
.attention.self.query.bias'
)
bk
=
state_dict
.
pop
(
f
'bert.encoder.layers.
{
d
}
.attention.self.key.bias'
)
bv
=
state_dict
.
pop
(
f
'bert.encoder.layers.
{
d
}
.attention.self.value.bias'
)
state_dict
[
f
'bert.encoder.layers.
{
d
}
.mixer.Wqkv.weight'
]
=
torch
.
cat
(
[
Wq
,
Wk
,
Wv
],
dim
=
0
)
state_dict
[
f
'bert.encoder.layers.
{
d
}
.mixer.Wqkv.bias'
]
=
torch
.
cat
(
[
bq
,
bk
,
bv
],
dim
=
0
)
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
'^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)'
,
r
'bert.encoder.layers.\1.mixer.out_proj.\2'
,
key
)
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
def
key_mapping_decoder_bias
(
key
):
return
re
.
sub
(
r
'^cls.predictions.bias'
,
'cls.predictions.decoder.bias'
,
key
)
state_dict
=
OrderedDict
((
key_mapping_decoder_bias
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
flash_attn/models/gpt.py
View file @
5fb6df0e
...
...
@@ -10,7 +10,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
transformers
.models.gpt2.configuration_gpt2
import
GPT2Config
from
transformers
import
GPT2Config
from
flash_attn.modules.mha
import
MHA
from
flash_attn.modules.mlp
import
Mlp
,
FusedDenseGeluDense
...
...
flash_attn/modules/block.py
View file @
5fb6df0e
...
...
@@ -23,10 +23,16 @@ class Block(nn.Module):
def
__init__
(
self
,
dim
,
mixer_cls
=
None
,
mlp_cls
=
None
,
norm_cls
=
nn
.
LayerNorm
,
dropout_cls
=
nn
.
Dropout
,
prenorm
=
True
,
resid_dropout
=
0.
,
drop_path
=
0.
,
fused_dropout_add_ln
=
False
):
fused_dropout_add_ln
=
False
,
return_residual
=
False
):
"""
return_residual: whether each of the sub-layers (mixer and mlp) will return the residual.
This is for performance reason: for post-norm architecture, returning the input allows us
to fuse the backward of nn.Linear with the residual connection.
"""
super
().
__init__
()
self
.
prenorm
=
prenorm
self
.
fused_dropout_add_ln
=
fused_dropout_add_ln
self
.
return_residual
=
return_residual
if
mixer_cls
is
None
:
mixer_cls
=
partial
(
MHA
,
num_heads
=
dim
//
64
)
if
mlp_cls
is
None
:
...
...
@@ -92,8 +98,11 @@ class Block(nn.Module):
return
hidden_states
,
residual
else
:
assert
residual
is
None
mixer_out
=
self
.
mixer
(
hidden_states
,
**
(
mixer_kwargs
if
mixer_kwargs
is
not
None
else
{}))
mixer_out
=
self
.
mixer
(
hidden_states
,
**
(
mixer_kwargs
if
mixer_kwargs
is
not
None
else
{})
)
if
self
.
return_residual
:
# mixer out is actually a pair here
mixer_out
,
hidden_states
=
mixer_out
if
not
self
.
fused_dropout_add_ln
:
hidden_states
=
self
.
norm1
((
self
.
drop_path1
(
self
.
dropout1
(
mixer_out
))
+
hidden_states
).
to
(
dtype
=
self
.
norm1
.
weight
.
dtype
))
...
...
@@ -111,6 +120,8 @@ class Block(nn.Module):
)
if
not
isinstance
(
self
.
mlp
,
nn
.
Identity
):
mlp_out
=
self
.
mlp
(
hidden_states
)
if
self
.
return_residual
:
# mlp out is actually a pair here
mlp_out
,
hidden_states
=
mlp_out
if
not
self
.
fused_dropout_add_ln
:
hidden_states
=
self
.
norm2
((
self
.
drop_path2
(
self
.
dropout2
(
mlp_out
))
+
hidden_states
).
to
(
dtype
=
self
.
norm2
.
weight
.
dtype
))
...
...
flash_attn/modules/embedding.py
View file @
5fb6df0e
...
...
@@ -3,8 +3,6 @@
import
torch
import
torch.nn
as
nn
from
einops
import
repeat
class
GPT2Embeddings
(
nn
.
Module
):
...
...
@@ -21,15 +19,51 @@ class GPT2Embeddings(nn.Module):
def
forward
(
self
,
input_ids
,
position_ids
=
None
):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
"""
batch_size
,
seqlen
=
input_ids
.
shape
embeddings
=
self
.
word_embeddings
(
input_ids
)
if
self
.
max_position_embeddings
>
0
:
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
embeddings
=
embeddings
+
position_embeddings
return
embeddings
class
BertEmbeddings
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
vocab_size
,
max_position_embeddings
,
type_vocab_size
,
padding_idx
=
None
):
"""
If max_position_embeddings <= 0, there's no position embeddings
If type_vocab_size <= 0, there's no token type embeddings
"""
super
().
__init__
()
self
.
word_embeddings
=
nn
.
Embedding
(
vocab_size
,
embed_dim
,
padding_idx
=
padding_idx
)
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
if
self
.
max_position_embeddings
>
0
:
self
.
position_embeddings
=
nn
.
Embedding
(
max_position_embeddings
,
embed_dim
)
if
self
.
type_vocab_size
>
0
:
self
.
token_type_embeddings
=
nn
.
Embedding
(
type_vocab_size
,
embed_dim
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
):
"""
input_ids: (batch, seqlen)
position_ids: (batch, seqlen)
token_type_ids: (batch, seqlen)
"""
batch_size
,
seqlen
=
input_ids
.
shape
input_
embeddings
=
self
.
word_embeddings
(
input_ids
)
embeddings
=
self
.
word_embeddings
(
input_ids
)
if
self
.
max_position_embeddings
>
0
:
if
position_ids
is
None
:
position_ids
=
repeat
(
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
),
's -> b s'
,
b
=
batch_size
)
position_ids
=
torch
.
arange
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
return
input_embeddings
+
position_embeddings
else
:
return
input_embeddings
embeddings
=
embeddings
+
position_embeddings
if
self
.
type_vocab_size
>
0
:
if
token_type_ids
is
None
:
token_type_ids
=
torch
.
zeros
(
seqlen
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
embeddings
=
embeddings
+
token_type_embeddings
return
embeddings
flash_attn/modules/mha.py
View file @
5fb6df0e
...
...
@@ -53,28 +53,49 @@ class FlashSelfAttention(nn.Module):
self
.
dropout_p
=
attention_dropout
self
.
triton
=
triton
def
forward
(
self
,
qkv
):
def
forward
(
self
,
qkv
,
cu_seqlens
=
None
,
max_seqlen
=
None
):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
qkv: The tensor containing the query, key, and value.
If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D).
If cu_seqlens is not None and max_seqlen is not None, then qkv has shape
(total, 3, H, D), where total is the sum of the sequence lengths in the batch.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into qkv.
max_seqlen: int. Maximum sequence length in the batch.
Returns:
--------
out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None,
else (B, S, H, D).
"""
assert
qkv
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
qkv
.
is_cuda
batch_size
,
seqlen
=
qkv
.
shape
[
0
],
qkv
.
shape
[
1
]
if
self
.
triton
and
(
self
.
dropout_p
==
0
or
not
self
.
training
):
# Triton version doesn't support dropout
output
=
flash_attn_qkvpacked_func
(
qkv
,
None
,
self
.
causal
,
self
.
softmax_scale
)
else
:
qkv
=
rearrange
(
qkv
,
'b s ... -> (b s) ...'
)
max_s
=
seqlen
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen
,
step
=
seqlen
,
dtype
=
torch
.
int32
,
device
=
qkv
.
device
)
output
=
flash_attn_unpadded_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_s
,
self
.
dropout_p
if
self
.
training
else
0.0
,
unpadded
=
cu_seqlens
is
not
None
if
unpadded
:
assert
cu_seqlens
.
dtype
==
torch
.
int32
assert
max_seqlen
is
not
None
assert
isinstance
(
max_seqlen
,
int
)
return
flash_attn_unpadded_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_seqlen
,
self
.
dropout_p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
self
.
causal
)
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
return
output
else
:
batch_size
,
seqlen
=
qkv
.
shape
[
0
],
qkv
.
shape
[
1
]
# Triton version doesn't support dropout
if
self
.
triton
and
(
self
.
dropout_p
==
0
or
not
self
.
training
):
output
=
flash_attn_qkvpacked_func
(
qkv
,
None
,
self
.
causal
,
self
.
softmax_scale
)
else
:
qkv
=
rearrange
(
qkv
,
'b s ... -> (b s) ...'
)
max_seqlen
=
seqlen
cu_seqlens
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen
,
step
=
seqlen
,
dtype
=
torch
.
int32
,
device
=
qkv
.
device
)
output
=
flash_attn_unpadded_qkvpacked_func
(
qkv
,
cu_seqlens
,
max_seqlen
,
self
.
dropout_p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
self
.
causal
)
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
return
output
class
FlashCrossAttention
(
nn
.
Module
):
...
...
@@ -146,16 +167,24 @@ class SelfAttention(nn.Module):
self
.
softmax_scale
=
softmax_scale
self
.
dropout_p
=
attention_dropout
def
forward
(
self
,
qkv
):
def
forward
(
self
,
qkv
,
key_padding_mask
=
None
):
"""Implements the multihead softmax attention.
Arguments
---------
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D)
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, S)
"""
batch_size
,
seqlen
=
qkv
.
shape
[
0
],
qkv
.
shape
[
1
]
q
,
k
,
v
=
qkv
.
unbind
(
dim
=
2
)
softmax_scale
=
self
.
softmax_scale
or
1.0
/
math
.
sqrt
(
q
.
shape
[
-
1
])
scores
=
torch
.
einsum
(
'bthd,bshd->bhts'
,
q
,
k
*
softmax_scale
)
if
key_padding_mask
is
not
None
:
padding_mask
=
torch
.
full
((
batch_size
,
seqlen
),
-
10000.0
,
dtype
=
scores
.
dtype
,
device
=
scores
.
device
)
padding_mask
.
masked_fill_
(
key_padding_mask
,
0.0
)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores
=
scores
+
rearrange
(
padding_mask
,
'b s -> b 1 1 s'
)
if
self
.
causal
:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
...
...
@@ -239,6 +268,7 @@ class MHA(nn.Module):
self
.
causal
=
causal
self
.
dwconv
=
dwconv
self
.
rotary_emb_dim
=
rotary_emb_dim
self
.
use_flash_attn
=
use_flash_attn
self
.
return_residual
=
return_residual
self
.
checkpointing
=
checkpointing
...
...
@@ -279,12 +309,35 @@ class MHA(nn.Module):
# output projection always have the bias (for now)
self
.
out_proj
=
linear_cls
(
embed_dim
,
embed_dim
,
**
factory_kwargs
)
def
forward
(
self
,
x
,
x_kv
=
None
):
def
forward
(
self
,
x
,
x_kv
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
key_padding_mask
=
None
):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim)
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total
is the is the sum of the sequence lengths in the batch.
x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into x. Only applicable when using
FlashAttention.
max_seqlen: int. Maximum sequence length in the batch.
key_padding_mask: boolean mask, True means to keep, False means to mask out.
(batch, seqlen). Only applicable when not using FlashAttention.
"""
if
cu_seqlens
is
not
None
:
assert
max_seqlen
is
not
None
assert
key_padding_mask
is
None
assert
self
.
use_flash_attn
assert
not
self
.
cross_attn
,
(
'Unpadded FlashAttention code path for cross-attention'
'is not implemented yet'
)
assert
not
self
.
dwconv
assert
self
.
rotary_emb_dim
==
0
if
key_padding_mask
is
not
None
:
assert
cu_seqlens
is
None
assert
max_seqlen
is
None
assert
not
self
.
use_flash_attn
assert
not
self
.
cross_attn
,
(
'Key padding mask code path for cross-attention'
'is not implemented yet'
)
if
not
self
.
cross_attn
:
if
not
self
.
return_residual
:
qkv
=
self
.
Wqkv
(
x
)
...
...
@@ -293,14 +346,15 @@ class MHA(nn.Module):
if
self
.
dwconv
:
qkv
=
rearrange
(
self
.
dwconv_qkv
(
rearrange
(
qkv
,
'b s d -> b d s'
))[...,
:
-
2
],
'b d s -> b s d'
).
contiguous
()
qkv
=
rearrange
(
qkv
,
'
b s
(three h d) ->
b s
three h d'
,
three
=
3
,
h
=
self
.
num_heads
)
qkv
=
rearrange
(
qkv
,
'
...
(three h d) ->
...
three h d'
,
three
=
3
,
h
=
self
.
num_heads
)
if
self
.
rotary_emb_dim
>
0
:
qkv
=
self
.
rotary_emb
(
qkv
)
extra_kwargs
=
({
'cu_seqlens'
:
cu_seqlens
,
'max_seqlen'
:
max_seqlen
}
if
self
.
use_flash_attn
else
{
'key_padding_mask'
:
key_padding_mask
})
if
not
self
.
checkpointing
:
context
=
self
.
inner_attn
(
qkv
)
context
=
self
.
inner_attn
(
qkv
,
**
extra_kwargs
)
else
:
# context = torch.utils.checkpoint.checkpoint(self._inner_attention, qkv)
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
)
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
extra_kwargs
)
else
:
q
=
rearrange
(
self
.
Wq
(
x
),
'b s (h d) -> b s h d'
,
h
=
self
.
num_heads
)
kv
=
rearrange
(
self
.
Wkv
(
x
if
x_kv
is
None
else
x_kv
),
'b s (two h d) -> b s two h d'
,
...
...
@@ -313,7 +367,6 @@ class MHA(nn.Module):
if
not
self
.
checkpointing
:
context
=
self
.
inner_attn
(
q
,
kv
)
else
:
# context = torch.utils.checkpoint.checkpoint(self._inner_attention, qkv)
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
q
,
kv
)
out
=
self
.
out_proj
(
rearrange
(
context
,
'
b s
h d ->
b s
(h d)'
))
out
=
self
.
out_proj
(
rearrange
(
context
,
'
...
h d ->
...
(h d)'
))
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
flash_attn/ops/layer_norm.py
View file @
5fb6df0e
...
...
@@ -200,6 +200,10 @@ class DropoutAddLayerNormSubsetFn(torch.autograd.Function):
None
,
None
)
def
layer_norm
(
x
,
weight
,
bias
,
epsilon
):
return
DropoutAddLayerNormFn
.
apply
(
x
,
None
,
weight
,
bias
,
None
,
None
,
0.0
,
epsilon
,
False
)
def
dropout_add_layer_norm
(
x0
,
x1
,
weight
,
bias
,
dropout_p
,
epsilon
,
rowscale
=
None
,
layerscale
=
None
,
prenorm
=
False
,
residual_in_fp32
=
False
,
return_dropout_mask
=
False
):
...
...
tests/models/test_bert.py
0 → 100644
View file @
5fb6df0e
import
re
from
collections
import
OrderedDict
import
torch
import
torch.nn.functional
as
F
import
pytest
from
einops
import
rearrange
from
transformers
import
BertConfig
from
transformers.models.bert.modeling_bert
import
BertModel
as
BertModelHF
from
transformers.models.bert.modeling_bert
import
BertForPreTraining
as
BertForPreTrainingHF
from
flash_attn.models.bert
import
BertModel
,
BertForPreTraining
from
flash_attn.models.bert
import
state_dict_from_pretrained
from
flash_attn.models.bert
import
remap_state_dict
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"bert-base-uncased"
,
"bert-large-uncased"
])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def
test_bert_state_dict
(
model_name
):
config
=
BertConfig
.
from_pretrained
(
model_name
)
pretrained_state_dict
=
remap_state_dict
(
state_dict_from_pretrained
(
model_name
),
config
)
model
=
BertForPreTraining
(
config
)
state_dict
=
model
.
state_dict
()
assert
state_dict
.
keys
()
==
pretrained_state_dict
.
keys
()
for
k
in
state_dict
.
keys
():
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
def
get_hf_models
(
model_name
,
config
,
dtype
):
pretrained_state_dict
=
state_dict_from_pretrained
(
model_name
)
def
key_mapping_ln_gamma_beta
(
key
):
key
=
re
.
sub
(
r
'LayerNorm.gamma$'
,
'LayerNorm.weight'
,
key
)
key
=
re
.
sub
(
r
'LayerNorm.beta$'
,
'LayerNorm.bias'
,
key
)
return
key
pretrained_state_dict
=
OrderedDict
((
key_mapping_ln_gamma_beta
(
k
),
v
)
for
k
,
v
in
pretrained_state_dict
.
items
())
model_hf
=
BertForPreTrainingHF
(
config
)
# Missing key(s) in state_dict: "bert.embeddings.position_ids", "cls.predictions.decoder.bias"
# position_ids is a buffer, and predictions.decoder.bias is tied to predictions.bias.
model_hf
.
load_state_dict
(
pretrained_state_dict
,
strict
=
False
)
model_hf
.
cuda
().
to
(
dtype
=
dtype
)
return
model_hf
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"bert-base-uncased"
,
"bert-large-uncased"
])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def
test_bert_non_optimized
(
model_name
):
"""Check that our implementation of BERT (without any optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype
=
torch
.
float16
config
=
BertConfig
.
from_pretrained
(
model_name
)
# Our implementation assumes the activation is nn.GELU(approximate='tanh')
# Huggingface calls it "gelu_new" or "gelu_fast".
config
.
hidden_act
=
"gelu_new"
model
=
BertForPreTraining
.
from_pretrained
(
model_name
,
config
)
model
=
model
.
cuda
().
to
(
dtype
=
dtype
)
model_ref
=
get_hf_models
(
model_name
,
config
,
torch
.
float32
)
model_hf
=
get_hf_models
(
model_name
,
config
,
torch
.
float16
)
model
.
eval
()
model_ref
.
eval
()
model_hf
.
eval
()
torch
.
manual_seed
(
0
)
batch_size
=
4
max_seqlen
=
512
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
'cuda'
)
attention_mask
=
torch
.
arange
(
max_seqlen
,
device
=
'cuda'
)[
None
,
:]
<
seqlens
[:,
None
]
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
'cuda'
)
sequence_output
,
pooled_output
=
model
.
bert
(
input_ids
,
attention_mask
=
attention_mask
)
out_hf
=
model_hf
.
bert
(
input_ids
,
attention_mask
=
attention_mask
)
sequence_output_hf
,
pooled_output_hf
=
out_hf
.
last_hidden_state
,
out_hf
.
pooler_output
out_ref
=
model_ref
.
bert
(
input_ids
,
attention_mask
=
attention_mask
)
sequence_output_ref
,
pooled_output_ref
=
out_ref
.
last_hidden_state
,
out_ref
.
pooler_output
print
(
f
'Output max diff:
{
(
sequence_output
-
sequence_output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Output mean diff:
{
(
sequence_output
-
sequence_output_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
sequence_output_hf
-
sequence_output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
sequence_output_hf
-
sequence_output_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
sequence_output
-
sequence_output_ref
).
abs
().
max
().
item
()
<
2
*
(
sequence_output_hf
-
sequence_output_ref
).
abs
().
max
().
item
()
assert
(
pooled_output
-
pooled_output_ref
).
abs
().
max
().
item
()
<
2
*
(
pooled_output_hf
-
pooled_output_ref
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"bert-base-uncased"
,
"bert-large-uncased"
])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def
test_bert_optimized
(
model_name
):
"""Check that our implementation of BERT (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype
=
torch
.
float16
config
=
BertConfig
.
from_pretrained
(
model_name
)
# Our implementation assumes the activation is nn.GELU(approximate='tanh')
# Huggingface calls it "gelu_new" or "gelu_fast".
config
.
hidden_act
=
"gelu_new"
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_dense_gelu_dense
=
True
config
.
fused_dropout_add_ln
=
True
model
=
BertForPreTraining
.
from_pretrained
(
model_name
,
config
)
model
=
model
.
cuda
().
to
(
dtype
=
dtype
)
model_ref
=
get_hf_models
(
model_name
,
config
,
torch
.
float32
)
model_hf
=
get_hf_models
(
model_name
,
config
,
torch
.
float16
)
model
.
eval
()
model_ref
.
eval
()
model_hf
.
eval
()
torch
.
manual_seed
(
0
)
batch_size
=
4
max_seqlen
=
512
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
'cuda'
)
attention_mask
=
torch
.
arange
(
max_seqlen
,
device
=
'cuda'
)[
None
,
:]
<
seqlens
[:,
None
]
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
'cuda'
)
sequence_output
,
pooled_output
=
model
.
bert
(
input_ids
,
attention_mask
=
attention_mask
)
out_hf
=
model_hf
.
bert
(
input_ids
,
attention_mask
=
attention_mask
)
sequence_output_hf
,
pooled_output_hf
=
out_hf
.
last_hidden_state
,
out_hf
.
pooler_output
# Need to zero out the padded tokens in the sequence before comparison.
sequence_output_hf
[
~
attention_mask
,
:]
=
0.0
out_ref
=
model_ref
.
bert
(
input_ids
,
attention_mask
=
attention_mask
)
sequence_output_ref
,
pooled_output_ref
=
out_ref
.
last_hidden_state
,
out_ref
.
pooler_output
sequence_output_ref
[
~
attention_mask
,
:]
=
0.0
print
(
f
'BertModel output max diff:
{
(
sequence_output
-
sequence_output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'BertModel output mean diff:
{
(
sequence_output
-
sequence_output_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 BertModel max diff:
{
(
sequence_output_hf
-
sequence_output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 BertModel mean diff:
{
(
sequence_output_hf
-
sequence_output_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
sequence_output
-
sequence_output_ref
).
abs
().
max
().
item
()
<
4
*
(
sequence_output_hf
-
sequence_output_ref
).
abs
().
max
().
item
()
assert
(
pooled_output
-
pooled_output_ref
).
abs
().
max
().
item
()
<
4
*
(
pooled_output_hf
-
pooled_output_ref
).
abs
().
max
().
item
()
prediction_scores
,
seq_relationship_scores
=
model
(
input_ids
,
attention_mask
=
attention_mask
)
# Need to zero out the padded tokens in the sequence before comparison.
prediction_scores
=
prediction_scores
.
clone
()
prediction_scores
[
~
attention_mask
,
:]
=
0.0
out_hf
=
model_hf
(
input_ids
,
attention_mask
=
attention_mask
)
prediction_scores_hf
,
seq_relationship_scores_hf
=
out_hf
.
prediction_logits
,
out_hf
.
seq_relationship_logits
prediction_scores_hf
[
~
attention_mask
,
:]
=
0.0
out_ref
=
model_ref
(
input_ids
,
attention_mask
=
attention_mask
)
prediction_scores_ref
,
seq_relationship_scores_ref
=
out_ref
.
prediction_logits
,
out_ref
.
seq_relationship_logits
prediction_scores_ref
[
~
attention_mask
,
:]
=
0.0
print
(
f
'prediction_scores max diff:
{
(
prediction_scores
-
prediction_scores_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'prediction_scores mean diff:
{
(
prediction_scores
-
prediction_scores_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 prediction_scoresff:
{
(
prediction_scores_hf
-
prediction_scores_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 prediction_scoresiff:
{
(
prediction_scores_hf
-
prediction_scores_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
prediction_scores
-
prediction_scores_ref
).
abs
().
max
().
item
()
<
2
*
(
prediction_scores_hf
-
prediction_scores_ref
).
abs
().
max
().
item
()
assert
(
seq_relationship_scores
-
seq_relationship_scores_ref
).
abs
().
max
().
item
()
<
2
*
(
seq_relationship_scores_hf
-
seq_relationship_scores_ref
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"bert-base-uncased"
,
"bert-large-uncased"
])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def
test_bert_dense_seq_output
(
model_name
):
"""Check that our implementation of BERT (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype
=
torch
.
float16
config
=
BertConfig
.
from_pretrained
(
model_name
)
# Our implementation assumes the activation is nn.GELU(approximate='tanh')
# Huggingface calls it "gelu_new" or "gelu_fast".
config
.
hidden_act
=
"gelu_new"
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_dense_gelu_dense
=
True
config
.
fused_dropout_add_ln
=
True
config
.
dense_seq_output
=
True
config
.
use_xentropy
=
True
model
=
BertForPreTraining
.
from_pretrained
(
model_name
,
config
)
model
=
model
.
cuda
().
to
(
dtype
=
dtype
)
model_ref
=
get_hf_models
(
model_name
,
config
,
torch
.
float32
)
model_hf
=
get_hf_models
(
model_name
,
config
,
torch
.
float16
)
model
.
eval
()
model_ref
.
eval
()
model_hf
.
eval
()
torch
.
manual_seed
(
0
)
batch_size
=
4
max_seqlen
=
512
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
'cuda'
)
attention_mask
=
torch
.
arange
(
max_seqlen
,
device
=
'cuda'
)[
None
,
:]
<
seqlens
[:,
None
]
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
'cuda'
)
labels
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
'cuda'
)
labels
[(
torch
.
rand
(
batch_size
,
max_seqlen
,
device
=
'cuda'
)
<
0.15
)
|
~
attention_mask
]
=
0
masked_tokens_mask
=
labels
.
flatten
()
>
0
next_sequence_label
=
torch
.
randint
(
0
,
2
,
(
batch_size
,),
device
=
'cuda'
)
total_loss
,
prediction_scores
,
seq_relationship_scores
,
_
,
_
=
model
(
input_ids
,
attention_mask
=
attention_mask
,
labels
=
labels
,
next_sentence_label
=
next_sequence_label
)
out_hf
=
model_hf
(
input_ids
,
attention_mask
=
attention_mask
,
labels
=
labels
,
next_sentence_label
=
next_sequence_label
)
prediction_scores_hf
,
seq_relationship_scores_hf
=
out_hf
.
prediction_logits
,
out_hf
.
seq_relationship_logits
prediction_scores_hf
=
rearrange
(
prediction_scores_hf
,
'b s d -> (b s) d'
)[
masked_tokens_mask
]
out_ref
=
model_ref
(
input_ids
,
attention_mask
=
attention_mask
,
labels
=
labels
,
next_sentence_label
=
next_sequence_label
)
prediction_scores_ref
,
seq_relationship_scores_ref
=
out_ref
.
prediction_logits
,
out_ref
.
seq_relationship_logits
prediction_scores_ref
=
rearrange
(
prediction_scores_ref
,
'b s d -> (b s) d'
)[
masked_tokens_mask
]
print
(
f
'prediction_scores max diff:
{
(
prediction_scores
-
prediction_scores_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'prediction_scores mean diff:
{
(
prediction_scores
-
prediction_scores_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 prediction_scoresff:
{
(
prediction_scores_hf
-
prediction_scores_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 prediction_scoresiff:
{
(
prediction_scores_hf
-
prediction_scores_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
prediction_scores
-
prediction_scores_ref
).
abs
().
max
().
item
()
<
2
*
(
prediction_scores_hf
-
prediction_scores_ref
).
abs
().
max
().
item
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment