Unverified Commit 4b63d013 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

Make EosTokenCriteria compatible with mps (#30376)

parent 57fc00f3
...@@ -481,6 +481,17 @@ class EosTokenCriteria(StoppingCriteria): ...@@ -481,6 +481,17 @@ class EosTokenCriteria(StoppingCriteria):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
if input_ids.device.type == "mps":
# https://github.com/pytorch/pytorch/issues/77764#issuecomment-2067838075
is_done = (
input_ids[:, -1]
.tile(self.eos_token_id.shape[0], 1)
.eq(self.eos_token_id.unsqueeze(1).to(input_ids.device))
.sum(dim=0)
.bool()
.squeeze()
)
else:
is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device)) is_done = torch.isin(input_ids[:, -1], self.eos_token_id.to(input_ids.device))
return is_done return is_done
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment