Unverified Commit cea17acd authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Generate] Fix generate with inputs_embeds on GPU (#14564)

parent 25156eb2
...@@ -401,7 +401,7 @@ class GenerationMixin: ...@@ -401,7 +401,7 @@ class GenerationMixin:
# First if `inputs_embeds` are given, but no `attention_mask` assume that full attention_mask is used # First if `inputs_embeds` are given, but no `attention_mask` assume that full attention_mask is used
if inputs_embeds is not None: if inputs_embeds is not None:
return torch.ones((inputs_embeds.shape[0], inputs_embeds.shape[1]), dtype=torch.long) return torch.ones((inputs_embeds.shape[0], inputs_embeds.shape[1]), dtype=torch.long, device=self.device)
# Otherwise, use `input_ids` # Otherwise, use `input_ids`
is_pad_token_in_inputs_ids = (pad_token_id is not None) and (pad_token_id in input_ids) is_pad_token_in_inputs_ids = (pad_token_id is not None) and (pad_token_id in input_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment