Unverified Commit 5cce3076 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: generate without `tf.TensorArray` (#17801)

parent ab223fc1
This diff is collapsed.
...@@ -874,8 +874,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -874,8 +874,9 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel, TFCausalLanguageModelingLoss):
new_past = [None for _ in range(len(past))] new_past = [None for _ in range(len(past))]
slice_start_base = tf.constant([0, 0, 0, 1, 0]) slice_start_base = tf.constant([0, 0, 0, 1, 0])
attention_mask_update_slice = tf.ones((batch_size, 1), dtype=attention_mask.dtype) attention_mask_update_slice = tf.ones((batch_size, 1), dtype=attention_mask.dtype)
# correct 5 here # -1 because current_pos has already been incremented before this function
new_past_index = current_pos - 1 # -1 again because last index = len - 1
new_past_index = current_pos - 2
for i in range(len(past)): for i in range(len(past)):
update_slice = past[i][:, :, :, -1:] update_slice = past[i][:, :, :, -1:]
......
...@@ -1202,7 +1202,6 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1202,7 +1202,6 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
def prepare_inputs_for_generation(self, inputs, past=None, use_mems=None, **kwargs): def prepare_inputs_for_generation(self, inputs, past=None, use_mems=None, **kwargs):
# Add dummy token at the end (no attention on this one) # Add dummy token at the end (no attention on this one)
effective_batch_size = inputs.shape[0] effective_batch_size = inputs.shape[0]
dummy_token = tf.zeros((effective_batch_size, 1), dtype=inputs.dtype) dummy_token = tf.zeros((effective_batch_size, 1), dtype=inputs.dtype)
...@@ -1212,12 +1211,12 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1212,12 +1211,12 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
offset = 2 offset = 2
if past: if past:
inputs = tf.concat([inputs[:, -offset:], dummy_token], axis=1) input_ids = tf.concat([inputs[:, -offset:], dummy_token], axis=1)
else: else:
inputs = tf.concat([inputs, dummy_token], axis=1) input_ids = tf.concat([inputs, dummy_token], axis=1)
# Build permutation mask so that previous tokens don't see last token # Build permutation mask so that previous tokens don't see last token
sequence_length = inputs.shape[1] sequence_length = input_ids.shape[1]
perm_mask = tf.zeros((effective_batch_size, sequence_length, sequence_length - 1)) perm_mask = tf.zeros((effective_batch_size, sequence_length, sequence_length - 1))
perm_mask_seq_end = tf.ones((effective_batch_size, sequence_length, 1)) perm_mask_seq_end = tf.ones((effective_batch_size, sequence_length, 1))
perm_mask = tf.concat([perm_mask, perm_mask_seq_end], axis=-1) perm_mask = tf.concat([perm_mask, perm_mask_seq_end], axis=-1)
...@@ -1228,7 +1227,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -1228,7 +1227,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel, TFCausalLanguageModelingLoss):
target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1) target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1)
inputs = { inputs = {
"input_ids": inputs, "input_ids": input_ids,
"perm_mask": perm_mask, "perm_mask": perm_mask,
"target_mapping": target_mapping, "target_mapping": target_mapping,
"use_mems": use_mems, "use_mems": use_mems,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment