Commit 46b9d30f authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.11.0-dev' of http://10.16.6.30/dcutoolkit/deeplearing/vllm into v0.11.0-dev

parents ef5ebdbf 77bec956
......@@ -230,6 +230,7 @@ if TYPE_CHECKING:
VLLM_USE_OPT_ZEROS: bool = False
VLLM_USE_OPT_CAT: bool = False
VLLM_USE_OPT_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD: bool = False
VLLM_USE_LIGHTOP_MOE_SUM: bool = False
VLLM_USE_LIGHTOP_MOE_ALIGN: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
......@@ -243,6 +244,7 @@ if TYPE_CHECKING:
VLLM_USE_FUSED_RMS_ROPE: bool = False
VLLM_USE_MARLIN_W16A16_MOE:bool = False
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1625,6 +1627,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_OPT_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_OPT_MOE_SUM", "False").lower() in
("true", "1")),
# vLLM will use lightop moe_sum_mul_add
"VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM_MUL_ADD",
"False").lower() in ("true", "1")),
# vLLM will use lightop moe_sum
"VLLM_USE_LIGHTOP_MOE_SUM":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_SUM", "False").lower() in
......@@ -1679,6 +1685,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_V1_FAST_TOKEN_ID_COPY":
lambda: (os.environ.get("VLLM_V1_FAST_TOKEN_ID_COPY", "False").lower() in
("true", "1")),
# If set to 1/True, enable the reduced top-k/top-p sampling path in the
# V1 PyTorch-native sampler.
"VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER":
lambda: (os.getenv("VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER",
"0").lower() in ("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -240,7 +240,7 @@ def fused_experts_impl_w16a16_marlin(hidden_states: torch.Tensor,
assert hidden_states.dtype in [torch.bfloat16,torch.float16]
compute_type = hidden_states.dtype
assert use_lightop, (
"only BW and set LMSLIM_USE_LIGHTOP=1 support Marlin W16A16 MoE")
"only BW and set VLLM_USE_LIGHTOP=1 support Marlin W16A16 MoE")
num_tokens, K = hidden_states.shape
E, twoN, K_w1 = w1.shape
......
......@@ -59,6 +59,26 @@ logger = init_logger(__name__)
if envs.VLLM_USE_GLOBAL_CACHE13:
moe_cache_singleton = None
# Cache Marlin-packed weights so we only reorder once per weight tensor.
_marlin_weight_cache: Dict[Tuple[int, torch.device, torch.dtype, torch.Size], torch.Tensor] = {}
# Cache packed W16A16 Marlin weights by parameter identity so we can offload
# original layouts from GPU without losing the packed copies.
_w16a16_marlin_weight_cache: Dict[int, Tuple[torch.Tensor, torch.Tensor]] = {}
def _get_marlin_packed_weight(weight: torch.Tensor,
pack_fn: Callable[[torch.Tensor], torch.Tensor]
) -> torch.Tensor:
key = (weight.data_ptr(), weight.device, weight.dtype, weight.shape)
cached = _marlin_weight_cache.get(key)
if cached is not None:
return cached
# Marlin packing is done per expert and reshaped back to original dims.
packed = torch.stack([pack_fn(weight[i]).contiguous()
for i in range(weight.shape[0])],
dim=0)
_marlin_weight_cache[key] = packed
return packed
arch_name = torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0]
arch_cu = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count
......@@ -1966,7 +1986,51 @@ def fused_experts_impl(
cache13 = get_moe_cache(top_k_num, N,K if not use_nn_moe else w2.shape[2], device=hidden_states.device, dtype=hidden_states.dtype)
else:
cache13 = torch.empty(M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]), device=hidden_states.device, dtype=hidden_states.dtype)
from vllm.model_executor.layers.fused_moe.fuse_moe_w16a16_marlin import fused_experts_impl_w16a16_marlin
if (envs.VLLM_USE_MARLIN_W16A16_MOE
and fused_experts_impl_w16a16_marlin is not None):
# Only pack when shapes match the expected [E, 2N, K] / [E, K, N/2] contract.
# If shapes are unexpected, skip packing and fall back to non-Marlin paths below.
from vllm.model_executor.layers.fused_moe.marlin_quant import w16a16_marlin_weight
cache_key = id(w1)
cached_marlin = _w16a16_marlin_weight_cache.get(cache_key)
if cached_marlin is None:
w1_marlin = _get_marlin_packed_weight(w1, w16a16_marlin_weight)
w2_marlin = _get_marlin_packed_weight(w2, w16a16_marlin_weight)
# Offload original layout weights from GPU to avoid double residency.
with torch.no_grad():
w1_cpu = w1.detach().to("cpu")
w2_cpu = w2.detach().to("cpu")
if hasattr(w1, "data"):
w1.data = w1_cpu # type: ignore[attr-defined]
else:
w1 = w1_cpu
if hasattr(w2, "data"):
w2.data = w2_cpu # type: ignore[attr-defined]
else:
w2 = w2_cpu
_w16a16_marlin_weight_cache[cache_key] = (w1_marlin, w2_marlin)
else:
w1_marlin, w2_marlin = cached_marlin
return fused_experts_impl_w16a16_marlin(
hidden_states=hidden_states,
w1=w1,
w2=w2,
w1_marlin=w1_marlin,
w2_marlin=w2_marlin,
topk_weights=topk_weights,
topk_ids=topk_ids,
cache13=cache13,
inplace=inplace,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
use_nn_moe=False,
)
if use_int8_w8a8 is True:
return fused_experts_impl_int8(hidden_states=hidden_states,
w1=w1,
......
......@@ -41,3 +41,10 @@ class SamplingMetadata:
# Loaded logits processors
logitsprocs: LogitsProcessors
# Optional host-side summaries to avoid device sync in fast paths.
# When `top_k` is provided, `max_top_k` is the maximum top-k value across
# the batch on the host (Python int).
max_top_k: Optional[int] = None
# True if any request in the batch has top_k == vocab_size (i.e. no top-k).
has_any_no_top_k: bool = False
......@@ -81,12 +81,35 @@ class TopKTopPSampler(nn.Module):
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
*,
max_top_k: Optional[int] = None,
has_any_no_top_k: bool = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
PyTorch-native implementation of top-k and top-p sampling.
The logits tensor may be updated in-place.
"""
# Fast path: when top-k is enabled, avoid full-vocab sort/softmax by
# sampling only from the top-k candidates (and applying top-p within
# that set). This is especially important on ROCm where the PyTorch
# native sort path can be very expensive.
#
# NOTE: Do not branch on device tensors here; doing so triggers
# `aten::is_nonzero` and synchronizes the CPU with GPU.
if (self.logprobs_mode not in ("processed_logits", "processed_logprobs")
and envs.VLLM_V1_USE_REDUCED_TOPK_TOPP_SAMPLER and k is not None
and max_top_k is not None and not has_any_no_top_k
and max_top_k <= 4096):
try:
return (sample_top_k_top_p_reduced(logits,
generators,
k,
p,
max_top_k=max_top_k), None)
except Exception:
# Fall back to the reference implementation for safety.
pass
logits = self.apply_top_k_top_p(logits, k, p)
logits_to_return = None
if self.logprobs_mode == "processed_logits":
......@@ -102,6 +125,9 @@ class TopKTopPSampler(nn.Module):
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
*,
max_top_k: Optional[int] = None,
has_any_no_top_k: bool = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""More optimized implementation for top-k and top-p sampling."""
# We prefer `random_sample` over `flashinfer_sample` when sorting is
......@@ -112,7 +138,12 @@ class TopKTopPSampler(nn.Module):
logger.debug_once("FlashInfer 0.2.3+ does not support "
"per-request generators. Falling back to "
"PyTorch-native implementation.")
return self.forward_native(logits, generators, k, p)
return self.forward_native(logits,
generators,
k,
p,
max_top_k=max_top_k,
has_any_no_top_k=has_any_no_top_k)
assert self.logprobs_mode not in (
"processed_logits", "processed_logprobs"
), "FlashInfer does not support returning logits/logprobs"
......@@ -127,6 +158,9 @@ class TopKTopPSampler(nn.Module):
generators: dict[int, torch.Generator],
k: Optional[torch.Tensor],
p: Optional[torch.Tensor],
*,
max_top_k: Optional[int] = None,
has_any_no_top_k: bool = False,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""
PyTorch-native implementation of top-k and top-p sampling for CPU.
......@@ -253,6 +287,56 @@ def random_sample(
return probs.div_(q).argmax(dim=-1).view(-1)
def sample_top_k_top_p_reduced(
logits: torch.Tensor,
generators: dict[int, torch.Generator],
k: torch.Tensor,
p: Optional[torch.Tensor],
*,
max_top_k: int,
) -> torch.Tensor:
"""Sample from logits using only the top-k candidates.
This avoids full-vocab sorting and full-vocab softmax/exponential kernels.
Semantics match applying top-k then top-p (if provided) and sampling from
the resulting distribution.
"""
vocab_size = logits.shape[-1]
# Cap for safety; very large top-k values may be expensive or defeat the
# purpose of the reduced path.
if max_top_k <= 0 or max_top_k >= vocab_size:
masked_logits = apply_top_k_top_p(logits, k, p)
probs = masked_logits.softmax(dim=-1, dtype=torch.float32)
return random_sample(probs, generators)
topk = logits.topk(max_top_k, dim=-1)
topk_logits = topk.values
topk_indices = topk.indices
# Apply per-row top-k (some rows may have smaller k).
# topk_logits is sorted descending by default.
k = k.to(torch.long)
arange_k = torch.arange(max_top_k, device=logits.device).unsqueeze(0)
keep_k = arange_k < k.unsqueeze(1)
topk_logits = topk_logits.masked_fill(~keep_k, -float("inf"))
# Convert to probabilities over the reduced candidate set.
probs = topk_logits.softmax(dim=-1, dtype=torch.float32)
if p is not None:
# Apply top-p within the reduced set. Since candidates are already
# sorted by descending logit, we can do cumulative top-p on this order.
# Keep tokens until cumprob exceeds p, inclusive of the boundary token.
cumprob = torch.cumsum(probs, dim=-1)
cumprob_prev = cumprob - probs
keep_p = cumprob_prev <= p.unsqueeze(1)
probs = probs * keep_p
# Sample a position within the reduced set and map it back to vocab ids.
pos = random_sample(probs, generators)
return topk_indices.gather(1, pos.unsqueeze(1)).squeeze(1)
def flashinfer_sample(
logits: torch.Tensor,
k: Optional[torch.Tensor],
......
......@@ -182,6 +182,8 @@ class Sampler(nn.Module):
sampling_metadata.generators,
sampling_metadata.top_k,
sampling_metadata.top_p,
max_top_k=sampling_metadata.max_top_k,
has_any_no_top_k=sampling_metadata.has_any_no_top_k,
)
if greedy_sampled is None:
......
......@@ -802,6 +802,15 @@ class InputBatch:
self.allowed_token_ids_mask, num_reqs)
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
# Host-side summaries to avoid device synchronization in sampling
# fast paths (e.g. reduced top-k/top-p sampling).
max_top_k = None
has_any_no_top_k = False
if not self.no_top_k and num_reqs > 0:
top_k_cpu = self.top_k_cpu[:num_reqs]
max_top_k = int(top_k_cpu.max())
has_any_no_top_k = bool((top_k_cpu == self.vocab_size).any())
return SamplingMetadata(
temperature=temperature,
all_greedy=self.all_greedy,
......@@ -819,6 +828,8 @@ class InputBatch:
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=self.bad_words_token_ids,
logitsprocs=self.logitsprocs,
max_top_k=max_top_k,
has_any_no_top_k=has_any_no_top_k,
)
def get_pooling_params(self) -> list[PoolingParams]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment