Commit b58514dd authored by 王敏's avatar 王敏
Browse files

[perf]1.优化pcp代码 2.优化ep低延迟模式调度,消除空泡

parent c462f3a0
...@@ -299,6 +299,13 @@ class ParallelConfig: ...@@ -299,6 +299,13 @@ class ParallelConfig:
should only be set by API server scale-out. should only be set by API server scale-out.
""" """
enable_lightly_cp: bool = False
"""Use lightly context parallel."""
enable_lightly_cplb: bool = False
"""Use lightly context parallel load balancing."""
@field_validator("disable_nccl_for_dp_synchronization", mode="wrap") @field_validator("disable_nccl_for_dp_synchronization", mode="wrap")
@classmethod @classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any: def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
......
...@@ -1061,6 +1061,12 @@ class VllmConfig: ...@@ -1061,6 +1061,12 @@ class VllmConfig:
# Handle the KV connector configs # Handle the KV connector configs
self._post_init_kv_transfer_config() self._post_init_kv_transfer_config()
if self.parallel_config.enable_lightly_cp and not self.model_config.enforce_eager:
raise ValueError(
"Lightly context parallel currently only supports the eager mode!!!"
)
def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list: def update_sizes_for_sequence_parallelism(self, possible_sizes: list) -> list:
# remove the sizes that not multiple of tp_size when # remove the sizes that not multiple of tp_size when
# enable sequence parallelism # enable sequence parallelism
...@@ -1186,8 +1192,7 @@ class VllmConfig: ...@@ -1186,8 +1192,7 @@ class VllmConfig:
if ( if (
self.parallel_config.tensor_parallel_size > 1 self.parallel_config.tensor_parallel_size > 1
and (self.compilation_config.pass_config.enable_sp) and self.compilation_config.pass_config.enable_sp
#or envs.VLLM_MLA_CP)
): ):
cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism( cudagraph_capture_sizes = self.update_sizes_for_sequence_parallelism(
cudagraph_capture_sizes cudagraph_capture_sizes
......
...@@ -582,6 +582,9 @@ class EngineArgs: ...@@ -582,6 +582,9 @@ class EngineArgs:
kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend kv_offloading_backend: KVOffloadingBackend = CacheConfig.kv_offloading_backend
tokens_only: bool = False tokens_only: bool = False
enable_lightly_cp: bool = ParallelConfig.enable_lightly_cp
enable_lightly_cplb: bool = ParallelConfig.enable_lightly_cplb
def __post_init__(self): def __post_init__(self):
# support `EngineArgs(compilation_config={...})` # support `EngineArgs(compilation_config={...})`
# without having to manually construct a # without having to manually construct a
...@@ -899,6 +902,15 @@ class EngineArgs: ...@@ -899,6 +902,15 @@ class EngineArgs:
"--worker-extension-cls", **parallel_kwargs["worker_extension_cls"] "--worker-extension-cls", **parallel_kwargs["worker_extension_cls"]
) )
parallel_group.add_argument(
"--enable-lightly-cp",
**parallel_kwargs["enable_lightly_cp"],
)
parallel_group.add_argument(
"--enable-lightly-cplb",
**parallel_kwargs["enable_lightly_cplb"],
)
# KV cache arguments # KV cache arguments
cache_kwargs = get_kwargs(CacheConfig) cache_kwargs = get_kwargs(CacheConfig)
cache_group = parser.add_argument_group( cache_group = parser.add_argument_group(
...@@ -1500,20 +1512,6 @@ class EngineArgs: ...@@ -1500,20 +1512,6 @@ class EngineArgs:
data_parallel_external_lb = ( data_parallel_external_lb = (
self.data_parallel_external_lb or self.data_parallel_rank is not None self.data_parallel_external_lb or self.data_parallel_rank is not None
) )
if (
envs.VLLM_MLA_CP
and self.max_num_batched_tokens is not None
and self.max_num_batched_tokens < self.tensor_parallel_size**3
):
raise ValueError(
"max_num_batched_tokens should be larger than "
"tensor_parallel_size ** 3 when enabled VLLM_MLA_CP"
)
logger.info("[MLACP] VLLM_MLA_CP is %s", envs.VLLM_MLA_CP)
logger.info("[MLACP] VLLM_MLA_CPLB is %s", envs.VLLM_MLA_CPLB)
# Local DP rank = 1, use pure-external LB. # Local DP rank = 1, use pure-external LB.
if data_parallel_external_lb: if data_parallel_external_lb:
assert self.data_parallel_rank is not None, ( assert self.data_parallel_rank is not None, (
...@@ -1644,6 +1642,8 @@ class EngineArgs: ...@@ -1644,6 +1642,8 @@ class EngineArgs:
cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size, cp_kv_cache_interleave_size=self.cp_kv_cache_interleave_size,
_api_process_count=self._api_process_count, _api_process_count=self._api_process_count,
_api_process_rank=self._api_process_rank, _api_process_rank=self._api_process_rank,
enable_lightly_cp=self.enable_lightly_cp,
enable_lightly_cplb=self.enable_lightly_cplb,
) )
speculative_config = self.create_speculative_config( speculative_config = self.create_speculative_config(
......
...@@ -324,8 +324,9 @@ if TYPE_CHECKING: ...@@ -324,8 +324,9 @@ if TYPE_CHECKING:
USE_LIGHTOP_TOPK: bool = False USE_LIGHTOP_TOPK: bool = False
USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX: bool = False USE_LIGHTOP_CONVERT_REQ_INDEX_TO_GLOBAL_INDEX: bool = False
VLLM_DISABLE_DSA: bool = False VLLM_DISABLE_DSA: bool = False
VLLM_MLA_CP: bool = False VLLM_LIGHTLY_CP_THRESHOULD: int = 2048
VLLM_MLA_CPLB: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
"XDG_CACHE_HOME", "XDG_CACHE_HOME",
...@@ -2012,13 +2013,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -2012,13 +2013,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DISABLE_DSA": "VLLM_DISABLE_DSA":
lambda: (os.environ.get("VLLM_DISABLE_DSA", "False").lower() in lambda: (os.environ.get("VLLM_DISABLE_DSA", "False").lower() in
("true", "1")), ("true", "1")),
# If set to 1/True, enable mla context parallel
"VLLM_MLA_CP": # MLA_CP open threshold
lambda: (os.environ.get("VLLM_MLA_CP", "False").lower() in "VLLM_LIGHTLY_CP_THRESHOULD":
("true", "1")), lambda: int(os.getenv("VLLM_LIGHTLY_CP_THRESHOULD", "2048")),
"VLLM_MLA_CPLB":
lambda: (os.environ.get("VLLM_MLA_CPLB", "False").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -242,7 +242,8 @@ class ForwardContext: ...@@ -242,7 +242,8 @@ class ForwardContext:
scatter_indexes_tensor: torch.Tensor | None = None scatter_indexes_tensor: torch.Tensor | None = None
gather_indexes_tensor: torch.Tensor | None = None gather_indexes_tensor: torch.Tensor | None = None
enable_mla_cp: bool = False enable_lightly_cp: bool = False
enable_lightly_cplb : bool = False
def __post_init__(self): def __post_init__(self):
assert self.cudagraph_runtime_mode.valid_runtime_modes(), ( assert self.cudagraph_runtime_mode.valid_runtime_modes(), (
...@@ -279,7 +280,8 @@ def create_forward_context( ...@@ -279,7 +280,8 @@ def create_forward_context(
skip_compiled: bool = False, skip_compiled: bool = False,
scatter_indexes_tensor: torch.Tensor | None = None, scatter_indexes_tensor: torch.Tensor | None = None,
gather_indexes_tensor: torch.Tensor | None = None, gather_indexes_tensor: torch.Tensor | None = None,
enable_mla_cp: bool = False enable_lightly_cp: bool = False,
enable_lightly_cplb: bool = False
): ):
if vllm_config.compilation_config.fast_moe_cold_start: if vllm_config.compilation_config.fast_moe_cold_start:
if vllm_config.speculative_config is None: if vllm_config.speculative_config is None:
...@@ -307,7 +309,8 @@ def create_forward_context( ...@@ -307,7 +309,8 @@ def create_forward_context(
skip_compiled=skip_compiled, skip_compiled=skip_compiled,
scatter_indexes_tensor=scatter_indexes_tensor, scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor, gather_indexes_tensor=gather_indexes_tensor,
enable_mla_cp=enable_mla_cp, enable_lightly_cp=enable_lightly_cp,
enable_lightly_cplb=enable_lightly_cplb,
additional_kwargs=additional_kwargs or {}, additional_kwargs=additional_kwargs or {},
) )
...@@ -341,7 +344,8 @@ def set_forward_context( ...@@ -341,7 +344,8 @@ def set_forward_context(
skip_compiled: bool = False, skip_compiled: bool = False,
scatter_indexes_tensor: torch.Tensor | None = None, scatter_indexes_tensor: torch.Tensor | None = None,
gather_indexes_tensor: torch.Tensor | None = None, gather_indexes_tensor: torch.Tensor | None = None,
enable_mla_cp: bool = False, enable_lightly_cp: bool = False,
enable_lightly_cplb: bool = False,
): ):
"""A context manager that stores the current forward context, """A context manager that stores the current forward context,
can be attention metadata, etc. can be attention metadata, etc.
...@@ -353,7 +357,8 @@ def set_forward_context( ...@@ -353,7 +357,8 @@ def set_forward_context(
forward_start_time = time.perf_counter() forward_start_time = time.perf_counter()
dp_metadata: DPMetadata | None = None dp_metadata: DPMetadata | None = None
if vllm_config.parallel_config.data_parallel_size > 1 and ( if vllm_config.parallel_config.data_parallel_size > 1 and \
envs.VLLM_ALL2ALL_BACKEND != "deepep_low_latency" and (
attn_metadata is not None or num_tokens is not None attn_metadata is not None or num_tokens is not None
): ):
# If num_tokens_across_dp hasn't already been initialized, then # If num_tokens_across_dp hasn't already been initialized, then
...@@ -404,7 +409,8 @@ def set_forward_context( ...@@ -404,7 +409,8 @@ def set_forward_context(
skip_compiled, skip_compiled,
scatter_indexes_tensor, scatter_indexes_tensor,
gather_indexes_tensor, gather_indexes_tensor,
enable_mla_cp enable_lightly_cp,
enable_lightly_cplb
) )
try: try:
......
...@@ -205,20 +205,6 @@ def moe_grouped_gemm( ...@@ -205,20 +205,6 @@ def moe_grouped_gemm(
return output return output
def native_w8a8_perChannel_batch_matmul(q_a1_all, weight13, qa1_scale_all, w13_scale, output_dtype):
A = q_a1_all.to(torch.float32)
B = weight13.to(torch.float32)
assert A.shape[-1] == B.shape[-1], "Dimension mismatch"
C = torch.bmm(A, B.transpose(1,2)) # [E, M, K]
C = qa1_scale_all * C * w13_scale.transpose(1,2) # Broadcast per-column scale
C = C.to(output_dtype)
return C
def scales_shape_stride_dtype( def scales_shape_stride_dtype(
E: int, T: int, G: int, quant_scale_fmt: DeepGemmQuantScaleFMT E: int, T: int, G: int, quant_scale_fmt: DeepGemmQuantScaleFMT
) -> tuple[tuple[int, ...], tuple[int, ...], torch.dtype]: ) -> tuple[tuple[int, ...], tuple[int, ...], torch.dtype]:
...@@ -589,7 +575,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -589,7 +575,6 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2: torch.Tensor, workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None, expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
use_nn_moe: bool | None = False,
**_ **_
): ):
assert expert_tokens_meta is not None assert expert_tokens_meta is not None
...@@ -612,11 +597,13 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -612,11 +597,13 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N)) workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
expected_m = self.estimate_expected_m( # expected_m = self.estimate_expected_m(
global_num_experts=global_num_experts, # global_num_experts=global_num_experts,
max_tokens_per_expert=max_num_tokens, # max_tokens_per_expert=max_num_tokens,
topk=topk_ids.size(-1), # topk=topk_ids.size(-1),
) # )
expected_m = self.get_expected_m()
if self.quant_config.use_fp8_w8a16 or self.quant_config.use_fp8_w8a8: if self.quant_config.use_fp8_w8a16 or self.quant_config.use_fp8_w8a8:
fp8_m_grouped_gemm_nt_masked( fp8_m_grouped_gemm_nt_masked(
......
...@@ -854,7 +854,7 @@ class FusedMoE(CustomOp): ...@@ -854,7 +854,7 @@ class FusedMoE(CustomOp):
def use_dp_chunking(self) -> bool: def use_dp_chunking(self) -> bool:
return ( return (
self.moe_parallel_config.use_pplx_kernels self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels #or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_parallel_config.use_mori_kernels or self.moe_parallel_config.use_mori_kernels
or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels) or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels)
) and envs.VLLM_ENABLE_MOE_DP_CHUNK ) and envs.VLLM_ENABLE_MOE_DP_CHUNK
......
...@@ -406,6 +406,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -406,6 +406,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
self.quant_config = quant_config self.quant_config = quant_config
self.max_num_tokens = max_num_tokens self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers self.num_dispatchers = num_dispatchers
self.expected_m = max_num_tokens
@staticmethod @staticmethod
def expects_unquantized_inputs( def expects_unquantized_inputs(
...@@ -774,6 +775,12 @@ class FusedMoEPermuteExpertsUnpermute(ABC): ...@@ -774,6 +775,12 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
chooses to do weight application. chooses to do weight application.
""" """
raise NotImplementedError raise NotImplementedError
def set_expected_m(self, expected_m):
self.expected_m = expected_m
def get_expected_m(self):
return self.expected_m
def _slice_scales( def _slice_scales(
...@@ -1074,6 +1081,12 @@ class FusedMoEModularKernel(torch.nn.Module): ...@@ -1074,6 +1081,12 @@ class FusedMoEModularKernel(torch.nn.Module):
The _prepare method is a wrapper around self.prepare_finalize.prepare The _prepare method is a wrapper around self.prepare_finalize.prepare
that handles DBO and async. that handles DBO and async.
""" """
expected_m = (
hidden_states.shape[0] * self.fused_experts.num_dispatchers * topk_ids.shape[1]
+ global_num_experts
) // global_num_experts
self.fused_experts.set_expected_m(expected_m)
if not self.prepare_finalize.supports_async(): if not self.prepare_finalize.supports_async():
# We shouldn't be running an a2a kernel that doesn't # We shouldn't be running an a2a kernel that doesn't
# support async prepare/finalize # support async prepare/finalize
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
from vllm.attention.layer import MLAAttention from vllm.attention.layer import MLAAttention
from vllm.config import CacheConfig from vllm.config import CacheConfig
import vllm.envs as envs import vllm.envs as envs
...@@ -115,6 +114,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -115,6 +114,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
self.prefix = prefix self.prefix = prefix
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -189,11 +189,12 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -189,11 +189,12 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
if llama_4_scaling is not None: if llama_4_scaling is not None:
q *= llama_4_scaling q *= llama_4_scaling
enable_mla_cp = get_forward_context().enable_mla_cp #envs.VLLM_MLA_CP # and not get_forward_context().draft_model enable_lightly_cp = get_forward_context().enable_lightly_cp
enable_lightly_cplb = get_forward_context().enable_lightly_cplb
# if not use_fused_rms_rope_concat: # if not use_fused_rms_rope_concat:
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
if enable_mla_cp: if enable_lightly_cp:
kv_c_normed = tensor_model_parallel_all_gather( kv_c_normed = tensor_model_parallel_all_gather(
kv_c_normed.contiguous(), 0 kv_c_normed.contiguous(), 0
) )
...@@ -202,7 +203,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -202,7 +203,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
) )
gather_indexes_tensor = get_forward_context().gather_indexes_tensor gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if envs.VLLM_MLA_CPLB and gather_indexes_tensor is not None: if enable_lightly_cplb and gather_indexes_tensor is not None:
# Reorder kv after pcp allgather. # Reorder kv after pcp allgather.
kv_c_normed = torch.index_select(kv_c_normed, 0, gather_indexes_tensor) kv_c_normed = torch.index_select(kv_c_normed, 0, gather_indexes_tensor)
k_pe = torch.index_select(k_pe, 0, gather_indexes_tensor) k_pe = torch.index_select(k_pe, 0, gather_indexes_tensor)
...@@ -243,7 +244,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -243,7 +244,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
"expose 'cos_sin_cache'." "expose 'cos_sin_cache'."
) )
if enable_mla_cp: if enable_lightly_cp:
kv_c = tensor_model_parallel_all_gather( kv_c = tensor_model_parallel_all_gather(
kv_c.contiguous(), 0 kv_c.contiguous(), 0
) )
...@@ -251,7 +252,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -251,7 +252,7 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
k_pe.contiguous(), 0 k_pe.contiguous(), 0
) )
gather_indexes_tensor = get_forward_context().gather_indexes_tensor gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if envs.VLLM_MLA_CPLB and gather_indexes_tensor is not None: if enable_lightly_cplb and gather_indexes_tensor is not None:
# Reorder kv after pcp allgather. # Reorder kv after pcp allgather.
kv_c = torch.index_select(kv_c, 0, gather_indexes_tensor) kv_c = torch.index_select(kv_c, 0, gather_indexes_tensor)
k_pe = torch.index_select(k_pe, 0, gather_indexes_tensor) k_pe = torch.index_select(k_pe, 0, gather_indexes_tensor)
......
...@@ -198,8 +198,8 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -198,8 +198,8 @@ class DeepSeekMultiTokenPredictor(nn.Module):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = spec_step_idx % self.num_mtp_layers current_step_idx = spec_step_idx % self.num_mtp_layers
enable_mla_cp = get_forward_context().enable_mla_cp#envs.VLLM_MLA_CP # and not get_forward_context().draft_model enable_lightly_cp = get_forward_context().enable_lightly_cp
if enable_mla_cp: if enable_lightly_cp:
scatter_indexes_tensor = get_forward_context().scatter_indexes_tensor scatter_indexes_tensor = get_forward_context().scatter_indexes_tensor
if scatter_indexes_tensor is None: if scatter_indexes_tensor is None:
inputs_embeds_per_rank = torch.chunk(inputs_embeds, chunks=self.tp_size, dim=0) inputs_embeds_per_rank = torch.chunk(inputs_embeds, chunks=self.tp_size, dim=0)
...@@ -212,7 +212,6 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -212,7 +212,6 @@ class DeepSeekMultiTokenPredictor(nn.Module):
positions_per_rank = torch.chunk(positions, chunks=self.tp_size, dim=0) positions_per_rank = torch.chunk(positions, chunks=self.tp_size, dim=0)
positions = positions_per_rank[self.tp_rank].contiguous() positions = positions_per_rank[self.tp_rank].contiguous()
else: else:
#scatter_indexes_tensor = scatter_indexes_tensor[scatter_indexes_tensor != -1]
scatter_indexes_tensor = torch.where(scatter_indexes_tensor == -1, 0, scatter_indexes_tensor) scatter_indexes_tensor = torch.where(scatter_indexes_tensor == -1, 0, scatter_indexes_tensor)
inputs_embeds = torch.index_select(inputs_embeds, 0, scatter_indexes_tensor) inputs_embeds = torch.index_select(inputs_embeds, 0, scatter_indexes_tensor)
...@@ -228,7 +227,7 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -228,7 +227,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
current_step_idx, current_step_idx,
) )
if enable_mla_cp: if enable_lightly_cp:
hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0) hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if gather_indexes_tensor is not None: if gather_indexes_tensor is not None:
......
...@@ -183,10 +183,9 @@ class DeepseekAttention(nn.Module): ...@@ -183,10 +183,9 @@ class DeepseekAttention(nn.Module):
return output return output
def eff_2d_iqis_all_gather( def iqis_all_gather(
iqis: tuple[torch.Tensor, torch.Tensor], iqis: tuple[torch.Tensor, torch.Tensor],
tp_size: int | None = None, tp_size: int | None = None
tp_rank: int | None = None
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
assert iqis is not None assert iqis is not None
iq_tensor, is_tensor = iqis iq_tensor, is_tensor = iqis
...@@ -221,6 +220,7 @@ def eff_2d_iqis_all_gather( ...@@ -221,6 +220,7 @@ def eff_2d_iqis_all_gather(
is_gathered = is_gathered_int8.view(torch.float32) is_gathered = is_gathered_int8.view(torch.float32)
return (iq_gathered, is_gathered) return (iq_gathered, is_gathered)
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
def __init__( def __init__(
self, self,
...@@ -267,15 +267,10 @@ class DeepseekV2MLP(nn.Module): ...@@ -267,15 +267,10 @@ class DeepseekV2MLP(nn.Module):
x, x,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None *, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
): ):
enable_mla_cp = get_forward_context().enable_mla_cp #envs.VLLM_MLA_CP# and not get_forward_context().draft_model enable_lightly_cp = get_forward_context().enable_lightly_cp
if enable_mla_cp: if enable_lightly_cp:
if iqis is not None and iqis[0] is not None and iqis[1] is not None: if iqis is not None and iqis[0] is not None and iqis[1] is not None:
if False: iqis = iqis_all_gather(iqis, tp_size=self.tp_size)
i_q_gahter = tensor_model_parallel_all_gather(iqis[0].contiguous(), 0)
i_s_gather = tensor_model_parallel_all_gather(iqis[1].contiguous(), 0)
iqis = (i_q_gahter, i_s_gather)
else:
iqis = eff_2d_iqis_all_gather(iqis, tp_size=self.tp_size, tp_rank=get_tensor_model_parallel_rank())
else: else:
x = tensor_model_parallel_all_gather(x.contiguous(), 0) x = tensor_model_parallel_all_gather(x.contiguous(), 0)
...@@ -293,72 +288,12 @@ class DeepseekV2MLP(nn.Module): ...@@ -293,72 +288,12 @@ class DeepseekV2MLP(nn.Module):
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj(x) x, _ = self.down_proj(x)
if enable_mla_cp: if enable_lightly_cp:
x = tensor_model_parallel_reduce_scatter(x.contiguous(), dim=0) x = tensor_model_parallel_reduce_scatter(x.contiguous(), dim=0)
return x return x
elif self.tp_size > 1: elif self.tp_size > 1:
x = tensor_model_parallel_all_reduce(x) x = tensor_model_parallel_all_reduce(x)
return x return x
class DeepseekV2SharedMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
reduce_results: bool = True,
is_sequence_parallel=False,
prefix: str = "",
) -> None:
super().__init__()
# If is_sequence_parallel, the input and output tensors are sharded
# across the ranks within the tp_group. In this case the weights are
# replicated and no collective ops are needed.
# Otherwise we use standard TP with an allreduce at the end.
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.down_proj"
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self,
x,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
):
if envs.USE_FUSED_RMS_QUANT:
gate_up, _ = self.gate_up_proj(x, iqis=iqis)
if envs.USE_FUSED_SILU_MUL_QUANT:
from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
xq, xs = lm_fuse_silu_mul_quant(gate_up)
x, _ = self.down_proj(gate_up, iqis=(xq, xs))
else:
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
else:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class DeepseekV2MoE(nn.Module): class DeepseekV2MoE(nn.Module):
...@@ -431,7 +366,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -431,7 +366,7 @@ class DeepseekV2MoE(nn.Module):
else: else:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV2SharedMLP( self.shared_experts = DeepseekV2MLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
...@@ -477,8 +412,8 @@ class DeepseekV2MoE(nn.Module): ...@@ -477,8 +412,8 @@ class DeepseekV2MoE(nn.Module):
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None *, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor: ) -> torch.Tensor:
enable_mla_cp = get_forward_context().enable_mla_cp#envs.VLLM_MLA_CP #and not get_forward_context().draft_model enable_lightly_cp = get_forward_context().enable_lightly_cp
if enable_mla_cp: if enable_lightly_cp:
hidden_states = tensor_model_parallel_all_gather( hidden_states = tensor_model_parallel_all_gather(
hidden_states.contiguous(), 0 hidden_states.contiguous(), 0
) )
...@@ -553,7 +488,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -553,7 +488,7 @@ class DeepseekV2MoE(nn.Module):
assert shared_output is not None assert shared_output is not None
final_hidden_states += shared_output final_hidden_states += shared_output
if enable_mla_cp: if enable_lightly_cp:
final_hidden_states = tensor_model_parallel_reduce_scatter( final_hidden_states = tensor_model_parallel_reduce_scatter(
final_hidden_states.contiguous(), 0 final_hidden_states.contiguous(), 0
) )
...@@ -889,13 +824,14 @@ class Indexer(nn.Module): ...@@ -889,13 +824,14 @@ class Indexer(nn.Module):
# `k_pe` is [num_tokens, 1, rope_dim] (MQA). # `k_pe` is [num_tokens, 1, rope_dim] (MQA).
k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1) k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
enable_mla_cp = get_forward_context().enable_mla_cp#envs.VLLM_MLA_CP # and not get_forward_context().draft_model enable_lightly_cp = get_forward_context().enable_lightly_cp
if enable_mla_cp: if enable_lightly_cp:
k = tensor_model_parallel_all_gather( k = tensor_model_parallel_all_gather(
k.contiguous(), 0 k.contiguous(), 0
) )
gather_indexes_tensor = get_forward_context().gather_indexes_tensor gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if envs.VLLM_MLA_CPLB and gather_indexes_tensor is not None: enable_lightly_cplb = get_forward_context().enable_lightly_cplb
if enable_lightly_cplb and gather_indexes_tensor is not None:
k = torch.index_select(k, 0, gather_indexes_tensor) k = torch.index_select(k, 0, gather_indexes_tensor)
# we only quant q here since k quant is fused with cache insertion # we only quant q here since k quant is fused with cache insertion
...@@ -964,8 +900,8 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -964,8 +900,8 @@ class DeepseekV2MLAAttention(nn.Module):
self.num_heads = num_heads self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0 assert num_heads % tp_size == 0
#self.num_local_heads = num_heads // tp_size self.num_local_heads = num_heads // tp_size if not \
self.num_local_heads = num_heads // tp_size if not envs.VLLM_MLA_CP else self.num_heads vllm_config.parallel_config.enable_lightly_cp else self.num_heads
self.scaling = self.qk_head_dim**-0.5 self.scaling = self.qk_head_dim**-0.5
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
...@@ -999,7 +935,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -999,7 +935,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.q_b_proj", prefix=f"{prefix}.q_b_proj",
disable_tp=envs.VLLM_MLA_CP, disable_tp=vllm_config.parallel_config.enable_lightly_cp
) )
else: else:
self.q_proj = ColumnParallelLinear( self.q_proj = ColumnParallelLinear(
...@@ -1008,7 +944,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -1008,7 +944,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.q_proj", prefix=f"{prefix}.q_proj",
disable_tp=envs.VLLM_MLA_CP, disable_tp=vllm_config.parallel_config.enable_lightly_cp,
) )
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear( self.kv_b_proj = ColumnParallelLinear(
...@@ -1017,7 +953,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -1017,7 +953,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj", prefix=f"{prefix}.kv_b_proj",
disable_tp=envs.VLLM_MLA_CP, disable_tp=vllm_config.parallel_config.enable_lightly_cp,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.num_heads * self.v_head_dim, self.num_heads * self.v_head_dim,
...@@ -1025,7 +961,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -1025,7 +961,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
disable_tp=envs.VLLM_MLA_CP, disable_tp=vllm_config.parallel_config.enable_lightly_cp,
) )
if config.rope_parameters["rope_type"] != "default": if config.rope_parameters["rope_type"] != "default":
...@@ -1262,8 +1198,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1262,8 +1198,7 @@ class DeepseekV2DecoderLayer(nn.Module):
residual *= 1.0 / self.routed_scaling_factor residual *= 1.0 / self.routed_scaling_factor
# Fully Connected # Fully Connected
enable_mla_cp = get_forward_context().enable_mla_cp enable_lightly_cp = get_forward_context().enable_lightly_cp
skip_moe_large_batch_size = enable_mla_cp
update_hs = True if isinstance(self.mlp, DeepseekV2MoE) else False update_hs = True if isinstance(self.mlp, DeepseekV2MoE) else False
assert self.post_attention_layernorm.has_weight is True assert self.post_attention_layernorm.has_weight is True
_i_q, _i_s, residual = self.post_attention_layernorm(x=hidden_states, _i_q, _i_s, residual = self.post_attention_layernorm(x=hidden_states,
...@@ -1272,7 +1207,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1272,7 +1207,7 @@ class DeepseekV2DecoderLayer(nn.Module):
update_input=update_hs update_input=update_hs
) )
new_resi = residual new_resi = residual
if skip_moe_large_batch_size and isinstance(self.mlp, DeepseekV2MoE): if enable_lightly_cp and isinstance(self.mlp, DeepseekV2MoE):
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
else: else:
hidden_states = self.mlp(hidden_states, iqis=(_i_q, _i_s)) hidden_states = self.mlp(hidden_states, iqis=(_i_q, _i_s))
...@@ -1437,8 +1372,8 @@ class DeepseekV2Model(nn.Module): ...@@ -1437,8 +1372,8 @@ class DeepseekV2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
enable_mla_cp = get_forward_context().enable_mla_cp#envs.VLLM_MLA_CP # and not get_forward_context().draft_model enable_lightly_cp = get_forward_context().enable_lightly_cp
if enable_mla_cp: if enable_lightly_cp:
scatter_indexes_tensor = get_forward_context().scatter_indexes_tensor scatter_indexes_tensor = get_forward_context().scatter_indexes_tensor
if scatter_indexes_tensor is None: if scatter_indexes_tensor is None:
hidden_states_per_rank = torch.chunk(hidden_states, chunks=self.tp_size, dim=0) hidden_states_per_rank = torch.chunk(hidden_states, chunks=self.tp_size, dim=0)
...@@ -1481,7 +1416,7 @@ class DeepseekV2Model(nn.Module): ...@@ -1481,7 +1416,7 @@ class DeepseekV2Model(nn.Module):
) )
if not get_pp_group().is_last_rank: if not get_pp_group().is_last_rank:
if enable_mla_cp: if enable_lightly_cp:
hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0) hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
residual = tensor_model_parallel_all_gather(residual.contiguous(), dim=0) residual = tensor_model_parallel_all_gather(residual.contiguous(), dim=0)
return IntermediateTensors( return IntermediateTensors(
...@@ -1490,7 +1425,7 @@ class DeepseekV2Model(nn.Module): ...@@ -1490,7 +1425,7 @@ class DeepseekV2Model(nn.Module):
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
if enable_mla_cp: if enable_lightly_cp:
hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0) hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if gather_indexes_tensor is not None: if gather_indexes_tensor is not None:
......
...@@ -332,7 +332,6 @@ class CommonAttentionMetadata: ...@@ -332,7 +332,6 @@ class CommonAttentionMetadata:
"""Number of requests""" """Number of requests"""
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading # TODO(lucas): rename to num_tokens since it may be padded and this is misleading
num_actual_tokens: int num_actual_tokens: int
"""Total number of tokens in batch""" """Total number of tokens in batch"""
max_query_len: int max_query_len: int
"""Longest query in batch""" """Longest query in batch"""
...@@ -348,7 +347,7 @@ class CommonAttentionMetadata: ...@@ -348,7 +347,7 @@ class CommonAttentionMetadata:
scatter_indexes_tensor: torch.Tensor | None = None scatter_indexes_tensor: torch.Tensor | None = None
gather_indexes_tensor: torch.Tensor | None = None gather_indexes_tensor: torch.Tensor | None = None
cp_common_metadata: CpCommonAttentionMetadata | None = None cp_common_metadata: CpCommonAttentionMetadata | None = None
enable_mla_cp: bool = False enable_lightly_cp: bool = False
causal: bool = True causal: bool = True
......
...@@ -78,7 +78,8 @@ class SpecDecodeBaseProposer: ...@@ -78,7 +78,8 @@ class SpecDecodeBaseProposer:
self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
# The drafter can get longer sequences than the target model. # The drafter can get longer sequences than the target model.
max_batch_size = vllm_config.scheduler_config.max_num_seqs if not envs.VLLM_MLA_CPLB \ max_batch_size = vllm_config.scheduler_config.max_num_seqs if not \
vllm_config.parallel_config.enable_lightly_cplb \
else vllm_config.scheduler_config.max_num_seqs * 2 else vllm_config.scheduler_config.max_num_seqs * 2
self.max_num_tokens = ( self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
...@@ -224,7 +225,10 @@ class SpecDecodeBaseProposer: ...@@ -224,7 +225,10 @@ class SpecDecodeBaseProposer:
self.scatter_indexes_tensor = None self.scatter_indexes_tensor = None
self.gather_indexes_tensor = None self.gather_indexes_tensor = None
if envs.VLLM_MLA_CP:
self.enable_lightly_cp = vllm_config.parallel_config.enable_lightly_cp
self.enable_lightly_cplb = self.enable_lightly_cp and vllm_config.parallel_config.enable_lightly_cplb
if self.enable_lightly_cp:
self.query_start_loc = CpuGpuBuffer( self.query_start_loc = CpuGpuBuffer(
max_batch_size + 1, max_batch_size + 1,
dtype=torch.int32, dtype=torch.int32,
...@@ -339,8 +343,8 @@ class SpecDecodeBaseProposer: ...@@ -339,8 +343,8 @@ class SpecDecodeBaseProposer:
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
) )
enable_mla_cp = envs.VLLM_MLA_CP and num_tokens > self.runner.mla_cp_threshould enable_lightly_cp = self.enable_lightly_cp and num_tokens > self.runner.lightly_cp_threshould
if enable_mla_cp: if enable_lightly_cp:
num_tokens_dp_padded = self._pad_for_mla_cp(num_tokens_dp_padded) num_tokens_dp_padded = self._pad_for_mla_cp(num_tokens_dp_padded)
common_attn_metadata = self._prepare_cp_metadata( common_attn_metadata = self._prepare_cp_metadata(
...@@ -436,7 +440,8 @@ class SpecDecodeBaseProposer: ...@@ -436,7 +440,8 @@ class SpecDecodeBaseProposer:
), ),
scatter_indexes_tensor=self.scatter_indexes_tensor, scatter_indexes_tensor=self.scatter_indexes_tensor,
gather_indexes_tensor=self.gather_indexes_tensor, gather_indexes_tensor=self.gather_indexes_tensor,
enable_mla_cp=envs.VLLM_MLA_CP and num_tokens > self.runner.mla_cp_threshould, enable_lightly_cp=self.enable_lightly_cp and num_tokens > self.runner.lightly_cp_threshould,
enable_lightly_cplb=self.enable_lightly_cplb
): ):
ret_hidden_states = self.model(**model_kwargs) ret_hidden_states = self.model(**model_kwargs)
if not self.model_returns_tuple(): if not self.model_returns_tuple():
...@@ -513,7 +518,7 @@ class SpecDecodeBaseProposer: ...@@ -513,7 +518,7 @@ class SpecDecodeBaseProposer:
if batch_size_across_dp is not None: if batch_size_across_dp is not None:
batch_size_across_dp[self.dp_rank] = input_batch_size batch_size_across_dp[self.dp_rank] = input_batch_size
if enable_mla_cp: if enable_lightly_cp:
common_attn_metadata = common_attn_metadata.cp_common_metadata common_attn_metadata = common_attn_metadata.cp_common_metadata
common_attn_metadata.num_actual_tokens = batch_size common_attn_metadata.num_actual_tokens = batch_size
......
...@@ -6,6 +6,7 @@ import numpy as np ...@@ -6,6 +6,7 @@ import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import vllm.envs as envs
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_dp_group from vllm.distributed.parallel_state import get_dp_group
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -208,7 +209,7 @@ def coordinate_batch_across_dp( ...@@ -208,7 +209,7 @@ def coordinate_batch_across_dp(
] ]
""" """
if parallel_config.data_parallel_size == 1: if parallel_config.data_parallel_size == 1 or envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency":
# Early exit. # Early exit.
return False, None, cudagraph_mode return False, None, cudagraph_mode
......
...@@ -189,6 +189,7 @@ from .utils import ( ...@@ -189,6 +189,7 @@ from .utils import (
sanity_check_mm_encoder_outputs, sanity_check_mm_encoder_outputs,
) )
from vllm.v1.spec_decode.utils import DraftProbs from vllm.v1.spec_decode.utils import DraftProbs
from vllm.utils.torch_utils import async_tensor_h2d
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
...@@ -382,12 +383,15 @@ class GPUModelRunner( ...@@ -382,12 +383,15 @@ class GPUModelRunner(
self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group
self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_tokens = scheduler_config.max_num_batched_tokens
#self.max_num_reqs = scheduler_config.max_num_seqs #self.max_num_reqs = scheduler_config.max_num_seqs
self.enable_lightly_cp = self.parallel_config.enable_lightly_cp
self.enable_lightly_cplb = self.enable_lightly_cp and self.parallel_config.enable_lightly_cplb
self.max_num_reqs = ( self.max_num_reqs = (
scheduler_config.max_num_seqs scheduler_config.max_num_seqs
if not envs.VLLM_MLA_CPLB if not self.enable_lightly_cplb
else scheduler_config.max_num_seqs * 2 else scheduler_config.max_num_seqs * 2
) )
self.mla_cp_threshould = 512 self.lightly_cp_threshould = envs.VLLM_LIGHTLY_CP_THRESHOULD
# Broadcast PP output for external_launcher (torchrun) # Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks # to make sure we are synced across pp ranks
...@@ -1525,7 +1529,7 @@ class GPUModelRunner( ...@@ -1525,7 +1529,7 @@ class GPUModelRunner(
local_scatter_indexes_tensor = None local_scatter_indexes_tensor = None
gather_indexes_tensor = None gather_indexes_tensor = None
if envs.VLLM_MLA_CPLB: if self.enable_lightly_cp:
rank_tokens = 0 rank_tokens = 0
rank_pad_tokens = 0 rank_pad_tokens = 0
accu_q_start = 0 accu_q_start = 0
...@@ -1736,7 +1740,7 @@ class GPUModelRunner( ...@@ -1736,7 +1740,7 @@ class GPUModelRunner(
cp_common_metadata=cp_common_metadata, cp_common_metadata=cp_common_metadata,
scatter_indexes_tensor=scatter_indexes_tensor, scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor, gather_indexes_tensor=gather_indexes_tensor,
enable_mla_cp=True enable_lightly_cp=True
) )
return cm_base return cm_base
...@@ -2040,8 +2044,8 @@ class GPUModelRunner( ...@@ -2040,8 +2044,8 @@ class GPUModelRunner(
if self.model_config.enable_return_routed_experts: if self.model_config.enable_return_routed_experts:
self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy() self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy()
mla_cp_enable = envs.VLLM_MLA_CP and num_tokens > self.mla_cp_threshould enable_lightly_cp = self.enable_lightly_cp and num_tokens > self.lightly_cp_threshould
if not mla_cp_enable: if not enable_lightly_cp:
cm_base = CommonAttentionMetadata( cm_base = CommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1],
...@@ -2183,19 +2187,19 @@ class GPUModelRunner( ...@@ -2183,19 +2187,19 @@ class GPUModelRunner(
cm.block_table_tensor = _get_block_table(kv_cache_gid) cm.block_table_tensor = _get_block_table(kv_cache_gid)
cm.slot_mapping = slot_mappings[kv_cache_gid] cm.slot_mapping = slot_mappings[kv_cache_gid]
if cm.seq_indexes_list is not None: if enable_lightly_cp and cm.seq_indexes_list is not None:
cm.block_table_tensor = cm.block_table_tensor[cm.seq_indexes_list] cm.block_table_tensor = cm.block_table_tensor[cm.seq_indexes_list]
if self.speculative_config and spec_decode_common_attn_metadata is None and hasattr(self, "drafter"): if self.speculative_config and spec_decode_common_attn_metadata is None and hasattr(self, "drafter"):
if isinstance(self.drafter, EagleProposer): if isinstance(self.drafter, EagleProposer):
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
if mla_cp_enable: if enable_lightly_cp:
spec_decode_common_attn_metadata = cm.cp_common_metadata spec_decode_common_attn_metadata = cm.cp_common_metadata
else: else:
spec_decode_common_attn_metadata = cm spec_decode_common_attn_metadata = cm
#spec_decode_common_attn_metadata = cm #spec_decode_common_attn_metadata = cm
else: else:
if mla_cp_enable: if enable_lightly_cp:
spec_decode_common_attn_metadata = cm.cp_common_metadata spec_decode_common_attn_metadata = cm.cp_common_metadata
else: else:
spec_decode_common_attn_metadata = cm spec_decode_common_attn_metadata = cm
...@@ -2230,7 +2234,7 @@ class GPUModelRunner( ...@@ -2230,7 +2234,7 @@ class GPUModelRunner(
_metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined] _metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined]
if ( if (
(not envs.VLLM_MLA_CP) (not self.enable_lightly_cp)
and spec_decode_common_attn_metadata is not None and spec_decode_common_attn_metadata is not None
and (num_reqs != num_reqs_padded or num_tokens != num_tokens_padded) and (num_reqs != num_reqs_padded or num_tokens != num_tokens_padded)
): ):
...@@ -3110,7 +3114,7 @@ class GPUModelRunner( ...@@ -3110,7 +3114,7 @@ class GPUModelRunner(
# Pad tokens to multiple of tensor_parallel_size when # Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP # enabled collective fusion for SP
if envs.VLLM_MLA_CP and num_scheduled_tokens > self.mla_cp_threshould: if self.enable_lightly_cp and num_scheduled_tokens > self.lightly_cp_threshould:
return self._pad_for_mla_cp(num_scheduled_tokens) return self._pad_for_mla_cp(num_scheduled_tokens)
tp_size = self.vllm_config.parallel_config.tensor_parallel_size tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.compilation_config.pass_config.enable_sp and tp_size > 1: if self.compilation_config.pass_config.enable_sp and tp_size > 1:
...@@ -3808,7 +3812,7 @@ class GPUModelRunner( ...@@ -3808,7 +3812,7 @@ class GPUModelRunner(
) )
num_tokens_padded = batch_desc.num_tokens num_tokens_padded = batch_desc.num_tokens
if envs.VLLM_MLA_CP and num_tokens_unpadded > self.mla_cp_threshould: if self.enable_lightly_cp and num_tokens_unpadded > self.lightly_cp_threshould:
num_tokens_padded = self._pad_for_mla_cp(num_tokens_unpadded) num_tokens_padded = self._pad_for_mla_cp(num_tokens_unpadded)
num_reqs_padded = ( num_reqs_padded = (
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
...@@ -3927,7 +3931,8 @@ class GPUModelRunner( ...@@ -3927,7 +3931,8 @@ class GPUModelRunner(
skip_compiled=has_encoder_input, skip_compiled=has_encoder_input,
scatter_indexes_tensor=scatter_indexes_tensor, scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor, gather_indexes_tensor=gather_indexes_tensor,
enable_mla_cp=envs.VLLM_MLA_CP and num_tokens_unpadded > self.mla_cp_threshould, enable_lightly_cp=self.enable_lightly_cp and num_tokens_unpadded > self.lightly_cp_threshould,
enable_lightly_cplb=self.enable_lightly_cplb
), ),
record_function_or_nullcontext("gpu_model_runner: forward"), record_function_or_nullcontext("gpu_model_runner: forward"),
self.maybe_get_kv_connector_output( self.maybe_get_kv_connector_output(
...@@ -4421,7 +4426,7 @@ class GPUModelRunner( ...@@ -4421,7 +4426,7 @@ class GPUModelRunner(
) )
#total_num_tokens = common_attn_metadata.num_actual_tokens #total_num_tokens = common_attn_metadata.num_actual_tokens
if ( if (
envs.VLLM_MLA_CP self.enable_lightly_cp
and common_attn_metadata.cp_common_metadata is not None and common_attn_metadata.cp_common_metadata is not None
): ):
total_num_tokens = ( total_num_tokens = (
...@@ -4952,9 +4957,6 @@ class GPUModelRunner( ...@@ -4952,9 +4957,6 @@ class GPUModelRunner(
or cudagraph_runtime_mode.valid_runtime_modes() or cudagraph_runtime_mode.valid_runtime_modes()
) )
# if envs.VLLM_MLA_CP:
# num_tokens = max(self.tp_size, num_tokens)
# If cudagraph_mode.decode_mode() == FULL and # If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.separate_routine(). This means that we are using # cudagraph_mode.separate_routine(). This means that we are using
# different graphs and/or modes for mixed prefill-decode batches vs. # different graphs and/or modes for mixed prefill-decode batches vs.
...@@ -5116,9 +5118,6 @@ class GPUModelRunner( ...@@ -5116,9 +5118,6 @@ class GPUModelRunner(
inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded] inputs_embeds = self.inputs_embeds.gpu[:num_tokens_padded]
model_kwargs = self._init_model_kwargs() model_kwargs = self._init_model_kwargs()
else: else:
self.input_ids.gpu[:num_tokens_padded] = torch.randint(0, self.model_config.get_vocab_size(),
(num_tokens_padded,),
dtype=torch.int32)
input_ids = self.input_ids.gpu[:num_tokens_padded] input_ids = self.input_ids.gpu[:num_tokens_padded]
inputs_embeds = None inputs_embeds = None
...@@ -5159,7 +5158,8 @@ class GPUModelRunner( ...@@ -5159,7 +5158,8 @@ class GPUModelRunner(
batch_descriptor=batch_desc, batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices_padded, ubatch_slices=ubatch_slices_padded,
slot_mapping=slot_mappings, slot_mapping=slot_mappings,
enable_mla_cp=envs.VLLM_MLA_CP and num_tokens_unpadded > self.mla_cp_threshould, enable_lightly_cp=self.enable_lightly_cp and num_tokens_unpadded > self.lightly_cp_threshould,
enable_lightly_cplb=self.enable_lightly_cplb
), ),
): ):
outputs = self.model( outputs = self.model(
...@@ -5232,9 +5232,15 @@ class GPUModelRunner( ...@@ -5232,9 +5232,15 @@ class GPUModelRunner(
self.eplb_step(is_dummy=True, is_profile=is_profile) self.eplb_step(is_dummy=True, is_profile=is_profile)
logit_indices = np.cumsum(num_scheduled_tokens) - 1 logit_indices = np.cumsum(num_scheduled_tokens) - 1
logit_indices_device = torch.from_numpy(logit_indices).to( # logit_indices_device = torch.from_numpy(logit_indices).to(
self.device, non_blocking=True # self.device, non_blocking=True
) # )
logit_indices = logit_indices.tolist()
logit_indices_device = async_tensor_h2d(
logit_indices,
dtype=torch.int32,
target_device=self.device,
pin_memory=True)
return hidden_states, hidden_states[logit_indices_device] return hidden_states, hidden_states[logit_indices_device]
@torch.inference_mode() @torch.inference_mode()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment