"src/webui/vscode:/vscode.git/clone" did not exist on "252f36f8092abeb86d13fb5bfda186b32004b875"
Commit 9b1e03d4 authored by laibao's avatar laibao
Browse files

V1 采样器:新增 reduced top-k/top-p 采样路径

新增环境变量 VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER 用于开关控制
扩展 SamplingMetadata,增加 max_top_k 与 has_any_no_top_k
在 InputBatch 侧计算 top-k 的主机端汇总信息,避免 device 同步
更新 Sampler/TopKTopPSampler 传递并使用新参数以启用优化采样
parent 7d5faa43
...@@ -196,6 +196,7 @@ if TYPE_CHECKING: ...@@ -196,6 +196,7 @@ if TYPE_CHECKING:
VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: bool = False
VLLM_USE_FUSED_RMS_ROPE: bool = False VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_MARLIN_W16A16_MOE:bool = False VLLM_USE_MARLIN_W16A16_MOE:bool = False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
VLLM_USE_FUSED_FILL_RMS_CAT:bool = False VLLM_USE_FUSED_FILL_RMS_CAT:bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -1280,6 +1281,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1280,6 +1281,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FUSED_FILL_RMS_CAT": "VLLM_USE_FUSED_FILL_RMS_CAT":
lambda: (os.environ.get("VLLM_USE_FUSED_FILL_RMS_CAT", "False").lower() in lambda: (os.environ.get("VLLM_USE_FUSED_FILL_RMS_CAT", "False").lower() in
("true", "1")), ("true", "1")),
# If set to 1/True, enable the reduced top-k/top-p sampling path in the
# V1 PyTorch-native sampler.
"VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER":
lambda: (os.getenv("VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER",
"0").lower() in ("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -41,3 +41,10 @@ class SamplingMetadata: ...@@ -41,3 +41,10 @@ class SamplingMetadata:
# Loaded logits processors # Loaded logits processors
logitsprocs: LogitsProcessorManager logitsprocs: LogitsProcessorManager
# Optional host-side summaries to avoid device sync in fast paths.
# When `top_k` is provided, `max_top_k` is the maximum top-k value across
# the batch on the host (Python int).
max_top_k: Optional[int] = None
# True if any request in the batch has top_k == vocab_size (i.e. no top-k).
has_any_no_top_k: bool = False
...@@ -72,12 +72,34 @@ class TopKTopPSampler(nn.Module): ...@@ -72,12 +72,34 @@ class TopKTopPSampler(nn.Module):
generators: dict[int, torch.Generator], generators: dict[int, torch.Generator],
k: Optional[torch.Tensor], k: Optional[torch.Tensor],
p: Optional[torch.Tensor], p: Optional[torch.Tensor],
*,
max_top_k: Optional[int] = None,
has_any_no_top_k: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
PyTorch-native implementation of top-k and top-p sampling. PyTorch-native implementation of top-k and top-p sampling.
The logits tensor may be updated in-place. The logits tensor may be updated in-place.
""" """
# Fast path: when top-k is enabled, avoid full-vocab sort/softmax by
# sampling only from the top-k candidates (and applying top-p within
# that set). This is especially important on ROCm where the PyTorch
# native sort path can be very expensive.
#
# NOTE: Do not branch on device tensors here; doing so triggers
# `aten::is_nonzero` and synchronizes the CPU with GPU.
if (envs.VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER and k is not None
and max_top_k is not None and not has_any_no_top_k
and max_top_k <= 4096):
try:
return sample_top_k_top_p_reduced(logits,
generators,
k,
p,
max_top_k=max_top_k)
except Exception:
# Fall back to the reference implementation for safety.
pass
logits = apply_top_k_top_p(logits, k, p) logits = apply_top_k_top_p(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32) probs = logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators) return random_sample(probs, generators)
...@@ -88,6 +110,9 @@ class TopKTopPSampler(nn.Module): ...@@ -88,6 +110,9 @@ class TopKTopPSampler(nn.Module):
generators: dict[int, torch.Generator], generators: dict[int, torch.Generator],
k: Optional[torch.Tensor], k: Optional[torch.Tensor],
p: Optional[torch.Tensor], p: Optional[torch.Tensor],
*,
max_top_k: Optional[int] = None,
has_any_no_top_k: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
"""More optimized implementation for top-k and top-p sampling.""" """More optimized implementation for top-k and top-p sampling."""
if k is None and p is None: if k is None and p is None:
...@@ -100,7 +125,12 @@ class TopKTopPSampler(nn.Module): ...@@ -100,7 +125,12 @@ class TopKTopPSampler(nn.Module):
logger.warning("FlashInfer 0.2.3+ does not support " logger.warning("FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to " "per-request generators. Falling back to "
"PyTorch-native implementation.") "PyTorch-native implementation.")
return self.forward_native(logits, generators, k, p) return self.forward_native(logits,
generators,
k,
p,
max_top_k=max_top_k,
has_any_no_top_k=has_any_no_top_k)
# flashinfer sampling functions expect contiguous logits. # flashinfer sampling functions expect contiguous logits.
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous # In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
# because of slicing operation in logits_processor. # because of slicing operation in logits_processor.
...@@ -257,6 +287,56 @@ def random_sample( ...@@ -257,6 +287,56 @@ def random_sample(
return probs.div_(q).argmax(dim=-1).view(-1) return probs.div_(q).argmax(dim=-1).view(-1)
def sample_top_k_top_p_reduced(
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: torch.Tensor,
p: Optional[torch.Tensor],
*,
max_top_k: int,
) -> torch.Tensor:
"""Sample from logits using only the top-k candidates.
This avoids full-vocab sorting and full-vocab softmax/exponential kernels.
Semantics match applying top-k then top-p (if provided) and sampling from
the resulting distribution.
"""
vocab_size = logits.shape[-1]
# Cap for safety; very large top-k values may be expensive or defeat the
# purpose of the reduced path.
if max_top_k <= 0 or max_top_k >= vocab_size:
masked_logits = apply_top_k_top_p(logits, k, p)
probs = masked_logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
topk = logits.topk(max_top_k, dim=-1)
topk_logits = topk.values
topk_indices = topk.indices
# Apply per-row top-k (some rows may have smaller k).
# topk_logits is sorted descending by default.
k = k.to(torch.long)
arange_k = torch.arange(max_top_k, device=logits.device).unsqueeze(0)
keep_k = arange_k < k.unsqueeze(1)
topk_logits = topk_logits.masked_fill(~keep_k, -float("inf"))
# Convert to probabilities over the reduced candidate set.
probs = topk_logits.softmax(dim=-1, dtype=torch.float32)
if p is not None:
# Apply top-p within the reduced set. Since candidates are already
# sorted by descending logit, we can do cumulative top-p on this order.
# Keep tokens until cumprob exceeds p, inclusive of the boundary token.
cumprob = torch.cumsum(probs, dim=-1)
cumprob_prev = cumprob - probs
keep_p = cumprob_prev <= p.unsqueeze(1)
probs = probs * keep_p
# Sample a position within the reduced set and map it back to vocab ids.
pos = random_sample(probs, generators)
return topk_indices.gather(1, pos.unsqueeze(1)).squeeze(1)
def flashinfer_sample( def flashinfer_sample(
logits: torch.Tensor, logits: torch.Tensor,
k: Optional[torch.Tensor], k: Optional[torch.Tensor],
......
...@@ -123,6 +123,8 @@ class Sampler(nn.Module): ...@@ -123,6 +123,8 @@ class Sampler(nn.Module):
sampling_metadata.generators, sampling_metadata.generators,
sampling_metadata.top_k, sampling_metadata.top_k,
sampling_metadata.top_p, sampling_metadata.top_p,
max_top_k=sampling_metadata.max_top_k,
has_any_no_top_k=sampling_metadata.has_any_no_top_k,
) )
if greedy_sampled is None: if greedy_sampled is None:
......
...@@ -660,6 +660,15 @@ class InputBatch: ...@@ -660,6 +660,15 @@ class InputBatch:
self.allowed_token_ids_mask, num_reqs) self.allowed_token_ids_mask, num_reqs)
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs] allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
# Host-side summaries to avoid device synchronization in sampling
# fast paths (e.g. reduced top-k/top-p sampling).
max_top_k = None
has_any_no_top_k = False
if not self.no_top_k and num_reqs > 0:
top_k_cpu = self.top_k_cpu[:num_reqs]
max_top_k = int(top_k_cpu.max())
has_any_no_top_k = bool((top_k_cpu == self.vocab_size).any())
return SamplingMetadata( return SamplingMetadata(
temperature=temperature, temperature=temperature,
all_greedy=self.all_greedy, all_greedy=self.all_greedy,
...@@ -677,6 +686,8 @@ class InputBatch: ...@@ -677,6 +686,8 @@ class InputBatch:
allowed_token_ids_mask=allowed_token_ids_mask, allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=self.bad_words_token_ids, bad_words_token_ids=self.bad_words_token_ids,
logitsprocs=self.logitsprocs, logitsprocs=self.logitsprocs,
max_top_k=max_top_k,
has_any_no_top_k=has_any_no_top_k,
) )
def _make_sampling_metadata_expanded( def _make_sampling_metadata_expanded(
...@@ -751,6 +762,13 @@ class InputBatch: ...@@ -751,6 +762,13 @@ class InputBatch:
expanded_generators[row_idx] = generator expanded_generators[row_idx] = generator
row_idx += 1 row_idx += 1
max_top_k = None
has_any_no_top_k = False
if not self.no_top_k and num_reqs > 0:
top_k_cpu = self.top_k_cpu[:num_reqs]
max_top_k = int(top_k_cpu.max())
has_any_no_top_k = bool((top_k_cpu == self.vocab_size).any())
return SamplingMetadata( return SamplingMetadata(
temperature=_expand_cpu_to_gpu( temperature=_expand_cpu_to_gpu(
None if all_greedy else self.temperature_cpu_tensor), None if all_greedy else self.temperature_cpu_tensor),
...@@ -776,6 +794,8 @@ class InputBatch: ...@@ -776,6 +794,8 @@ class InputBatch:
allowed_token_ids_mask, dtype=torch.bool), allowed_token_ids_mask, dtype=torch.bool),
bad_words_token_ids=expanded_bad_words_token_ids, bad_words_token_ids=expanded_bad_words_token_ids,
logitsprocs=self.logitsprocs, logitsprocs=self.logitsprocs,
max_top_k=max_top_k,
has_any_no_top_k=has_any_no_top_k,
) )
@property @property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment