Commit a17c410d authored by laibao's avatar laibao
Browse files

feat(sampler): 增加 reduced topk+topp 采样快速路径以降低全词表 softmax 开销

新增 VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER 开关并补充适用场景说明

在 V1 GPU 输入批预计算 max_top_k/has_any_no_top_k,native sampler 满足条件时走 reduced fast path,异常自动回退
parent 2544deb6
......@@ -301,6 +301,7 @@ if TYPE_CHECKING:
VLLM_REJECT_SAMPLE_OPT: bool = False
VLLM_USE_MOE_W16A16_TRITON: bool = False
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
def get_default_cache_root():
......@@ -1883,6 +1884,19 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_V1_FAST_TOKEN_ID_COPY":
lambda: (os.environ.get("VLLM_V1_FAST_TOKEN_ID_COPY", "False").lower() in
("true", "1")),
# If set to 1/True, enable reduced top-k/top-p sampling fast path in the
# V1 PyTorch-native sampler path.
#
# Recommended when both top_k is enabled and top_p < 1.0 (nucleus
# sampling). Not recommended for top-k only (top_p == 1.0) due to
# potential behavior differences when the k-th logit is tied.
"VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER":
lambda: (
os.environ.get(
"VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER", "False"
).lower()
in ("true", "1")
),
}
# --8<-- [end:env-vars-definition]
......
......@@ -40,5 +40,9 @@ class SamplingMetadata:
# Loaded logits processors
logitsprocs: LogitsProcessors
# Optional host-side summaries for top-k fast paths.
max_top_k: int | None = None
has_any_no_top_k: bool = False
# Speculative token ids
spec_token_ids: list[list[int]] | None = None
......@@ -95,12 +95,44 @@ class TopKTopPSampler(nn.Module):
generators: dict[int, torch.Generator],
k: torch.Tensor | None,
p: torch.Tensor | None,
*,
max_top_k: int | None = None,
has_any_no_top_k: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""
PyTorch-native implementation of top-k and top-p sampling.
The logits tensor may be updated in-place.
"""
# Fast path: when top-k is enabled, avoid full-vocab sort/softmax by
# sampling from only the reduced candidate set.
if (
self.logprobs_mode not in ("processed_logits", "processed_logprobs")
and envs.VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER
and k is not None
and p is not None
and max_top_k is not None
and not has_any_no_top_k
and max_top_k <= 4096
):
try:
return (
sample_top_k_top_p_reduced(
logits,
generators,
k,
p,
max_top_k=max_top_k,
),
None,
)
except Exception:
# Fall back to the reference implementation for safety.
logger.debug_once(
"Reduced top-k/top-p sampler failed; falling back to the "
"reference implementation."
)
logits = self.apply_top_k_top_p(logits, k, p)
logits_to_return = None
if self.logprobs_mode == "processed_logits":
......@@ -332,6 +364,47 @@ def random_sample(
return probs.div_(q).argmax(dim=-1).view(-1)
def sample_top_k_top_p_reduced(
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: torch.Tensor,
p: torch.Tensor | None,
*,
max_top_k: int,
) -> torch.Tensor:
"""Sample logits from only the top-k candidate set."""
vocab_size = logits.shape[-1]
# Guard for extreme values that can defeat the purpose of this fast path.
if max_top_k <= 0 or max_top_k >= vocab_size:
masked_logits = apply_top_k_top_p(logits, k, p)
probs = masked_logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
topk = logits.topk(max_top_k, dim=-1)
topk_logits = topk.values
topk_indices = topk.indices
# Apply per-row top-k on the reduced candidate set.
k = k.to(torch.long)
arange_k = torch.arange(max_top_k, device=logits.device).unsqueeze(0)
keep_k = arange_k < k.unsqueeze(1)
topk_logits = topk_logits.masked_fill(~keep_k, -float("inf"))
# Convert to probabilities over the reduced candidate set.
probs = topk_logits.softmax(dim=-1, dtype=torch.float32)
if p is not None:
# Apply top-p in descending-logit order within the reduced set.
cumprob = torch.cumsum(probs, dim=-1)
cumprob_prev = cumprob - probs
keep_p = cumprob_prev < p.unsqueeze(1)
probs = probs * keep_p
# Sample position in reduced set, then map back to vocab ids.
pos = random_sample(probs, generators)
return topk_indices.gather(1, pos.unsqueeze(1)).squeeze(1)
def flashinfer_sample(
logits: torch.Tensor,
k: torch.Tensor | None,
......
......@@ -5,6 +5,7 @@
import torch
import torch.nn as nn
from vllm import envs
from vllm.config.model import LogprobsMode
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
......@@ -184,6 +185,23 @@ class Sampler(nn.Module):
logits = processor.apply(logits)
# Apply top_k and/or top_p.
if (
envs.VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER
and sampling_metadata.top_k is not None
and sampling_metadata.top_p is not None
and sampling_metadata.max_top_k is not None
and not sampling_metadata.has_any_no_top_k
and self.topk_topp_sampler.forward.__name__ == "forward_native"
):
random_sampled, processed_logprobs = self.topk_topp_sampler(
logits,
sampling_metadata.generators,
sampling_metadata.top_k,
sampling_metadata.top_p,
max_top_k=sampling_metadata.max_top_k,
has_any_no_top_k=sampling_metadata.has_any_no_top_k,
)
else:
random_sampled, processed_logprobs = self.topk_topp_sampler(
logits,
sampling_metadata.generators,
......
......@@ -812,6 +812,16 @@ class InputBatch:
def _make_sampling_metadata(self, repeat_counts: Optional[torch.Tensor] = None) -> SamplingMetadata:
num_reqs = self.num_reqs
# Host-side summaries for reduced top-k/top-p sampling.
# Compute before copy_slice(top_k), which may rewrite top_k_cpu_tensor
# when repeat_counts is provided.
max_top_k = None
has_any_no_top_k = False
if not self.no_top_k and num_reqs > 0:
top_k_cpu = self.top_k_cpu[:num_reqs]
max_top_k = int(top_k_cpu.max())
has_any_no_top_k = bool((top_k_cpu == self.vocab_size).any())
if not self.all_greedy:
temperature = copy_slice(
self.temperature_cpu_tensor, self.temperature,
......@@ -889,6 +899,8 @@ class InputBatch:
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=self.bad_words_token_ids,
logitsprocs=self.logitsprocs,
max_top_k=max_top_k,
has_any_no_top_k=has_any_no_top_k,
)
def get_pooling_params(self) -> list[PoolingParams]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment