Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
cdcdd5f0
Unverified
Commit
cdcdd5f0
authored
Feb 24, 2021
by
Julien Plu
Committed by
GitHub
Feb 24, 2021
Browse files
Rework casts (#10274)
parent
2d458b2c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
30 deletions
+24
-30
src/transformers/models/xlnet/modeling_tf_xlnet.py
src/transformers/models/xlnet/modeling_tf_xlnet.py
+24
-30
No files found.
src/transformers/models/xlnet/modeling_tf_xlnet.py
View file @
cdcdd5f0
...
...
@@ -150,7 +150,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
attn_score = (ac + bd + ef) * self.scale
if attn_mask is not None:
# attn_score = attn_score * (1 - attn_mask) - 1e30 * attn_mask
if
attn_mask
.
dtype
==
tf
.
float16
:
if attn_mask.dtype == tf.float16
or attn_mask.dtype == tf.bfloat16
:
attn_score = attn_score - 65500 * attn_mask
else:
attn_score = attn_score - 1e30 * attn_mask
...
...
@@ -476,7 +476,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
def _prune_heads(self, heads_to_prune):
raise NotImplementedError
def
create_mask
(
self
,
qlen
,
mlen
,
dtype
=
tf
.
float32
):
def create_mask(self, qlen, mlen):
"""
Creates causal attention mask. Float mask where 1.0 indicates masked, 0.0 indicates not-masked.
...
...
@@ -495,10 +495,10 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0]
"""
attn_mask
=
tf
.
ones
([
qlen
,
qlen
]
,
dtype
=
dtype
)
attn_mask = tf.ones([qlen, qlen])
mask_u = tf.matrix_band_part(attn_mask, 0, -1)
mask_dia = tf.matrix_band_part(attn_mask, 0, 0)
attn_mask_pad
=
tf
.
zeros
([
qlen
,
mlen
]
,
dtype
=
dtype
)
attn_mask_pad = tf.zeros([qlen, mlen])
ret = tf.concat([attn_mask_pad, mask_u - mask_dia], 1)
if self.same_length:
mask_l = tf.matrix_band_part(attn_mask, -1, 0)
...
...
@@ -537,11 +537,9 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
return pos_emb
def
relative_positional_encoding
(
self
,
qlen
,
klen
,
bsz
=
None
,
dtype
=
None
):
def relative_positional_encoding(self, qlen, klen, bsz=None):
"""create relative positional encoding."""
freq_seq = tf.range(0, self.d_model, 2.0)
if
dtype
is
not
None
and
dtype
!=
tf
.
float32
:
freq_seq
=
tf
.
cast
(
freq_seq
,
dtype
=
dtype
)
inv_freq = 1 / (10000 ** (freq_seq / self.d_model))
if self.attn_type == "bi":
...
...
@@ -557,10 +555,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
fwd_pos_seq = tf.range(beg, end, -1.0)
bwd_pos_seq = tf.range(-beg, -end, 1.0)
if
dtype
is
not
None
and
dtype
!=
tf
.
float32
:
fwd_pos_seq
=
tf
.
cast
(
fwd_pos_seq
,
dtype
=
dtype
)
bwd_pos_seq
=
tf
.
cast
(
bwd_pos_seq
,
dtype
=
dtype
)
if self.clamp_len > 0:
fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len)
bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -self.clamp_len, self.clamp_len)
...
...
@@ -576,8 +570,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1)
else:
fwd_pos_seq = tf.range(beg, end, -1.0)
if
dtype
is
not
None
and
dtype
!=
tf
.
float32
:
fwd_pos_seq
=
tf
.
cast
(
fwd_pos_seq
,
dtype
=
dtype
)
if self.clamp_len > 0:
fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len)
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
...
...
@@ -661,8 +653,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
mlen = shape_list(inputs["mems"][0])[0] if inputs["mems"] is not None and inputs["mems"][0] is not None else 0
klen = mlen + qlen
dtype_float
=
tf
.
bfloat16
if
self
.
use_bfloat16
else
tf
.
float32
# Attention mask
# causal attention mask
if self.attn_type == "uni":
...
...
@@ -679,7 +669,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
"or attention_mask (uses 0 for padding, added for compatibility with BERT). Please choose one."
)
if inputs["input_mask"] is None and inputs["attention_mask"] is not None:
inputs
[
"input_mask"
]
=
1.0
-
tf
.
cast
(
inputs
[
"attention_mask"
],
dtype
=
dtype_float
)
one_cst = tf.constant(1.0)
inputs["input_mask"] = 1.0 - tf.cast(inputs["attention_mask"], dtype=one_cst.dtype)
if inputs["input_mask"] is not None and inputs["perm_mask"] is not None:
data_mask = inputs["input_mask"][None] + inputs["perm_mask"]
elif inputs["input_mask"] is not None and inputs["perm_mask"] is None:
...
...
@@ -692,7 +683,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if data_mask is not None:
# all mems can be attended to
if mlen > 0:
mems_mask
=
tf
.
zeros
([
shape_list
(
data_mask
)[
0
],
mlen
,
bsz
]
,
dtype
=
dtype_float
)
mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz])
data_mask = tf.concat([mems_mask, data_mask], axis=1)
if attn_mask is None:
attn_mask = data_mask[:, :, :, None]
...
...
@@ -700,13 +691,13 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
attn_mask += data_mask[:, :, :, None]
if attn_mask is not None:
attn_mask
=
tf
.
cast
(
attn_mask
>
0
,
dtype
=
dtype_float
)
attn_mask = tf.cast(attn_mask > 0, dtype=
attn_mask.dtype
)
if attn_mask is not None:
non_tgt_mask
=
-
tf
.
eye
(
qlen
,
dtype
=
dtype_float
)
non_tgt_mask = -tf.eye(qlen)
if mlen > 0:
non_tgt_mask
=
tf
.
concat
([
tf
.
zeros
([
qlen
,
mlen
]
,
dtype
=
dtype_float
),
non_tgt_mask
],
axis
=-
1
)
non_tgt_mask
=
tf
.
cast
((
attn_mask
+
non_tgt_mask
[:,
:,
None
,
None
])
>
0
,
dtype
=
dtype_float
)
non_tgt_mask = tf.concat([tf.zeros([qlen, mlen]), non_tgt_mask], axis=-1)
non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=
non_tgt_mask.dtype
)
else:
non_tgt_mask = None
...
...
@@ -729,19 +720,22 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
if inputs["token_type_ids"] is not None:
# Convert `token_type_ids` to one-hot `seg_mat`
if mlen > 0:
mem_pad
=
tf
.
zeros
([
mlen
,
bsz
],
dtype
=
tf
.
int32
)
mem_pad = tf.zeros([mlen, bsz], dtype=
inputs["token_type_ids"].dtype
)
cat_ids = tf.concat([mem_pad, inputs["token_type_ids"]], 0)
else:
cat_ids = inputs["token_type_ids"]
# `1` indicates not in the same segment [qlen x klen x bsz]
seg_mat
=
tf
.
cast
(
tf
.
logical_not
(
tf
.
equal
(
inputs
[
"token_type_ids"
][:,
None
],
cat_ids
[
None
,
:])),
tf
.
int32
)
seg_mat
=
tf
.
one_hot
(
seg_mat
,
2
,
dtype
=
dtype_float
)
seg_mat = tf.cast(
tf.logical_not(tf.equal(inputs["token_type_ids"][:, None], cat_ids[None, :])),
dtype=inputs["token_type_ids"].dtype,
)
seg_mat = tf.one_hot(seg_mat, 2)
else:
seg_mat = None
# Positional encoding
pos_emb
=
self
.
relative_positional_encoding
(
qlen
,
klen
,
bsz
=
bsz
,
dtype
=
dtype_float
)
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
pos_emb = self.dropout(pos_emb, training=inputs["training"])
# Prepare head mask if needed
...
...
@@ -1258,7 +1252,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
offset = 2
effective_batch_size = inputs.shape[0]
dummy_token
=
tf
.
zeros
((
effective_batch_size
,
1
),
dtype
=
tf
.
int32
)
dummy_token = tf.zeros((effective_batch_size, 1), dtype=
inputs.dtype
)
if past:
inputs = tf.concat([inputs[:, -offset:], dummy_token], axis=1)
...
...
@@ -1267,13 +1261,13 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
# Build permutation mask so that previous tokens don't see last token
sequence_length = inputs.shape[1]
perm_mask
=
tf
.
zeros
((
effective_batch_size
,
sequence_length
,
sequence_length
-
1
)
,
dtype
=
tf
.
float32
)
perm_mask_seq_end
=
tf
.
ones
((
effective_batch_size
,
sequence_length
,
1
)
,
dtype
=
tf
.
float32
)
perm_mask = tf.zeros((effective_batch_size, sequence_length, sequence_length - 1))
perm_mask_seq_end = tf.ones((effective_batch_size, sequence_length, 1))
perm_mask = tf.concat([perm_mask, perm_mask_seq_end], axis=-1)
# We'll only predict the last token
target_mapping
=
tf
.
zeros
((
effective_batch_size
,
1
,
sequence_length
-
1
)
,
dtype
=
tf
.
float32
)
target_mapping_seq_end
=
tf
.
ones
((
effective_batch_size
,
1
,
1
)
,
dtype
=
tf
.
float32
)
target_mapping = tf.zeros((effective_batch_size, 1, sequence_length - 1))
target_mapping_seq_end = tf.ones((effective_batch_size, 1, 1))
target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1)
inputs = {
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment