Commit 26645e58 authored by 王敏's avatar 王敏
Browse files

[feat]基于mla sp实现pcp

parent d1fd831b
...@@ -1500,6 +1500,20 @@ class EngineArgs: ...@@ -1500,6 +1500,20 @@ class EngineArgs:
data_parallel_external_lb = ( data_parallel_external_lb = (
self.data_parallel_external_lb or self.data_parallel_rank is not None self.data_parallel_external_lb or self.data_parallel_rank is not None
) )
if (
envs.VLLM_MLA_CP
and self.max_num_batched_tokens is not None
and self.max_num_batched_tokens < self.tensor_parallel_size**3
):
raise ValueError(
"max_num_batched_tokens should be larger than "
"tensor_parallel_size ** 3 when enabled VLLM_MLA_CP"
)
logger.info("[MLACP] VLLM_MLA_CP is %s", envs.VLLM_MLA_CP)
logger.info("[MLACP] VLLM_MLA_CPLB is %s", envs.VLLM_MLA_CPLB)
# Local DP rank = 1, use pure-external LB. # Local DP rank = 1, use pure-external LB.
if data_parallel_external_lb: if data_parallel_external_lb:
assert self.data_parallel_rank is not None, ( assert self.data_parallel_rank is not None, (
......
...@@ -240,6 +240,9 @@ class ForwardContext: ...@@ -240,6 +240,9 @@ class ForwardContext:
additional_kwargs: dict[str, Any] = field(default_factory=dict) additional_kwargs: dict[str, Any] = field(default_factory=dict)
scatter_indexes_tensor: torch.Tensor | None = None
gather_indexes_tensor: torch.Tensor | None = None
def __post_init__(self): def __post_init__(self):
assert self.cudagraph_runtime_mode.valid_runtime_modes(), ( assert self.cudagraph_runtime_mode.valid_runtime_modes(), (
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}" f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
...@@ -273,6 +276,8 @@ def create_forward_context( ...@@ -273,6 +276,8 @@ def create_forward_context(
slot_mapping: dict[str, torch.Tensor] | None = None, slot_mapping: dict[str, torch.Tensor] | None = None,
additional_kwargs: dict[str, Any] | None = None, additional_kwargs: dict[str, Any] | None = None,
skip_compiled: bool = False, skip_compiled: bool = False,
scatter_indexes_tensor: torch.Tensor | None = None,
gather_indexes_tensor: torch.Tensor | None = None,
): ):
if vllm_config.compilation_config.fast_moe_cold_start: if vllm_config.compilation_config.fast_moe_cold_start:
if vllm_config.speculative_config is None: if vllm_config.speculative_config is None:
...@@ -298,6 +303,8 @@ def create_forward_context( ...@@ -298,6 +303,8 @@ def create_forward_context(
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
ubatch_slices=ubatch_slices, ubatch_slices=ubatch_slices,
skip_compiled=skip_compiled, skip_compiled=skip_compiled,
scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor,
additional_kwargs=additional_kwargs or {}, additional_kwargs=additional_kwargs or {},
) )
...@@ -329,6 +336,8 @@ def set_forward_context( ...@@ -329,6 +336,8 @@ def set_forward_context(
ubatch_slices: UBatchSlices | None = None, ubatch_slices: UBatchSlices | None = None,
slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None, slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None,
skip_compiled: bool = False, skip_compiled: bool = False,
scatter_indexes_tensor: torch.Tensor | None = None,
gather_indexes_tensor: torch.Tensor | None = None,
): ):
"""A context manager that stores the current forward context, """A context manager that stores the current forward context,
can be attention metadata, etc. can be attention metadata, etc.
...@@ -389,6 +398,8 @@ def set_forward_context( ...@@ -389,6 +398,8 @@ def set_forward_context(
slot_mapping, slot_mapping,
additional_kwargs, additional_kwargs,
skip_compiled, skip_compiled,
scatter_indexes_tensor,
gather_indexes_tensor,
) )
try: try:
......
...@@ -9,6 +9,9 @@ from vllm.config import CacheConfig ...@@ -9,6 +9,9 @@ from vllm.config import CacheConfig
import vllm.envs as envs import vllm.envs as envs
from vllm.model_executor.custom_op import PluggableLayer from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.distributed import (
tensor_model_parallel_all_gather,
)
@dataclass @dataclass
...@@ -183,8 +186,19 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -183,8 +186,19 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
if llama_4_scaling is not None: if llama_4_scaling is not None:
q *= llama_4_scaling q *= llama_4_scaling
enable_mla_cp = envs.VLLM_MLA_CP # and not get_forward_context().draft_model
# if not use_fused_rms_rope_concat: # if not use_fused_rms_rope_concat:
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
if enable_mla_cp:
kv_c_normed = tensor_model_parallel_all_gather(
kv_c_normed.contiguous(), 0
)
k_pe = tensor_model_parallel_all_gather(
k_pe.contiguous(), 0
)
attn_out = self.mla_attn( attn_out = self.mla_attn(
q, q,
kv_c_normed, kv_c_normed,
...@@ -220,6 +234,15 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -220,6 +234,15 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT requires rotary_emb to " "VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT requires rotary_emb to "
"expose 'cos_sin_cache'." "expose 'cos_sin_cache'."
) )
if enable_mla_cp:
kv_c = tensor_model_parallel_all_gather(
kv_c.contiguous(), 0
)
k_pe = tensor_model_parallel_all_gather(
k_pe.contiguous(), 0
)
attn_out = self.mla_attn( attn_out = self.mla_attn(
q[..., self.qk_nope_head_dim:], q[..., self.qk_nope_head_dim:],
kv_c, kv_c,
......
...@@ -71,7 +71,7 @@ def sparse_attn_indexer( ...@@ -71,7 +71,7 @@ def sparse_attn_indexer(
) )
attn_metadata = attn_metadata[k_cache_prefix] attn_metadata = attn_metadata[k_cache_prefix]
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
slot_mapping = attn_metadata.slot_mapping slot_mapping = attn_metadata.slot_mapping[:attn_metadata.num_kv_actual_tokens]
has_decode = attn_metadata.num_decodes > 0 has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0 has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
......
...@@ -46,6 +46,7 @@ from vllm.distributed import ( ...@@ -46,6 +46,7 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce, tensor_model_parallel_reduce_scatter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
...@@ -211,10 +212,82 @@ class DeepseekV2MLP(nn.Module): ...@@ -211,10 +212,82 @@ class DeepseekV2MLP(nn.Module):
hidden_size, hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results, #reduce_results=reduce_results,
reduce_results=False,
disable_tp=is_sequence_parallel, disable_tp=is_sequence_parallel,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
) )
self.tp_size = get_tensor_model_parallel_world_size()
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self,
x,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
):
enable_mla_cp = envs.VLLM_MLA_CP# and not get_forward_context().draft_model
if enable_mla_cp:
x = tensor_model_parallel_all_gather(
x.contiguous(), 0
)
if envs.USE_FUSED_RMS_QUANT:
gate_up, _ = self.gate_up_proj(x, iqis=iqis)
if envs.USE_FUSED_SILU_MUL_QUANT:
from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
xq, xs = lm_fuse_silu_mul_quant(gate_up)
x, _ = self.down_proj(gate_up, iqis=(xq, xs))
else:
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
else:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
if enable_mla_cp:
x = tensor_model_parallel_reduce_scatter(x.contiguous(), dim=0)
elif self.tp_size > 1:
x = tensor_model_parallel_all_reduce(x)
return x
class DeepseekV2SharedMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
reduce_results: bool = True,
is_sequence_parallel=False,
prefix: str = "",
) -> None:
super().__init__()
# If is_sequence_parallel, the input and output tensors are sharded
# across the ranks within the tp_group. In this case the weights are
# replicated and no collective ops are needed.
# Otherwise we use standard TP with an allreduce at the end.
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.down_proj"
)
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError( raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now." f"Unsupported activation: {hidden_act}. Only silu is supported for now."
...@@ -311,7 +384,7 @@ class DeepseekV2MoE(nn.Module): ...@@ -311,7 +384,7 @@ class DeepseekV2MoE(nn.Module):
else: else:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV2MLP( self.shared_experts = DeepseekV2SharedMLP(
hidden_size=config.hidden_size, hidden_size=config.hidden_size,
intermediate_size=intermediate_size, intermediate_size=intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
...@@ -357,6 +430,11 @@ class DeepseekV2MoE(nn.Module): ...@@ -357,6 +430,11 @@ class DeepseekV2MoE(nn.Module):
def forward(self, hidden_states: torch.Tensor, def forward(self, hidden_states: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None *, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor: ) -> torch.Tensor:
enable_mla_cp = envs.VLLM_MLA_CP #and not get_forward_context().draft_model
if enable_mla_cp:
hidden_states = tensor_model_parallel_all_gather(
hidden_states.contiguous(), 0
)
num_tokens, hidden_dim = hidden_states.shape num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim) hidden_states = hidden_states.view(-1, hidden_dim)
...@@ -428,7 +506,12 @@ class DeepseekV2MoE(nn.Module): ...@@ -428,7 +506,12 @@ class DeepseekV2MoE(nn.Module):
assert shared_output is not None assert shared_output is not None
final_hidden_states += shared_output final_hidden_states += shared_output
if self.is_sequence_parallel: if enable_mla_cp:
final_hidden_states = tensor_model_parallel_reduce_scatter(
final_hidden_states.contiguous(), 0
)
return final_hidden_states
elif self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather( final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0 final_hidden_states, 0
) )
...@@ -756,6 +839,12 @@ class Indexer(nn.Module): ...@@ -756,6 +839,12 @@ class Indexer(nn.Module):
# `k_pe` is [num_tokens, 1, rope_dim] (MQA). # `k_pe` is [num_tokens, 1, rope_dim] (MQA).
k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1) k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
enable_mla_cp = envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if enable_mla_cp:
k = tensor_model_parallel_all_gather(
k.contiguous(), 0
)
# we only quant q here since k quant is fused with cache insertion # we only quant q here since k quant is fused with cache insertion
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938": if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
q = q.view(-1, self.head_dim) q = q.view(-1, self.head_dim)
...@@ -819,7 +908,8 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -819,7 +908,8 @@ class DeepseekV2MLAAttention(nn.Module):
self.num_heads = num_heads self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0 assert num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size #self.num_local_heads = num_heads // tp_size
self.num_local_heads = num_heads // tp_size if not envs.VLLM_MLA_CP else self.num_heads
self.scaling = self.qk_head_dim**-0.5 self.scaling = self.qk_head_dim**-0.5
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
...@@ -853,6 +943,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -853,6 +943,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.q_b_proj", prefix=f"{prefix}.q_b_proj",
disable_tp=envs.VLLM_MLA_CP,
) )
else: else:
self.q_proj = ColumnParallelLinear( self.q_proj = ColumnParallelLinear(
...@@ -861,6 +952,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -861,6 +952,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.q_proj", prefix=f"{prefix}.q_proj",
disable_tp=envs.VLLM_MLA_CP,
) )
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear( self.kv_b_proj = ColumnParallelLinear(
...@@ -869,6 +961,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -869,6 +961,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj", prefix=f"{prefix}.kv_b_proj",
disable_tp=envs.VLLM_MLA_CP,
) )
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
self.num_heads * self.v_head_dim, self.num_heads * self.v_head_dim,
...@@ -876,6 +969,7 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -876,6 +969,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
disable_tp=envs.VLLM_MLA_CP,
) )
if config.rope_parameters["rope_type"] != "default": if config.rope_parameters["rope_type"] != "default":
...@@ -1217,6 +1311,9 @@ class DeepseekV2Model(nn.Module): ...@@ -1217,6 +1311,9 @@ class DeepseekV2Model(nn.Module):
self.config = config self.config = config
self.device = current_platform.device_type self.device = current_platform.device_type
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
#添加判断,默认开启DSA #添加判断,默认开启DSA
force_disable_dsa = os.environ.get("VLLM_DISABLE_DSA", "0") == "1" force_disable_dsa = os.environ.get("VLLM_DISABLE_DSA", "0") == "1"
...@@ -1279,6 +1376,19 @@ class DeepseekV2Model(nn.Module): ...@@ -1279,6 +1376,19 @@ class DeepseekV2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"] hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"] residual = intermediate_tensors["residual"]
enable_mla_cp = envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if enable_mla_cp:
hidden_states_per_rank = torch.chunk(hidden_states, chunks=self.tp_size, dim=0)
hidden_states = hidden_states_per_rank[self.tp_rank].contiguous()
if residual is not None:
residual_per_rank = torch.chunk(residual, chunks=self.tp_size, dim=0)
residual = residual_per_rank[self.tp_rank].contiguous()
if positions is not None:
positions_per_rank = torch.chunk(positions, chunks=self.tp_size, dim=0)
positions = positions_per_rank[self.tp_rank].contiguous()
# Compute llama 4 scaling once per forward pass if enabled # Compute llama 4 scaling once per forward pass if enabled
llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None) llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None)
llama_4_scaling: torch.Tensor | None llama_4_scaling: torch.Tensor | None
...@@ -1304,6 +1414,10 @@ class DeepseekV2Model(nn.Module): ...@@ -1304,6 +1414,10 @@ class DeepseekV2Model(nn.Module):
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
if enable_mla_cp:
hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
return hidden_states return hidden_states
......
...@@ -285,6 +285,18 @@ class AttentionMetadata: ...@@ -285,6 +285,18 @@ class AttentionMetadata:
T = TypeVar("T", bound=AttentionMetadata) T = TypeVar("T", bound=AttentionMetadata)
@dataclass
class CpCommonAttentionMetadata:
# sp related metadata
query_start_loc: torch.Tensor
query_start_loc_cpu: torch.Tensor
seq_lens: torch.Tensor
_seq_lens_cpu: torch.Tensor
num_actual_tokens: int
max_query_len: int
num_reqs: int
req_ids: list[str]
@dataclass @dataclass
class CommonAttentionMetadata: class CommonAttentionMetadata:
...@@ -306,6 +318,7 @@ class CommonAttentionMetadata: ...@@ -306,6 +318,7 @@ class CommonAttentionMetadata:
"""Number of requests""" """Number of requests"""
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading # TODO(lucas): rename to num_tokens since it may be padded and this is misleading
num_actual_tokens: int num_actual_tokens: int
"""Total number of tokens in batch""" """Total number of tokens in batch"""
max_query_len: int max_query_len: int
"""Longest query in batch""" """Longest query in batch"""
...@@ -315,6 +328,14 @@ class CommonAttentionMetadata: ...@@ -315,6 +328,14 @@ class CommonAttentionMetadata:
block_table_tensor: torch.Tensor block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
num_kv_actual_tokens: int
seq_indexes_list: list[int] | None = None
scatter_indexes_tensor: torch.Tensor | None = None
gather_indexes_tensor: torch.Tensor | None = None
cp_common_metadata: CpCommonAttentionMetadata | None = None
enable_mla_cp: bool = False
causal: bool = True causal: bool = True
# Needed by FastPrefillAttentionBuilder # Needed by FastPrefillAttentionBuilder
......
...@@ -138,6 +138,7 @@ class FlashMLASparseMetadata(AttentionMetadata): ...@@ -138,6 +138,7 @@ class FlashMLASparseMetadata(AttentionMetadata):
max_seq_len: int max_seq_len: int
num_actual_tokens: int # Number of tokens excluding padding. num_actual_tokens: int # Number of tokens excluding padding.
num_kv_actual_tokens: int
query_start_loc: torch.Tensor query_start_loc: torch.Tensor
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
...@@ -693,6 +694,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad ...@@ -693,6 +694,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
max_query_len=cm.max_query_len, max_query_len=cm.max_query_len,
max_seq_len=cm.max_seq_len, max_seq_len=cm.max_seq_len,
num_actual_tokens=cm.num_actual_tokens, num_actual_tokens=cm.num_actual_tokens,
num_kv_actual_tokens=cm.num_kv_actual_tokens,
query_start_loc=cm.query_start_loc, query_start_loc=cm.query_start_loc,
slot_mapping=cm.slot_mapping, slot_mapping=cm.slot_mapping,
block_table=cm.block_table_tensor, block_table=cm.block_table_tensor,
...@@ -1024,12 +1026,13 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): ...@@ -1024,12 +1026,13 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
return output.fill_(0) return output.fill_(0)
num_actual_toks = attn_metadata.num_actual_tokens num_actual_toks = attn_metadata.num_actual_tokens
num_kv_actual_toks = attn_metadata.num_kv_actual_tokens
# Inputs and outputs may be padded for CUDA graphs # Inputs and outputs may be padded for CUDA graphs
q = q[:num_actual_toks, ...] q = q[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...] k_c_normed = k_c_normed[:num_kv_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...] k_pe = k_pe[:num_kv_actual_toks, ...]
assert self.topk_indices_buffer is not None assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[:num_actual_toks] topk_indices = self.topk_indices_buffer[:num_actual_toks]
......
...@@ -105,6 +105,7 @@ class DeepseekV32IndexerMetadata: ...@@ -105,6 +105,7 @@ class DeepseekV32IndexerMetadata:
max_seq_len: int max_seq_len: int
num_actual_tokens: int # Number of tokens excluding padding. num_actual_tokens: int # Number of tokens excluding padding.
num_kv_actual_tokens: int
query_start_loc: torch.Tensor query_start_loc: torch.Tensor
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
# The dimension of the attention heads # The dimension of the attention heads
...@@ -437,6 +438,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): ...@@ -437,6 +438,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
max_query_len=common_attn_metadata.max_query_len, max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len, max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens, num_actual_tokens=common_attn_metadata.num_actual_tokens,
num_kv_actual_tokens=common_attn_metadata.num_kv_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc, query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping, slot_mapping=common_attn_metadata.slot_mapping,
head_dim=128, head_dim=128,
......
...@@ -802,6 +802,7 @@ class SpecDecodeBaseProposer: ...@@ -802,6 +802,7 @@ class SpecDecodeBaseProposer:
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu, _num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
num_reqs=common_attn_metadata.num_reqs, num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens, num_actual_tokens=total_num_tokens,
num_kv_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(), max_query_len=new_query_len_per_req.max().item(),
max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(), max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
block_table_tensor=common_attn_metadata.block_table_tensor, block_table_tensor=common_attn_metadata.block_table_tensor,
......
...@@ -233,6 +233,10 @@ class BlockTable: ...@@ -233,6 +233,10 @@ class BlockTable:
def get_device_tensor(self, num_reqs: int) -> torch.Tensor: def get_device_tensor(self, num_reqs: int) -> torch.Tensor:
"""Returns the device tensor of the block table.""" """Returns the device tensor of the block table."""
return self.block_table.gpu[:num_reqs] return self.block_table.gpu[:num_reqs]
def get_device_tensor_range(self, start_req: int, end_req: int) -> torch.Tensor:
"""Returns the device tensor of the block table."""
return self.block_table.gpu[start_req:end_req]
def get_cpu_tensor(self) -> torch.Tensor: def get_cpu_tensor(self) -> torch.Tensor:
"""Returns the CPU tensor of the block table.""" """Returns the CPU tensor of the block table."""
......
...@@ -42,8 +42,13 @@ from vllm.distributed.parallel_state import ( ...@@ -42,8 +42,13 @@ from vllm.distributed.parallel_state import (
get_tp_group, get_tp_group,
graph_capture, graph_capture,
is_global_first_rank, is_global_first_rank,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
prepare_communication_buffer_for_model, prepare_communication_buffer_for_model,
) )
from vllm.distributed import (
tensor_model_parallel_all_gather
)
from vllm.forward_context import ( from vllm.forward_context import (
BatchDescriptor, BatchDescriptor,
set_forward_context, set_forward_context,
...@@ -104,6 +109,7 @@ from vllm.v1.attention.backend import ( ...@@ -104,6 +109,7 @@ from vllm.v1.attention.backend import (
AttentionMetadataBuilder, AttentionMetadataBuilder,
AttentionType, AttentionType,
CommonAttentionMetadata, CommonAttentionMetadata,
CpCommonAttentionMetadata,
MultipleOf, MultipleOf,
) )
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
...@@ -371,10 +377,16 @@ class GPUModelRunner( ...@@ -371,10 +377,16 @@ class GPUModelRunner(
# Always set to false after the first forward pass # Always set to false after the first forward pass
self.calculate_kv_scales = self.cache_config.calculate_kv_scales self.calculate_kv_scales = self.cache_config.calculate_kv_scales
self.tp_size = self.parallel_config.tensor_parallel_size
self.dcp_world_size = self.parallel_config.decode_context_parallel_size self.dcp_world_size = self.parallel_config.decode_context_parallel_size
self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group
self.max_num_tokens = scheduler_config.max_num_batched_tokens self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs #self.max_num_reqs = scheduler_config.max_num_seqs
self.max_num_reqs = (
scheduler_config.max_num_seqs
if not envs.VLLM_MLA_CPLB
else scheduler_config.max_num_seqs * 2
)
# Broadcast PP output for external_launcher (torchrun) # Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks # to make sure we are synced across pp ranks
...@@ -1485,6 +1497,236 @@ class GPUModelRunner( ...@@ -1485,6 +1497,236 @@ class GPUModelRunner(
return encoder_seq_lens, encoder_seq_lens_cpu return encoder_seq_lens, encoder_seq_lens_cpu
def _distribute_tokens_to_cp_ranks(
self,
total_q_len: int,
q_lens_cpu: np.ndarray,
kv_lens_cpu: np.ndarray,
tp_rank: int,
tp_size: int,
req_ids: list[str],
):
tokens_per_rank = (total_q_len + tp_size - 1) // tp_size
start_token = tp_rank * tokens_per_rank
end_token = min((tp_rank + 1) * tokens_per_rank, total_q_len)
q_lens = []
seq_count = 0
seq_indexes = []
kv_lens = []
local_req_ids = []
local_scatter_indexes_tensor = None
gather_indexes_tensor = None
if envs.VLLM_MLA_CPLB:
rank_tokens = 0
rank_pad_tokens = 0
accu_q_start = 0
scatter_indexes: list[int] = []
num_requests = len(q_lens_cpu)
for i in range(num_requests):
req_q_len = q_lens_cpu[i]
req_pad_q_len = round_up(q_lens_cpu[i], 2 * tp_size)
kv_len = kv_lens_cpu[i]
chunk_q_len = req_pad_q_len // (2 * tp_size)
q_1_start = tp_rank * chunk_q_len
q_1_end = (tp_rank + 1) * chunk_q_len
q_2_start = req_pad_q_len - (tp_rank + 1) * chunk_q_len
q_2_end = req_pad_q_len - tp_rank * chunk_q_len
q_len_1 = (
chunk_q_len
if q_1_end <= req_q_len
else max(0, req_q_len - q_1_start)
)
q_len_2 = (
chunk_q_len
if q_2_end <= req_q_len
else max(0, req_q_len - q_2_start)
)
kv_len_1 = kv_len - req_q_len + min(req_q_len, q_1_end)
kv_len_2 = kv_len - req_q_len + min(req_q_len, q_2_end)
scatter_index1 = range(
accu_q_start + q_1_start, accu_q_start + q_1_start + q_len_1
)
scatter_index2 = range(
accu_q_start + q_2_start, accu_q_start + q_2_start + q_len_2
)
accu_q_start += req_q_len
if q_len_1 > 0:
q_lens.append(q_len_1)
kv_lens.append(kv_len_1)
seq_indexes.append(i)
local_req_ids.append(req_ids[i])
scatter_indexes.extend(scatter_index1)
seq_count += 1
rank_tokens += q_len_1
if q_len_2 > 0:
q_lens.append(q_len_2)
kv_lens.append(kv_len_2)
seq_indexes.append(i)
local_req_ids.append(req_ids[i])
scatter_indexes.extend(scatter_index2)
seq_count += 1
rank_tokens += q_len_2
rank_pad_tokens += chunk_q_len * 2
if len(scatter_indexes) < rank_pad_tokens:
scatter_indexes.extend([-1] * (rank_pad_tokens - len(scatter_indexes)))
local_scatter_indexes_tensor = torch.tensor(
scatter_indexes, dtype=torch.int64, device=self.device
)
global_scatter_indexes_tensor = tensor_model_parallel_all_gather(
local_scatter_indexes_tensor.contiguous(), dim=0
)
non_neg_mask = global_scatter_indexes_tensor != -1
non_neg_values = global_scatter_indexes_tensor[non_neg_mask]
non_neg_positions = torch.where(non_neg_mask)[0]
sorted_indices = torch.argsort(non_neg_values)
gather_indexes_tensor = non_neg_positions[sorted_indices]
if isinstance(rank_tokens, torch.Tensor):
rank_tokens = rank_tokens.item()
else:
current_seq = 0
current_pos = 0
rank_tokens = min(tokens_per_rank, end_token - start_token)
while start_token < end_token and current_seq < len(q_lens_cpu):
q_len = q_lens_cpu[current_seq]
q_start = current_pos
q_end = current_pos + q_len
kv_len = kv_lens_cpu[current_seq]
# Find overlap between this sequence and rank's token range
overlap_start = max(start_token, q_start)
overlap_end = min(end_token, q_end)
if overlap_start < overlap_end:
# This sequence contributes tokens to this rank
token_count = overlap_end - overlap_start
q_lens.append(token_count)
start_token = overlap_end
seq_count += 1
seq_indexes.append(current_seq)
local_req_ids.append(req_ids[current_seq])
if q_end <= end_token:
kv_lens.append(kv_len)
else:
kv_lens.append(kv_len - (q_end - end_token))
current_pos = q_end
current_seq += 1
return (
rank_tokens,
np.array(q_lens, dtype=np.int32),
seq_count,
np.array(kv_lens, dtype=np.int32),
np.array(local_req_ids, dtype=str),
local_scatter_indexes_tensor,
gather_indexes_tensor,
seq_indexes,
)
def _prepare_cp_metadata(
self,
num_reqs_padded,
max_query_len,
num_tokens,
block_table_gid_0,
slot_mapping_gid_0,
):
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
tp_rank = get_tensor_model_parallel_rank()
cp_common_metadata = CpCommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1].clone(),
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1].clone(),
seq_lens=self.seq_lens.gpu[:num_reqs_padded].clone(),
_seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded].clone(),
max_query_len=max_query_len,
num_reqs=num_reqs_padded,
req_ids=self.input_batch.req_ids,
num_actual_tokens=num_tokens,
)
query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs_padded + 1]
q_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
kv_lens_cpu = self.seq_lens.cpu[:num_reqs_padded]
total_q_len = num_tokens
total_kv_len = num_tokens
(
total_q_len,
q_lens_cpu,
seq_count,
kv_lens_cpu,
local_req_ids,
scatter_indexes_tensor,
gather_indexes_tensor,
seq_indexes_list,
) = self._distribute_tokens_to_cp_ranks(
total_q_len,
q_lens_cpu,
kv_lens_cpu,
tp_rank,
tp_size,
self.input_batch.req_ids,
)
num_reqs = seq_count
cu_num_tokens = np.cumsum(q_lens_cpu)
self.query_start_loc.np[0] = 0
self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens
self.query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1])
self.query_start_loc.copy_to_gpu()
q_acc_lens = self.query_start_loc.gpu[: num_reqs + 1]
q_acc_lens_cpu = self.query_start_loc.cpu[: num_reqs + 1]
max_q_len = max(q_acc_lens_cpu)
self.seq_lens.np[:num_reqs] = kv_lens_cpu
self.seq_lens.np[num_reqs:].fill(0)
self.seq_lens.copy_to_gpu()
kv_lens = self.seq_lens.gpu[:num_reqs]
kv_lens_cpu = self.seq_lens.cpu[:num_reqs]
max_kv_len = max(kv_lens_cpu)
num_computed_tokens_cpu = kv_lens_cpu - q_acc_lens_cpu[1:]
blk_table_tensor = block_table_gid_0[seq_indexes_list]
cm_base = CommonAttentionMetadata(
query_start_loc=q_acc_lens,
query_start_loc_cpu=q_acc_lens_cpu,
seq_lens=kv_lens,
_seq_lens_cpu=kv_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=total_q_len,
max_query_len=max_q_len,
max_seq_len=max_kv_len,
block_table_tensor=blk_table_tensor,
slot_mapping=slot_mapping_gid_0,
causal=True,
num_kv_actual_tokens=total_kv_len,
seq_indexes_list=seq_indexes_list,
cp_common_metadata=cp_common_metadata,
scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor,
)
return cm_base
def _prepare_inputs( def _prepare_inputs(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
...@@ -1718,13 +1960,20 @@ class GPUModelRunner( ...@@ -1718,13 +1960,20 @@ class GPUModelRunner(
num_scheduled_tokens: dict[str, int] | None = None, num_scheduled_tokens: dict[str, int] | None = None,
cascade_attn_prefix_lens: list[list[int]] | None = None, cascade_attn_prefix_lens: list[list[int]] | None = None,
slot_mappings: dict[int, torch.Tensor] | None = None, slot_mappings: dict[int, torch.Tensor] | None = None,
) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]: ) -> tuple[
PerLayerAttnMetadata,
CommonAttentionMetadata | None,
torch.Tensor | None,
torch.Tensor | None,
]:
""" """
:return: tuple[attn_metadata, spec_decode_common_attn_metadata] :return: tuple[attn_metadata, spec_decode_common_attn_metadata]
""" """
# Attention metadata is not needed for attention free models # Attention metadata is not needed for attention free models
if len(self.kv_cache_config.kv_cache_groups) == 0: if len(self.kv_cache_config.kv_cache_groups) == 0:
return {}, None return {}, None, None, None
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
num_tokens_padded = num_tokens_padded or num_tokens num_tokens_padded = num_tokens_padded or num_tokens
num_reqs_padded = num_reqs_padded or num_reqs num_reqs_padded = num_reqs_padded or num_reqs
...@@ -1772,25 +2021,40 @@ class GPUModelRunner( ...@@ -1772,25 +2021,40 @@ class GPUModelRunner(
assert slot_mappings is not None assert slot_mappings is not None
block_table_gid_0 = _get_block_table(0) block_table_gid_0 = _get_block_table(0)
slot_mapping_gid_0 = slot_mappings[0] slot_mapping_gid_0 = slot_mappings[0]
scatter_indexes_tensor = None
gather_indexes_tensor = None
if self.model_config.enable_return_routed_experts: if self.model_config.enable_return_routed_experts:
self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy() self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy()
cm_base = CommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], if not envs.VLLM_MLA_CP or num_tokens <= tp_size * tp_size:
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], cm_base = CommonAttentionMetadata(
seq_lens=self.seq_lens.gpu[:num_reqs_padded], query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
_seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded], query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1],
_num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[ seq_lens=self.seq_lens.gpu[:num_reqs_padded],
:num_reqs_padded _seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded],
], _num_computed_tokens_cpu=self.input_batch.num_computed_tokens_cpu_tensor[
num_reqs=num_reqs_padded, :num_reqs_padded
num_actual_tokens=num_tokens_padded, ],
max_query_len=max_query_len, num_reqs=num_reqs_padded,
max_seq_len=max_seq_len, num_actual_tokens=num_tokens_padded,
block_table_tensor=block_table_gid_0, num_kv_actual_tokens=num_tokens_padded,
slot_mapping=slot_mapping_gid_0, max_query_len=max_query_len,
causal=True, max_seq_len=max_seq_len,
) block_table_tensor=block_table_gid_0,
slot_mapping=slot_mapping_gid_0,
causal=True,
)
else:
cm_base = self._prepare_cp_metadata(
num_reqs_padded,
max_query_len,
num_tokens,
block_table_gid_0,
slot_mapping_gid_0,
)
scatter_indexes_tensor = cm_base.scatter_indexes_tensor
gather_indexes_tensor = cm_base.gather_indexes_tensor
if self.dcp_world_size > 1: if self.dcp_world_size > 1:
self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens( self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens(
...@@ -1901,6 +2165,9 @@ class GPUModelRunner( ...@@ -1901,6 +2165,9 @@ class GPUModelRunner(
cm.block_table_tensor = _get_block_table(kv_cache_gid) cm.block_table_tensor = _get_block_table(kv_cache_gid)
cm.slot_mapping = slot_mappings[kv_cache_gid] cm.slot_mapping = slot_mappings[kv_cache_gid]
if cm.seq_indexes_list is not None:
cm.block_table_tensor = cm.block_table_tensor[cm.seq_indexes_list]
if self.speculative_config and spec_decode_common_attn_metadata is None and hasattr(self, "drafter"): if self.speculative_config and spec_decode_common_attn_metadata is None and hasattr(self, "drafter"):
if isinstance(self.drafter, EagleProposer): if isinstance(self.drafter, EagleProposer):
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
...@@ -1936,8 +2203,10 @@ class GPUModelRunner( ...@@ -1936,8 +2203,10 @@ class GPUModelRunner(
for _metadata in attn_metadata.values(): for _metadata in attn_metadata.values():
_metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined] _metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined]
if spec_decode_common_attn_metadata is not None and ( if (
num_reqs != num_reqs_padded or num_tokens != num_tokens_padded (not envs.VLLM_MLA_CP)
and spec_decode_common_attn_metadata is not None
and (num_reqs != num_reqs_padded or num_tokens != num_tokens_padded)
): ):
# Currently the drafter still only uses piecewise cudagraphs (and modifies # Currently the drafter still only uses piecewise cudagraphs (and modifies
# the attention metadata in directly), and therefore does not want to use # the attention metadata in directly), and therefore does not want to use
...@@ -1946,7 +2215,12 @@ class GPUModelRunner( ...@@ -1946,7 +2215,12 @@ class GPUModelRunner(
spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs) spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs)
) )
return attn_metadata, spec_decode_common_attn_metadata return (
attn_metadata,
spec_decode_common_attn_metadata,
scatter_indexes_tensor,
gather_indexes_tensor
)
def _compute_cascade_attn_prefix_lens( def _compute_cascade_attn_prefix_lens(
self, self,
...@@ -2798,9 +3072,19 @@ class GPUModelRunner( ...@@ -2798,9 +3072,19 @@ class GPUModelRunner(
return model_runner_output return model_runner_output
def _pad_for_mla_cp(self, num_scheduled_tokens: int) -> int:
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if num_scheduled_tokens <= tp_size * tp_size:
return num_scheduled_tokens * tp_size
else:
return round_up(num_scheduled_tokens, tp_size)
def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int: def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int:
# Pad tokens to multiple of tensor_parallel_size when # Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP # enabled collective fusion for SP
if envs.VLLM_MLA_CP:
return self._pad_for_mla_cp(num_scheduled_tokens)
tp_size = self.vllm_config.parallel_config.tensor_parallel_size tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.compilation_config.pass_config.enable_sp and tp_size > 1: if self.compilation_config.pass_config.enable_sp and tp_size > 1:
return round_up(num_scheduled_tokens, tp_size) return round_up(num_scheduled_tokens, tp_size)
...@@ -3497,6 +3781,8 @@ class GPUModelRunner( ...@@ -3497,6 +3781,8 @@ class GPUModelRunner(
) )
num_tokens_padded = batch_desc.num_tokens num_tokens_padded = batch_desc.num_tokens
if envs.VLLM_MLA_CP:
num_tokens_padded = self._pad_for_mla_cp(num_tokens_unpadded)
num_reqs_padded = ( num_reqs_padded = (
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
) )
...@@ -3553,8 +3839,12 @@ class GPUModelRunner( ...@@ -3553,8 +3839,12 @@ class GPUModelRunner(
ubatch_slices=ubatch_slices_padded, ubatch_slices=ubatch_slices_padded,
) )
attn_metadata, spec_decode_common_attn_metadata = ( (
self._build_attention_metadata( attn_metadata,
spec_decode_common_attn_metadata,
scatter_indexes_tensor,
gather_indexes_tensor,
) = self._build_attention_metadata(
num_tokens=num_tokens_unpadded, num_tokens=num_tokens_unpadded,
num_tokens_padded=num_tokens_padded if pad_attn else None, num_tokens_padded=num_tokens_padded if pad_attn else None,
num_reqs=num_reqs, num_reqs=num_reqs,
...@@ -3567,7 +3857,6 @@ class GPUModelRunner( ...@@ -3567,7 +3857,6 @@ class GPUModelRunner(
cascade_attn_prefix_lens=cascade_attn_prefix_lens, cascade_attn_prefix_lens=cascade_attn_prefix_lens,
slot_mappings=slot_mappings_by_group, slot_mappings=slot_mappings_by_group,
) )
)
( (
input_ids, input_ids,
...@@ -3608,6 +3897,8 @@ class GPUModelRunner( ...@@ -3608,6 +3897,8 @@ class GPUModelRunner(
ubatch_slices=ubatch_slices_padded, ubatch_slices=ubatch_slices_padded,
slot_mapping=slot_mappings, slot_mapping=slot_mappings,
skip_compiled=has_encoder_input, skip_compiled=has_encoder_input,
scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor,
), ),
record_function_or_nullcontext("gpu_model_runner: forward"), record_function_or_nullcontext("gpu_model_runner: forward"),
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output, self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
...@@ -4094,7 +4385,16 @@ class GPUModelRunner( ...@@ -4094,7 +4385,16 @@ class GPUModelRunner(
spec_decode_metadata, spec_decode_metadata,
valid_sampled_tokens_count, valid_sampled_tokens_count,
) )
total_num_tokens = common_attn_metadata.num_actual_tokens #total_num_tokens = common_attn_metadata.num_actual_tokens
if (
envs.VLLM_MLA_CP
and common_attn_metadata.cp_common_metadata is not None
):
total_num_tokens = (
common_attn_metadata.cp_common_metadata.num_actual_tokens
)
else:
total_num_tokens = common_attn_metadata.num_actual_tokens
# When padding the batch, token_indices is just a range # When padding the batch, token_indices is just a range
target_token_ids = self.input_ids.gpu[:total_num_tokens] target_token_ids = self.input_ids.gpu[:total_num_tokens]
target_positions = self._get_positions(total_num_tokens) target_positions = self._get_positions(total_num_tokens)
...@@ -4618,6 +4918,9 @@ class GPUModelRunner( ...@@ -4618,6 +4918,9 @@ class GPUModelRunner(
or cudagraph_runtime_mode.valid_runtime_modes() or cudagraph_runtime_mode.valid_runtime_modes()
) )
if envs.VLLM_MLA_CP:
num_tokens = max(self.tp_size, num_tokens)
# If cudagraph_mode.decode_mode() == FULL and # If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.separate_routine(). This means that we are using # cudagraph_mode.separate_routine(). This means that we are using
# different graphs and/or modes for mixed prefill-decode batches vs. # different graphs and/or modes for mixed prefill-decode batches vs.
...@@ -4748,7 +5051,7 @@ class GPUModelRunner( ...@@ -4748,7 +5051,7 @@ class GPUModelRunner(
self.query_start_loc.copy_to_gpu() self.query_start_loc.copy_to_gpu()
pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL
attn_metadata, _ = self._build_attention_metadata( attn_metadata, _, _, _ = self._build_attention_metadata(
num_tokens=num_tokens_unpadded, num_tokens=num_tokens_unpadded,
num_reqs=num_reqs_padded, num_reqs=num_reqs_padded,
max_query_len=max_query_len, max_query_len=max_query_len,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment