Unverified Commit 0d2bffad authored by Michal Szutenberg's avatar Michal Szutenberg Committed by GitHub
Browse files

Remove tf.roll wherever not needed (#12512)

It was used in shift_right.
After this change TF code is more similar to Pytorch implementations
Also, TF graphs are optimized (one node less)
parent 95425d54
...@@ -1571,9 +1571,8 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In ...@@ -1571,9 +1571,8 @@ def tf_top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float("In
) )
# Shift the indices to the right to keep also the first token above the threshold # Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove = tf.roll(sorted_indices_to_remove, 1, axis=-1)
sorted_indices_to_remove = tf.concat( sorted_indices_to_remove = tf.concat(
[tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, 1:]], [tf.zeros_like(sorted_indices_to_remove[:, :1]), sorted_indices_to_remove[:, :-1]],
-1, -1,
) )
# scatter sorted tensors to original indexing # scatter sorted tensors to original indexing
......
...@@ -61,9 +61,8 @@ LARGE_NEGATIVE = -1e8 ...@@ -61,9 +61,8 @@ LARGE_NEGATIVE = -1e8
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
shifted_input_ids = tf.roll(input_ids, 1, axis=-1) start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
# replace possible -100 values in labels by `pad_token_id` # replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where( shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
......
...@@ -64,9 +64,8 @@ LARGE_NEGATIVE = -1e8 ...@@ -64,9 +64,8 @@ LARGE_NEGATIVE = -1e8
# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right # Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
shifted_input_ids = tf.roll(input_ids, 1, axis=-1) start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
# replace possible -100 values in labels by `pad_token_id` # replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where( shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
......
...@@ -62,9 +62,8 @@ LARGE_NEGATIVE = -1e8 ...@@ -62,9 +62,8 @@ LARGE_NEGATIVE = -1e8
# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right # Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
shifted_input_ids = tf.roll(input_ids, 1, axis=-1) start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
# replace possible -100 values in labels by `pad_token_id` # replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where( shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
......
...@@ -56,9 +56,8 @@ LARGE_NEGATIVE = -1e8 ...@@ -56,9 +56,8 @@ LARGE_NEGATIVE = -1e8
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
shifted_input_ids = tf.roll(input_ids, 1, axis=-1) start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
# replace possible -100 values in labels by `pad_token_id` # replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where( shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
......
...@@ -63,9 +63,8 @@ LARGE_NEGATIVE = -1e8 ...@@ -63,9 +63,8 @@ LARGE_NEGATIVE = -1e8
# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right # Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
shifted_input_ids = tf.roll(input_ids, 1, axis=-1) start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
# replace possible -100 values in labels by `pad_token_id` # replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where( shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
......
...@@ -63,9 +63,8 @@ LARGE_NEGATIVE = -1e8 ...@@ -63,9 +63,8 @@ LARGE_NEGATIVE = -1e8
# Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right # Copied from transformers.models.bart.modeling_tf_bart.shift_tokens_right
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
shifted_input_ids = tf.roll(input_ids, 1, axis=-1) start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
# replace possible -100 values in labels by `pad_token_id` # replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where( shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
......
...@@ -1365,9 +1365,8 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -1365,9 +1365,8 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
shifted_input_ids = tf.cast(input_ids, tf.int32) shifted_input_ids = tf.cast(input_ids, tf.int32)
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), start_token_id) start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), start_token_id)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, :-1]], -1)
# replace possible -100 values in labels by `pad_token_id` # replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where( shifted_input_ids = tf.where(
......
...@@ -873,9 +873,8 @@ class TFT5PreTrainedModel(TFPreTrainedModel): ...@@ -873,9 +873,8 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
decoder_start_token_id is not None decoder_start_token_id is not None
), "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information" ), "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information"
shifted_input_ids = tf.roll(input_ids, 1, axis=-1) start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
# replace possible -100 values in labels by `pad_token_id` # replace possible -100 values in labels by `pad_token_id`
......
...@@ -1514,9 +1514,8 @@ LARGE_NEGATIVE = -1e8 ...@@ -1514,9 +1514,8 @@ LARGE_NEGATIVE = -1e8
def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int): def shift_tokens_right(input_ids: tf.Tensor, pad_token_id: int, decoder_start_token_id: int):
shifted_input_ids = tf.roll(input_ids, 1, axis=-1) start_tokens = tf.fill((shape_list(input_ids)[0], 1), decoder_start_token_id)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) shifted_input_ids = tf.concat([start_tokens, input_ids[:, :-1]], -1)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
# replace possible -100 values in labels by `pad_token_id` # replace possible -100 values in labels by `pad_token_id`
shifted_input_ids = tf.where( shifted_input_ids = tf.where(
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment