"docs/vscode:/vscode.git/clone" did not exist on "4eafc729285e459a5fc96efd6f7b313b155cad48"
Unverified Commit 1528e079 authored by jthomson04's avatar jthomson04 Committed by GitHub
Browse files

[Perf] Avoid pageable HtoD transfer in MinTokensLogitsProcessor (#29826)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
parent afb1e5b3
...@@ -110,7 +110,7 @@ class MinPLogitsProcessor(LogitsProcessor): ...@@ -110,7 +110,7 @@ class MinPLogitsProcessor(LogitsProcessor):
# Identify valid tokens using threshold comparison # Identify valid tokens using threshold comparison
invalid_token_mask = probability_values < adjusted_min_p invalid_token_mask = probability_values < adjusted_min_p
# Apply mask using boolean indexing # Apply mask using boolean indexing
logits[invalid_token_mask] = -float("inf") logits.masked_fill_(invalid_token_mask, -float("inf"))
return logits return logits
...@@ -178,6 +178,10 @@ class MinTokensLogitsProcessor(LogitsProcessor): ...@@ -178,6 +178,10 @@ class MinTokensLogitsProcessor(LogitsProcessor):
self._device_tensor([], torch.int32), self._device_tensor([], torch.int32),
) )
self.neg_inf_tensor = torch.tensor(
-float("inf"), dtype=torch.float32, device=self.device
)
def is_argmax_invariant(self) -> bool: def is_argmax_invariant(self) -> bool:
"""By censoring stop tokens, min-tokens can change the outcome """By censoring stop tokens, min-tokens can change the outcome
of the argmax operation in greedy sampling.""" of the argmax operation in greedy sampling."""
...@@ -229,7 +233,7 @@ class MinTokensLogitsProcessor(LogitsProcessor): ...@@ -229,7 +233,7 @@ class MinTokensLogitsProcessor(LogitsProcessor):
def apply(self, logits: torch.Tensor) -> torch.Tensor: def apply(self, logits: torch.Tensor) -> torch.Tensor:
if self.min_toks: if self.min_toks:
# Inhibit EOS token for requests which have not reached min length # Inhibit EOS token for requests which have not reached min length
logits[self.logits_slice] = -float("inf") logits.index_put_(self.logits_slice, self.neg_inf_tensor)
return logits return logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment