Unverified Commit c3287ebd authored by Will Frey's avatar Will Frey Committed by GitHub
Browse files

Update typing in generation_logits_process.py (#12900)

Change `torch.Tensor` -> `torch.FloatTensor` in `TemperatureLogitsWarper` to be consistent with the `LogitsWarper` ABC signature annotation.
parent df55c2b9
...@@ -137,7 +137,7 @@ class TemperatureLogitsWarper(LogitsWarper): ...@@ -137,7 +137,7 @@ class TemperatureLogitsWarper(LogitsWarper):
self.temperature = temperature self.temperature = temperature
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
scores = scores / self.temperature scores = scores / self.temperature
return scores return scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment