Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
714c1b4f
Commit
714c1b4f
authored
Jan 01, 2023
by
Tri Dao
Browse files
[Bert] Fix embedding layer norm before embedding dropout
parent
ef1ba918
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
8 deletions
+10
-8
flash_attn/models/bert.py
flash_attn/models/bert.py
+5
-6
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+5
-2
No files found.
flash_attn/models/bert.py
View file @
714c1b4f
...
@@ -295,7 +295,7 @@ class BertModel(BertPreTrainedModel):
...
@@ -295,7 +295,7 @@ class BertModel(BertPreTrainedModel):
config
.
vocab_size
+=
(
self
.
pad_vocab_size_multiple
config
.
vocab_size
+=
(
self
.
pad_vocab_size_multiple
-
(
config
.
vocab_size
%
self
.
pad_vocab_size_multiple
))
-
(
config
.
vocab_size
%
self
.
pad_vocab_size_multiple
))
self
.
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
)
self
.
fused_dropout_add_ln
=
getattr
(
config
,
'fused_dropout_add_ln'
,
False
)
if
self
.
fused_dropout_add_ln
and
dropout_add_
layer_norm
is
None
:
if
self
.
fused_dropout_add_ln
and
layer_norm
is
None
:
raise
ImportError
(
'dropout_add_layer_norm is not installed'
)
raise
ImportError
(
'dropout_add_layer_norm is not installed'
)
assert
config
.
position_embedding_type
==
'absolute'
assert
config
.
position_embedding_type
==
'absolute'
assert
config
.
hidden_act
in
[
'gelu'
,
'gelu_new'
,
'gelu_fast'
]
assert
config
.
hidden_act
in
[
'gelu'
,
'gelu_new'
,
'gelu_fast'
]
...
@@ -320,14 +320,13 @@ class BertModel(BertPreTrainedModel):
...
@@ -320,14 +320,13 @@ class BertModel(BertPreTrainedModel):
hidden_states
=
self
.
embeddings
(
input_ids
,
position_ids
=
position_ids
,
hidden_states
=
self
.
embeddings
(
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
)
token_type_ids
=
token_type_ids
)
# TD [2022-12:18]: Don't need to force residual in fp32
# TD [2022-12:18]: Don't need to force residual in fp32
# BERT puts embedding LayerNorm before embedding dropout.
if
not
self
.
fused_dropout_add_ln
:
if
not
self
.
fused_dropout_add_ln
:
hidden_states
=
self
.
emb_drop
(
hidden_states
)
hidden_states
=
self
.
emb_ln
(
hidden_states
)
hidden_states
=
self
.
emb_ln
(
hidden_states
)
else
:
else
:
hidden_states
=
dropout_add_layer_norm
(
hidden_states
=
layer_norm
(
hidden_states
,
self
.
emb_ln
.
weight
,
self
.
emb_ln
.
bias
,
hidden_states
,
None
,
self
.
emb_ln
.
weight
,
self
.
emb_ln
.
bias
,
self
.
emb_ln
.
eps
)
self
.
emb_drop
.
p
if
self
.
training
else
0.0
,
self
.
emb_ln
.
eps
,
prenorm
=
False
,
hidden_states
=
self
.
emb_drop
(
hidden_states
)
)
if
masked_tokens_mask
is
not
None
:
if
masked_tokens_mask
is
not
None
:
batch_size
,
seqlen
=
input_ids
.
shape
[:
2
]
batch_size
,
seqlen
=
input_ids
.
shape
[:
2
]
...
...
flash_attn/models/gpt.py
View file @
714c1b4f
...
@@ -220,6 +220,9 @@ class GPTModel(GPTPreTrainedModel):
...
@@ -220,6 +220,9 @@ class GPTModel(GPTPreTrainedModel):
self
.
apply
(
partial
(
_init_weights
,
n_layer
=
config
.
num_hidden_layers
,
self
.
apply
(
partial
(
_init_weights
,
n_layer
=
config
.
num_hidden_layers
,
initializer_range
=
config
.
initializer_range
))
initializer_range
=
config
.
initializer_range
))
self
.
tie_weights
()
def
tie_weights
(
self
):
if
self
.
process_group
is
not
None
:
if
self
.
process_group
is
not
None
:
sync_sequence_parallel_params
(
self
,
self
.
process_group
)
sync_sequence_parallel_params
(
self
,
self
.
process_group
)
...
@@ -266,11 +269,11 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
...
@@ -266,11 +269,11 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
self
.
apply
(
partial
(
_init_weights
,
n_layer
=
config
.
num_hidden_layers
,
self
.
apply
(
partial
(
_init_weights
,
n_layer
=
config
.
num_hidden_layers
,
initializer_range
=
config
.
initializer_range
))
initializer_range
=
config
.
initializer_range
))
self
.
tie_weights
()
self
.
tie_weights
()
if
self
.
process_group
is
not
None
:
sync_sequence_parallel_params
(
self
,
self
.
process_group
)
def
tie_weights
(
self
):
def
tie_weights
(
self
):
self
.
lm_head
.
weight
=
self
.
transformer
.
embeddings
.
word_embeddings
.
weight
self
.
lm_head
.
weight
=
self
.
transformer
.
embeddings
.
word_embeddings
.
weight
if
self
.
process_group
is
not
None
:
sync_sequence_parallel_params
(
self
,
self
.
process_group
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
inference_params
=
None
):
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
inference_params
=
None
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment