Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
13cdceb3
Commit
13cdceb3
authored
Dec 19, 2022
by
Tri Dao
Browse files
Implement last_layer_subset optimization for BERT
parent
5fb6df0e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
262 additions
and
102 deletions
+262
-102
flash_attn/models/bert.py
flash_attn/models/bert.py
+151
-44
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+70
-34
flash_attn/modules/mlp.py
flash_attn/modules/mlp.py
+6
-5
tests/models/test_bert.py
tests/models/test_bert.py
+35
-19
No files found.
flash_attn/models/bert.py
View file @
13cdceb3
...
...
@@ -17,6 +17,8 @@ import torch.nn as nn
import
torch.nn.functional
as
F
from
transformers
import
BertConfig
from
transformers.models.bert.modeling_bert
import
BaseModelOutputWithPoolingAndCrossAttentions
from
transformers.models.bert.modeling_bert
import
BertForPreTrainingOutput
from
einops
import
rearrange
...
...
@@ -24,7 +26,8 @@ from flash_attn.modules.mha import MHA
from
flash_attn.modules.mlp
import
Mlp
,
FusedDenseGeluDense
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.embedding
import
BertEmbeddings
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
,
index_first_axis
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
from
flash_attn.bert_padding
import
index_first_axis
,
index_first_axis_residual
try
:
from
flash_attn.ops.fused_dense
import
FusedDenseTD
...
...
@@ -45,21 +48,27 @@ except ImportError:
logger
=
logging
.
getLogger
(
__name__
)
def
create_mixer_cls
(
config
):
def
create_mixer_cls
(
config
,
cross_attn
=
False
,
return_residual
=
False
):
use_flash_attn
=
getattr
(
config
,
'use_flash_attn'
,
False
)
fused_bias_fc
=
getattr
(
config
,
'fused_bias_fc'
,
False
)
mixer_cls
=
partial
(
MHA
,
num_heads
=
config
.
num_attention_heads
,
mixer_cls
=
partial
(
MHA
,
num_heads
=
config
.
num_attention_heads
,
cross_attn
=
cross_attn
,
dropout
=
config
.
attention_probs_dropout_prob
,
causal
=
False
,
fused_bias_fc
=
fused_bias_fc
,
use_flash_attn
=
use_flash_attn
)
fused_bias_fc
=
fused_bias_fc
,
use_flash_attn
=
use_flash_attn
,
return_residual
=
return_residual
)
return
mixer_cls
def
create_mlp_cls
(
config
,
layer_idx
=
None
):
def
create_mlp_cls
(
config
,
layer_idx
=
None
,
return_residual
=
False
):
inner_dim
=
config
.
intermediate_size
fused_dense_gelu_dense
=
getattr
(
config
,
'fused_dense_gelu_dense'
,
False
)
if
fused_dense_gelu_dense
:
assert
config
.
hidden_act
in
[
'gelu_new'
,
'gelu_fast'
],
(
'fused_dense_gelu_dense only '
'supports approximate gelu'
)
if
not
fused_dense_gelu_dense
:
approximate
=
'tanh'
if
config
.
hidden_act
in
[
'gelu_new'
,
'gelu_fast'
]
else
'none'
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
inner_dim
,
activation
=
partial
(
F
.
gelu
,
approximate
=
'tanh'
))
activation
=
partial
(
F
.
gelu
,
approximate
=
approximate
),
return_residual
=
return_residual
)
else
:
mlp_checkpoint_lvl
=
getattr
(
config
,
'mlp_checkpoint_lvl'
,
0
)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
...
...
@@ -67,17 +76,24 @@ def create_mlp_cls(config, layer_idx=None):
assert
layer_idx
is
not
None
mlp_checkpoint_lvl
=
mlp_checkpoint_lvl
[
layer_idx
]
mlp_cls
=
partial
(
FusedDenseGeluDense
,
hidden_features
=
inner_dim
,
checkpoint_lvl
=
mlp_checkpoint_lvl
)
checkpoint_lvl
=
mlp_checkpoint_lvl
,
return_residual
=
return_residual
)
return
mlp_cls
def
create_block
(
config
,
layer_idx
=
None
):
mixer_cls
=
create_mixer_cls
(
config
)
mlp_cls
=
create_mlp_cls
(
config
,
layer_idx
)
last_layer_subset
=
getattr
(
config
,
'last_layer_subset'
,
False
)
cross_attn
=
last_layer_subset
and
layer_idx
==
config
.
num_hidden_layers
-
1
# TD [2022-12-19]: For cross attention (last layer), we actually want to return the
# residual x_kv, not residual x. But it's annoying to change the API (and it only affects
# one layer) so we just choose not to return residual in this case.
return_residual
=
not
cross_attn
mixer_cls
=
create_mixer_cls
(
config
,
cross_attn
,
return_residual
=
return_residual
)
mlp_cls
=
create_mlp_cls
(
config
,
layer_idx
,
return_residual
=
return_residual
)
norm_cls
=
partial
(
nn
.
LayerNorm
,
eps
=
config
.
layer_norm_eps
)
block
=
Block
(
config
.
hidden_size
,
mixer_cls
,
mlp_cls
,
norm_cls
=
norm_cls
,
prenorm
=
False
,
resid_dropout
=
config
.
hidden_dropout_prob
,
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
))
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
),
return_residual
=
return_residual
)
return
block
...
...
@@ -101,21 +117,49 @@ class BertEncoder(nn.Module):
self
.
layers
=
nn
.
ModuleList
([
create_block
(
config
,
layer_idx
=
i
)
for
i
in
range
(
config
.
num_hidden_layers
)])
def
forward
(
self
,
hidden_states
,
key_padding_mask
=
None
):
def
forward
(
self
,
hidden_states
,
key_padding_mask
=
None
,
subset_mask
=
None
):
"""If subset_mask is not None, we only want output for the subset of the sequence.
This means that we only compute the last layer output for these tokens.
subset_mask: (batch, seqlen), dtype=torch.bool
"""
if
key_padding_mask
is
None
or
not
self
.
use_flash_attn
:
mixer_kwargs
=
({
'key_padding_mask'
:
key_padding_mask
}
if
key_padding_mask
is
not
None
else
None
)
for
layer
in
self
.
layers
:
hidden_states
=
layer
(
hidden_states
,
mixer_kwargs
=
mixer_kwargs
)
if
subset_mask
is
not
None
:
hidden_states
=
hidden_states
[
subset_mask
]
else
:
batch
,
seqlen
=
hidden_states
.
shape
[:
2
]
hidden_states
,
indices
,
cu_seqlens
,
max_seqlen_in_batch
=
unpad_input
(
hidden_states
,
key_padding_mask
)
mixer_kwargs
=
{
'cu_seqlens'
:
cu_seqlens
,
'max_seqlen'
:
max_seqlen_in_batch
}
for
layer
in
self
.
layers
:
hidden_states
=
layer
(
hidden_states
,
mixer_kwargs
=
mixer_kwargs
)
hidden_states
=
pad_input
(
hidden_states
,
indices
,
batch
,
seqlen
)
if
subset_mask
is
None
:
for
layer
in
self
.
layers
:
hidden_states
=
layer
(
hidden_states
,
mixer_kwargs
=
mixer_kwargs
)
hidden_states
=
pad_input
(
hidden_states
,
indices
,
batch
,
seqlen
)
else
:
for
layer
in
self
.
layers
[:
-
1
]:
hidden_states
=
layer
(
hidden_states
,
mixer_kwargs
=
mixer_kwargs
)
if
key_padding_mask
is
not
None
:
subset_idx
=
torch
.
nonzero
(
subset_mask
[
key_padding_mask
],
as_tuple
=
False
).
flatten
()
subset_seqlens
=
(
subset_mask
&
key_padding_mask
).
sum
(
dim
=-
1
,
dtype
=
torch
.
int32
)
subset_cu_seqlens
=
F
.
pad
(
torch
.
cumsum
(
subset_seqlens
,
dim
=
0
,
dtype
=
torch
.
torch
.
int32
),
(
1
,
0
))
else
:
subset_idx
=
torch
.
nonzero
(
subset_mask
,
as_tuple
=
False
).
flatten
()
subset_seqlens
=
subset_mask
.
sum
(
dim
=-
1
,
dtype
=
torch
.
int32
)
subset_cu_seqlens
=
F
.
pad
(
torch
.
cumsum
(
subset_seqlens
,
dim
=
0
,
dtype
=
torch
.
torch
.
int32
),
(
1
,
0
))
hidden_states_subset
,
hidden_states
=
index_first_axis_residual
(
hidden_states
,
subset_idx
)
# It's ok to set max_seqlen_q to be much larger
mixer_kwargs
=
{
'x_kv'
:
hidden_states
,
'cu_seqlens'
:
subset_cu_seqlens
,
'max_seqlen'
:
max_seqlen_in_batch
,
'cu_seqlens_k'
:
cu_seqlens
,
'max_seqlen_k'
:
max_seqlen_in_batch
}
hidden_states
=
self
.
layers
[
-
1
](
hidden_states_subset
,
mixer_kwargs
=
mixer_kwargs
)
return
hidden_states
...
...
@@ -151,7 +195,8 @@ class BertPredictionHeadTransform(nn.Module):
raise
ImportError
(
'dropout_add_layer_norm is not installed'
)
linear_cls
=
nn
.
Linear
if
not
fused_bias_fc
else
FusedDenseTD
self
.
dense
=
linear_cls
(
config
.
hidden_size
,
config
.
hidden_size
)
self
.
transform_act_fn
=
nn
.
GELU
(
approximate
=
'tanh'
)
approximate
=
'tanh'
if
config
.
hidden_act
in
[
'gelu_new'
,
'gelu_fast'
]
else
'none'
self
.
transform_act_fn
=
nn
.
GELU
(
approximate
=
approximate
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -264,6 +309,11 @@ class BertModel(BertPreTrainedModel):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
masked_tokens_mask
=
None
):
"""If masked_tokens_mask is not None (i.e. last_layer_subset == True in BertForPreTraining),
we only want the output for the masked tokens. This means that we only compute the last
layer output for these tokens.
masked_tokens_mask: (batch, seqlen), dtype=torch.bool
"""
hidden_states
=
self
.
embeddings
(
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
)
# TD [2022-12:18]: Don't need to force residual in fp32
...
...
@@ -275,9 +325,38 @@ class BertModel(BertPreTrainedModel):
hidden_states
,
None
,
self
.
emb_ln
.
weight
,
self
.
emb_ln
.
bias
,
self
.
emb_drop
.
p
if
self
.
training
else
0.0
,
self
.
emb_ln
.
eps
,
prenorm
=
False
,
)
sequence_output
=
self
.
encoder
(
hidden_states
,
key_padding_mask
=
attention_mask
)
pooled_output
=
self
.
pooler
(
sequence_output
)
if
self
.
pooler
is
not
None
else
None
return
sequence_output
,
pooled_output
if
masked_tokens_mask
is
not
None
:
batch_size
,
seqlen
=
input_ids
.
shape
[:
2
]
# We also need the first column for the CLS token
first_col_mask
=
torch
.
zeros
(
batch_size
,
seqlen
,
dtype
=
torch
.
bool
,
device
=
input_ids
.
device
)
first_col_mask
[:,
0
]
=
True
subset_mask
=
masked_tokens_mask
|
first_col_mask
else
:
subset_mask
=
None
sequence_output
=
self
.
encoder
(
hidden_states
,
key_padding_mask
=
attention_mask
,
subset_mask
=
subset_mask
)
if
masked_tokens_mask
is
None
:
pooled_output
=
self
.
pooler
(
sequence_output
)
if
self
.
pooler
is
not
None
else
None
else
:
# TD [2022-03-01]: the indexing here is very tricky.
if
attention_mask
is
not
None
:
subset_idx
=
subset_mask
[
attention_mask
]
pool_input
=
sequence_output
[
first_col_mask
[
attention_mask
][
subset_idx
]]
sequence_output
=
sequence_output
[
masked_tokens_mask
[
attention_mask
][
subset_idx
]]
else
:
pool_input
=
sequence_output
[
first_col_mask
[
subset_mask
]]
sequence_output
=
sequence_output
[
masked_tokens_mask
[
subset_mask
]]
pooled_output
=
(
self
.
pooler
(
pool_input
,
pool
=
False
)
if
self
.
pooler
is
not
None
else
None
)
return
BaseModelOutputWithPoolingAndCrossAttentions
(
last_hidden_state
=
sequence_output
,
pooler_output
=
pooled_output
,
)
class
BertForPreTraining
(
BertPreTrainedModel
):
...
...
@@ -290,11 +369,13 @@ class BertForPreTraining(BertPreTrainedModel):
# If last_layer_subset, we only need the compute the last layer for a subset of tokens
# (e.g., the tokens we need to compute the masked LM loss and the next-sentence prediction).
self
.
last_layer_subset
=
getattr
(
config
,
'last_layer_subset'
,
False
)
assert
not
self
.
last_layer_subset
,
'last_layer_subset is not implemented yet'
if
self
.
last_layer_subset
:
assert
self
.
dense_seq_output
,
'last_layer_subset requires dense_seq_output'
use_xentropy
=
getattr
(
config
,
'use_xentropy'
,
False
)
if
use_xentropy
and
CrossEntropyLossApex
is
None
:
raise
ImportError
(
'xentropy_cuda is not installed'
)
loss_cls
=
nn
.
CrossEntropyLoss
if
not
use_xentropy
else
CrossEntropyLossApex
loss_cls
=
(
nn
.
CrossEntropyLoss
if
not
use_xentropy
else
partial
(
CrossEntropyLossApex
,
inplace_backward
=
True
))
self
.
bert
=
BertModel
(
config
)
self
.
cls
=
BertPreTrainingHeads
(
config
)
...
...
@@ -311,6 +392,8 @@ class BertForPreTraining(BertPreTrainedModel):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
labels
=
None
,
next_sentence_label
=
None
):
"""
If labels are provided, they must be 0 for masked out tokens (as specified in the attention
mask).
Outputs:
if `labels` and `next_sentence_label` are not `None`:
Outputs the total_loss which is the sum of the masked language modeling loss and the next
...
...
@@ -322,10 +405,12 @@ class BertForPreTraining(BertPreTrainedModel):
"""
masked_tokens_mask
=
labels
>
0
if
(
self
.
last_layer_subset
and
labels
is
not
None
)
else
None
sequence_output
,
pooled_
output
=
self
.
bert
(
output
s
=
self
.
bert
(
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
attention_mask
=
attention_mask
.
bool
(),
masked_tokens_mask
=
masked_tokens_mask
attention_mask
=
attention_mask
.
bool
()
if
attention_mask
is
not
None
else
None
,
masked_tokens_mask
=
masked_tokens_mask
)
sequence_output
,
pooled_output
=
outputs
.
last_hidden_state
,
outputs
.
pooler_output
if
self
.
dense_seq_output
and
labels
is
not
None
:
masked_token_idx
=
torch
.
nonzero
(
labels
.
flatten
()
>
0
,
as_tuple
=
False
).
flatten
()
if
not
self
.
last_layer_subset
:
...
...
@@ -333,8 +418,9 @@ class BertForPreTraining(BertPreTrainedModel):
masked_token_idx
)
prediction_scores
,
seq_relationship_score
=
self
.
cls
(
sequence_output
,
pooled_output
)
total_loss
=
None
if
labels
is
not
None
and
next_sentence_label
is
not
None
:
if
masked_token_idx
is
not
None
:
# prediction_scores are already flattened
if
self
.
dense_seq_output
and
labels
is
not
None
:
# prediction_scores are already flattened
masked_lm_loss
=
self
.
mlm_loss
(
prediction_scores
,
labels
.
flatten
()[
masked_token_idx
])
else
:
...
...
@@ -342,22 +428,13 @@ class BertForPreTraining(BertPreTrainedModel):
rearrange
(
labels
,
'... -> (...)'
))
next_sentence_loss
=
self
.
nsp_loss
(
rearrange
(
seq_relationship_score
,
'... t -> (...) t'
),
rearrange
(
next_sentence_label
,
'... -> (...)'
))
total_loss
=
(
masked_lm_loss
+
next_sentence_loss
).
float
()
# Masked Language Model Accuracy
masked_lm_labels_flat
=
labels
.
view
(
-
1
)
mlm_labels
=
masked_lm_labels_flat
[
masked_lm_labels_flat
!=
0
]
if
not
self
.
dense_seq_output
:
prediction_scores_flat
=
rearrange
(
prediction_scores
,
'... v -> (...) v'
)
mlm_predictions_scores
=
prediction_scores_flat
[
masked_lm_labels_flat
!=
0
]
mlm_predictions
=
mlm_predictions_scores
.
argmax
(
dim
=-
1
)
else
:
mlm_predictions
=
prediction_scores
.
argmax
(
dim
=-
1
)
mlm_acc
=
(
mlm_predictions
==
mlm_labels
).
sum
(
dtype
=
torch
.
float
)
/
mlm_labels
.
numel
()
total_loss
=
masked_lm_loss
.
float
()
+
next_sentence_loss
.
float
()
return
total_loss
,
prediction_scores
,
seq_relationship_score
,
mlm_acc
,
mlm_labels
.
numel
()
else
:
return
prediction_scores
,
seq_relationship_score
return
BertForPreTrainingOutput
(
loss
=
total_loss
,
prediction_logits
=
prediction_scores
,
seq_relationship_logits
=
seq_relationship_score
,
)
def
state_dict_from_pretrained
(
model_name
):
...
...
@@ -401,6 +478,7 @@ def remap_state_dict(state_dict, config):
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
last_layer_subset
=
getattr
(
config
,
'last_layer_subset'
,
False
)
for
d
in
range
(
config
.
num_hidden_layers
):
Wq
=
state_dict
.
pop
(
f
'bert.encoder.layers.
{
d
}
.attention.self.query.weight'
)
Wk
=
state_dict
.
pop
(
f
'bert.encoder.layers.
{
d
}
.attention.self.key.weight'
)
...
...
@@ -408,12 +486,22 @@ def remap_state_dict(state_dict, config):
bq
=
state_dict
.
pop
(
f
'bert.encoder.layers.
{
d
}
.attention.self.query.bias'
)
bk
=
state_dict
.
pop
(
f
'bert.encoder.layers.
{
d
}
.attention.self.key.bias'
)
bv
=
state_dict
.
pop
(
f
'bert.encoder.layers.
{
d
}
.attention.self.value.bias'
)
state_dict
[
f
'bert.encoder.layers.
{
d
}
.mixer.Wqkv.weight'
]
=
torch
.
cat
(
[
Wq
,
Wk
,
Wv
],
dim
=
0
)
state_dict
[
f
'bert.encoder.layers.
{
d
}
.mixer.Wqkv.bias'
]
=
torch
.
cat
(
[
bq
,
bk
,
bv
],
dim
=
0
)
if
not
(
last_layer_subset
and
d
==
config
.
num_hidden_layers
-
1
):
state_dict
[
f
'bert.encoder.layers.
{
d
}
.mixer.Wqkv.weight'
]
=
torch
.
cat
(
[
Wq
,
Wk
,
Wv
],
dim
=
0
)
state_dict
[
f
'bert.encoder.layers.
{
d
}
.mixer.Wqkv.bias'
]
=
torch
.
cat
(
[
bq
,
bk
,
bv
],
dim
=
0
)
else
:
state_dict
[
f
'bert.encoder.layers.
{
d
}
.mixer.Wq.weight'
]
=
Wq
state_dict
[
f
'bert.encoder.layers.
{
d
}
.mixer.Wkv.weight'
]
=
torch
.
cat
(
[
Wk
,
Wv
],
dim
=
0
)
state_dict
[
f
'bert.encoder.layers.
{
d
}
.mixer.Wq.bias'
]
=
bq
state_dict
[
f
'bert.encoder.layers.
{
d
}
.mixer.Wkv.bias'
]
=
torch
.
cat
(
[
bk
,
bv
],
dim
=
0
)
def
key_mapping_attn
(
key
):
return
re
.
sub
(
r
'^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)'
,
r
'bert.encoder.layers.\1.mixer.out_proj.\2'
,
key
)
...
...
@@ -423,4 +511,23 @@ def remap_state_dict(state_dict, config):
return
re
.
sub
(
r
'^cls.predictions.bias'
,
'cls.predictions.decoder.bias'
,
key
)
state_dict
=
OrderedDict
((
key_mapping_decoder_bias
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
if
pad_vocab_size_multiple
>
1
:
word_embeddings
=
state_dict
[
'bert.embeddings.word_embeddings.weight'
]
state_dict
[
'bert.embeddings.word_embeddings.weight'
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
config
.
vocab_size
-
word_embeddings
.
shape
[
0
])
)
decoder_weight
=
state_dict
[
'cls.predictions.decoder.weight'
]
state_dict
[
'cls.predictions.decoder.weight'
]
=
F
.
pad
(
decoder_weight
,
(
0
,
0
,
0
,
config
.
vocab_size
-
decoder_weight
.
shape
[
0
])
)
# If the vocab was padded, we want to set the decoder bias for those padded indices to be
# strongly negative (i.e. the decoder shouldn't predict those indices).
# TD [2022-05-09]: I don't think it affects the MLPerf training.
decoder_bias
=
state_dict
[
'cls.predictions.decoder.bias'
]
state_dict
[
'cls.predictions.decoder.bias'
]
=
F
.
pad
(
decoder_bias
,
(
0
,
config
.
vocab_size
-
decoder_bias
.
shape
[
0
]),
value
=-
100.0
)
return
state_dict
flash_attn/modules/mha.py
View file @
13cdceb3
...
...
@@ -120,34 +120,55 @@ class FlashCrossAttention(nn.Module):
self
.
dropout_p
=
attention_dropout
self
.
triton
=
triton
def
forward
(
self
,
q
,
kv
):
def
forward
(
self
,
q
,
kv
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
cu_seqlens_k
=
None
,
max_seqlen_k
=
None
):
"""Implements the multihead softmax attention.
Arguments
---------
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
max_seqlen: int. Maximum sequence length in the batch of q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_k: int. Maximum sequence length in the batch of k and v.
"""
assert
q
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
assert
q
.
is_cuda
and
kv
.
is_cuda
batch_size
,
seqlen_q
=
q
.
shape
[
0
],
q
.
shape
[
1
]
seqlen_k
=
kv
.
shape
[
1
]
assert
kv
.
shape
[
0
]
==
batch_size
and
kv
.
shape
[
3
]
==
q
.
shape
[
2
]
and
kv
.
shape
[
4
]
==
q
.
shape
[
3
]
if
self
.
triton
and
(
self
.
dropout_p
==
0.0
or
not
self
.
training
):
# Triton version doesn't support dropout
output
=
flash_attn_kvpacked_func
(
q
,
kv
,
None
,
self
.
causal
,
self
.
softmax_scale
)
else
:
q
=
rearrange
(
q
,
'b s ... -> (b s) ...'
)
kv
=
rearrange
(
kv
,
'b s ... -> (b s) ...'
)
cu_seqlens_q
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_q
,
step
=
seqlen_q
,
dtype
=
torch
.
int32
,
device
=
q
.
device
)
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_k
,
step
=
seqlen_k
,
dtype
=
torch
.
int32
,
device
=
kv
.
device
)
output
=
flash_attn_unpadded_kvpacked_func
(
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
seqlen_q
,
seqlen_k
,
unpadded
=
cu_seqlens
is
not
None
if
unpadded
:
assert
cu_seqlens
.
dtype
==
torch
.
int32
assert
max_seqlen
is
not
None
assert
isinstance
(
max_seqlen
,
int
)
assert
cu_seqlens_k
is
not
None
assert
cu_seqlens_k
.
dtype
==
torch
.
int32
assert
max_seqlen_k
is
not
None
assert
isinstance
(
max_seqlen
,
int
)
return
flash_attn_unpadded_kvpacked_func
(
q
,
kv
,
cu_seqlens
,
cu_seqlens_k
,
max_seqlen
,
max_seqlen_k
,
self
.
dropout_p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
self
.
causal
)
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
return
output
else
:
batch_size
,
seqlen_q
=
q
.
shape
[
0
],
q
.
shape
[
1
]
seqlen_k
=
kv
.
shape
[
1
]
assert
kv
.
shape
[
0
]
==
batch_size
and
kv
.
shape
[
3
]
==
q
.
shape
[
2
]
and
kv
.
shape
[
4
]
==
q
.
shape
[
3
]
if
self
.
triton
and
(
self
.
dropout_p
==
0.0
or
not
self
.
training
):
# Triton version doesn't support dropout
output
=
flash_attn_kvpacked_func
(
q
,
kv
,
None
,
self
.
causal
,
self
.
softmax_scale
)
else
:
q
=
rearrange
(
q
,
'b s ... -> (b s) ...'
)
kv
=
rearrange
(
kv
,
'b s ... -> (b s) ...'
)
cu_seqlens_q
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_q
,
step
=
seqlen_q
,
dtype
=
torch
.
int32
,
device
=
q
.
device
)
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_k
,
step
=
seqlen_k
,
dtype
=
torch
.
int32
,
device
=
kv
.
device
)
output
=
flash_attn_unpadded_kvpacked_func
(
q
,
kv
,
cu_seqlens_q
,
cu_seqlens_k
,
seqlen_q
,
seqlen_k
,
self
.
dropout_p
if
self
.
training
else
0.0
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
self
.
causal
)
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
return
output
class
SelfAttention
(
nn
.
Module
):
...
...
@@ -214,12 +235,14 @@ class CrossAttention(nn.Module):
self
.
softmax_scale
=
softmax_scale
self
.
dropout_p
=
attention_dropout
def
forward
(
self
,
q
,
kv
):
def
forward
(
self
,
q
,
kv
,
key_padding_mask
=
None
):
"""Implements the multihead softmax attention.
Arguments
---------
q: The tensor containing the query. (B, Sq, H, D)
kv: The tensor containing the key and value. (B, Sk, 2, H, D)
key_padding_mask: boolean mask to apply to the attention weights. True means to keep,
False means to mask out. (B, Sk)
"""
batch_size
,
seqlen_q
=
q
.
shape
[
0
],
q
.
shape
[
1
]
seqlen_k
=
kv
.
shape
[
1
]
...
...
@@ -227,6 +250,12 @@ class CrossAttention(nn.Module):
k
,
v
=
kv
.
unbind
(
dim
=
2
)
softmax_scale
=
self
.
softmax_scale
or
1.0
/
math
.
sqrt
(
q
.
shape
[
-
1
])
scores
=
torch
.
einsum
(
'bthd,bshd->bhts'
,
q
,
k
*
softmax_scale
)
if
key_padding_mask
is
not
None
:
padding_mask
=
torch
.
full
((
batch_size
,
seqlen_k
),
-
10000.0
,
dtype
=
scores
.
dtype
,
device
=
scores
.
device
)
padding_mask
.
masked_fill_
(
key_padding_mask
,
0.0
)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores
=
scores
+
rearrange
(
padding_mask
,
'b s -> b 1 1 s'
)
if
self
.
causal
:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
...
...
@@ -295,9 +324,11 @@ class MHA(nn.Module):
groups
=
3
*
embed_dim
)
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
else
:
# TODO: use the residual linear class for Wq
self
.
Wq
=
linear_cls
(
embed_dim
,
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
self
.
Wkv
=
linear_cls
(
embed_dim
,
2
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
if
not
self
.
return_residual
:
self
.
Wkv
=
linear_cls
(
embed_dim
,
2
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
else
:
self
.
Wkv
=
linear_resid_cls
(
embed_dim
,
2
*
embed_dim
,
bias
=
bias
,
**
factory_kwargs
)
if
self
.
dwconv
:
self
.
dwconv_q
=
nn
.
Conv1d
(
embed_dim
,
embed_dim
,
kernel_size
=
3
,
padding
=
2
,
groups
=
embed_dim
)
...
...
@@ -309,7 +340,8 @@ class MHA(nn.Module):
# output projection always have the bias (for now)
self
.
out_proj
=
linear_cls
(
embed_dim
,
embed_dim
,
**
factory_kwargs
)
def
forward
(
self
,
x
,
x_kv
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
key_padding_mask
=
None
):
def
forward
(
self
,
x
,
x_kv
=
None
,
key_padding_mask
=
None
,
cu_seqlens
=
None
,
max_seqlen
=
None
,
**
kwargs
):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if
...
...
@@ -327,17 +359,15 @@ class MHA(nn.Module):
assert
max_seqlen
is
not
None
assert
key_padding_mask
is
None
assert
self
.
use_flash_attn
assert
not
self
.
cross_attn
,
(
'Unpadded FlashAttention code path for cross-attention'
'is not implemented yet'
)
assert
not
self
.
dwconv
assert
self
.
rotary_emb_dim
==
0
if
key_padding_mask
is
not
None
:
assert
cu_seqlens
is
None
assert
max_seqlen
is
None
assert
not
self
.
use_flash_attn
assert
not
self
.
cross_attn
,
(
'Key padding mask code path for cross-attention'
'is not implemented yet'
)
kwargs
=
({
'cu_seqlens'
:
cu_seqlens
,
'max_seqlen'
:
max_seqlen
,
**
kwargs
}
if
self
.
use_flash_attn
else
{
'key_padding_mask'
:
key_padding_mask
,
**
kwargs
})
if
not
self
.
cross_attn
:
if
not
self
.
return_residual
:
qkv
=
self
.
Wqkv
(
x
)
...
...
@@ -349,24 +379,30 @@ class MHA(nn.Module):
qkv
=
rearrange
(
qkv
,
'... (three h d) -> ... three h d'
,
three
=
3
,
h
=
self
.
num_heads
)
if
self
.
rotary_emb_dim
>
0
:
qkv
=
self
.
rotary_emb
(
qkv
)
extra_kwargs
=
({
'cu_seqlens'
:
cu_seqlens
,
'max_seqlen'
:
max_seqlen
}
if
self
.
use_flash_attn
else
{
'key_padding_mask'
:
key_padding_mask
})
if
not
self
.
checkpointing
:
context
=
self
.
inner_attn
(
qkv
,
**
extra_
kwargs
)
context
=
self
.
inner_attn
(
qkv
,
**
kwargs
)
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
extra_
kwargs
)
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
else
:
q
=
rearrange
(
self
.
Wq
(
x
),
'b s (h d) -> b s h d'
,
h
=
self
.
num_heads
)
kv
=
rearrange
(
self
.
Wkv
(
x
if
x_kv
is
None
else
x_kv
),
'b s (two h d) -> b s two h d'
,
two
=
2
,
h
=
self
.
num_heads
)
if
not
self
.
return_residual
:
q
=
self
.
Wq
(
x
)
kv
=
self
.
Wkv
(
x_kv
if
x_kv
is
not
None
else
x
)
else
:
if
x_kv
is
not
None
:
kv
,
x_kv
=
self
.
Wkv
(
x_kv
)
else
:
kv
,
x
=
self
.
Wkv
(
x
)
q
=
self
.
Wq
(
x
)
q
=
rearrange
(
q
,
'... (h d) -> ... h d'
,
h
=
self
.
num_heads
)
kv
=
rearrange
(
kv
,
'... (two h d) -> ... two h d'
,
two
=
2
,
h
=
self
.
num_heads
)
if
self
.
dwconv
:
q
=
rearrange
(
self
.
dwconv_q
(
rearrange
(
q
,
'b s d -> b d s'
))[...,
:
-
2
],
'b d s -> b s d'
).
contiguous
()
kv
=
rearrange
(
self
.
dwconv_kv
(
rearrange
(
kv
,
'b s d -> b d s'
))[...,
:
-
2
],
'b d s -> b s d'
).
contiguous
()
if
not
self
.
checkpointing
:
context
=
self
.
inner_attn
(
q
,
kv
)
context
=
self
.
inner_attn
(
q
,
kv
,
**
kwargs
)
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
q
,
kv
)
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
q
,
kv
,
**
kwargs
)
out
=
self
.
out_proj
(
rearrange
(
context
,
'... h d -> ... (h d)'
))
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
flash_attn/modules/mlp.py
View file @
13cdceb3
...
...
@@ -15,20 +15,21 @@ except ImportError:
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
activation
=
F
.
gelu
,
device
=
None
,
dtype
=
None
):
return_residual
=
False
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
return_residual
=
return_residual
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
,
**
factory_kwargs
)
self
.
activation
=
activation
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
,
**
factory_kwargs
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
activation
(
x
)
x
=
self
.
fc2
(
x
)
return
x
y
=
self
.
fc1
(
x
)
y
=
self
.
activation
(
y
)
y
=
self
.
fc2
(
y
)
return
y
if
not
self
.
return_residual
else
(
y
,
x
)
class
FusedDenseGeluDense
(
nn
.
Module
):
...
...
tests/models/test_bert.py
View file @
13cdceb3
...
...
@@ -53,15 +53,12 @@ def test_bert_non_optimized(model_name):
"""
dtype
=
torch
.
float16
config
=
BertConfig
.
from_pretrained
(
model_name
)
# Our implementation assumes the activation is nn.GELU(approximate='tanh')
# Huggingface calls it "gelu_new" or "gelu_fast".
config
.
hidden_act
=
"gelu_new"
model
=
BertForPreTraining
.
from_pretrained
(
model_name
,
config
)
model
=
model
.
cuda
().
to
(
dtype
=
dtype
)
model_ref
=
get_hf_models
(
model_name
,
config
,
torch
.
float32
)
model_hf
=
get_hf_models
(
model_name
,
config
,
torch
.
float16
)
model_hf
=
get_hf_models
(
model_name
,
config
,
dtype
)
model
.
eval
()
model_ref
.
eval
()
...
...
@@ -74,7 +71,8 @@ def test_bert_non_optimized(model_name):
attention_mask
=
torch
.
arange
(
max_seqlen
,
device
=
'cuda'
)[
None
,
:]
<
seqlens
[:,
None
]
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
'cuda'
)
sequence_output
,
pooled_output
=
model
.
bert
(
input_ids
,
attention_mask
=
attention_mask
)
out
=
model
.
bert
(
input_ids
,
attention_mask
=
attention_mask
)
sequence_output
,
pooled_output
=
out
.
last_hidden_state
,
out
.
pooler_output
out_hf
=
model_hf
.
bert
(
input_ids
,
attention_mask
=
attention_mask
)
sequence_output_hf
,
pooled_output_hf
=
out_hf
.
last_hidden_state
,
out_hf
.
pooler_output
out_ref
=
model_ref
.
bert
(
input_ids
,
attention_mask
=
attention_mask
)
...
...
@@ -84,8 +82,8 @@ def test_bert_non_optimized(model_name):
print
(
f
'Output mean diff:
{
(
sequence_output
-
sequence_output_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
sequence_output_hf
-
sequence_output_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
sequence_output_hf
-
sequence_output_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
sequence_output
-
sequence_output_ref
).
abs
().
max
().
item
()
<
2
*
(
sequence_output_hf
-
sequence_output_ref
).
abs
().
max
().
item
()
assert
(
pooled_output
-
pooled_output_ref
).
abs
().
max
().
item
()
<
2
*
(
pooled_output_hf
-
pooled_output_ref
).
abs
().
max
().
item
()
assert
(
sequence_output
-
sequence_output_ref
).
abs
().
max
().
item
()
<
3
*
(
sequence_output_hf
-
sequence_output_ref
).
abs
().
max
().
item
()
assert
(
pooled_output
-
pooled_output_ref
).
abs
().
max
().
item
()
<
3
*
(
pooled_output_hf
-
pooled_output_ref
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"bert-base-uncased"
,
"bert-large-uncased"
])
...
...
@@ -97,8 +95,9 @@ def test_bert_optimized(model_name):
"""
dtype
=
torch
.
float16
config
=
BertConfig
.
from_pretrained
(
model_name
)
# Our implementation assumes the activation is nn.GELU(approximate='tanh')
# Huggingface calls it "gelu_new" or "gelu_fast".
# Our implementation of fused_dense_gelu_dense assumes the activation is
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
# If you just want "gelu", disable fused_dense_gelu_dense.
config
.
hidden_act
=
"gelu_new"
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
...
...
@@ -109,7 +108,7 @@ def test_bert_optimized(model_name):
model
=
model
.
cuda
().
to
(
dtype
=
dtype
)
model_ref
=
get_hf_models
(
model_name
,
config
,
torch
.
float32
)
model_hf
=
get_hf_models
(
model_name
,
config
,
torch
.
float16
)
model_hf
=
get_hf_models
(
model_name
,
config
,
dtype
)
model
.
eval
()
model_ref
.
eval
()
...
...
@@ -122,7 +121,8 @@ def test_bert_optimized(model_name):
attention_mask
=
torch
.
arange
(
max_seqlen
,
device
=
'cuda'
)[
None
,
:]
<
seqlens
[:,
None
]
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
'cuda'
)
sequence_output
,
pooled_output
=
model
.
bert
(
input_ids
,
attention_mask
=
attention_mask
)
out
=
model
.
bert
(
input_ids
,
attention_mask
=
attention_mask
)
sequence_output
,
pooled_output
=
out
.
last_hidden_state
,
out
.
pooler_output
out_hf
=
model_hf
.
bert
(
input_ids
,
attention_mask
=
attention_mask
)
sequence_output_hf
,
pooled_output_hf
=
out_hf
.
last_hidden_state
,
out_hf
.
pooler_output
# Need to zero out the padded tokens in the sequence before comparison.
...
...
@@ -138,7 +138,8 @@ def test_bert_optimized(model_name):
assert
(
sequence_output
-
sequence_output_ref
).
abs
().
max
().
item
()
<
4
*
(
sequence_output_hf
-
sequence_output_ref
).
abs
().
max
().
item
()
assert
(
pooled_output
-
pooled_output_ref
).
abs
().
max
().
item
()
<
4
*
(
pooled_output_hf
-
pooled_output_ref
).
abs
().
max
().
item
()
prediction_scores
,
seq_relationship_scores
=
model
(
input_ids
,
attention_mask
=
attention_mask
)
out
=
model
(
input_ids
,
attention_mask
=
attention_mask
)
prediction_scores
,
seq_relationship_scores
=
out
.
prediction_logits
,
out
.
seq_relationship_logits
# Need to zero out the padded tokens in the sequence before comparison.
prediction_scores
=
prediction_scores
.
clone
()
prediction_scores
[
~
attention_mask
,
:]
=
0.0
...
...
@@ -157,30 +158,36 @@ def test_bert_optimized(model_name):
assert
(
seq_relationship_scores
-
seq_relationship_scores_ref
).
abs
().
max
().
item
()
<
2
*
(
seq_relationship_scores_hf
-
seq_relationship_scores_ref
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
'last_layer_subset'
,
[
False
,
True
])
# @pytest.mark.parametrize('last_layer_subset', [True])
@
pytest
.
mark
.
parametrize
(
'has_key_padding_mask'
,
[
True
,
False
])
# @pytest.mark.parametrize('has_key_padding_mask', [True])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"bert-base-uncased"
,
"bert-large-uncased"
])
# @pytest.mark.parametrize('model_name', ["bert-base-uncased"])
def
test_bert_dense_seq_output
(
model_name
):
def
test_bert_dense_seq_output
(
model_name
,
has_key_padding_mask
,
last_layer_subset
):
"""Check that our implementation of BERT (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype
=
torch
.
float16
config
=
BertConfig
.
from_pretrained
(
model_name
)
# Our implementation assumes the activation is nn.GELU(approximate='tanh')
# Huggingface calls it "gelu_new" or "gelu_fast".
# Our implementation of fused_dense_gelu_dense assumes the activation is
# nn.GELU(approximate='tanh'). Huggingface calls it "gelu_new" or "gelu_fast".
# If you just want "gelu", disable fused_dense_gelu_dense.
config
.
hidden_act
=
"gelu_new"
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_dense_gelu_dense
=
True
config
.
fused_dropout_add_ln
=
True
config
.
dense_seq_output
=
True
config
.
last_layer_subset
=
last_layer_subset
config
.
use_xentropy
=
True
model
=
BertForPreTraining
.
from_pretrained
(
model_name
,
config
)
model
=
model
.
cuda
().
to
(
dtype
=
dtype
)
model_ref
=
get_hf_models
(
model_name
,
config
,
torch
.
float32
)
model_hf
=
get_hf_models
(
model_name
,
config
,
torch
.
float16
)
model_hf
=
get_hf_models
(
model_name
,
config
,
dtype
)
model
.
eval
()
model_ref
.
eval
()
...
...
@@ -190,19 +197,25 @@ def test_bert_dense_seq_output(model_name):
batch_size
=
4
max_seqlen
=
512
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
'cuda'
)
attention_mask
=
torch
.
arange
(
max_seqlen
,
device
=
'cuda'
)[
None
,
:]
<
seqlens
[:,
None
]
if
has_key_padding_mask
:
attention_mask
=
torch
.
arange
(
max_seqlen
,
device
=
'cuda'
)[
None
,
:]
<
seqlens
[:,
None
]
else
:
attention_mask
=
None
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
'cuda'
)
labels
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
'cuda'
)
labels
[(
torch
.
rand
(
batch_size
,
max_seqlen
,
device
=
'cuda'
)
<
0.15
)
|
~
attention_mask
]
=
0
if
attention_mask
is
not
None
:
labels
[
~
attention_mask
]
=
0
labels
[(
torch
.
rand
(
batch_size
,
max_seqlen
,
device
=
'cuda'
)
>
0.15
)]
=
0
masked_tokens_mask
=
labels
.
flatten
()
>
0
next_sequence_label
=
torch
.
randint
(
0
,
2
,
(
batch_size
,),
device
=
'cuda'
)
total_loss
,
prediction_scores
,
seq_relationship_scores
,
_
,
_
=
model
(
out
=
model
(
input_ids
,
attention_mask
=
attention_mask
,
labels
=
labels
,
next_sentence_label
=
next_sequence_label
)
prediction_scores
,
seq_relationship_scores
=
out
.
prediction_logits
,
out
.
seq_relationship_logits
out_hf
=
model_hf
(
input_ids
,
attention_mask
=
attention_mask
,
labels
=
labels
,
next_sentence_label
=
next_sequence_label
)
prediction_scores_hf
,
seq_relationship_scores_hf
=
out_hf
.
prediction_logits
,
out_hf
.
seq_relationship_logits
...
...
@@ -217,3 +230,6 @@ def test_bert_dense_seq_output(model_name):
print
(
f
'HF fp16 prediction_scoresff:
{
(
prediction_scores_hf
-
prediction_scores_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 prediction_scoresiff:
{
(
prediction_scores_hf
-
prediction_scores_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
prediction_scores
-
prediction_scores_ref
).
abs
().
max
().
item
()
<
2
*
(
prediction_scores_hf
-
prediction_scores_ref
).
abs
().
max
().
item
()
assert
(
seq_relationship_scores
-
seq_relationship_scores_ref
).
abs
().
max
().
item
()
<
2
*
(
seq_relationship_scores_hf
-
seq_relationship_scores_ref
).
abs
().
max
().
item
()
# The loss calculation from HF is wrong: it doesn't ignore the labels that are 0.
# assert (out.loss - out_ref.loss).abs().max().item() < 2 * (out_hf.loss - out_ref.loss).abs().max().item()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment