Commit c4c4c999 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make GPT2 and CTRL shape consistent between torch and TF

parent 2529b2d3
...@@ -104,10 +104,10 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -104,10 +104,10 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
k = self.split_into_heads(k, batch_size) k = self.split_into_heads(k, batch_size)
v = self.split_into_heads(v, batch_size) v = self.split_into_heads(v, batch_size)
if layer_past is not None: if layer_past is not None:
past_key, past_value = tf.unstack(layer_past, axis=1) past_key, past_value = tf.unstack(layer_past, axis=0)
k = tf.concat((past_key, k), axis=-2) k = tf.concat((past_key, k), axis=-2)
v = tf.concat((past_value, v), axis=-2) v = tf.concat((past_value, v), axis=-2)
present = tf.stack((k, v), axis=1) present = tf.stack((k, v), axis=0)
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask) output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
scaled_attention = tf.transpose(output[0], perm=[0, 2, 1, 3]) scaled_attention = tf.transpose(output[0], perm=[0, 2, 1, 3])
......
...@@ -139,10 +139,10 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -139,10 +139,10 @@ class TFAttention(tf.keras.layers.Layer):
key = self.split_heads(key) key = self.split_heads(key)
value = self.split_heads(value) value = self.split_heads(value)
if layer_past is not None: if layer_past is not None:
past_key, past_value = tf.unstack(layer_past, axis=1) past_key, past_value = tf.unstack(layer_past, axis=0)
key = tf.concat([past_key, key], axis=-2) key = tf.concat([past_key, key], axis=-2)
value = tf.concat([past_value, value], axis=-2) value = tf.concat([past_value, value], axis=-2)
present = tf.stack([key, value], axis=1) present = tf.stack([key, value], axis=0)
attn_outputs = self._attn([query, key, value, attention_mask, head_mask], training=training) attn_outputs = self._attn([query, key, value, attention_mask, head_mask], training=training)
a = attn_outputs[0] a = attn_outputs[0]
......
...@@ -658,7 +658,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -658,7 +658,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0: if repetition_penalty != 1.0:
next_token_logits_penalties = _create_next_token_logits_penalties(input_ids, next_token_logits, repetition_penalty) next_token_logits_penalties = _create_next_token_logits_penalties(
input_ids, next_token_logits, repetition_penalty
)
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties) next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
if do_sample: if do_sample:
...@@ -779,7 +781,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -779,7 +781,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0: if repetition_penalty != 1.0:
next_token_logits_penalties = _create_next_token_logits_penalties(input_ids, next_token_logits, repetition_penalty) next_token_logits_penalties = _create_next_token_logits_penalties(
input_ids, next_token_logits, repetition_penalty
)
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties) next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
if do_sample: if do_sample:
...@@ -791,11 +795,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -791,11 +795,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
next_token_logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2 next_token_logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
) # (batch_size * num_beams, vocab_size) ) # (batch_size * num_beams, vocab_size)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search) # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
next_tokens = tf.random.categorical(next_token_logits, dtype=tf.int32, num_samples=2) # (batch_size * num_beams, vocab_size) next_tokens = tf.random.categorical(
next_token_logits, dtype=tf.int32, num_samples=2
) # (batch_size * num_beams, vocab_size)
# Compute next scores # Compute next scores
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
_scores = tf.gather(scores, next_tokens, batch_dims=1) # (batch_size * num_beams, 2) _scores = tf.gather(scores, next_tokens, batch_dims=1) # (batch_size * num_beams, 2)
next_scores = _scores + tf.broadcast_to(beam_scores[:, None], (batch_size * num_beams, 2)) # (batch_size * num_beams, 2) next_scores = _scores + tf.broadcast_to(
beam_scores[:, None], (batch_size * num_beams, 2)
) # (batch_size * num_beams, 2)
# Match shape of greedy beam search # Match shape of greedy beam search
next_tokens = tf.reshape(next_tokens, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams) next_tokens = tf.reshape(next_tokens, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams)
next_scores = tf.reshape(next_scores, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams) next_scores = tf.reshape(next_scores, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams)
...@@ -804,10 +812,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -804,10 +812,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
assert shape_list(scores) == [batch_size * num_beams, vocab_size] assert shape_list(scores) == [batch_size * num_beams, vocab_size]
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product) # Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
next_scores = scores + tf.broadcast_to(beam_scores[:, None], (batch_size * num_beams, vocab_size)) # (batch_size * num_beams, vocab_size) next_scores = scores + tf.broadcast_to(
beam_scores[:, None], (batch_size * num_beams, vocab_size)
) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams) # re-organize to group the beam together (we are keeping top hypothesis accross beams)
next_scores = tf.reshape(next_scores, (batch_size, num_beams * vocab_size)) # (batch_size, num_beams * vocab_size) next_scores = tf.reshape(
next_scores, (batch_size, num_beams * vocab_size)
) # (batch_size, num_beams * vocab_size)
next_scores, next_tokens = tf.math.top_k(next_scores, 2 * num_beams, sorted=True) next_scores, next_tokens = tf.math.top_k(next_scores, 2 * num_beams, sorted=True)
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams] assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
...@@ -909,7 +921,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -909,7 +921,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
best_hyp = sorted_hyps.pop()[1] best_hyp = sorted_hyps.pop()[1]
sent_lengths_list.append(len(best_hyp)) sent_lengths_list.append(len(best_hyp))
best.append(best_hyp) best.append(best_hyp)
assert output_batch_size == len(best), "Output batch size {} must match output beam hypotheses {}".format(output_batch_size, len(best)) assert output_batch_size == len(best), "Output batch size {} must match output beam hypotheses {}".format(
output_batch_size, len(best)
)
sent_lengths = tf.convert_to_tensor(sent_lengths_list, dtype=tf.int32) sent_lengths = tf.convert_to_tensor(sent_lengths_list, dtype=tf.int32)
...@@ -925,7 +939,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -925,7 +939,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
decoded_hypo = tf.concat([hypo, padding], axis=0) decoded_hypo = tf.concat([hypo, padding], axis=0)
if sent_lengths[i] < max_length: if sent_lengths[i] < max_length:
decoded_hypo = tf.where(tf.range(max_length) == sent_lengths[i], eos_token_ids[0] * tf.ones((sent_max_len,), dtype=tf.int32), decoded_hypo) decoded_hypo = tf.where(
tf.range(max_length) == sent_lengths[i],
eos_token_ids[0] * tf.ones((sent_max_len,), dtype=tf.int32),
decoded_hypo,
)
decoded_list.append(decoded_hypo) decoded_list.append(decoded_hypo)
decoded = tf.stack(decoded_list) decoded = tf.stack(decoded_list)
else: else:
...@@ -942,7 +960,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -942,7 +960,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# get the correct batch idx from layer past batch dim # get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position # batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[:, i], 1)) for i in beam_idx] reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[:, i], 1)) for i in beam_idx]
# TODO: check whether it is an error that TF past.shape != Torch past.shape
reordered_layer_past = tf.concat(reordered_layer_past, axis=1) reordered_layer_past = tf.concat(reordered_layer_past, axis=1)
# check that shape matches # check that shape matches
assert shape_list(reordered_layer_past) == shape_list(layer_past) assert shape_list(reordered_layer_past) == shape_list(layer_past)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment