Unverified Commit 06f1692b authored by Maurice Gonzenbach's avatar Maurice Gonzenbach Committed by GitHub
Browse files

Fix _shift_right function in TFT5PreTrainedModel (#6214)

parent 0b418673
...@@ -783,8 +783,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel): ...@@ -783,8 +783,7 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
decoder_start_token_id is not None decoder_start_token_id is not None
), "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information" ), "self.model.config.decoder_start_token_id has to be defined. In TF T5 it is usually set to the pad_token_id. See T5 docs for more information"
# shift inputs to the right shifted_input_ids = tf.cast(input_ids, tf.int32)
shifted_input_ids = tf.zeros_like(input_ids, dtype=tf.int32)
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1) shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id) start_tokens = tf.fill((shape_list(shifted_input_ids)[0], 1), decoder_start_token_id)
shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1) shifted_input_ids = tf.concat([start_tokens, shifted_input_ids[:, 1:]], -1)
...@@ -795,9 +794,12 @@ class TFT5PreTrainedModel(TFPreTrainedModel): ...@@ -795,9 +794,12 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids shifted_input_ids == -100, tf.fill(shape_list(shifted_input_ids), pad_token_id), shifted_input_ids
) )
assert tf.math.reduce_any( # "Verify that `labels` has only positive values and -100"
shifted_input_ids >= 0 assert_gte0 = tf.debugging.assert_greater_equal(shifted_input_ids, tf.cast(0, tf.int32))
).numpy(), "Verify that `labels` has only positive values and -100"
# Make sure the assertion op is called by wrapping the result in an identity no-op
with tf.control_dependencies([assert_gte0]):
shifted_input_ids = tf.identity(shifted_input_ids)
return shifted_input_ids return shifted_input_ids
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment