Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
cdcdd5f0
Unverified
Commit
cdcdd5f0
authored
Feb 24, 2021
by
Julien Plu
Committed by
GitHub
Feb 24, 2021
Browse files
Rework casts (#10274)
parent
2d458b2c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
30 deletions
+24
-30
src/transformers/models/xlnet/modeling_tf_xlnet.py
src/transformers/models/xlnet/modeling_tf_xlnet.py
+24
-30
No files found.
src/transformers/models/xlnet/modeling_tf_xlnet.py
View file @
cdcdd5f0
...
@@ -150,7 +150,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
...
@@ -150,7 +150,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
attn_score
=
(
ac
+
bd
+
ef
)
*
self
.
scale
attn_score
=
(
ac
+
bd
+
ef
)
*
self
.
scale
if
attn_mask
is
not
None
:
if
attn_mask
is
not
None
:
# attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
# attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
if
attn_mask
.
dtype
==
tf
.
float16
:
if
attn_mask
.
dtype
==
tf
.
float16
or
attn_mask
.
dtype
==
tf
.
bfloat16
:
attn_score
=
attn_score
-
65500
*
attn_mask
attn_score
=
attn_score
-
65500
*
attn_mask
else
:
else
:
attn_score
=
attn_score
-
1e30
*
attn_mask
attn_score
=
attn_score
-
1e30
*
attn_mask
...
@@ -476,7 +476,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -476,7 +476,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
def
_prune_heads
(
self
,
heads_to_prune
):
def
_prune_heads
(
self
,
heads_to_prune
):
raise
NotImplementedError
raise
NotImplementedError
def
create_mask
(
self
,
qlen
,
mlen
,
dtype
=
tf
.
float32
):
def
create_mask
(
self
,
qlen
,
mlen
):
"""
"""
Creates causal attention mask. Float mask where 1.0 indicates masked, 0.0 indicates not-masked.
Creates causal attention mask. Float mask where 1.0 indicates masked, 0.0 indicates not-masked.
...
@@ -495,10 +495,10 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -495,10 +495,10 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0]
v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0]
"""
"""
attn_mask
=
tf
.
ones
([
qlen
,
qlen
]
,
dtype
=
dtype
)
attn_mask
=
tf
.
ones
([
qlen
,
qlen
])
mask_u
=
tf
.
matrix_band_part
(
attn_mask
,
0
,
-
1
)
mask_u
=
tf
.
matrix_band_part
(
attn_mask
,
0
,
-
1
)
mask_dia
=
tf
.
matrix_band_part
(
attn_mask
,
0
,
0
)
mask_dia
=
tf
.
matrix_band_part
(
attn_mask
,
0
,
0
)
attn_mask_pad
=
tf
.
zeros
([
qlen
,
mlen
]
,
dtype
=
dtype
)
attn_mask_pad
=
tf
.
zeros
([
qlen
,
mlen
])
ret
=
tf
.
concat
([
attn_mask_pad
,
mask_u
-
mask_dia
],
1
)
ret
=
tf
.
concat
([
attn_mask_pad
,
mask_u
-
mask_dia
],
1
)
if
self
.
same_length
:
if
self
.
same_length
:
mask_l
=
tf
.
matrix_band_part
(
attn_mask
,
-
1
,
0
)
mask_l
=
tf
.
matrix_band_part
(
attn_mask
,
-
1
,
0
)
...
@@ -537,11 +537,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -537,11 +537,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
return
pos_emb
return
pos_emb
def
relative_positional_encoding
(
self
,
qlen
,
klen
,
bsz
=
None
,
dtype
=
None
):
def
relative_positional_encoding
(
self
,
qlen
,
klen
,
bsz
=
None
):
"""create relative positional encoding."""
"""create relative positional encoding."""
freq_seq
=
tf
.
range
(
0
,
self
.
d_model
,
2.0
)
freq_seq
=
tf
.
range
(
0
,
self
.
d_model
,
2.0
)
if
dtype
is
not
None
and
dtype
!=
tf
.
float32
:
freq_seq
=
tf
.
cast
(
freq_seq
,
dtype
=
dtype
)
inv_freq
=
1
/
(
10000
**
(
freq_seq
/
self
.
d_model
))
inv_freq
=
1
/
(
10000
**
(
freq_seq
/
self
.
d_model
))
if
self
.
attn_type
==
"bi"
:
if
self
.
attn_type
==
"bi"
:
...
@@ -557,10 +555,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -557,10 +555,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
fwd_pos_seq
=
tf
.
range
(
beg
,
end
,
-
1.0
)
fwd_pos_seq
=
tf
.
range
(
beg
,
end
,
-
1.0
)
bwd_pos_seq
=
tf
.
range
(
-
beg
,
-
end
,
1.0
)
bwd_pos_seq
=
tf
.
range
(
-
beg
,
-
end
,
1.0
)
if
dtype
is
not
None
and
dtype
!=
tf
.
float32
:
fwd_pos_seq
=
tf
.
cast
(
fwd_pos_seq
,
dtype
=
dtype
)
bwd_pos_seq
=
tf
.
cast
(
bwd_pos_seq
,
dtype
=
dtype
)
if
self
.
clamp_len
>
0
:
if
self
.
clamp_len
>
0
:
fwd_pos_seq
=
tf
.
clip_by_value
(
fwd_pos_seq
,
-
self
.
clamp_len
,
self
.
clamp_len
)
fwd_pos_seq
=
tf
.
clip_by_value
(
fwd_pos_seq
,
-
self
.
clamp_len
,
self
.
clamp_len
)
bwd_pos_seq
=
tf
.
clip_by_value
(
bwd_pos_seq
,
-
self
.
clamp_len
,
self
.
clamp_len
)
bwd_pos_seq
=
tf
.
clip_by_value
(
bwd_pos_seq
,
-
self
.
clamp_len
,
self
.
clamp_len
)
...
@@ -576,8 +570,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -576,8 +570,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
pos_emb
=
tf
.
concat
([
fwd_pos_emb
,
bwd_pos_emb
],
axis
=
1
)
pos_emb
=
tf
.
concat
([
fwd_pos_emb
,
bwd_pos_emb
],
axis
=
1
)
else
:
else
:
fwd_pos_seq
=
tf
.
range
(
beg
,
end
,
-
1.0
)
fwd_pos_seq
=
tf
.
range
(
beg
,
end
,
-
1.0
)
if
dtype
is
not
None
and
dtype
!=
tf
.
float32
:
fwd_pos_seq
=
tf
.
cast
(
fwd_pos_seq
,
dtype
=
dtype
)
if
self
.
clamp_len
>
0
:
if
self
.
clamp_len
>
0
:
fwd_pos_seq
=
tf
.
clip_by_value
(
fwd_pos_seq
,
-
self
.
clamp_len
,
self
.
clamp_len
)
fwd_pos_seq
=
tf
.
clip_by_value
(
fwd_pos_seq
,
-
self
.
clamp_len
,
self
.
clamp_len
)
pos_emb
=
self
.
positional_embedding
(
fwd_pos_seq
,
inv_freq
,
bsz
)
pos_emb
=
self
.
positional_embedding
(
fwd_pos_seq
,
inv_freq
,
bsz
)
...
@@ -661,8 +653,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -661,8 +653,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
mlen
=
shape_list
(
inputs
[
"mems"
][
0
])[
0
]
if
inputs
[
"mems"
]
is
not
None
and
inputs
[
"mems"
][
0
]
is
not
None
else
0
mlen
=
shape_list
(
inputs
[
"mems"
][
0
])[
0
]
if
inputs
[
"mems"
]
is
not
None
and
inputs
[
"mems"
][
0
]
is
not
None
else
0
klen
=
mlen
+
qlen
klen
=
mlen
+
qlen
dtype_float
=
tf
.
bfloat16
if
self
.
use_bfloat16
else
tf
.
float32
# Attention mask
# Attention mask
# causal attention mask
# causal attention mask
if
self
.
attn_type
==
"uni"
:
if
self
.
attn_type
==
"uni"
:
...
@@ -679,7 +669,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -679,7 +669,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
"or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one."
"or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one."
)
)
if
inputs
[
"input_mask"
]
is
None
and
inputs
[
"attention_mask"
]
is
not
None
:
if
inputs
[
"input_mask"
]
is
None
and
inputs
[
"attention_mask"
]
is
not
None
:
inputs
[
"input_mask"
]
=
1.0
-
tf
.
cast
(
inputs
[
"attention_mask"
],
dtype
=
dtype_float
)
one_cst
=
tf
.
constant
(
1.0
)
inputs
[
"input_mask"
]
=
1.0
-
tf
.
cast
(
inputs
[
"attention_mask"
],
dtype
=
one_cst
.
dtype
)
if
inputs
[
"input_mask"
]
is
not
None
and
inputs
[
"perm_mask"
]
is
not
None
:
if
inputs
[
"input_mask"
]
is
not
None
and
inputs
[
"perm_mask"
]
is
not
None
:
data_mask
=
inputs
[
"input_mask"
][
None
]
+
inputs
[
"perm_mask"
]
data_mask
=
inputs
[
"input_mask"
][
None
]
+
inputs
[
"perm_mask"
]
elif
inputs
[
"input_mask"
]
is
not
None
and
inputs
[
"perm_mask"
]
is
None
:
elif
inputs
[
"input_mask"
]
is
not
None
and
inputs
[
"perm_mask"
]
is
None
:
...
@@ -692,7 +683,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -692,7 +683,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if
data_mask
is
not
None
:
if
data_mask
is
not
None
:
# all mems can be attended to
# all mems can be attended to
if
mlen
>
0
:
if
mlen
>
0
:
mems_mask
=
tf
.
zeros
([
shape_list
(
data_mask
)[
0
],
mlen
,
bsz
]
,
dtype
=
dtype_float
)
mems_mask
=
tf
.
zeros
([
shape_list
(
data_mask
)[
0
],
mlen
,
bsz
])
data_mask
=
tf
.
concat
([
mems_mask
,
data_mask
],
axis
=
1
)
data_mask
=
tf
.
concat
([
mems_mask
,
data_mask
],
axis
=
1
)
if
attn_mask
is
None
:
if
attn_mask
is
None
:
attn_mask
=
data_mask
[:,
:,
:,
None
]
attn_mask
=
data_mask
[:,
:,
:,
None
]
...
@@ -700,13 +691,13 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -700,13 +691,13 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
attn_mask
+=
data_mask
[:,
:,
:,
None
]
attn_mask
+=
data_mask
[:,
:,
:,
None
]
if
attn_mask
is
not
None
:
if
attn_mask
is
not
None
:
attn_mask
=
tf
.
cast
(
attn_mask
>
0
,
dtype
=
dtype_float
)
attn_mask
=
tf
.
cast
(
attn_mask
>
0
,
dtype
=
attn_mask
.
dtype
)
if
attn_mask
is
not
None
:
if
attn_mask
is
not
None
:
non_tgt_mask
=
-
tf
.
eye
(
qlen
,
dtype
=
dtype_float
)
non_tgt_mask
=
-
tf
.
eye
(
qlen
)
if
mlen
>
0
:
if
mlen
>
0
:
non_tgt_mask
=
tf
.
concat
([
tf
.
zeros
([
qlen
,
mlen
]
,
dtype
=
dtype_float
),
non_tgt_mask
],
axis
=-
1
)
non_tgt_mask
=
tf
.
concat
([
tf
.
zeros
([
qlen
,
mlen
]),
non_tgt_mask
],
axis
=-
1
)
non_tgt_mask
=
tf
.
cast
((
attn_mask
+
non_tgt_mask
[:,
:,
None
,
None
])
>
0
,
dtype
=
dtype_float
)
non_tgt_mask
=
tf
.
cast
((
attn_mask
+
non_tgt_mask
[:,
:,
None
,
None
])
>
0
,
dtype
=
non_tgt_mask
.
dtype
)
else
:
else
:
non_tgt_mask
=
None
non_tgt_mask
=
None
...
@@ -729,19 +720,22 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
...
@@ -729,19 +720,22 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if
inputs
[
"token_type_ids"
]
is
not
None
:
if
inputs
[
"token_type_ids"
]
is
not
None
:
# Convert `token_type_ids` to one-hot `seg_mat`
# Convert `token_type_ids` to one-hot `seg_mat`
if
mlen
>
0
:
if
mlen
>
0
:
mem_pad
=
tf
.
zeros
([
mlen
,
bsz
],
dtype
=
tf
.
int32
)
mem_pad
=
tf
.
zeros
([
mlen
,
bsz
],
dtype
=
inputs
[
"token_type_ids"
].
dtype
)
cat_ids
=
tf
.
concat
([
mem_pad
,
inputs
[
"token_type_ids"
]],
0
)
cat_ids
=
tf
.
concat
([
mem_pad
,
inputs
[
"token_type_ids"
]],
0
)
else
:
else
:
cat_ids
=
inputs
[
"token_type_ids"
]
cat_ids
=
inputs
[
"token_type_ids"
]
# `1` indicates not in the same segment [qlen x klen x bsz]
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat
=
tf
.
cast
(
tf
.
logical_not
(
tf
.
equal
(
inputs
[
"token_type_ids"
][:,
None
],
cat_ids
[
None
,
:])),
tf
.
int32
)
seg_mat
=
tf
.
cast
(
seg_mat
=
tf
.
one_hot
(
seg_mat
,
2
,
dtype
=
dtype_float
)
tf
.
logical_not
(
tf
.
equal
(
inputs
[
"token_type_ids"
][:,
None
],
cat_ids
[
None
,
:])),
dtype
=
inputs
[
"token_type_ids"
].
dtype
,
)
seg_mat
=
tf
.
one_hot
(
seg_mat
,
2
)
else
:
else
:
seg_mat
=
None
seg_mat
=
None
# Positional encoding
# Positional encoding
pos_emb
=
self
.
relative_positional_encoding
(
qlen
,
klen
,
bsz
=
bsz
,
dtype
=
dtype_float
)
pos_emb
=
self
.
relative_positional_encoding
(
qlen
,
klen
,
bsz
=
bsz
)
pos_emb
=
self
.
dropout
(
pos_emb
,
training
=
inputs
[
"training"
])
pos_emb
=
self
.
dropout
(
pos_emb
,
training
=
inputs
[
"training"
])
# Prepare head mask if needed
# Prepare head mask if needed
...
@@ -1258,7 +1252,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -1258,7 +1252,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
offset
=
2
offset
=
2
effective_batch_size
=
inputs
.
shape
[
0
]
effective_batch_size
=
inputs
.
shape
[
0
]
dummy_token
=
tf
.
zeros
((
effective_batch_size
,
1
),
dtype
=
tf
.
int32
)
dummy_token
=
tf
.
zeros
((
effective_batch_size
,
1
),
dtype
=
inputs
.
dtype
)
if
past
:
if
past
:
inputs
=
tf
.
concat
([
inputs
[:,
-
offset
:],
dummy_token
],
axis
=
1
)
inputs
=
tf
.
concat
([
inputs
[:,
-
offset
:],
dummy_token
],
axis
=
1
)
...
@@ -1267,13 +1261,13 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -1267,13 +1261,13 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
# Build permutation mask so that previous tokens don't see last token
# Build permutation mask so that previous tokens don't see last token
sequence_length
=
inputs
.
shape
[
1
]
sequence_length
=
inputs
.
shape
[
1
]
perm_mask
=
tf
.
zeros
((
effective_batch_size
,
sequence_length
,
sequence_length
-
1
)
,
dtype
=
tf
.
float32
)
perm_mask
=
tf
.
zeros
((
effective_batch_size
,
sequence_length
,
sequence_length
-
1
))
perm_mask_seq_end
=
tf
.
ones
((
effective_batch_size
,
sequence_length
,
1
)
,
dtype
=
tf
.
float32
)
perm_mask_seq_end
=
tf
.
ones
((
effective_batch_size
,
sequence_length
,
1
))
perm_mask
=
tf
.
concat
([
perm_mask
,
perm_mask_seq_end
],
axis
=-
1
)
perm_mask
=
tf
.
concat
([
perm_mask
,
perm_mask_seq_end
],
axis
=-
1
)
# We'll only predict the last token
# We'll only predict the last token
target_mapping
=
tf
.
zeros
((
effective_batch_size
,
1
,
sequence_length
-
1
)
,
dtype
=
tf
.
float32
)
target_mapping
=
tf
.
zeros
((
effective_batch_size
,
1
,
sequence_length
-
1
))
target_mapping_seq_end
=
tf
.
ones
((
effective_batch_size
,
1
,
1
)
,
dtype
=
tf
.
float32
)
target_mapping_seq_end
=
tf
.
ones
((
effective_batch_size
,
1
,
1
))
target_mapping
=
tf
.
concat
([
target_mapping
,
target_mapping_seq_end
],
axis
=-
1
)
target_mapping
=
tf
.
concat
([
target_mapping
,
target_mapping_seq_end
],
axis
=-
1
)
inputs
=
{
inputs
=
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment