Unverified Commit 66875ac0 authored by Yury Sulsky's avatar Yury Sulsky Committed by GitHub
Browse files

Specify dtype=torch.bool to avoid xla error (#31191)

The StoppingCriteriaList allocates is_done without specifying dtype=torch.bool. On XLA this allocates a float tensor and causes a failure on the following line:

is_done = is_done | criteria(input_ids, scores, **kwargs)

by attempting to OR float with bool.
parent 8685b3c5
...@@ -502,7 +502,7 @@ class EosTokenCriteria(StoppingCriteria): ...@@ -502,7 +502,7 @@ class EosTokenCriteria(StoppingCriteria):
class StoppingCriteriaList(list): class StoppingCriteriaList(list):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING) @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor: def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device) is_done = torch.full((input_ids.shape[0],), False, device=input_ids.device, dtype=torch.bool)
for criteria in self: for criteria in self:
is_done = is_done | criteria(input_ids, scores, **kwargs) is_done = is_done | criteria(input_ids, scores, **kwargs)
return is_done return is_done
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment