"tests/models/data2vec/test_modeling_data2vec_text.py" did not exist on "c852036b4abca2c20e1adf92eda48472a7d84ef0"
Unverified Commit 3e116ed3 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Making TF TransfoXL model compliant with AMP (#10264)

* Fix AMP

* Apply style

* Remove unused import
parent 86caeb76
......@@ -59,6 +59,7 @@ class TFPositionalEmbedding(tf.keras.layers.Layer):
self.inv_freq = 1 / (10000 ** (tf.range(0, demb, 2.0) / demb))
def call(self, pos_seq, bsz=None):
self.inv_freq = tf.cast(self.inv_freq, dtype=pos_seq.dtype)
sinusoid_inp = tf.einsum("i,j->ij", pos_seq, self.inv_freq)
pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
......@@ -186,6 +187,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
qlen, rlen, bsz = shape_list(w)[0], shape_list(r)[0], shape_list(w)[1]
if mems is not None:
mems = tf.cast(mems, dtype=w.dtype)
cat = tf.concat([mems, w], 0)
if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(cat))
......@@ -227,7 +229,8 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
# compute attention probability
if attn_mask is not None:
attn_mask_t = attn_mask[:, :, None, None]
attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t
attn_mask_t = tf.cast(attn_mask_t, dtype=attn_score.dtype)
attn_score = attn_score * (1.0 - attn_mask_t) - 1e30 * attn_mask_t
# [qlen x klen x bsz x n_head]
attn_prob = tf.nn.softmax(attn_score, axis=1)
......@@ -313,6 +316,27 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
return outputs
class TFTransfoEmbeddings(tf.keras.layers.Layer):
def __init__(self, vocab_size, emb_size, init_std, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.emb_size = emb_size
self.init_std = init_std
def build(self, input_shape):
self.weight = self.add_weight(
shape=(self.vocab_size, self.emb_size),
initializer=get_initializer(self.init_std),
name="embeddings",
)
super().build(input_shape)
def call(self, inputs):
return tf.gather(self.weight, inputs)
class TFAdaptiveEmbedding(tf.keras.layers.Layer):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, init_std=0.02, sample_softmax=False, **kwargs):
super().__init__(**kwargs)
......@@ -331,6 +355,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
self.emb_layers = []
self.emb_projs = []
if div_val == 1:
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
else:
......@@ -338,10 +363,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
d_emb_i = d_embed // (div_val ** i)
self.emb_layers.append(
tf.keras.layers.Embedding(
TFTransfoEmbeddings(
r_idx - l_idx,
d_emb_i,
embeddings_initializer=get_initializer(init_std),
init_std,
name="emb_layers_._{}".format(i),
)
)
......@@ -357,6 +382,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
name="emb_projs_._{}".format(i),
)
)
super().build(input_shape)
def call(self, inp):
......@@ -374,8 +400,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
emb_i = self.emb_layers[i](inp_i)
emb_i = tf.einsum("id,de->ie", emb_i, self.emb_projs[i])
mask_idx = tf.cast(tf.where(mask_i), dtype=tf.int64)
emb_flat += tf.scatter_nd(mask_idx, emb_i, tf.cast(shape_list(emb_flat), dtype=tf.int64))
mask_idx = tf.where(mask_i)
scatter = tf.scatter_nd(mask_idx, emb_i, shape_list(emb_flat))
emb_flat = tf.cast(emb_flat, dtype=scatter.dtype)
emb_flat += scatter
embed_shape = shape_list(inp) + [self.d_proj]
embed = tf.reshape(emb_flat, embed_shape)
......@@ -501,7 +529,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
end_idx = mlen + tf.math.maximum(0, qlen)
beg_idx = tf.math.maximum(0, end_idx - tf.convert_to_tensor(self.mem_len))
for i in range(len(hids)):
mems[i] = tf.cast(mems[i], dtype=hids[i].dtype)
cat = tf.concat([mems[i], hids[i]], axis=0)
tf.stop_gradient(cat)
new_mems.append(cat[beg_idx:end_idx])
......@@ -1113,7 +1141,6 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
logits_shape = shape_list(logits)
in_logits = None
if self.config.pad_token_id is None:
sequence_lengths = -1
......@@ -1121,22 +1148,16 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc
if inputs["input_ids"] is not None:
sequence_lengths = (
tf.reduce_sum(
tf.cast(tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.int32),
tf.cast(
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id),
dtype=inputs["input_ids"].dtype,
),
-1,
keepdims=False,
)
- 1
)
def get_seq_element(sequence_position, input_batch):
return tf.strided_slice(
input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1]
)
result = tf.map_fn(
fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float"
)
in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]])
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
else:
sequence_lengths = -1
logger.warning(
......
......@@ -131,7 +131,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
else:
hidden_sizes = shape_list(hidden)
out = []
loss = tf.zeros(hidden_sizes[:2], dtype=tf.float32)
loss = tf.zeros(hidden_sizes[:2])
for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
if target is not None:
......@@ -168,7 +168,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
cur_logprob = self._gather_logprob(cur_tail_logprob, cur_target)
cur_logprob += cur_head_logprob[:, self.cutoff_ends[1] + i - 1]
if target is not None:
loss += tf.scatter_nd(mask_idx, -cur_logprob, tf.cast(shape_list(loss), dtype=tf.int64))
loss += tf.scatter_nd(mask_idx, -cur_logprob, shape_list(loss))
out = tf.concat(out, axis=-1)
if target is not None:
......
......@@ -205,10 +205,6 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
name = model.get_bias()
assert name is None
def test_mixed_precision(self):
# TODO JP: Make TransfoXL float16 compliant
pass
def test_xla_mode(self):
# TODO JP: Make TransfoXL XLA compliant
pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment