"configs/datasets/vscode:/vscode.git/clone" did not exist on "66d3aa4c01212f59649e0496fcacfc55d0dbe08c"
Commit c4c4c999 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

make GPT2 and CTRL shape consistent between torch and TF

parent 2529b2d3
......@@ -104,10 +104,10 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
k = self.split_into_heads(k, batch_size)
v = self.split_into_heads(v, batch_size)
if layer_past is not None:
past_key, past_value = tf.unstack(layer_past, axis=1)
past_key, past_value = tf.unstack(layer_past, axis=0)
k = tf.concat((past_key, k), axis=-2)
v = tf.concat((past_value, v), axis=-2)
present = tf.stack((k, v), axis=1)
present = tf.stack((k, v), axis=0)
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
scaled_attention = tf.transpose(output[0], perm=[0, 2, 1, 3])
......
......@@ -139,10 +139,10 @@ class TFAttention(tf.keras.layers.Layer):
key = self.split_heads(key)
value = self.split_heads(value)
if layer_past is not None:
past_key, past_value = tf.unstack(layer_past, axis=1)
past_key, past_value = tf.unstack(layer_past, axis=0)
key = tf.concat([past_key, key], axis=-2)
value = tf.concat([past_value, value], axis=-2)
present = tf.stack([key, value], axis=1)
present = tf.stack([key, value], axis=0)
attn_outputs = self._attn([query, key, value, attention_mask, head_mask], training=training)
a = attn_outputs[0]
......
......@@ -658,7 +658,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
next_token_logits_penalties = _create_next_token_logits_penalties(input_ids, next_token_logits, repetition_penalty)
next_token_logits_penalties = _create_next_token_logits_penalties(
input_ids, next_token_logits, repetition_penalty
)
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
if do_sample:
......@@ -779,7 +781,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
if repetition_penalty != 1.0:
next_token_logits_penalties = _create_next_token_logits_penalties(input_ids, next_token_logits, repetition_penalty)
next_token_logits_penalties = _create_next_token_logits_penalties(
input_ids, next_token_logits, repetition_penalty
)
next_token_logits = tf.math.multiply(next_token_logits, next_token_logits_penalties)
if do_sample:
......@@ -791,11 +795,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
next_token_logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2
) # (batch_size * num_beams, vocab_size)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of greedy beam search)
next_tokens = tf.random.categorical(next_token_logits, dtype=tf.int32, num_samples=2) # (batch_size * num_beams, vocab_size)
next_tokens = tf.random.categorical(
next_token_logits, dtype=tf.int32, num_samples=2
) # (batch_size * num_beams, vocab_size)
# Compute next scores
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
_scores = tf.gather(scores, next_tokens, batch_dims=1) # (batch_size * num_beams, 2)
next_scores = _scores + tf.broadcast_to(beam_scores[:, None], (batch_size * num_beams, 2)) # (batch_size * num_beams, 2)
next_scores = _scores + tf.broadcast_to(
beam_scores[:, None], (batch_size * num_beams, 2)
) # (batch_size * num_beams, 2)
# Match shape of greedy beam search
next_tokens = tf.reshape(next_tokens, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams)
next_scores = tf.reshape(next_scores, (batch_size, 2 * num_beams)) # (batch_size, 2 * num_beams)
......@@ -804,10 +812,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
scores = tf.nn.log_softmax(next_token_logits, axis=-1) # (batch_size * num_beams, vocab_size)
assert shape_list(scores) == [batch_size * num_beams, vocab_size]
# Add the log prob of the new beams to the log prob of the beginning of the sequence (sum of logs == log of the product)
next_scores = scores + tf.broadcast_to(beam_scores[:, None], (batch_size * num_beams, vocab_size)) # (batch_size * num_beams, vocab_size)
next_scores = scores + tf.broadcast_to(
beam_scores[:, None], (batch_size * num_beams, vocab_size)
) # (batch_size * num_beams, vocab_size)
# re-organize to group the beam together (we are keeping top hypothesis accross beams)
next_scores = tf.reshape(next_scores, (batch_size, num_beams * vocab_size)) # (batch_size, num_beams * vocab_size)
next_scores = tf.reshape(
next_scores, (batch_size, num_beams * vocab_size)
) # (batch_size, num_beams * vocab_size)
next_scores, next_tokens = tf.math.top_k(next_scores, 2 * num_beams, sorted=True)
assert shape_list(next_scores) == shape_list(next_tokens) == [batch_size, 2 * num_beams]
......@@ -909,7 +921,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
best_hyp = sorted_hyps.pop()[1]
sent_lengths_list.append(len(best_hyp))
best.append(best_hyp)
assert output_batch_size == len(best), "Output batch size {} must match output beam hypotheses {}".format(output_batch_size, len(best))
assert output_batch_size == len(best), "Output batch size {} must match output beam hypotheses {}".format(
output_batch_size, len(best)
)
sent_lengths = tf.convert_to_tensor(sent_lengths_list, dtype=tf.int32)
......@@ -925,7 +939,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
decoded_hypo = tf.concat([hypo, padding], axis=0)
if sent_lengths[i] < max_length:
decoded_hypo = tf.where(tf.range(max_length) == sent_lengths[i], eos_token_ids[0] * tf.ones((sent_max_len,), dtype=tf.int32), decoded_hypo)
decoded_hypo = tf.where(
tf.range(max_length) == sent_lengths[i],
eos_token_ids[0] * tf.ones((sent_max_len,), dtype=tf.int32),
decoded_hypo,
)
decoded_list.append(decoded_hypo)
decoded = tf.stack(decoded_list)
else:
......@@ -942,7 +960,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[:, i], 1)) for i in beam_idx]
# TODO: check whether it is an error that TF past.shape != Torch past.shape
reordered_layer_past = tf.concat(reordered_layer_past, axis=1)
# check that shape matches
assert shape_list(reordered_layer_past) == shape_list(layer_past)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment