Commit 586f0eba authored by 王敏's avatar 王敏
Browse files

[perf]合入lightop topp_topk 融合算子

parent 2036eb73
......@@ -12,6 +12,14 @@ from vllm.config.model import LogprobsMode
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
HAS_LIGHTOP_OPT_KERNEL = True
try:
from lightop.sampling import top_k_top_p_sampling_from_probs as top_k_top_p_sampling_from_probs_lightop
from lightop.sampling import top_k_sampling_from_probs as top_k_sampling_from_probs_lightop
from lightop.sampling import top_p_sampling_from_probs as top_p_sampling_from_probs_lightop
except ImportError:
HAS_LIGHTOP_OPT_KERNEL = False
logger = init_logger(__name__)
......@@ -86,6 +94,8 @@ class TopKTopPSampler(nn.Module):
self.forward = self.forward_native
else:
self.forward = self.forward_native
if HAS_LIGHTOP_OPT_KERNEL:
self.forward = self.forward_lightop_opt
self.apply_top_k_top_p = apply_top_k_top_p
......@@ -169,6 +179,19 @@ class TopKTopPSampler(nn.Module):
# because of slicing operation in logits_processor.
return flashinfer_sample(logits.contiguous(), k, p, generators), None
def forward_lightop_opt(
self,
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: torch.Tensor | None,
p: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Top-k and top-p sampling optimized by lightop."""
if (k is None and p is None) or generators:
return self.forward_native(logits, generators, k, p)
return lightop_sample(logits.contiguous(), k, p, generators), None
def forward_cpu(
self,
logits: torch.Tensor,
......@@ -453,6 +476,46 @@ def flashinfer_sample(
return next_token_ids.view(-1)
def lightop_sample(
logits: torch.Tensor,
k: torch.Tensor | None,
p: torch.Tensor | None,
generators: dict[int, torch.Generator],
) -> torch.Tensor:
"""Sample from the logits using lightop.
Statistically, this function is equivalent to the `random_sample` function.
However, this function is faster because it avoids sorting the logits tensor
via rejection sampling.
NOTE: The outputs of this function do not necessarily match the outputs of
the `random_sample` function. It only guarantees that the outputs are
statistically equivalent.
NOTE: This function includes CPU-GPU synchronization, while `random_sample`
does not. Call this function at the end of the forward pass to minimize
the synchronization overhead.
"""
assert not (k is None and p is None)
probs = logits.softmax(dim=-1, dtype=torch.float32)
if k is None:
# Top-p only.
next_token_ids = top_p_sampling_from_probs_lightop(
probs, p, deterministic=True
)
elif p is None:
# Top-k only.
next_token_ids = top_k_sampling_from_probs_lightop(
probs, k, deterministic=True
)
else:
# Both top-k and top-p.
next_token_ids = top_k_top_p_sampling_from_probs_lightop(
probs, k, p, deterministic=True
)
return next_token_ids.view(-1)
def _to_tensor_scalar_tuple(x):
if isinstance(x, torch.Tensor):
......
......@@ -98,8 +98,6 @@ class OptRejectionSampler(nn.Module):
# won't affect the original logits tensor.
assert logits is not None
sampling_metadata.all_greedy = True
sampling_metadata.all_random = False
sampler_output = self.sampler(
logits=logits,
sampling_metadata=replace(
......
......@@ -807,7 +807,7 @@ class InputBatch:
batch_update = self.batch_update_builder.get_and_reset(self.num_reqs)
for logit_proc in self.logitsprocs.all:
logit_proc.update_state(batch_update)
if batch_update:
if batch_update or repeat_counts is not None:
self.sampling_metadata = self._make_sampling_metadata(repeat_counts)
def _make_sampling_metadata(self, repeat_counts: Optional[torch.Tensor] = None) -> SamplingMetadata:
......
......@@ -4916,6 +4916,7 @@ class GPUModelRunner(
draft_probs = torch.randn(
num_reqs, self.speculative_config.num_speculative_tokens, logits.shape[-1], device=self.device,
dtype=logits.dtype)
dummy_metadata.all_greedy = True
logits = torch.randn(
num_tokens + num_reqs,
......@@ -5537,10 +5538,6 @@ class GPUModelRunner(
ValueError: If no valid block size found
"""
#exclude indexer backend
def _participates_in_block_size_selection(backend: type[AttentionBackend]) -> bool:
return not getattr(backend, "exclude_from_block_size_selection", False)
def block_size_is_supported(
backends: list[type[AttentionBackend]], block_size: int
) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment