Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3e116ed3
Unverified
Commit
3e116ed3
authored
Feb 19, 2021
by
Julien Plu
Committed by
GitHub
Feb 19, 2021
Browse files
Making TF TransfoXL model compliant with AMP (#10264)
* Fix AMP * Apply style * Remove unused import
parent
86caeb76
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
24 deletions
+41
-24
src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py
src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py
+39
-18
src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py
...ers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py
+2
-2
tests/test_modeling_tf_transfo_xl.py
tests/test_modeling_tf_transfo_xl.py
+0
-4
No files found.
src/transformers/models/transfo_xl/modeling_tf_transfo_xl.py
View file @
3e116ed3
...
...
@@ -59,6 +59,7 @@ class TFPositionalEmbedding(tf.keras.layers.Layer):
self
.
inv_freq
=
1
/
(
10000
**
(
tf
.
range
(
0
,
demb
,
2.0
)
/
demb
))
def
call
(
self
,
pos_seq
,
bsz
=
None
):
self
.
inv_freq
=
tf
.
cast
(
self
.
inv_freq
,
dtype
=
pos_seq
.
dtype
)
sinusoid_inp
=
tf
.
einsum
(
"i,j->ij"
,
pos_seq
,
self
.
inv_freq
)
pos_emb
=
tf
.
concat
([
tf
.
sin
(
sinusoid_inp
),
tf
.
cos
(
sinusoid_inp
)],
-
1
)
...
...
@@ -186,6 +187,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
qlen
,
rlen
,
bsz
=
shape_list
(
w
)[
0
],
shape_list
(
r
)[
0
],
shape_list
(
w
)[
1
]
if
mems
is
not
None
:
mems
=
tf
.
cast
(
mems
,
dtype
=
w
.
dtype
)
cat
=
tf
.
concat
([
mems
,
w
],
0
)
if
self
.
pre_lnorm
:
w_heads
=
self
.
qkv_net
(
self
.
layer_norm
(
cat
))
...
...
@@ -227,7 +229,8 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
# compute attention probability
if
attn_mask
is
not
None
:
attn_mask_t
=
attn_mask
[:,
:,
None
,
None
]
attn_score
=
attn_score
*
(
1
-
attn_mask_t
)
-
1e30
*
attn_mask_t
attn_mask_t
=
tf
.
cast
(
attn_mask_t
,
dtype
=
attn_score
.
dtype
)
attn_score
=
attn_score
*
(
1.0
-
attn_mask_t
)
-
1e30
*
attn_mask_t
# [qlen x klen x bsz x n_head]
attn_prob
=
tf
.
nn
.
softmax
(
attn_score
,
axis
=
1
)
...
...
@@ -313,6 +316,27 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
return
outputs
class
TFTransfoEmbeddings
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
vocab_size
,
emb_size
,
init_std
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
vocab_size
=
vocab_size
self
.
emb_size
=
emb_size
self
.
init_std
=
init_std
def
build
(
self
,
input_shape
):
self
.
weight
=
self
.
add_weight
(
shape
=
(
self
.
vocab_size
,
self
.
emb_size
),
initializer
=
get_initializer
(
self
.
init_std
),
name
=
"embeddings"
,
)
super
().
build
(
input_shape
)
def
call
(
self
,
inputs
):
return
tf
.
gather
(
self
.
weight
,
inputs
)
class
TFAdaptiveEmbedding
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
n_token
,
d_embed
,
d_proj
,
cutoffs
,
div_val
=
1
,
init_std
=
0.02
,
sample_softmax
=
False
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
...
...
@@ -331,6 +355,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
self
.
emb_layers
=
[]
self
.
emb_projs
=
[]
if
div_val
==
1
:
raise
NotImplementedError
# Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
else
:
...
...
@@ -338,10 +363,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
d_emb_i
=
d_embed
//
(
div_val
**
i
)
self
.
emb_layers
.
append
(
tf
.
keras
.
layers
.
Embedding
(
TFTransfo
Embedding
s
(
r_idx
-
l_idx
,
d_emb_i
,
embeddings_initializer
=
get_initializer
(
init_std
)
,
init_std
,
name
=
"emb_layers_._{}"
.
format
(
i
),
)
)
...
...
@@ -357,6 +382,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
name
=
"emb_projs_._{}"
.
format
(
i
),
)
)
super
().
build
(
input_shape
)
def
call
(
self
,
inp
):
...
...
@@ -374,8 +400,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
emb_i
=
self
.
emb_layers
[
i
](
inp_i
)
emb_i
=
tf
.
einsum
(
"id,de->ie"
,
emb_i
,
self
.
emb_projs
[
i
])
mask_idx
=
tf
.
cast
(
tf
.
where
(
mask_i
),
dtype
=
tf
.
int64
)
emb_flat
+=
tf
.
scatter_nd
(
mask_idx
,
emb_i
,
tf
.
cast
(
shape_list
(
emb_flat
),
dtype
=
tf
.
int64
))
mask_idx
=
tf
.
where
(
mask_i
)
scatter
=
tf
.
scatter_nd
(
mask_idx
,
emb_i
,
shape_list
(
emb_flat
))
emb_flat
=
tf
.
cast
(
emb_flat
,
dtype
=
scatter
.
dtype
)
emb_flat
+=
scatter
embed_shape
=
shape_list
(
inp
)
+
[
self
.
d_proj
]
embed
=
tf
.
reshape
(
emb_flat
,
embed_shape
)
...
...
@@ -501,7 +529,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
end_idx
=
mlen
+
tf
.
math
.
maximum
(
0
,
qlen
)
beg_idx
=
tf
.
math
.
maximum
(
0
,
end_idx
-
tf
.
convert_to_tensor
(
self
.
mem_len
))
for
i
in
range
(
len
(
hids
)):
mems
[
i
]
=
tf
.
cast
(
mems
[
i
],
dtype
=
hids
[
i
].
dtype
)
cat
=
tf
.
concat
([
mems
[
i
],
hids
[
i
]],
axis
=
0
)
tf
.
stop_gradient
(
cat
)
new_mems
.
append
(
cat
[
beg_idx
:
end_idx
])
...
...
@@ -1113,7 +1141,6 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc
hidden_states
=
transformer_outputs
[
0
]
logits
=
self
.
score
(
hidden_states
)
logits_shape
=
shape_list
(
logits
)
in_logits
=
None
if
self
.
config
.
pad_token_id
is
None
:
sequence_lengths
=
-
1
...
...
@@ -1121,22 +1148,16 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc
if
inputs
[
"input_ids"
]
is
not
None
:
sequence_lengths
=
(
tf
.
reduce_sum
(
tf
.
cast
(
tf
.
math
.
not_equal
(
inputs
[
"input_ids"
],
self
.
config
.
pad_token_id
),
tf
.
int32
),
tf
.
cast
(
tf
.
math
.
not_equal
(
inputs
[
"input_ids"
],
self
.
config
.
pad_token_id
),
dtype
=
inputs
[
"input_ids"
].
dtype
,
),
-
1
,
keepdims
=
False
,
)
-
1
)
def
get_seq_element
(
sequence_position
,
input_batch
):
return
tf
.
strided_slice
(
input_batch
,
[
sequence_position
,
0
],
[
sequence_position
+
1
,
input_batch
.
shape
[
-
1
]],
[
1
,
1
]
)
result
=
tf
.
map_fn
(
fn
=
lambda
t
:
get_seq_element
(
t
[
0
],
t
[
1
]),
elems
=
[
sequence_lengths
,
logits
],
dtype
=
"float"
)
in_logits
=
tf
.
reshape
(
result
,
[
logits_shape
[
0
],
logits_shape
[
-
1
]])
in_logits
=
tf
.
gather
(
logits
,
sequence_lengths
,
batch_dims
=
1
,
axis
=
1
)
else
:
sequence_lengths
=
-
1
logger
.
warning
(
...
...
src/transformers/models/transfo_xl/modeling_tf_transfo_xl_utilities.py
View file @
3e116ed3
...
...
@@ -131,7 +131,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
else
:
hidden_sizes
=
shape_list
(
hidden
)
out
=
[]
loss
=
tf
.
zeros
(
hidden_sizes
[:
2
]
,
dtype
=
tf
.
float32
)
loss
=
tf
.
zeros
(
hidden_sizes
[:
2
])
for
i
in
range
(
len
(
self
.
cutoffs
)):
l_idx
,
r_idx
=
self
.
cutoff_ends
[
i
],
self
.
cutoff_ends
[
i
+
1
]
if
target
is
not
None
:
...
...
@@ -168,7 +168,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
cur_logprob
=
self
.
_gather_logprob
(
cur_tail_logprob
,
cur_target
)
cur_logprob
+=
cur_head_logprob
[:,
self
.
cutoff_ends
[
1
]
+
i
-
1
]
if
target
is
not
None
:
loss
+=
tf
.
scatter_nd
(
mask_idx
,
-
cur_logprob
,
tf
.
cast
(
shape_list
(
loss
)
,
dtype
=
tf
.
int64
)
)
loss
+=
tf
.
scatter_nd
(
mask_idx
,
-
cur_logprob
,
shape_list
(
loss
))
out
=
tf
.
concat
(
out
,
axis
=-
1
)
if
target
is
not
None
:
...
...
tests/test_modeling_tf_transfo_xl.py
View file @
3e116ed3
...
...
@@ -205,10 +205,6 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
name
=
model
.
get_bias
()
assert
name
is
None
def
test_mixed_precision
(
self
):
# TODO JP: Make TransfoXL float16 compliant
pass
def
test_xla_mode
(
self
):
# TODO JP: Make TransfoXL XLA compliant
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment