Unverified Commit 81c8191b authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

FIX [`Generation`] Fix some issues when running the MaxLength criteria on CPU (#29317)

fix the bitwise or issue
parent e9476832
......@@ -73,7 +73,7 @@ class MaxLengthCriteria(StoppingCriteria):
f"maximum length ({self.max_position_embeddings}). Depending on the model, you may observe "
"exceptions, performance degradation, or nothing at all."
)
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device)
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)
class MaxNewTokensCriteria(StoppingCriteria):
......@@ -103,7 +103,7 @@ class MaxNewTokensCriteria(StoppingCriteria):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
is_done = input_ids.shape[-1] >= self.max_length
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device)
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)
class MaxTimeCriteria(StoppingCriteria):
......@@ -126,7 +126,7 @@ class MaxTimeCriteria(StoppingCriteria):
@add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> torch.BoolTensor:
is_done = time.time() - self.initial_timestamp > self.max_time
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device)
return torch.full((input_ids.shape[0],), is_done, device=input_ids.device, dtype=torch.bool)
class StoppingCriteriaList(list):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment