Unverified Commit 5a649ff3 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[generate] fix eos/pad id check on mps devices (#31695)


Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
parent f2a1e3ca
...@@ -1542,10 +1542,7 @@ class GenerationMixin: ...@@ -1542,10 +1542,7 @@ class GenerationMixin:
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.") logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
# we can't infer attn mask if pad token is set to be eos token in model's generation config # we can't infer attn mask if pad token is set to be eos token in model's generation config
if ( if eos_token_tensor is not None and pad_token_tensor in eos_token_tensor:
eos_token_tensor is not None
and torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
):
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask: if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning_once( logger.warning_once(
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token." "The attention mask is not set and cannot be inferred from input because pad token is same as eos token."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment