Commit b58514dd authored by 王敏's avatar 王敏
Browse files

[perf]1.优化pcp代码 2.优化ep低延迟模式调度,消除空泡

parent c462f3a0
......@@ -299,6 +299,13 @@ class ParallelConfig:
should only be set by API server scale-out.
"""
enable_lightly_cp: bool = False
"""Use lightly context parallel."""
enable_lightly_cplb: bool = False
"""Use lightly context parallel load balancing."""
@field_validator("disable_nccl_for_dp_synchronization", mode="wrap")
@classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
......
......@@ -1061,6 +1061,12 @@ class VllmConfig:
# Handle the KV connector configs
self._post_init_kv_transfer_config()
if self.parallel_config.enable_lightly_cp and not self.model_config.enforce_eager:
raise ValueError(
"Lightly context parallel currently only supports the eager mode!!!"
)
def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list:
# remove the sizes that not multiple of tp_size when
# enable sequence parallelism
......@@ -1186,8 +1192,7 @@ class VllmConfig:
if (
self.parallel_config.tensor_parallel_size > 1
and (self.compilation_config.pass_config.enable_sp)
#or envs.VLLM_MLA_CP)
and self.compilation_config.pass_config.enable_sp
):
cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism(
cudagraph_capture_sizes
......
......@@ -582,6 +582,9 @@ class EngineArgs:
kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend
tokens_only: bool = False
enable_lightly_cp: bool = ParallelConfig.enable_lightly_cp
enable_lightly_cplb: bool = ParallelConfig.enable_lightly_cplb
def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
......@@ -899,6 +902,15 @@ class EngineArgs:
"--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]
)
parallel_group.add_argument(
"--enable-lightly-cp",
**parallel_kwargs["enable_lightly_cp"],
)
parallel_group.add_argument(
"--enable-lightly-cplb",
**parallel_kwargs["enable_lightly_cplb"],
)
# KV cache arguments
cache_kwargs = get_kwargs(CacheConfig)
cache_group = parser.add_argument_group(
......@@ -1500,20 +1512,6 @@ class EngineArgs:
data_parallel_external_lb = (
self.data_parallel_external_lb or self.data_parallel_rank is not None
)
if (
envs.VLLM_MLA_CP
and self.max_num_batched_tokens is not None
and self.max_num_batched_tokens < self.tensor_parallel_size**3
):
raise ValueError(
"max_num_batched_tokens should be larger than "
"tensor_parallel_size ** 3 when enabled VLLM_MLA_CP"
)
logger.info("[MLACP] VLLM_MLA_CP is %s", envs.VLLM_MLA_CP)
logger.info("[MLACP] VLLM_MLA_CPLB is %s", envs.VLLM_MLA_CPLB)
# Local DP rank = 1, use pure-external LB.
if data_parallel_external_lb:
assert self.data_parallel_rank is not None, (
......@@ -1644,6 +1642,8 @@ class EngineArgs:
cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
_api_process_count=self._api_process_count,
_api_process_rank=self._api_process_rank,
enable_lightly_cp=self.enable_lightly_cp,
enable_lightly_cplb=self.enable_lightly_cplb,
)
speculative_config = self.create_speculative_config(
......
......@@ -324,8 +324,9 @@ if TYPE_CHECKING:
USE_LIGHTOP_TOPK: bool = False
USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX: bool = False
VLLM_DISABLE_DSA: bool = False
VLLM_MLA_CP: bool = False
VLLM_MLA_CPLB: bool = False
VLLM_LIGHTLY_CP_THRESHOULD: int = 2048
def get_default_cache_root():
return os.getenv(
"XDG_CACHE_HOME",
......@@ -2012,13 +2013,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DISABLE_DSA":
lambda: (os.environ.get("VLLM_DISABLE_DSA", "False").lower() in
("true", "1")),
# If set to 1/True, enable mla context parallel
"VLLM_MLA_CP":
lambda: (os.environ.get("VLLM_MLA_CP", "False").lower() in
("true", "1")),
"VLLM_MLA_CPLB":
lambda: (os.environ.get("VLLM_MLA_CPLB", "False").lower() in
("true", "1")),
# MLA_CP open threshold
"VLLM_LIGHTLY_CP_THRESHOULD":
lambda: int(os.getenv("VLLM_LIGHTLY_CP_THRESHOULD", "2048")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -242,7 +242,8 @@ class ForwardContext:
scatter_indexes_tensor: torch.Tensor | None = None
gather_indexes_tensor: torch.Tensor | None = None
enable_mla_cp: bool = False
enable_lightly_cp: bool = False
enable_lightly_cplb : bool = False
def __post_init__(self):
assert self.cudagraph_runtime_mode.valid_runtime_modes(), (
......@@ -279,7 +280,8 @@ def create_forward_context(
skip_compiled: bool = False,
scatter_indexes_tensor: torch.Tensor | None = None,
gather_indexes_tensor: torch.Tensor | None = None,
enable_mla_cp: bool = False
enable_lightly_cp: bool = False,
enable_lightly_cplb: bool = False
):
if vllm_config.compilation_config.fast_moe_cold_start:
if vllm_config.speculative_config is None:
......@@ -307,7 +309,8 @@ def create_forward_context(
skip_compiled=skip_compiled,
scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor,
enable_mla_cp=enable_mla_cp,
enable_lightly_cp=enable_lightly_cp,
enable_lightly_cplb=enable_lightly_cplb,
additional_kwargs=additional_kwargs or {},
)
......@@ -341,7 +344,8 @@ def set_forward_context(
skip_compiled: bool = False,
scatter_indexes_tensor: torch.Tensor | None = None,
gather_indexes_tensor: torch.Tensor | None = None,
enable_mla_cp: bool = False,
enable_lightly_cp: bool = False,
enable_lightly_cplb: bool = False,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
......@@ -353,7 +357,8 @@ def set_forward_context(
forward_start_time = time.perf_counter()
dp_metadata: DPMetadata | None = None
if vllm_config.parallel_config.data_parallel_size > 1 and (
if vllm_config.parallel_config.data_parallel_size > 1 and \
envs.VLLM_ALL2ALL_BACKEND != "deepep_low_latency" and (
attn_metadata is not None or num_tokens is not None
):
# If num_tokens_across_dp hasn't already been initialized, then
......@@ -404,7 +409,8 @@ def set_forward_context(
skip_compiled,
scatter_indexes_tensor,
gather_indexes_tensor,
enable_mla_cp
enable_lightly_cp,
enable_lightly_cplb
)
try:
......
......@@ -205,20 +205,6 @@ def moe_grouped_gemm(
return output
def native_w8a8_perChannel_batch_matmul(q_a1_all, weight13, qa1_scale_all, w13_scale, output_dtype):
A = q_a1_all.to(torch.float32)
B = weight13.to(torch.float32)
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
C = torch.bmm(A, B.transpose(1,2)) # [E, M, K]
C = qa1_scale_all * C * w13_scale.transpose(1,2) # Broadcast per-column scale
C = C.to(output_dtype)
return C
def scales_shape_stride_dtype(
E: int, T: int, G: int, quant_scale_fmt: DeepGemmQuantScaleFMT
) -> tuple[tuple[int, ...], tuple[int, ...], torch.dtype]:
......@@ -589,7 +575,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
use_nn_moe: bool | None = False,
**_
):
assert expert_tokens_meta is not None
......@@ -612,11 +597,13 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
expected_m = self.estimate_expected_m(
global_num_experts=global_num_experts,
max_tokens_per_expert=max_num_tokens,
topk=topk_ids.size(-1),
)
# expected_m = self.estimate_expected_m(
# global_num_experts=global_num_experts,
# max_tokens_per_expert=max_num_tokens,
# topk=topk_ids.size(-1),
# )
expected_m = self.get_expected_m()
if self.quant_config.use_fp8_w8a16 or self.quant_config.use_fp8_w8a8:
fp8_m_grouped_gemm_nt_masked(
......
......@@ -854,7 +854,7 @@ class FusedMoE(CustomOp):
def use_dp_chunking(self) -> bool:
return (
self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
#or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_parallel_config.use_mori_kernels
or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels)
) and envs.VLLM_ENABLE_MOE_DP_CHUNK
......
......@@ -406,6 +406,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
self.quant_config = quant_config
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers
self.expected_m = max_num_tokens
@staticmethod
def expects_unquantized_inputs(
......@@ -775,6 +776,12 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
"""
raise NotImplementedError
def set_expected_m(self, expected_m):
self.expected_m = expected_m
def get_expected_m(self):
return self.expected_m
def _slice_scales(
scales: torch.Tensor | None, start: int, end: int
......@@ -1074,6 +1081,12 @@ class FusedMoEModularKernel(torch.nn.Module):
The _prepare method is a wrapper around self.prepare_finalize.prepare
that handles DBO and async.
"""
expected_m = (
hidden_states.shape[0] * self.fused_experts.num_dispatchers * topk_ids.shape[1]
+ global_num_experts
) // global_num_experts
self.fused_experts.set_expected_m(expected_m)
if not self.prepare_finalize.supports_async():
# We shouldn't be running an a2a kernel that doesn't
# support async prepare/finalize
......
......@@ -3,7 +3,6 @@
from dataclasses import dataclass
import torch
from vllm.attention.layer import MLAAttention
from vllm.config import CacheConfig
import vllm.envs as envs
......@@ -115,6 +114,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
self.prefix = prefix
def forward(
self,
positions: torch.Tensor,
......@@ -189,11 +189,12 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
if llama_4_scaling is not None:
q *= llama_4_scaling
enable_mla_cp = get_forward_context().enable_mla_cp #envs.VLLM_MLA_CP # and not get_forward_context().draft_model
enable_lightly_cp = get_forward_context().enable_lightly_cp
enable_lightly_cplb = get_forward_context().enable_lightly_cplb
# if not use_fused_rms_rope_concat:
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
if enable_mla_cp:
if enable_lightly_cp:
kv_c_normed = tensor_model_parallel_all_gather(
kv_c_normed.contiguous(), 0
)
......@@ -202,7 +203,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if envs.VLLM_MLA_CPLB and gather_indexes_tensor is not None:
if enable_lightly_cplb and gather_indexes_tensor is not None:
# Reorder kv after pcp allgather.
kv_c_normed = torch.index_select(kv_c_normed, 0, gather_indexes_tensor)
k_pe = torch.index_select(k_pe, 0, gather_indexes_tensor)
......@@ -243,7 +244,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
"expose 'cos_sin_cache'."
)
if enable_mla_cp:
if enable_lightly_cp:
kv_c = tensor_model_parallel_all_gather(
kv_c.contiguous(), 0
)
......@@ -251,7 +252,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
k_pe.contiguous(), 0
)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if envs.VLLM_MLA_CPLB and gather_indexes_tensor is not None:
if enable_lightly_cplb and gather_indexes_tensor is not None:
# Reorder kv after pcp allgather.
kv_c = torch.index_select(kv_c, 0, gather_indexes_tensor)
k_pe = torch.index_select(k_pe, 0, gather_indexes_tensor)
......
......@@ -198,8 +198,8 @@ class DeepSeekMultiTokenPredictor(nn.Module):
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = spec_step_idx % self.num_mtp_layers
enable_mla_cp = get_forward_context().enable_mla_cp#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if enable_mla_cp:
enable_lightly_cp = get_forward_context().enable_lightly_cp
if enable_lightly_cp:
scatter_indexes_tensor = get_forward_context().scatter_indexes_tensor
if scatter_indexes_tensor is None:
inputs_embeds_per_rank = torch.chunk(inputs_embeds, chunks=self.tp_size, dim=0)
......@@ -212,7 +212,6 @@ class DeepSeekMultiTokenPredictor(nn.Module):
positions_per_rank = torch.chunk(positions, chunks=self.tp_size, dim=0)
positions = positions_per_rank[self.tp_rank].contiguous()
else:
#scatter_indexes_tensor = scatter_indexes_tensor[scatter_indexes_tensor != -1]
scatter_indexes_tensor = torch.where(scatter_indexes_tensor == -1, 0, scatter_indexes_tensor)
inputs_embeds = torch.index_select(inputs_embeds, 0, scatter_indexes_tensor)
......@@ -228,7 +227,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
current_step_idx,
)
if enable_mla_cp:
if enable_lightly_cp:
hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if gather_indexes_tensor is not None:
......
......@@ -183,10 +183,9 @@ class DeepseekAttention(nn.Module):
return output
def eff_2d_iqis_all_gather(
def iqis_all_gather(
iqis: tuple[torch.Tensor, torch.Tensor],
tp_size: int | None = None,
tp_rank: int | None = None
tp_size: int | None = None
) -> tuple[torch.Tensor, torch.Tensor]:
assert iqis is not None
iq_tensor, is_tensor = iqis
......@@ -221,6 +220,7 @@ def eff_2d_iqis_all_gather(
is_gathered = is_gathered_int8.view(torch.float32)
return (iq_gathered, is_gathered)
class DeepseekV2MLP(nn.Module):
def __init__(
self,
......@@ -267,15 +267,10 @@ class DeepseekV2MLP(nn.Module):
x,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
):
enable_mla_cp = get_forward_context().enable_mla_cp #envs.VLLM_MLA_CP# and not get_forward_context().draft_model
if enable_mla_cp:
enable_lightly_cp = get_forward_context().enable_lightly_cp
if enable_lightly_cp:
if iqis is not None and iqis[0] is not None and iqis[1] is not None:
if False:
i_q_gahter = tensor_model_parallel_all_gather(iqis[0].contiguous(), 0)
i_s_gather = tensor_model_parallel_all_gather(iqis[1].contiguous(), 0)
iqis = (i_q_gahter, i_s_gather)
else:
iqis = eff_2d_iqis_all_gather(iqis, tp_size=self.tp_size, tp_rank=get_tensor_model_parallel_rank())
iqis = iqis_all_gather(iqis, tp_size=self.tp_size)
else:
x = tensor_model_parallel_all_gather(x.contiguous(), 0)
......@@ -293,7 +288,7 @@ class DeepseekV2MLP(nn.Module):
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
if enable_mla_cp:
if enable_lightly_cp:
x = tensor_model_parallel_reduce_scatter(x.contiguous(), dim=0)
return x
elif self.tp_size > 1:
......@@ -301,66 +296,6 @@ class DeepseekV2MLP(nn.Module):
return x
class DeepseekV2SharedMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
reduce_results: bool = True,
is_sequence_parallel=False,
prefix: str = "",
) -> None:
super().__init__()
# If is_sequence_parallel, the input and output tensors are sharded
# across the ranks within the tp_group. In this case the weights are
# replicated and no collective ops are needed.
# Otherwise we use standard TP with an allreduce at the end.
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.down_proj"
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self,
x,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
):
if envs.USE_FUSED_RMS_QUANT:
gate_up, _ = self.gate_up_proj(x, iqis=iqis)
if envs.USE_FUSED_SILU_MUL_QUANT:
from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
xq, xs = lm_fuse_silu_mul_quant(gate_up)
x, _ = self.down_proj(gate_up, iqis=(xq, xs))
else:
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
else:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class DeepseekV2MoE(nn.Module):
def __init__(
self,
......@@ -431,7 +366,7 @@ class DeepseekV2MoE(nn.Module):
else:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV2SharedMLP(
self.shared_experts = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
......@@ -477,8 +412,8 @@ class DeepseekV2MoE(nn.Module):
def forward(self, hidden_states: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor:
enable_mla_cp = get_forward_context().enable_mla_cp#envs.VLLM_MLA_CP #and not get_forward_context().draft_model
if enable_mla_cp:
enable_lightly_cp = get_forward_context().enable_lightly_cp
if enable_lightly_cp:
hidden_states = tensor_model_parallel_all_gather(
hidden_states.contiguous(), 0
)
......@@ -553,7 +488,7 @@ class DeepseekV2MoE(nn.Module):
assert shared_output is not None
final_hidden_states += shared_output
if enable_mla_cp:
if enable_lightly_cp:
final_hidden_states = tensor_model_parallel_reduce_scatter(
final_hidden_states.contiguous(), 0
)
......@@ -889,13 +824,14 @@ class Indexer(nn.Module):
# `k_pe` is [num_tokens, 1, rope_dim] (MQA).
k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
enable_mla_cp = get_forward_context().enable_mla_cp#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if enable_mla_cp:
enable_lightly_cp = get_forward_context().enable_lightly_cp
if enable_lightly_cp:
k = tensor_model_parallel_all_gather(
k.contiguous(), 0
)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if envs.VLLM_MLA_CPLB and gather_indexes_tensor is not None:
enable_lightly_cplb = get_forward_context().enable_lightly_cplb
if enable_lightly_cplb and gather_indexes_tensor is not None:
k = torch.index_select(k, 0, gather_indexes_tensor)
# we only quant q here since k quant is fused with cache insertion
......@@ -964,8 +900,8 @@ class DeepseekV2MLAAttention(nn.Module):
self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0
#self.num_local_heads = num_heads // tp_size
self.num_local_heads = num_heads // tp_size if not envs.VLLM_MLA_CP else self.num_heads
self.num_local_heads = num_heads // tp_size if not \
vllm_config.parallel_config.enable_lightly_cp else self.num_heads
self.scaling = self.qk_head_dim**-0.5
self.max_position_embeddings = max_position_embeddings
......@@ -999,7 +935,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj",
disable_tp=envs.VLLM_MLA_CP,
disable_tp=vllm_config.parallel_config.enable_lightly_cp
)
else:
self.q_proj = ColumnParallelLinear(
......@@ -1008,7 +944,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
disable_tp=envs.VLLM_MLA_CP,
disable_tp=vllm_config.parallel_config.enable_lightly_cp,
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
......@@ -1017,7 +953,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj",
disable_tp=envs.VLLM_MLA_CP,
disable_tp=vllm_config.parallel_config.enable_lightly_cp,
)
self.o_proj = RowParallelLinear(
self.num_heads * self.v_head_dim,
......@@ -1025,7 +961,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
disable_tp=envs.VLLM_MLA_CP,
disable_tp=vllm_config.parallel_config.enable_lightly_cp,
)
if config.rope_parameters["rope_type"] != "default":
......@@ -1262,8 +1198,7 @@ class DeepseekV2DecoderLayer(nn.Module):
residual *= 1.0 / self.routed_scaling_factor
# Fully Connected
enable_mla_cp = get_forward_context().enable_mla_cp
skip_moe_large_batch_size = enable_mla_cp
enable_lightly_cp = get_forward_context().enable_lightly_cp
update_hs = True if isinstance(self.mlp, DeepseekV2MoE) else False
assert self.post_attention_layernorm.has_weight is True
_i_q, _i_s, residual = self.post_attention_layernorm(x=hidden_states,
......@@ -1272,7 +1207,7 @@ class DeepseekV2DecoderLayer(nn.Module):
update_input=update_hs
)
new_resi = residual
if skip_moe_large_batch_size and isinstance(self.mlp, DeepseekV2MoE):
if enable_lightly_cp and isinstance(self.mlp, DeepseekV2MoE):
hidden_states = self.mlp(hidden_states)
else:
hidden_states = self.mlp(hidden_states, iqis=(_i_q, _i_s))
......@@ -1437,8 +1372,8 @@ class DeepseekV2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
enable_mla_cp = get_forward_context().enable_mla_cp#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if enable_mla_cp:
enable_lightly_cp = get_forward_context().enable_lightly_cp
if enable_lightly_cp:
scatter_indexes_tensor = get_forward_context().scatter_indexes_tensor
if scatter_indexes_tensor is None:
hidden_states_per_rank = torch.chunk(hidden_states, chunks=self.tp_size, dim=0)
......@@ -1481,7 +1416,7 @@ class DeepseekV2Model(nn.Module):
)
if not get_pp_group().is_last_rank:
if enable_mla_cp:
if enable_lightly_cp:
hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
residual = tensor_model_parallel_all_gather(residual.contiguous(), dim=0)
return IntermediateTensors(
......@@ -1490,7 +1425,7 @@ class DeepseekV2Model(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual)
if enable_mla_cp:
if enable_lightly_cp:
hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if gather_indexes_tensor is not None:
......
......@@ -332,7 +332,6 @@ class CommonAttentionMetadata:
"""Number of requests"""
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading
num_actual_tokens: int
"""Total number of tokens in batch"""
max_query_len: int
"""Longest query in batch"""
......@@ -348,7 +347,7 @@ class CommonAttentionMetadata:
scatter_indexes_tensor: torch.Tensor | None = None
gather_indexes_tensor: torch.Tensor | None = None
cp_common_metadata: CpCommonAttentionMetadata | None = None
enable_mla_cp: bool = False
enable_lightly_cp: bool = False
causal: bool = True
......
......@@ -78,7 +78,8 @@ class SpecDecodeBaseProposer:
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
# The drafter can get longer sequences than the target model.
max_batch_size = vllm_config.scheduler_config.max_num_seqs if not envs.VLLM_MLA_CPLB \
max_batch_size = vllm_config.scheduler_config.max_num_seqs if not \
vllm_config.parallel_config.enable_lightly_cplb \
else vllm_config.scheduler_config.max_num_seqs * 2
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
......@@ -224,7 +225,10 @@ class SpecDecodeBaseProposer:
self.scatter_indexes_tensor = None
self.gather_indexes_tensor = None
if envs.VLLM_MLA_CP:
self.enable_lightly_cp = vllm_config.parallel_config.enable_lightly_cp
self.enable_lightly_cplb = self.enable_lightly_cp and vllm_config.parallel_config.enable_lightly_cplb
if self.enable_lightly_cp:
self.query_start_loc = CpuGpuBuffer(
max_batch_size + 1,
dtype=torch.int32,
......@@ -339,8 +343,8 @@ class SpecDecodeBaseProposer:
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
)
enable_mla_cp = envs.VLLM_MLA_CP and num_tokens > self.runner.mla_cp_threshould
if enable_mla_cp:
enable_lightly_cp = self.enable_lightly_cp and num_tokens > self.runner.lightly_cp_threshould
if enable_lightly_cp:
num_tokens_dp_padded = self._pad_for_mla_cp(num_tokens_dp_padded)
common_attn_metadata = self._prepare_cp_metadata(
......@@ -436,7 +440,8 @@ class SpecDecodeBaseProposer:
),
scatter_indexes_tensor=self.scatter_indexes_tensor,
gather_indexes_tensor=self.gather_indexes_tensor,
enable_mla_cp=envs.VLLM_MLA_CP and num_tokens > self.runner.mla_cp_threshould,
enable_lightly_cp=self.enable_lightly_cp and num_tokens > self.runner.lightly_cp_threshould,
enable_lightly_cplb=self.enable_lightly_cplb
):
ret_hidden_states = self.model(**model_kwargs)
if not self.model_returns_tuple():
......@@ -513,7 +518,7 @@ class SpecDecodeBaseProposer:
if batch_size_across_dp is not None:
batch_size_across_dp[self.dp_rank] = input_batch_size
if enable_mla_cp:
if enable_lightly_cp:
common_attn_metadata = common_attn_metadata.cp_common_metadata
common_attn_metadata.num_actual_tokens = batch_size
......
......@@ -6,6 +6,7 @@ import numpy as np
import torch
import torch.distributed as dist
import vllm.envs as envs
from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_dp_group
from vllm.logger import init_logger
......@@ -208,7 +209,7 @@ def coordinate_batch_across_dp(
]
"""
if parallel_config.data_parallel_size == 1:
if parallel_config.data_parallel_size == 1 or envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency":
# Early exit.
return False, None, cudagraph_mode
......
......@@ -189,6 +189,7 @@ from .utils import (
sanity_check_mm_encoder_outputs,
)
from vllm.v1.spec_decode.utils import DraftProbs
from vllm.utils.torch_utils import async_tensor_h2d
if TYPE_CHECKING:
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
......@@ -382,12 +383,15 @@ class GPUModelRunner(
self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group
self.max_num_tokens = scheduler_config.max_num_batched_tokens
#self.max_num_reqs = scheduler_config.max_num_seqs
self.enable_lightly_cp = self.parallel_config.enable_lightly_cp
self.enable_lightly_cplb = self.enable_lightly_cp and self.parallel_config.enable_lightly_cplb
self.max_num_reqs = (
scheduler_config.max_num_seqs
if not envs.VLLM_MLA_CPLB
if not self.enable_lightly_cplb
else scheduler_config.max_num_seqs * 2
)
self.mla_cp_threshould = 512
self.lightly_cp_threshould = envs.VLLM_LIGHTLY_CP_THRESHOULD
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
......@@ -1525,7 +1529,7 @@ class GPUModelRunner(
local_scatter_indexes_tensor = None
gather_indexes_tensor = None
if envs.VLLM_MLA_CPLB:
if self.enable_lightly_cp:
rank_tokens = 0
rank_pad_tokens = 0
accu_q_start = 0
......@@ -1736,7 +1740,7 @@ class GPUModelRunner(
cp_common_metadata=cp_common_metadata,
scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor,
enable_mla_cp=True
enable_lightly_cp=True
)
return cm_base
......@@ -2040,8 +2044,8 @@ class GPUModelRunner(
if self.model_config.enable_return_routed_experts:
self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy()
mla_cp_enable = envs.VLLM_MLA_CP and num_tokens > self.mla_cp_threshould
if not mla_cp_enable:
enable_lightly_cp = self.enable_lightly_cp and num_tokens > self.lightly_cp_threshould
if not enable_lightly_cp:
cm_base = CommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1],
......@@ -2183,19 +2187,19 @@ class GPUModelRunner(
cm.block_table_tensor = _get_block_table(kv_cache_gid)
cm.slot_mapping = slot_mappings[kv_cache_gid]
if cm.seq_indexes_list is not None:
if enable_lightly_cp and cm.seq_indexes_list is not None:
cm.block_table_tensor = cm.block_table_tensor[cm.seq_indexes_list]
if self.speculative_config and spec_decode_common_attn_metadata is None and hasattr(self, "drafter"):
if isinstance(self.drafter, EagleProposer):
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
if mla_cp_enable:
if enable_lightly_cp:
spec_decode_common_attn_metadata = cm.cp_common_metadata
else:
spec_decode_common_attn_metadata = cm
#spec_decode_common_attn_metadata = cm
else:
if mla_cp_enable:
if enable_lightly_cp:
spec_decode_common_attn_metadata = cm.cp_common_metadata
else:
spec_decode_common_attn_metadata = cm
......@@ -2230,7 +2234,7 @@ class GPUModelRunner(
_metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined]
if (
(not envs.VLLM_MLA_CP)
(not self.enable_lightly_cp)
and spec_decode_common_attn_metadata is not None
and (num_reqs != num_reqs_padded or num_tokens != num_tokens_padded)
):
......@@ -3110,7 +3114,7 @@ class GPUModelRunner(
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
if envs.VLLM_MLA_CP and num_scheduled_tokens > self.mla_cp_threshould:
if self.enable_lightly_cp and num_scheduled_tokens > self.lightly_cp_threshould:
return self._pad_for_mla_cp(num_scheduled_tokens)
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.compilation_config.pass_config.enable_sp and tp_size > 1:
......@@ -3808,7 +3812,7 @@ class GPUModelRunner(
)
num_tokens_padded = batch_desc.num_tokens
if envs.VLLM_MLA_CP and num_tokens_unpadded > self.mla_cp_threshould:
if self.enable_lightly_cp and num_tokens_unpadded > self.lightly_cp_threshould:
num_tokens_padded = self._pad_for_mla_cp(num_tokens_unpadded)
num_reqs_padded = (
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
......@@ -3927,7 +3931,8 @@ class GPUModelRunner(
skip_compiled=has_encoder_input,
scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor,
enable_mla_cp=envs.VLLM_MLA_CP and num_tokens_unpadded > self.mla_cp_threshould,
enable_lightly_cp=self.enable_lightly_cp and num_tokens_unpadded > self.lightly_cp_threshould,
enable_lightly_cplb=self.enable_lightly_cplb
),
record_function_or_nullcontext("gpu_model_runner: forward"),
self.maybe_get_kv_connector_output(
......@@ -4421,7 +4426,7 @@ class GPUModelRunner(
)
#total_num_tokens = common_attn_metadata.num_actual_tokens
if (
envs.VLLM_MLA_CP
self.enable_lightly_cp
and common_attn_metadata.cp_common_metadata is not None
):
total_num_tokens = (
......@@ -4952,9 +4957,6 @@ class GPUModelRunner(
or cudagraph_runtime_mode.valid_runtime_modes()
)
# if envs.VLLM_MLA_CP:
# num_tokens = max(self.tp_size, num_tokens)
# If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.separate_routine(). This means that we are using
# different graphs and/or modes for mixed prefill-decode batches vs.
......@@ -5116,9 +5118,6 @@ class GPUModelRunner(
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
model_kwargs = self._init_model_kwargs()
else:
self.input_ids.gpu[:num_tokens_padded] = torch.randint(0, self.model_config.get_vocab_size(),
(num_tokens_padded,),
dtype=torch.int32)
input_ids = self.input_ids.gpu[:num_tokens_padded]
inputs_embeds = None
......@@ -5159,7 +5158,8 @@ class GPUModelRunner(
batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices_padded,
slot_mapping=slot_mappings,
enable_mla_cp=envs.VLLM_MLA_CP and num_tokens_unpadded > self.mla_cp_threshould,
enable_lightly_cp=self.enable_lightly_cp and num_tokens_unpadded > self.lightly_cp_threshould,
enable_lightly_cplb=self.enable_lightly_cplb
),
):
outputs = self.model(
......@@ -5232,9 +5232,15 @@ class GPUModelRunner(
self.eplb_step(is_dummy=True, is_profile=is_profile)
logit_indices = np.cumsum(num_scheduled_tokens) - 1
logit_indices_device = torch.from_numpy(logit_indices).to(
self.device, non_blocking=True
)
# logit_indices_device = torch.from_numpy(logit_indices).to(
# self.device, non_blocking=True
# )
logit_indices = logit_indices.tolist()
logit_indices_device = async_tensor_h2d(
logit_indices,
dtype=torch.int32,
target_device=self.device,
pin_memory=True)
return hidden_states, hidden_states[logit_indices_device]
@torch.inference_mode()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment