Commit 094f1299 authored by yangql's avatar yangql
Browse files

修复auto模式乱码的问题

parent 7d4db7e8
......@@ -128,6 +128,7 @@ if TYPE_CHECKING:
VLLM_NIXL_SIDE_CHANNEL_HOST: str = "localhost"
VLLM_NIXL_SIDE_CHANNEL_PORT: int = 5557
VLLM_ALL2ALL_BACKEND: str = "naive"
VLLM_MOE_HT_THRESHOLD: int = 128
VLLM_ALLOW_MNNVL: bool = False
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
......@@ -954,6 +955,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ALL2ALL_BACKEND":
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),
# VLLM_MOE_HT_THRESHOLD
"VLLM_MOE_HT_THRESHOLD":
lambda: int(os.getenv("VLLM_MOE_HT_THRESHOLD", "128")),
# use ALLOW_MNNVL
"VLLM_ALLOW_MNNVL":
lambda: (os.environ.get("VLLM_ALLOW_MNNVL", "False").lower() in
......
......@@ -26,35 +26,34 @@ class DeepEPAutoPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def _get_current_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize:
"""Get the appropriate prepare_finalize based on current phase."""
# Try to infer phase from forward_context if available:
# - 有 decode tokens -> 使用 LL (decode)
# - 否则默认 HT (prefill)
try:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
# Handle both v0 (single AttentionMetadata) and v1 (dict) formats
if isinstance(attn_metadata, dict):
if attn_metadata:
attn_metadata = next(iter(attn_metadata.values()))
else:
attn_metadata = None
# try:
# forward_context = get_forward_context()
# attn_metadata = forward_context.attn_metadata
# # Handle both v0 (single AttentionMetadata) and v1 (dict) formats
# if isinstance(attn_metadata, dict):
# if attn_metadata:
# attn_metadata = next(iter(attn_metadata.values()))
# else:
# attn_metadata = None
if attn_metadata is not None and hasattr(attn_metadata,
"num_decode_tokens"):
# 只根据 decode tokens 判定:有 decode -> decode,否则 prefill
self._current_phase = ("decode"
if attn_metadata.num_decode_tokens > 0
else "prefill")
except Exception:
# If forward_context is not available, use stored phase
pass
# if attn_metadata is not None and hasattr(attn_metadata,
# "num_decode_tokens"):
# # 只根据 decode tokens 判定:有 decode -> decode,否则 prefill
# self._current_phase = ("decode"
# if attn_metadata.num_decode_tokens > 0
# else "prefill")
# except Exception:
# # If forward_context is not available, use stored phase
# pass
# Prefill uses HT, decode uses LL
if self._current_phase == "prefill":
print("************prefill***********")
# return self.ht_prepare_finalize
# else:
# return self.ll_prepare_finalize
return self.ht_prepare_finalize
#rint("************prefill***********")
return self.ll_prepare_finalize
else:
# print("attn_metadata.num_decode_tokens",attn_metadata.num_decode_tokens)
return self.ht__prepare_finalize
#return self.ht_prepare_finalize
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
pf = self._get_current_prepare_finalize()
......
......@@ -183,6 +183,20 @@ class FusedMoEMethodBase(QuantizeMethodBase):
prepare_finalize = DeepEPAutoPrepareAndFinalize(
ht_prepare_finalize, ll_prepare_finalize)
experts_ht = self.select_gemm_impl(ht_prepare_finalize, moe)
experts_ll = self.select_gemm_impl(ll_prepare_finalize, moe)
self.topk_indices_dtype = ll_prepare_finalize.topk_indices_dtype()
self.fused_experts = DeepGemmDisabledFusedMoEModularKernel(
prepare_finalize,
experts_ll,
experts_ht=experts_ht,
experts_ll=experts_ll,
shared_experts=layer.shared_experts if hasattr(layer, "shared_experts") else None,
)
return
elif moe.use_deepep_ht_kernels:
assert moe.dp_size == all2all_manager.dp_world_size
......@@ -959,7 +973,10 @@ class FusedMoE(torch.nn.Module):
@property
def use_deepep_ll_kernels(self):
return self.moe_parallel_config.use_deepep_ll_kernels
@property
def use_deepep_auto_kernels(self):
return self.moe_parallel_config.use_deepep_auto_kernels
@property
def shared_experts(self) -> Optional[torch.nn.Module]:
return None
......@@ -1486,7 +1503,7 @@ class FusedMoE(torch.nn.Module):
early.
"""
return (self.use_pplx_kernels or self.use_deepep_ht_kernels
or self.use_deepep_ll_kernels)
or self.use_deepep_ll_kernels or self.use_deepep_auto_kernels)
def maybe_all_reduce_tensor_model_parallel(
self, final_hidden_states: torch.Tensor):
......@@ -1494,7 +1511,7 @@ class FusedMoE(torch.nn.Module):
The pplx combine kernel reduces across GPU ranks by default.
"""
if (self.use_pplx_kernels or self.use_deepep_ht_kernels
or self.use_deepep_ll_kernels):
or self.use_deepep_ll_kernels or self.use_deepep_auto_kernels):
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
......
......@@ -6,7 +6,9 @@ from math import prod
from typing import Optional, final
from dataclasses import dataclass
from collections.abc import Callable
from vllm.logger import init_logger
logger = init_logger(__name__)
import torch
import vllm.envs as envs
......@@ -828,11 +830,16 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
fused_experts: CustomizedFusedMoEPermuteExpertsUnpermute,
experts_ht: CustomizedFusedMoEPermuteExpertsUnpermute = None,
experts_ll: CustomizedFusedMoEPermuteExpertsUnpermute = None,
shared_experts: Optional[torch.nn.Module] = None,
):
super().__init__()
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
self.fused_experts_ht = experts_ht
self.fused_experts_ll = experts_ll
self.shared_experts = shared_experts
# assert prepare_finalize.activation_format == \
......@@ -899,7 +906,29 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
prepare_finalize = self.prepare_finalize
fused_experts = self.fused_experts
# from vllm.config import get_current_vllm_config
# vllm_cfg = get_current_vllm_config()
# max_tokens_for_cudagraph = vllm_cfg.compilation_config.max_capture_size
# num_ht_ll_tokens = max_tokens_for_cudagraph
if envs.VLLM_ALL2ALL_BACKEND == "deepep_auto":
num_ht_ll_tokens = envs.VLLM_MOE_HT_THRESHOLD
num_tokens = hidden_states.size(0)
logger.info("num_tokens=%d", num_tokens)
if num_tokens > num_ht_ll_tokens and False:
prepare_finalize = self.prepare_finalize.ht_prepare_finalize
fused_experts = self.fused_experts_ht
else:
prepare_finalize = self.prepare_finalize.ll_prepare_finalize
fused_experts = self.fused_experts_ll
a1 = hidden_states
if inplace and self.shared_experts is None:
......@@ -911,7 +940,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
if global_num_experts == -1:
global_num_experts = local_num_experts
prepare_ret = self.prepare_finalize.prepare_async(
prepare_ret = prepare_finalize.prepare_async(
a1,
a1_scale,
a2_scale,
......@@ -920,7 +949,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
global_num_experts,
expert_map,
apply_router_weight_on_input,
self.fused_experts.quant_config,
fused_experts.quant_config,
)
hook, receiver = (
prepare_ret if isinstance(prepare_ret, tuple) else (None, prepare_ret)
......@@ -951,7 +980,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
# and can never run into the tensor.numel() == 0 case.
fused_out = torch.empty_like(a1q).to(dtype=a1.dtype)
else:
fused_out = self.fused_experts.apply(
fused_out = fused_experts.apply(
None,
a1,
a1q,
......@@ -978,7 +1007,7 @@ class DeepGemmDisabledFusedMoEModularKernel(torch.nn.Module):
)
shared_output = None
hook = self.prepare_finalize.finalize_async(output, fused_out, topk_weights,
hook = prepare_finalize.finalize_async(output, fused_out, topk_weights,
topk_ids, apply_router_weight_on_input, apply_weights_and_reduce=True)
if self.shared_experts is not None:
......
......@@ -1237,8 +1237,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# prefills, causing unnecessary and excessive padding of activations.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager or envs.VLLM_ALL2ALL_BACKEND != 'naive':
# Early exit.
return 0, None
# auto
if not envs.VLLM_ALL2ALL_BACKEND == "deepep_auto":
# Early exit.
return 0, None
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
num_tokens, dp_size, dp_rank)
......@@ -1313,6 +1315,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
spec_decode_metadata,
num_scheduled_tokens_np) = (self._prepare_inputs(scheduler_output))
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
logger.info("***********self.cudagraph_batch_sizes_max",self.cudagraph_batch_sizes[-1])
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment