Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4fca874e
Unverified
Commit
4fca874e
authored
Aug 25, 2020
by
Jay
Committed by
GitHub
Aug 25, 2020
Browse files
Remove hard-coded uses of float32 to fix mixed precision use (#6648)
parent
0344428f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
12 deletions
+13
-12
src/transformers/modeling_tf_bert.py
src/transformers/modeling_tf_bert.py
+6
-5
src/transformers/modeling_tf_electra.py
src/transformers/modeling_tf_electra.py
+7
-7
No files found.
src/transformers/modeling_tf_bert.py
View file @
4fca874e
...
...
@@ -215,8 +215,8 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
if
inputs_embeds
is
None
:
inputs_embeds
=
tf
.
gather
(
self
.
word_embeddings
,
input_ids
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
position_embeddings
=
tf
.
cast
(
self
.
position_embeddings
(
position_ids
)
,
inputs_embeds
.
dtype
)
token_type_embeddings
=
tf
.
cast
(
self
.
token_type_embeddings
(
token_type_ids
)
,
inputs_embeds
.
dtype
)
embeddings
=
inputs_embeds
+
position_embeddings
+
token_type_embeddings
embeddings
=
self
.
LayerNorm
(
embeddings
)
embeddings
=
self
.
dropout
(
embeddings
,
training
=
training
)
...
...
@@ -281,7 +281,7 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
attention_scores
=
tf
.
matmul
(
query_layer
,
key_layer
,
transpose_b
=
True
)
# (batch size, num_heads, seq_len_q, seq_len_k)
dk
=
tf
.
cast
(
shape_list
(
key_layer
)[
-
1
],
tf
.
float32
)
# scale attention_scores
dk
=
tf
.
cast
(
shape_list
(
key_layer
)[
-
1
],
attention_scores
.
dtype
)
# scale attention_scores
attention_scores
=
attention_scores
/
tf
.
math
.
sqrt
(
dk
)
if
attention_mask
is
not
None
:
...
...
@@ -613,6 +613,8 @@ class TFBertMainLayer(tf.keras.layers.Layer):
if
token_type_ids
is
None
:
token_type_ids
=
tf
.
fill
(
input_shape
,
0
)
embedding_output
=
self
.
embeddings
(
input_ids
,
position_ids
,
token_type_ids
,
inputs_embeds
,
training
=
training
)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
...
...
@@ -626,7 +628,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask
=
tf
.
cast
(
extended_attention_mask
,
tf
.
float32
)
extended_attention_mask
=
tf
.
cast
(
extended_attention_mask
,
embedding_output
.
dtype
)
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
# Prepare head mask if needed
...
...
@@ -640,7 +642,6 @@ class TFBertMainLayer(tf.keras.layers.Layer):
head_mask
=
[
None
]
*
self
.
num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
embedding_output
=
self
.
embeddings
(
input_ids
,
position_ids
,
token_type_ids
,
inputs_embeds
,
training
=
training
)
encoder_outputs
=
self
.
encoder
(
embedding_output
,
extended_attention_mask
,
...
...
src/transformers/modeling_tf_electra.py
View file @
4fca874e
...
...
@@ -134,8 +134,8 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
if
inputs_embeds
is
None
:
inputs_embeds
=
tf
.
gather
(
self
.
word_embeddings
,
input_ids
)
position_embeddings
=
self
.
position_embeddings
(
position_ids
)
token_type_embeddings
=
self
.
token_type_embeddings
(
token_type_ids
)
position_embeddings
=
tf
.
cast
(
self
.
position_embeddings
(
position_ids
)
,
inputs_embeds
.
dtype
)
token_type_embeddings
=
tf
.
cast
(
self
.
token_type_embeddings
(
token_type_ids
)
,
inputs_embeds
.
dtype
)
embeddings
=
inputs_embeds
+
position_embeddings
+
token_type_embeddings
embeddings
=
self
.
LayerNorm
(
embeddings
)
...
...
@@ -194,7 +194,7 @@ class TFElectraPreTrainedModel(TFBertPreTrainedModel):
config_class
=
ElectraConfig
base_model_prefix
=
"electra"
def
get_extended_attention_mask
(
self
,
attention_mask
,
input_shape
):
def
get_extended_attention_mask
(
self
,
attention_mask
,
input_shape
,
dtype
):
if
attention_mask
is
None
:
attention_mask
=
tf
.
fill
(
input_shape
,
1
)
...
...
@@ -211,7 +211,7 @@ class TFElectraPreTrainedModel(TFBertPreTrainedModel):
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask
=
tf
.
cast
(
extended_attention_mask
,
tf
.
float32
)
extended_attention_mask
=
tf
.
cast
(
extended_attention_mask
,
dtype
)
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
return
extended_attention_mask
...
...
@@ -314,11 +314,11 @@ class TFElectraMainLayer(TFElectraPreTrainedModel):
if
token_type_ids
is
None
:
token_type_ids
=
tf
.
fill
(
input_shape
,
0
)
extended_attention_mask
=
self
.
get_extended_attention_mask
(
attention_mask
,
input_shape
)
head_mask
=
self
.
get_head_mask
(
head_mask
)
hidden_states
=
self
.
embeddings
(
input_ids
,
position_ids
,
token_type_ids
,
inputs_embeds
,
training
=
training
)
extended_attention_mask
=
self
.
get_extended_attention_mask
(
attention_mask
,
input_shape
,
hidden_states
.
dtype
)
head_mask
=
self
.
get_head_mask
(
head_mask
)
if
hasattr
(
self
,
"embeddings_project"
):
hidden_states
=
self
.
embeddings_project
(
hidden_states
,
training
=
training
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment