Unverified Commit e7523c2e authored by Lukas Geiger's avatar Lukas Geiger Committed by GitHub
Browse files

[V1][Sampler] Improve performance of FlashInfer sampling by sampling logits...

[V1][Sampler] Improve performance of FlashInfer sampling by sampling logits instead of probs (#18608)
parent a869baca
...@@ -89,18 +89,18 @@ class TopKTopPSampler(nn.Module): ...@@ -89,18 +89,18 @@ class TopKTopPSampler(nn.Module):
p: Optional[torch.Tensor], p: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
"""More optimized implementation for top-k and top-p sampling.""" """More optimized implementation for top-k and top-p sampling."""
probs = logits.softmax(dim=-1, dtype=torch.float32)
if k is None and p is None: if k is None and p is None:
# We prefer `random_sample` over `flashinfer_sample` when sorting is # We prefer `random_sample` over `flashinfer_sample` when sorting is
# not needed. This is because `random_sample` does not require # not needed. This is because `random_sample` does not require
# CPU-GPU synchronization while `flashinfer_sample` does. # CPU-GPU synchronization while `flashinfer_sample` does.
probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators) return random_sample(probs, generators)
if generators: if generators:
logger.warning("FlashInfer 0.2.3+ does not support " logger.warning("FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to " "per-request generators. Falling back to "
"PyTorch-native implementation.") "PyTorch-native implementation.")
return self.forward_native(logits, generators, k, p) return self.forward_native(logits, generators, k, p)
return flashinfer_sample(probs, k, p, generators) return flashinfer_sample(logits, k, p, generators)
def forward_tpu( def forward_tpu(
self, self,
...@@ -254,12 +254,12 @@ def random_sample( ...@@ -254,12 +254,12 @@ def random_sample(
def flashinfer_sample( def flashinfer_sample(
probs: torch.Tensor, logits: torch.Tensor,
k: Optional[torch.Tensor], k: Optional[torch.Tensor],
p: Optional[torch.Tensor], p: Optional[torch.Tensor],
generators: dict[int, torch.Generator], generators: dict[int, torch.Generator],
) -> torch.Tensor: ) -> torch.Tensor:
"""Sample from the probabilities using FlashInfer. """Sample from the logits using FlashInfer.
Statistically, this function is equivalent to the `random_sample` function. Statistically, this function is equivalent to the `random_sample` function.
However, this function is faster because it avoids sorting the logits tensor However, this function is faster because it avoids sorting the logits tensor
...@@ -274,18 +274,19 @@ def flashinfer_sample( ...@@ -274,18 +274,19 @@ def flashinfer_sample(
the synchronization overhead. the synchronization overhead.
""" """
assert not (k is None and p is None) assert not (k is None and p is None)
if k is None: if k is None:
# Top-p only. # Top-p only.
probs = logits.softmax(dim=-1, dtype=torch.float32)
next_token_ids = flashinfer.sampling.top_p_sampling_from_probs( next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
probs, p, deterministic=True) probs, p, deterministic=True)
elif p is None: elif p is None:
# Top-k only. # Top-k only.
probs = logits.softmax(dim=-1, dtype=torch.float32)
next_token_ids = flashinfer.sampling.top_k_sampling_from_probs( next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
probs, k, deterministic=True) probs, k, deterministic=True)
else: else:
# Both top-k and top-p. # Both top-k and top-p.
next_token_ids = (flashinfer.sampling.top_k_top_p_sampling_from_probs( next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
probs, k, p, deterministic=True)) logits, k, p, deterministic=True)
return next_token_ids.view(-1) return next_token_ids.view(-1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment