Commit 77bec956 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.11.0-dev_tc_opt' into 'v0.11.0-dev'

V1 采样器:新增 reduced top-k/top-p 采样路径

See merge request dcutoolkit/deeplearing/vllm!350
parents 1c04646a 17f59521
......@@ -244,6 +244,7 @@ if TYPE_CHECKING:
VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_MARLIN_W16A16_MOE:bool = False
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1684,6 +1685,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_V1_FAST_TOKEN_ID_COPY":
lambda: (os.environ.get("VLLM_V1_FAST_TOKEN_ID_COPY", "False").lower() in
("true", "1")),
# If set to 1/True, enable the reduced top-k/top-p sampling path in the
# V1 PyTorch-native sampler.
"VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER":
lambda: (os.getenv("VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER",
"0").lower() in ("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -41,3 +41,10 @@ class SamplingMetadata:
# Loaded logits processors
logitsprocs: LogitsProcessors
# Optional host-side summaries to avoid device sync in fast paths.
# When `top_k` is provided, `max_top_k` is the maximum top-k value across
# the batch on the host (Python int).
max_top_k: Optional[int] = None
# True if any request in the batch has top_k == vocab_size (i.e. no top-k).
has_any_no_top_k: bool = False
......@@ -81,12 +81,35 @@ class TopKTopPSampler(nn.Module):
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
*,
max_top_k: Optional[int] = None,
has_any_no_top_k: bool = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
PyTorch-native implementation of top-k and top-p sampling.
The logits tensor may be updated in-place.
"""
# Fast path: when top-k is enabled, avoid full-vocab sort/softmax by
# sampling only from the top-k candidates (and applying top-p within
# that set). This is especially important on ROCm where the PyTorch
# native sort path can be very expensive.
#
# NOTE: Do not branch on device tensors here; doing so triggers
# `aten::is_nonzero` and synchronizes the CPU with GPU.
if (self.logprobs_mode not in ("processed_logits", "processed_logprobs")
and envs.VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER and k is not None
and max_top_k is not None and not has_any_no_top_k
and max_top_k <= 4096):
try:
return (sample_top_k_top_p_reduced(logits,
generators,
k,
p,
max_top_k=max_top_k), None)
except Exception:
# Fall back to the reference implementation for safety.
pass
logits = self.apply_top_k_top_p(logits, k, p)
logits_to_return = None
if self.logprobs_mode == "processed_logits":
......@@ -102,6 +125,9 @@ class TopKTopPSampler(nn.Module):
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
*,
max_top_k: Optional[int] = None,
has_any_no_top_k: bool = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""More optimized implementation for top-k and top-p sampling."""
# We prefer `random_sample` over `flashinfer_sample` when sorting is
......@@ -112,7 +138,12 @@ class TopKTopPSampler(nn.Module):
logger.debug_once("FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to "
"PyTorch-native implementation.")
return self.forward_native(logits, generators, k, p)
return self.forward_native(logits,
generators,
k,
p,
max_top_k=max_top_k,
has_any_no_top_k=has_any_no_top_k)
assert self.logprobs_mode not in (
"processed_logits", "processed_logprobs"
), "FlashInfer does not support returning logits/logprobs"
......@@ -127,6 +158,9 @@ class TopKTopPSampler(nn.Module):
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
*,
max_top_k: Optional[int] = None,
has_any_no_top_k: bool = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
PyTorch-native implementation of top-k and top-p sampling for CPU.
......@@ -253,6 +287,56 @@ def random_sample(
return probs.div_(q).argmax(dim=-1).view(-1)
def sample_top_k_top_p_reduced(
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: torch.Tensor,
p: Optional[torch.Tensor],
*,
max_top_k: int,
) -> torch.Tensor:
"""Sample from logits using only the top-k candidates.
This avoids full-vocab sorting and full-vocab softmax/exponential kernels.
Semantics match applying top-k then top-p (if provided) and sampling from
the resulting distribution.
"""
vocab_size = logits.shape[-1]
# Cap for safety; very large top-k values may be expensive or defeat the
# purpose of the reduced path.
if max_top_k <= 0 or max_top_k >= vocab_size:
masked_logits = apply_top_k_top_p(logits, k, p)
probs = masked_logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
topk = logits.topk(max_top_k, dim=-1)
topk_logits = topk.values
topk_indices = topk.indices
# Apply per-row top-k (some rows may have smaller k).
# topk_logits is sorted descending by default.
k = k.to(torch.long)
arange_k = torch.arange(max_top_k, device=logits.device).unsqueeze(0)
keep_k = arange_k < k.unsqueeze(1)
topk_logits = topk_logits.masked_fill(~keep_k, -float("inf"))
# Convert to probabilities over the reduced candidate set.
probs = topk_logits.softmax(dim=-1, dtype=torch.float32)
if p is not None:
# Apply top-p within the reduced set. Since candidates are already
# sorted by descending logit, we can do cumulative top-p on this order.
# Keep tokens until cumprob exceeds p, inclusive of the boundary token.
cumprob = torch.cumsum(probs, dim=-1)
cumprob_prev = cumprob - probs
keep_p = cumprob_prev <= p.unsqueeze(1)
probs = probs * keep_p
# Sample a position within the reduced set and map it back to vocab ids.
pos = random_sample(probs, generators)
return topk_indices.gather(1, pos.unsqueeze(1)).squeeze(1)
def flashinfer_sample(
logits: torch.Tensor,
k: Optional[torch.Tensor],
......
......@@ -182,6 +182,8 @@ class Sampler(nn.Module):
sampling_metadata.generators,
sampling_metadata.top_k,
sampling_metadata.top_p,
max_top_k=sampling_metadata.max_top_k,
has_any_no_top_k=sampling_metadata.has_any_no_top_k,
)
if greedy_sampled is None:
......
......@@ -802,6 +802,15 @@ class InputBatch:
self.allowed_token_ids_mask, num_reqs)
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
# Host-side summaries to avoid device synchronization in sampling
# fast paths (e.g. reduced top-k/top-p sampling).
max_top_k = None
has_any_no_top_k = False
if not self.no_top_k and num_reqs > 0:
top_k_cpu = self.top_k_cpu[:num_reqs]
max_top_k = int(top_k_cpu.max())
has_any_no_top_k = bool((top_k_cpu == self.vocab_size).any())
return SamplingMetadata(
temperature=temperature,
all_greedy=self.all_greedy,
......@@ -819,6 +828,8 @@ class InputBatch:
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=self.bad_words_token_ids,
logitsprocs=self.logitsprocs,
max_top_k=max_top_k,
has_any_no_top_k=has_any_no_top_k,
)
def get_pooling_params(self) -> list[PoolingParams]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment