"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ef0ac063c9b9da3e4da759866736e266dbb44cfe"
Commit c4c4c999 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make GPT2 and CTRL shape consistent between torch and TF

parent 2529b2d3
...@@ -104,10 +104,10 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -104,10 +104,10 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
k = self.split_into_heads(k, batch_size) k = self.split_into_heads(k, batch_size)
v = self.split_into_heads(v, batch_size) v = self.split_into_heads(v, batch_size)
if layer_past is not None: if layer_past is not None:
past_key, past_value = tf.unstack(layer_past, axis=1) past_key, past_value = tf.unstack(layer_past, axis=0)
k = tf.concat((past_key, k), axis=-2) k = tf.concat((past_key, k), axis=-2)
v = tf.concat((past_value, v), axis=-2) v = tf.concat((past_value, v), axis=-2)
present = tf.stack((k, v), axis=1) present = tf.stack((k, v), axis=0)
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask) output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
scaled_attention = tf.transpose(output[0], perm=[0, 2, 1, 3]) scaled_attention = tf.transpose(output[0], perm=[0, 2, 1, 3])
......
...@@ -139,10 +139,10 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -139,10 +139,10 @@ class TFAttention(tf.keras.layers.Layer):
key = self.split_heads(key) key = self.split_heads(key)
value = self.split_heads(value) value = self.split_heads(value)
if layer_past is not None: if layer_past is not None:
past_key, past_value = tf.unstack(layer_past, axis=1) past_key, past_value = tf.unstack(layer_past, axis=0)
key = tf.concat([past_key, key], axis=-2) key = tf.concat([past_key, key], axis=-2)
value = tf.concat([past_value, value], axis=-2) value = tf.concat([past_value, value], axis=-2)
present = tf.stack([key, value], axis=1) present = tf.stack([key, value], axis=0)
attn_outputs = self._attn([query, key, value, attention_mask, head_mask], training=training) attn_outputs = self._attn([query, key, value, attention_mask, head_mask], training=training)
a = attn_outputs[0] a = attn_outputs[0]
......
...@@ -658,7 +658,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -658,7 +658,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0: if repetition_penalty != 1.0:
next_token_logits_penalties = _create_next_token_logits_penalties(input_ids, next_token_logits, repetition_penalty) next_token_logits_penalties = _create_next_token_logits_penalties(
input_ids, next_token_logits, repetition_penalty
)
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties) next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
if do_sample: if do_sample:
...@@ -779,7 +781,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -779,7 +781,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0: if repetition_penalty != 1.0:
next_token_logits_penalties = _create_next_token_logits_penalties(input_ids, next_token_logits, repetition_penalty) next_token_logits_penalties = _create_next_token_logits_penalties(
input_ids, next_token_logits, repetition_penalty
)
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties) next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
if do_sample: if do_sample:
...@@ -791,11 +795,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -791,11 +795,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
next_token_logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2 next_token_logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
) # (batch_size * num_beams, vocab_size) ) # (batch_size * num_beams, vocab_size)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search) # Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
next_tokens = tf.random.categorical(next_token_logits, dtype=tf.int32, num_samples=2) # (batch_size * num_beams, vocab_size) next_tokens = tf.random.categorical(
next_token_logits, dtype=tf.int32, num_samples=2
) # (batch_size * num_beams, vocab_size)
# Compute next scores # Compute next scores
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
_scores = tf.gather(scores, next_tokens, batch_dims=1) # (batch_size * num_beams, 2) _scores = tf.gather(scores, next_tokens, batch_dims=1) # (batch_size * num_beams, 2)
next_scores = _scores + tf.broadcast_to(beam_scores[:, None], (batch_size * num_beams, 2)) # (batch_size * num_beams, 2) next_scores = _scores + tf.broadcast_to(
beam_scores[:, None], (batch_size * num_beams, 2)
) # (batch_size * num_beams, 2)
# Match shape of greedy beam search # Match shape of greedy beam search
next_tokens = tf.reshape(next_tokens, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams) next_tokens = tf.reshape(next_tokens, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams)
next_scores = tf.reshape(next_scores, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams) next_scores = tf.reshape(next_scores, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams)
...@@ -804,10 +812,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -804,10 +812,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size) scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
assert shape_list(scores) == [batch_size * num_beams, vocab_size] assert shape_list(scores) == [batch_size * num_beams, vocab_size]
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product) # Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
next_scores = scores + tf.broadcast_to(beam_scores[:, None], (batch_size * num_beams, vocab_size)) # (batch_size * num_beams, vocab_size) next_scores = scores + tf.broadcast_to(
beam_scores[:, None], (batch_size * num_beams, vocab_size)
) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams) # re-organize to group the beam together (we are keeping top hypothesis accross beams)
next_scores = tf.reshape(next_scores, (batch_size, num_beams * vocab_size)) # (batch_size, num_beams * vocab_size) next_scores = tf.reshape(
next_scores, (batch_size, num_beams * vocab_size)
) # (batch_size, num_beams * vocab_size)
next_scores, next_tokens = tf.math.top_k(next_scores, 2 * num_beams, sorted=True) next_scores, next_tokens = tf.math.top_k(next_scores, 2 * num_beams, sorted=True)
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams] assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
...@@ -909,7 +921,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -909,7 +921,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
best_hyp = sorted_hyps.pop()[1] best_hyp = sorted_hyps.pop()[1]
sent_lengths_list.append(len(best_hyp)) sent_lengths_list.append(len(best_hyp))
best.append(best_hyp) best.append(best_hyp)
assert output_batch_size == len(best), "Output batch size {} must match output beam hypotheses {}".format(output_batch_size, len(best)) assert output_batch_size == len(best), "Output batch size {} must match output beam hypotheses {}".format(
output_batch_size, len(best)
)
sent_lengths = tf.convert_to_tensor(sent_lengths_list, dtype=tf.int32) sent_lengths = tf.convert_to_tensor(sent_lengths_list, dtype=tf.int32)
...@@ -925,7 +939,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -925,7 +939,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
decoded_hypo = tf.concat([hypo, padding], axis=0) decoded_hypo = tf.concat([hypo, padding], axis=0)
if sent_lengths[i] < max_length: if sent_lengths[i] < max_length:
decoded_hypo = tf.where(tf.range(max_length) == sent_lengths[i], eos_token_ids[0] * tf.ones((sent_max_len,), dtype=tf.int32), decoded_hypo) decoded_hypo = tf.where(
tf.range(max_length) == sent_lengths[i],
eos_token_ids[0] * tf.ones((sent_max_len,), dtype=tf.int32),
decoded_hypo,
)
decoded_list.append(decoded_hypo) decoded_list.append(decoded_hypo)
decoded = tf.stack(decoded_list) decoded = tf.stack(decoded_list)
else: else:
...@@ -942,7 +960,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -942,7 +960,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# get the correct batch idx from layer past batch dim # get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position # batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[:, i], 1)) for i in beam_idx] reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[:, i], 1)) for i in beam_idx]
# TODO: check whether it is an error that TF past.shape != Torch past.shape
reordered_layer_past = tf.concat(reordered_layer_past, axis=1) reordered_layer_past = tf.concat(reordered_layer_past, axis=1)
# check that shape matches # check that shape matches
assert shape_list(reordered_layer_past) == shape_list(layer_past) assert shape_list(reordered_layer_past) == shape_list(layer_past)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment