"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "0885aa25646d55d270b6a518d36861de2bec90d1"
Commit 5c7c09af authored by jthomson04's avatar jthomson04 Committed by Kevin H. Luu
Browse files

[Perf] Avoid pageable HtoD transfer in MinTokensLogitsProcessor (#29826)


Signed-off-by: default avatarjthomson04 <jwillthomson19@gmail.com>
(cherry picked from commit 1528e079)
parent 7f718169
......@@ -110,7 +110,7 @@ class MinPLogitsProcessor(LogitsProcessor):
# Identify valid tokens using threshold comparison
invalid_token_mask = probability_values < adjusted_min_p
# Apply mask using boolean indexing
logits[invalid_token_mask] = -float("inf")
logits.masked_fill_(invalid_token_mask, -float("inf"))
return logits
......@@ -178,6 +178,10 @@ class MinTokensLogitsProcessor(LogitsProcessor):
self._device_tensor([], torch.int32),
)
self.neg_inf_tensor = torch.tensor(
-float("inf"), dtype=torch.float32, device=self.device
)
def is_argmax_invariant(self) -> bool:
"""By censoring stop tokens, min-tokens can change the outcome
of the argmax operation in greedy sampling."""
......@@ -229,7 +233,7 @@ class MinTokensLogitsProcessor(LogitsProcessor):
def apply(self, logits: torch.Tensor) -> torch.Tensor:
if self.min_toks:
# Inhibit EOS token for requests which have not reached min length
logits[self.logits_slice] = -float("inf")
logits.index_put_(self.logits_slice, self.neg_inf_tensor)
return logits
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment