Unverified Commit 5a649ff3 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[generate] fix eos/pad id check on mps devices (#31695)


Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
parent f2a1e3ca
......@@ -1542,10 +1542,7 @@ class GenerationMixin:
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
# we can't infer attn mask if pad token is set to be eos token in model's generation config
if (
eos_token_tensor is not None
and torch.isin(elements=eos_token_tensor, test_elements=pad_token_tensor).any()
):
if eos_token_tensor is not None and pad_token_tensor in eos_token_tensor:
if kwargs_has_attention_mask is not None and not kwargs_has_attention_mask:
logger.warning_once(
"The attention mask is not set and cannot be inferred from input because pad token is same as eos token."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment