Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
14ed3b97
Unverified
Commit
14ed3b97
authored
Feb 18, 2021
by
Julien Plu
Committed by
GitHub
Feb 18, 2021
Browse files
Fix AMP (#10216)
parent
bdf1669e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
24 deletions
+12
-24
src/transformers/models/funnel/modeling_tf_funnel.py
src/transformers/models/funnel/modeling_tf_funnel.py
+12
-16
tests/test_modeling_tf_funnel.py
tests/test_modeling_tf_funnel.py
+0
-8
No files found.
src/transformers/models/funnel/modeling_tf_funnel.py
View file @
14ed3b97
...
...
@@ -144,7 +144,7 @@ class TFFunnelAttentionStructure:
# attention_mask and token_type_ids have shape batch_size x seq_len
self
.
pooling_mult
=
1
self
.
seq_len
=
seq_len
=
shape_list
(
inputs_embeds
)[
1
]
position_embeds
=
self
.
get_position_embeds
(
seq_len
,
dtype
=
inputs_embeds
.
dtype
,
training
=
training
)
position_embeds
=
self
.
get_position_embeds
(
seq_len
,
training
=
training
)
token_type_mat
=
self
.
token_type_ids_to_mat
(
token_type_ids
)
if
token_type_ids
is
not
None
else
None
cls_mask
=
(
tf
.
pad
(
tf
.
ones
([
seq_len
-
1
,
seq_len
-
1
],
dtype
=
inputs_embeds
.
dtype
),
[[
1
,
0
],
[
1
,
0
]])
...
...
@@ -161,7 +161,7 @@ class TFFunnelAttentionStructure:
cls_mat
=
tf
.
logical_or
(
tf
.
expand_dims
(
cls_ids
,
-
1
),
tf
.
expand_dims
(
cls_ids
,
-
2
))
return
tf
.
logical_or
(
cls_mat
,
token_type_mat
)
def
get_position_embeds
(
self
,
seq_len
,
dtype
=
tf
.
float32
,
training
=
False
):
def
get_position_embeds
(
self
,
seq_len
,
training
=
False
):
"""
Create and cache inputs related to relative position encoding. Those are very different depending on whether we
are using the factorized or the relative shift attention:
...
...
@@ -177,8 +177,8 @@ class TFFunnelAttentionStructure:
if
self
.
attention_type
==
"factorized"
:
# Notations from the paper, appending A.2.2, final formula.
# We need to create and return the matrices phi, psi, pi and omega.
pos_seq
=
tf
.
range
(
0
,
seq_len
,
1.0
,
dtype
=
dtype
)
freq_seq
=
tf
.
range
(
0
,
self
.
d_model
//
2
,
1.0
,
dtype
=
dtype
)
pos_seq
=
tf
.
range
(
0
,
seq_len
,
1.0
)
freq_seq
=
tf
.
range
(
0
,
self
.
d_model
//
2
,
1.0
)
inv_freq
=
1
/
(
10000
**
(
freq_seq
/
(
self
.
d_model
//
2
)))
sinusoid
=
tf
.
einsum
(
"i,d->id"
,
pos_seq
,
inv_freq
)
...
...
@@ -195,17 +195,17 @@ class TFFunnelAttentionStructure:
else
:
# Notations from the paper, appending A.2.1, final formula.
# We need to create and return all the possible vectors R for all blocks and shifts.
freq_seq
=
tf
.
range
(
0
,
self
.
d_model
//
2
,
1.0
,
dtype
=
dtype
)
freq_seq
=
tf
.
range
(
0
,
self
.
d_model
//
2
,
1.0
)
inv_freq
=
1
/
(
10000
**
(
freq_seq
/
(
self
.
d_model
//
2
)))
# Maximum relative positions for the first input
rel_pos_id
=
tf
.
range
(
-
seq_len
*
2
,
seq_len
*
2
,
1.0
,
dtype
=
dtype
)
rel_pos_id
=
tf
.
range
(
-
seq_len
*
2
,
seq_len
*
2
,
1.0
)
zero_offset
=
seq_len
*
tf
.
constant
(
2
)
sinusoid
=
tf
.
einsum
(
"i,d->id"
,
rel_pos_id
,
inv_freq
)
sin_embed
=
self
.
sin_dropout
(
tf
.
sin
(
sinusoid
),
training
=
training
)
cos_embed
=
self
.
cos_dropout
(
tf
.
cos
(
sinusoid
),
training
=
training
)
pos_embed
=
tf
.
concat
([
sin_embed
,
cos_embed
],
axis
=-
1
)
pos
=
tf
.
range
(
0
,
seq_len
,
dtype
=
dtype
)
pos
=
tf
.
range
(
0
,
seq_len
)
pooled_pos
=
pos
position_embeds_list
=
[]
for
block_index
in
range
(
0
,
self
.
num_blocks
):
...
...
@@ -258,7 +258,7 @@ class TFFunnelAttentionStructure:
else
:
return
pos_id
[::
2
]
def
relative_pos
(
self
,
pos
,
stride
,
pooled_pos
=
None
,
shift
=
1
.0
):
def
relative_pos
(
self
,
pos
,
stride
,
pooled_pos
=
None
,
shift
=
1
):
"""
Build the relative positional vector between `pos` and `pooled_pos`.
"""
...
...
@@ -266,7 +266,7 @@ class TFFunnelAttentionStructure:
pooled_pos
=
pos
ref_point
=
pooled_pos
[
0
]
-
pos
[
0
]
num_remove
=
shift
*
tf
.
cast
(
shape_list
(
pooled_pos
)[
0
]
,
dtype
=
ref_point
.
dtype
)
num_remove
=
shift
*
shape_list
(
pooled_pos
)[
0
]
max_dist
=
ref_point
+
num_remove
*
stride
min_dist
=
pooled_pos
[
0
]
-
pos
[
-
1
]
...
...
@@ -522,17 +522,13 @@ class TFFunnelRelMultiheadAttention(tf.keras.layers.Layer):
# merge attention scores
attn_score
=
content_score
+
positional_attn
+
token_type_attn
# precision safe in case of mixed precision training
dtype
=
attn_score
.
dtype
if
dtype
!=
tf
.
float32
:
attn_score
=
tf
.
cast
(
attn_score
,
tf
.
float32
)
# perform masking
if
attention_mask
is
not
None
:
attn_score
=
attn_score
-
INF
*
(
1
-
tf
.
cast
(
attention_mask
[:,
None
,
None
],
tf
.
float32
))
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
attn_score
.
dtype
)
attn_score
=
attn_score
-
(
INF
*
(
1
-
attention_mask
[:,
None
,
None
]))
# attention probability
attn_prob
=
tf
.
nn
.
softmax
(
attn_score
,
axis
=-
1
)
if
dtype
!=
tf
.
float32
:
attn_prob
=
tf
.
cast
(
attn_prob
,
dtype
)
attn_prob
=
self
.
attention_dropout
(
attn_prob
,
training
=
training
)
# attention output, shape batch_size x seq_len x n_head x d_head
...
...
tests/test_modeling_tf_funnel.py
View file @
14ed3b97
...
...
@@ -372,10 +372,6 @@ class TFFunnelModelTest(TFModelTesterMixin, unittest.TestCase):
# This test is too long (>30sec) and makes fail the CI
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make Funnel float16 compliant
pass
@
require_tf
class
TFFunnelBaseModelTest
(
TFModelTesterMixin
,
unittest
.
TestCase
):
...
...
@@ -407,7 +403,3 @@ class TFFunnelBaseModelTest(TFModelTesterMixin, unittest.TestCase):
def
test_saved_model_creation
(
self
):
# This test is too long (>30sec) and makes fail the CI
pass
def
test_mixed_precision
(
self
):
# TODO JP: Make Funnel float16 compliant
pass
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment