Unverified Commit f1a6df32 authored by Ekagra Ranjan's avatar Ekagra Ranjan Committed by GitHub
Browse files

Generate: Simplify is_pad_token_not_equal_to_eos_token_id (#18933)

parent 85125fcf
...@@ -1739,9 +1739,7 @@ class TFGenerationMixin: ...@@ -1739,9 +1739,7 @@ class TFGenerationMixin:
) -> tf.Tensor: ) -> tf.Tensor:
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in (tf.int32, tf.int64) is_input_ids = len(inputs.shape) == 2 and inputs.dtype in (tf.int32, tf.int64)
is_pad_token_in_inputs = (pad_token_id is not None) and tf.math.reduce_any(inputs == pad_token_id) is_pad_token_in_inputs = (pad_token_id is not None) and tf.math.reduce_any(inputs == pad_token_id)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id != eos_token_id)
(eos_token_id is not None) and (pad_token_id != eos_token_id)
)
# Check if input is input_ids and padded -> only then is attention_mask defined # Check if input is input_ids and padded -> only then is attention_mask defined
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
......
...@@ -495,9 +495,8 @@ class GenerationMixin: ...@@ -495,9 +495,8 @@ class GenerationMixin:
) -> torch.LongTensor: ) -> torch.LongTensor:
is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long] is_input_ids = len(inputs.shape) == 2 and inputs.dtype in [torch.int, torch.long]
is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs) is_pad_token_in_inputs = (pad_token_id is not None) and (pad_token_id in inputs)
is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or ( is_pad_token_not_equal_to_eos_token_id = (eos_token_id is None) or (pad_token_id != eos_token_id)
(eos_token_id is not None) and (pad_token_id != eos_token_id)
)
# Check if input is input_ids and padded -> only then is attention_mask defined # Check if input is input_ids and padded -> only then is attention_mask defined
if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id: if is_input_ids and is_pad_token_in_inputs and is_pad_token_not_equal_to_eos_token_id:
return inputs.ne(pad_token_id).long() return inputs.ne(pad_token_id).long()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment