Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3e116ed3
Unverified
Commit
3e116ed3
authored
Feb 19, 2021
by
Julien Plu
Committed by
GitHub
Feb 19, 2021
Browse files
Making TF TransfoXL model compliant with AMP (#10264)
* Fix AMP * Apply style * Remove unused import
parent
86caeb76
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
24 deletions
+41
-24
src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py
src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py
+39
-18
src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py
...ers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py
+2
-2
tests/test_modeling_tf_transfo_xl.py
tests/test_modeling_tf_transfo_xl.py
+0
-4
No files found.
src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py
View file @
3e116ed3
...
@@ -59,6 +59,7 @@ class TFPositionalEmbedding(tf.keras.layers.Layer):
...
@@ -59,6 +59,7 @@ class TFPositionalEmbedding(tf.keras.layers.Layer):
self
.
inv_freq
=
1
/
(
10000
**
(
tf
.
range
(
0
,
demb
,
2.0
)
/
demb
))
self
.
inv_freq
=
1
/
(
10000
**
(
tf
.
range
(
0
,
demb
,
2.0
)
/
demb
))
def
call
(
self
,
pos_seq
,
bsz
=
None
):
def
call
(
self
,
pos_seq
,
bsz
=
None
):
self
.
inv_freq
=
tf
.
cast
(
self
.
inv_freq
,
dtype
=
pos_seq
.
dtype
)
sinusoid_inp
=
tf
.
einsum
(
"i,j->ij"
,
pos_seq
,
self
.
inv_freq
)
sinusoid_inp
=
tf
.
einsum
(
"i,j->ij"
,
pos_seq
,
self
.
inv_freq
)
pos_emb
=
tf
.
concat
([
tf
.
sin
(
sinusoid_inp
),
tf
.
cos
(
sinusoid_inp
)],
-
1
)
pos_emb
=
tf
.
concat
([
tf
.
sin
(
sinusoid_inp
),
tf
.
cos
(
sinusoid_inp
)],
-
1
)
...
@@ -186,6 +187,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
...
@@ -186,6 +187,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
qlen
,
rlen
,
bsz
=
shape_list
(
w
)[
0
],
shape_list
(
r
)[
0
],
shape_list
(
w
)[
1
]
qlen
,
rlen
,
bsz
=
shape_list
(
w
)[
0
],
shape_list
(
r
)[
0
],
shape_list
(
w
)[
1
]
if
mems
is
not
None
:
if
mems
is
not
None
:
mems
=
tf
.
cast
(
mems
,
dtype
=
w
.
dtype
)
cat
=
tf
.
concat
([
mems
,
w
],
0
)
cat
=
tf
.
concat
([
mems
,
w
],
0
)
if
self
.
pre_lnorm
:
if
self
.
pre_lnorm
:
w_heads
=
self
.
qkv_net
(
self
.
layer_norm
(
cat
))
w_heads
=
self
.
qkv_net
(
self
.
layer_norm
(
cat
))
...
@@ -227,7 +229,8 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
...
@@ -227,7 +229,8 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
# compute attention probability
# compute attention probability
if
attn_mask
is
not
None
:
if
attn_mask
is
not
None
:
attn_mask_t
=
attn_mask
[:,
:,
None
,
None
]
attn_mask_t
=
attn_mask
[:,
:,
None
,
None
]
attn_score
=
attn_score
*
(
1
-
attn_mask_t
)
-
1e30
*
attn_mask_t
attn_mask_t
=
tf
.
cast
(
attn_mask_t
,
dtype
=
attn_score
.
dtype
)
attn_score
=
attn_score
*
(
1.0
-
attn_mask_t
)
-
1e30
*
attn_mask_t
# [qlen x klen x bsz x n_head]
# [qlen x klen x bsz x n_head]
attn_prob
=
tf
.
nn
.
softmax
(
attn_score
,
axis
=
1
)
attn_prob
=
tf
.
nn
.
softmax
(
attn_score
,
axis
=
1
)
...
@@ -313,6 +316,27 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
...
@@ -313,6 +316,27 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
return
outputs
return
outputs
class
TFTransfoEmbeddings
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
vocab_size
,
emb_size
,
init_std
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
vocab_size
=
vocab_size
self
.
emb_size
=
emb_size
self
.
init_std
=
init_std
def
build
(
self
,
input_shape
):
self
.
weight
=
self
.
add_weight
(
shape
=
(
self
.
vocab_size
,
self
.
emb_size
),
initializer
=
get_initializer
(
self
.
init_std
),
name
=
"embeddings"
,
)
super
().
build
(
input_shape
)
def
call
(
self
,
inputs
):
return
tf
.
gather
(
self
.
weight
,
inputs
)
class
TFAdaptiveEmbedding
(
tf
.
keras
.
layers
.
Layer
):
class
TFAdaptiveEmbedding
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
div_val
=
1
,
init_std
=
0.02
,
sample_softmax
=
False
,
**
kwargs
):
def
__init__
(
self
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
div_val
=
1
,
init_std
=
0.02
,
sample_softmax
=
False
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
...
@@ -331,6 +355,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
...
@@ -331,6 +355,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
self
.
emb_layers
=
[]
self
.
emb_layers
=
[]
self
.
emb_projs
=
[]
self
.
emb_projs
=
[]
if
div_val
==
1
:
if
div_val
==
1
:
raise
NotImplementedError
# Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
raise
NotImplementedError
# Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
else
:
else
:
...
@@ -338,10 +363,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
...
@@ -338,10 +363,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
d_emb_i
=
d_embed
//
(
div_val
**
i
)
d_emb_i
=
d_embed
//
(
div_val
**
i
)
self
.
emb_layers
.
append
(
self
.
emb_layers
.
append
(
tf
.
keras
.
layers
.
Embedding
(
TFTransfo
Embedding
s
(
r_idx
-
l_idx
,
r_idx
-
l_idx
,
d_emb_i
,
d_emb_i
,
embeddings_initializer
=
get_initializer
(
init_std
)
,
init_std
,
name
=
"emb_layers_._{}"
.
format
(
i
),
name
=
"emb_layers_._{}"
.
format
(
i
),
)
)
)
)
...
@@ -357,6 +382,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
...
@@ -357,6 +382,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
name
=
"emb_projs_._{}"
.
format
(
i
),
name
=
"emb_projs_._{}"
.
format
(
i
),
)
)
)
)
super
().
build
(
input_shape
)
super
().
build
(
input_shape
)
def
call
(
self
,
inp
):
def
call
(
self
,
inp
):
...
@@ -374,8 +400,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
...
@@ -374,8 +400,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
emb_i
=
self
.
emb_layers
[
i
](
inp_i
)
emb_i
=
self
.
emb_layers
[
i
](
inp_i
)
emb_i
=
tf
.
einsum
(
"id,de->ie"
,
emb_i
,
self
.
emb_projs
[
i
])
emb_i
=
tf
.
einsum
(
"id,de->ie"
,
emb_i
,
self
.
emb_projs
[
i
])
mask_idx
=
tf
.
cast
(
tf
.
where
(
mask_i
),
dtype
=
tf
.
int64
)
mask_idx
=
tf
.
where
(
mask_i
)
emb_flat
+=
tf
.
scatter_nd
(
mask_idx
,
emb_i
,
tf
.
cast
(
shape_list
(
emb_flat
),
dtype
=
tf
.
int64
))
scatter
=
tf
.
scatter_nd
(
mask_idx
,
emb_i
,
shape_list
(
emb_flat
))
emb_flat
=
tf
.
cast
(
emb_flat
,
dtype
=
scatter
.
dtype
)
emb_flat
+=
scatter
embed_shape
=
shape_list
(
inp
)
+
[
self
.
d_proj
]
embed_shape
=
shape_list
(
inp
)
+
[
self
.
d_proj
]
embed
=
tf
.
reshape
(
emb_flat
,
embed_shape
)
embed
=
tf
.
reshape
(
emb_flat
,
embed_shape
)
...
@@ -501,7 +529,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
...
@@ -501,7 +529,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
end_idx
=
mlen
+
tf
.
math
.
maximum
(
0
,
qlen
)
end_idx
=
mlen
+
tf
.
math
.
maximum
(
0
,
qlen
)
beg_idx
=
tf
.
math
.
maximum
(
0
,
end_idx
-
tf
.
convert_to_tensor
(
self
.
mem_len
))
beg_idx
=
tf
.
math
.
maximum
(
0
,
end_idx
-
tf
.
convert_to_tensor
(
self
.
mem_len
))
for
i
in
range
(
len
(
hids
)):
for
i
in
range
(
len
(
hids
)):
mems
[
i
]
=
tf
.
cast
(
mems
[
i
],
dtype
=
hids
[
i
].
dtype
)
cat
=
tf
.
concat
([
mems
[
i
],
hids
[
i
]],
axis
=
0
)
cat
=
tf
.
concat
([
mems
[
i
],
hids
[
i
]],
axis
=
0
)
tf
.
stop_gradient
(
cat
)
tf
.
stop_gradient
(
cat
)
new_mems
.
append
(
cat
[
beg_idx
:
end_idx
])
new_mems
.
append
(
cat
[
beg_idx
:
end_idx
])
...
@@ -1113,7 +1141,6 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc
...
@@ -1113,7 +1141,6 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc
hidden_states
=
transformer_outputs
[
0
]
hidden_states
=
transformer_outputs
[
0
]
logits
=
self
.
score
(
hidden_states
)
logits
=
self
.
score
(
hidden_states
)
logits_shape
=
shape_list
(
logits
)
in_logits
=
None
in_logits
=
None
if
self
.
config
.
pad_token_id
is
None
:
if
self
.
config
.
pad_token_id
is
None
:
sequence_lengths
=
-
1
sequence_lengths
=
-
1
...
@@ -1121,22 +1148,16 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc
...
@@ -1121,22 +1148,16 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc
if
inputs
[
"input_ids"
]
is
not
None
:
if
inputs
[
"input_ids"
]
is
not
None
:
sequence_lengths
=
(
sequence_lengths
=
(
tf
.
reduce_sum
(
tf
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
not_equal
(
inputs
[
"input_ids"
],
self
.
config
.
pad_token_id
),
tf
.
int32
),
tf
.
cast
(
tf
.
math
.
not_equal
(
inputs
[
"input_ids"
],
self
.
config
.
pad_token_id
),
dtype
=
inputs
[
"input_ids"
].
dtype
,
),
-
1
,
-
1
,
keepdims
=
False
,
keepdims
=
False
,
)
)
-
1
-
1
)
)
in_logits
=
tf
.
gather
(
logits
,
sequence_lengths
,
batch_dims
=
1
,
axis
=
1
)
def
get_seq_element
(
sequence_position
,
input_batch
):
return
tf
.
strided_slice
(
input_batch
,
[
sequence_position
,
0
],
[
sequence_position
+
1
,
input_batch
.
shape
[
-
1
]],
[
1
,
1
]
)
result
=
tf
.
map_fn
(
fn
=
lambda
t
:
get_seq_element
(
t
[
0
],
t
[
1
]),
elems
=
[
sequence_lengths
,
logits
],
dtype
=
"float"
)
in_logits
=
tf
.
reshape
(
result
,
[
logits_shape
[
0
],
logits_shape
[
-
1
]])
else
:
else
:
sequence_lengths
=
-
1
sequence_lengths
=
-
1
logger
.
warning
(
logger
.
warning
(
...
...
src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py
View file @
3e116ed3
...
@@ -131,7 +131,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
...
@@ -131,7 +131,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
else
:
else
:
hidden_sizes
=
shape_list
(
hidden
)
hidden_sizes
=
shape_list
(
hidden
)
out
=
[]
out
=
[]
loss
=
tf
.
zeros
(
hidden_sizes
[:
2
]
,
dtype
=
tf
.
float32
)
loss
=
tf
.
zeros
(
hidden_sizes
[:
2
])
for
i
in
range
(
len
(
self
.
cutoffs
)):
for
i
in
range
(
len
(
self
.
cutoffs
)):
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
if
target
is
not
None
:
if
target
is
not
None
:
...
@@ -168,7 +168,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
...
@@ -168,7 +168,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
cur_logprob
=
self
.
_gather_logprob
(
cur_tail_logprob
,
cur_target
)
cur_logprob
=
self
.
_gather_logprob
(
cur_tail_logprob
,
cur_target
)
cur_logprob
+=
cur_head_logprob
[:,
self
.
cutoff_ends
[
1
]
+
i
-
1
]
cur_logprob
+=
cur_head_logprob
[:,
self
.
cutoff_ends
[
1
]
+
i
-
1
]
if
target
is
not
None
:
if
target
is
not
None
:
loss
+=
tf
.
scatter_nd
(
mask_idx
,
-
cur_logprob
,
tf
.
cast
(
shape_list
(
loss
)
,
dtype
=
tf
.
int64
)
)
loss
+=
tf
.
scatter_nd
(
mask_idx
,
-
cur_logprob
,
shape_list
(
loss
))
out
=
tf
.
concat
(
out
,
axis
=-
1
)
out
=
tf
.
concat
(
out
,
axis
=-
1
)
if
target
is
not
None
:
if
target
is
not
None
:
...
...
tests/test_modeling_tf_transfo_xl.py
View file @
3e116ed3
...
@@ -205,10 +205,6 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -205,10 +205,6 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
name
=
model
.
get_bias
()
name
=
model
.
get_bias
()
assert
name
is
None
assert
name
is
None
def
test_mixed_precision
(
self
):
# TODO JP: Make TransfoXL float16 compliant
pass
def
test_xla_mode
(
self
):
def
test_xla_mode
(
self
):
# TODO JP: Make TransfoXL XLA compliant
# TODO JP: Make TransfoXL XLA compliant
pass
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment