"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "a1afec9e1759b0fdb256d41d429161cc15ecf500"
Unverified Commit 3e116ed3 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Making TF TransfoXL model compliant with AMP (#10264)

* Fix AMP

* Apply style

* Remove unused import
parent 86caeb76
...@@ -59,6 +59,7 @@ class TFPositionalEmbedding(tf.keras.layers.Layer): ...@@ -59,6 +59,7 @@ class TFPositionalEmbedding(tf.keras.layers.Layer):
self.inv_freq = 1 / (10000 ** (tf.range(0, demb, 2.0) / demb)) self.inv_freq = 1 / (10000 ** (tf.range(0, demb, 2.0) / demb))
def call(self, pos_seq, bsz=None): def call(self, pos_seq, bsz=None):
self.inv_freq = tf.cast(self.inv_freq, dtype=pos_seq.dtype)
sinusoid_inp = tf.einsum("i,j->ij", pos_seq, self.inv_freq) sinusoid_inp = tf.einsum("i,j->ij", pos_seq, self.inv_freq)
pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1) pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], -1)
...@@ -186,6 +187,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer): ...@@ -186,6 +187,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
qlen, rlen, bsz = shape_list(w)[0], shape_list(r)[0], shape_list(w)[1] qlen, rlen, bsz = shape_list(w)[0], shape_list(r)[0], shape_list(w)[1]
if mems is not None: if mems is not None:
mems = tf.cast(mems, dtype=w.dtype)
cat = tf.concat([mems, w], 0) cat = tf.concat([mems, w], 0)
if self.pre_lnorm: if self.pre_lnorm:
w_heads = self.qkv_net(self.layer_norm(cat)) w_heads = self.qkv_net(self.layer_norm(cat))
...@@ -227,7 +229,8 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer): ...@@ -227,7 +229,8 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
# compute attention probability # compute attention probability
if attn_mask is not None: if attn_mask is not None:
attn_mask_t = attn_mask[:, :, None, None] attn_mask_t = attn_mask[:, :, None, None]
attn_score = attn_score * (1 - attn_mask_t) - 1e30 * attn_mask_t attn_mask_t = tf.cast(attn_mask_t, dtype=attn_score.dtype)
attn_score = attn_score * (1.0 - attn_mask_t) - 1e30 * attn_mask_t
# [qlen x klen x bsz x n_head] # [qlen x klen x bsz x n_head]
attn_prob = tf.nn.softmax(attn_score, axis=1) attn_prob = tf.nn.softmax(attn_score, axis=1)
...@@ -313,6 +316,27 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer): ...@@ -313,6 +316,27 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
return outputs return outputs
class TFTransfoEmbeddings(tf.keras.layers.Layer):
def __init__(self, vocab_size, emb_size, init_std, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.emb_size = emb_size
self.init_std = init_std
def build(self, input_shape):
self.weight = self.add_weight(
shape=(self.vocab_size, self.emb_size),
initializer=get_initializer(self.init_std),
name="embeddings",
)
super().build(input_shape)
def call(self, inputs):
return tf.gather(self.weight, inputs)
class TFAdaptiveEmbedding(tf.keras.layers.Layer): class TFAdaptiveEmbedding(tf.keras.layers.Layer):
def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, init_std=0.02, sample_softmax=False, **kwargs): def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, init_std=0.02, sample_softmax=False, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -331,6 +355,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer): ...@@ -331,6 +355,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
self.emb_layers = [] self.emb_layers = []
self.emb_projs = [] self.emb_projs = []
if div_val == 1: if div_val == 1:
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
else: else:
...@@ -338,10 +363,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer): ...@@ -338,10 +363,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
d_emb_i = d_embed // (div_val ** i) d_emb_i = d_embed // (div_val ** i)
self.emb_layers.append( self.emb_layers.append(
tf.keras.layers.Embedding( TFTransfoEmbeddings(
r_idx - l_idx, r_idx - l_idx,
d_emb_i, d_emb_i,
embeddings_initializer=get_initializer(init_std), init_std,
name="emb_layers_._{}".format(i), name="emb_layers_._{}".format(i),
) )
) )
...@@ -357,6 +382,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer): ...@@ -357,6 +382,7 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
name="emb_projs_._{}".format(i), name="emb_projs_._{}".format(i),
) )
) )
super().build(input_shape) super().build(input_shape)
def call(self, inp): def call(self, inp):
...@@ -374,8 +400,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer): ...@@ -374,8 +400,10 @@ class TFAdaptiveEmbedding(tf.keras.layers.Layer):
emb_i = self.emb_layers[i](inp_i) emb_i = self.emb_layers[i](inp_i)
emb_i = tf.einsum("id,de->ie", emb_i, self.emb_projs[i]) emb_i = tf.einsum("id,de->ie", emb_i, self.emb_projs[i])
mask_idx = tf.cast(tf.where(mask_i), dtype=tf.int64) mask_idx = tf.where(mask_i)
emb_flat += tf.scatter_nd(mask_idx, emb_i, tf.cast(shape_list(emb_flat), dtype=tf.int64)) scatter = tf.scatter_nd(mask_idx, emb_i, shape_list(emb_flat))
emb_flat = tf.cast(emb_flat, dtype=scatter.dtype)
emb_flat += scatter
embed_shape = shape_list(inp) + [self.d_proj] embed_shape = shape_list(inp) + [self.d_proj]
embed = tf.reshape(emb_flat, embed_shape) embed = tf.reshape(emb_flat, embed_shape)
...@@ -501,7 +529,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -501,7 +529,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
end_idx = mlen + tf.math.maximum(0, qlen) end_idx = mlen + tf.math.maximum(0, qlen)
beg_idx = tf.math.maximum(0, end_idx - tf.convert_to_tensor(self.mem_len)) beg_idx = tf.math.maximum(0, end_idx - tf.convert_to_tensor(self.mem_len))
for i in range(len(hids)): for i in range(len(hids)):
mems[i] = tf.cast(mems[i], dtype=hids[i].dtype)
cat = tf.concat([mems[i], hids[i]], axis=0) cat = tf.concat([mems[i], hids[i]], axis=0)
tf.stop_gradient(cat) tf.stop_gradient(cat)
new_mems.append(cat[beg_idx:end_idx]) new_mems.append(cat[beg_idx:end_idx])
...@@ -1113,7 +1141,6 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc ...@@ -1113,7 +1141,6 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
logits = self.score(hidden_states) logits = self.score(hidden_states)
logits_shape = shape_list(logits)
in_logits = None in_logits = None
if self.config.pad_token_id is None: if self.config.pad_token_id is None:
sequence_lengths = -1 sequence_lengths = -1
...@@ -1121,22 +1148,16 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc ...@@ -1121,22 +1148,16 @@ class TFTransfoXLForSequenceClassification(TFTransfoXLPreTrainedModel, TFSequenc
if inputs["input_ids"] is not None: if inputs["input_ids"] is not None:
sequence_lengths = ( sequence_lengths = (
tf.reduce_sum( tf.reduce_sum(
tf.cast(tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), tf.int32), tf.cast(
tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id),
dtype=inputs["input_ids"].dtype,
),
-1, -1,
keepdims=False, keepdims=False,
) )
- 1 - 1
) )
in_logits = tf.gather(logits, sequence_lengths, batch_dims=1, axis=1)
def get_seq_element(sequence_position, input_batch):
return tf.strided_slice(
input_batch, [sequence_position, 0], [sequence_position + 1, input_batch.shape[-1]], [1, 1]
)
result = tf.map_fn(
fn=lambda t: get_seq_element(t[0], t[1]), elems=[sequence_lengths, logits], dtype="float"
)
in_logits = tf.reshape(result, [logits_shape[0], logits_shape[-1]])
else: else:
sequence_lengths = -1 sequence_lengths = -1
logger.warning( logger.warning(
......
...@@ -131,7 +131,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer): ...@@ -131,7 +131,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
else: else:
hidden_sizes = shape_list(hidden) hidden_sizes = shape_list(hidden)
out = [] out = []
loss = tf.zeros(hidden_sizes[:2], dtype=tf.float32) loss = tf.zeros(hidden_sizes[:2])
for i in range(len(self.cutoffs)): for i in range(len(self.cutoffs)):
l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1]
if target is not None: if target is not None:
...@@ -168,7 +168,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer): ...@@ -168,7 +168,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
cur_logprob = self._gather_logprob(cur_tail_logprob, cur_target) cur_logprob = self._gather_logprob(cur_tail_logprob, cur_target)
cur_logprob += cur_head_logprob[:, self.cutoff_ends[1] + i - 1] cur_logprob += cur_head_logprob[:, self.cutoff_ends[1] + i - 1]
if target is not None: if target is not None:
loss += tf.scatter_nd(mask_idx, -cur_logprob, tf.cast(shape_list(loss), dtype=tf.int64)) loss += tf.scatter_nd(mask_idx, -cur_logprob, shape_list(loss))
out = tf.concat(out, axis=-1) out = tf.concat(out, axis=-1)
if target is not None: if target is not None:
......
...@@ -205,10 +205,6 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -205,10 +205,6 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
name = model.get_bias() name = model.get_bias()
assert name is None assert name is None
def test_mixed_precision(self):
# TODO JP: Make TransfoXL float16 compliant
pass
def test_xla_mode(self): def test_xla_mode(self):
# TODO JP: Make TransfoXL XLA compliant # TODO JP: Make TransfoXL XLA compliant
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment