Unverified Commit c3287ebd authored by Will Frey's avatar Will Frey Committed by GitHub
Browse files

Update typing in generation_logits_process.py (#12900)

Change `torch.Tensor` -> `torch.FloatTensor` in `TemperatureLogitsWarper` to be consistent with the `LogitsWarper` ABC signature annotation.
parent df55c2b9
......@@ -137,7 +137,7 @@ class TemperatureLogitsWarper(LogitsWarper):
self.temperature = temperature
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor:
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.FloatTensor:
scores = scores / self.temperature
return scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment