Commit 26645e58 authored by 王敏's avatar 王敏
Browse files

[feat]基于mla sp实现pcp

parent d1fd831b
......@@ -1500,6 +1500,20 @@ class EngineArgs:
data_parallel_external_lb = (
self.data_parallel_external_lb or self.data_parallel_rank is not None
)
if (
envs.VLLM_MLA_CP
and self.max_num_batched_tokens is not None
and self.max_num_batched_tokens < self.tensor_parallel_size**3
):
raise ValueError(
"max_num_batched_tokens should be larger than "
"tensor_parallel_size ** 3 when enabled VLLM_MLA_CP"
)
logger.info("[MLACP] VLLM_MLA_CP is %s", envs.VLLM_MLA_CP)
logger.info("[MLACP] VLLM_MLA_CPLB is %s", envs.VLLM_MLA_CPLB)
# Local DP rank = 1, use pure-external LB.
if data_parallel_external_lb:
assert self.data_parallel_rank is not None, (
......
......@@ -240,6 +240,9 @@ class ForwardContext:
additional_kwargs: dict[str, Any] = field(default_factory=dict)
scatter_indexes_tensor: torch.Tensor | None = None
gather_indexes_tensor: torch.Tensor | None = None
def __post_init__(self):
assert self.cudagraph_runtime_mode.valid_runtime_modes(), (
f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}"
......@@ -273,6 +276,8 @@ def create_forward_context(
slot_mapping: dict[str, torch.Tensor] | None = None,
additional_kwargs: dict[str, Any] | None = None,
skip_compiled: bool = False,
scatter_indexes_tensor: torch.Tensor | None = None,
gather_indexes_tensor: torch.Tensor | None = None,
):
if vllm_config.compilation_config.fast_moe_cold_start:
if vllm_config.speculative_config is None:
......@@ -298,6 +303,8 @@ def create_forward_context(
batch_descriptor=batch_descriptor,
ubatch_slices=ubatch_slices,
skip_compiled=skip_compiled,
scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor,
additional_kwargs=additional_kwargs or {},
)
......@@ -329,6 +336,8 @@ def set_forward_context(
ubatch_slices: UBatchSlices | None = None,
slot_mapping: dict[str, torch.Tensor] | list[dict[str, torch.Tensor]] | None = None,
skip_compiled: bool = False,
scatter_indexes_tensor: torch.Tensor | None = None,
gather_indexes_tensor: torch.Tensor | None = None,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
......@@ -389,6 +398,8 @@ def set_forward_context(
slot_mapping,
additional_kwargs,
skip_compiled,
scatter_indexes_tensor,
gather_indexes_tensor,
)
try:
......
......@@ -9,6 +9,9 @@ from vllm.config import CacheConfig
import vllm.envs as envs
from vllm.model_executor.custom_op import PluggableLayer
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.distributed import (
tensor_model_parallel_all_gather,
)
@dataclass
......@@ -183,8 +186,19 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
if llama_4_scaling is not None:
q *= llama_4_scaling
enable_mla_cp = envs.VLLM_MLA_CP # and not get_forward_context().draft_model
# if not use_fused_rms_rope_concat:
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
if enable_mla_cp:
kv_c_normed = tensor_model_parallel_all_gather(
kv_c_normed.contiguous(), 0
)
k_pe = tensor_model_parallel_all_gather(
k_pe.contiguous(), 0
)
attn_out = self.mla_attn(
q,
kv_c_normed,
......@@ -220,6 +234,15 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
"VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT requires rotary_emb to "
"expose 'cos_sin_cache'."
)
if enable_mla_cp:
kv_c = tensor_model_parallel_all_gather(
kv_c.contiguous(), 0
)
k_pe = tensor_model_parallel_all_gather(
k_pe.contiguous(), 0
)
attn_out = self.mla_attn(
q[..., self.qk_nope_head_dim:],
kv_c,
......
......@@ -71,7 +71,7 @@ def sparse_attn_indexer(
)
attn_metadata = attn_metadata[k_cache_prefix]
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
slot_mapping = attn_metadata.slot_mapping
slot_mapping = attn_metadata.slot_mapping[:attn_metadata.num_kv_actual_tokens]
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
......
......@@ -46,6 +46,7 @@ from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
)
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce, tensor_model_parallel_reduce_scatter
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
......@@ -211,10 +212,82 @@ class DeepseekV2MLP(nn.Module):
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
#reduce_results=reduce_results,
reduce_results=False,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.down_proj",
)
self.tp_size = get_tensor_model_parallel_world_size()
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self,
x,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
):
enable_mla_cp = envs.VLLM_MLA_CP# and not get_forward_context().draft_model
if enable_mla_cp:
x = tensor_model_parallel_all_gather(
x.contiguous(), 0
)
if envs.USE_FUSED_RMS_QUANT:
gate_up, _ = self.gate_up_proj(x, iqis=iqis)
if envs.USE_FUSED_SILU_MUL_QUANT:
from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
xq, xs = lm_fuse_silu_mul_quant(gate_up)
x, _ = self.down_proj(gate_up, iqis=(xq, xs))
else:
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
else:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
if enable_mla_cp:
x = tensor_model_parallel_reduce_scatter(x.contiguous(), dim=0)
elif self.tp_size > 1:
x = tensor_model_parallel_all_reduce(x)
return x
class DeepseekV2SharedMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: QuantizationConfig | None = None,
reduce_results: bool = True,
is_sequence_parallel=False,
prefix: str = "",
) -> None:
super().__init__()
# If is_sequence_parallel, the input and output tensors are sharded
# across the ranks within the tp_group. In this case the weights are
# replicated and no collective ops are needed.
# Otherwise we use standard TP with an allreduce at the end.
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.down_proj"
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
......@@ -311,7 +384,7 @@ class DeepseekV2MoE(nn.Module):
else:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV2MLP(
self.shared_experts = DeepseekV2SharedMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
......@@ -357,6 +430,11 @@ class DeepseekV2MoE(nn.Module):
def forward(self, hidden_states: torch.Tensor,
*, iqis: tuple[torch.Tensor, torch.Tensor] | None = None
) -> torch.Tensor:
enable_mla_cp = envs.VLLM_MLA_CP #and not get_forward_context().draft_model
if enable_mla_cp:
hidden_states = tensor_model_parallel_all_gather(
hidden_states.contiguous(), 0
)
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
......@@ -428,7 +506,12 @@ class DeepseekV2MoE(nn.Module):
assert shared_output is not None
final_hidden_states += shared_output
if self.is_sequence_parallel:
if enable_mla_cp:
final_hidden_states = tensor_model_parallel_reduce_scatter(
final_hidden_states.contiguous(), 0
)
return final_hidden_states
elif self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0
)
......@@ -756,6 +839,12 @@ class Indexer(nn.Module):
# `k_pe` is [num_tokens, 1, rope_dim] (MQA).
k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
enable_mla_cp = envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if enable_mla_cp:
k = tensor_model_parallel_all_gather(
k.contiguous(), 0
)
# we only quant q here since k quant is fused with cache insertion
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
q = q.view(-1, self.head_dim)
......@@ -819,7 +908,8 @@ class DeepseekV2MLAAttention(nn.Module):
self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size
#self.num_local_heads = num_heads // tp_size
self.num_local_heads = num_heads // tp_size if not envs.VLLM_MLA_CP else self.num_heads
self.scaling = self.qk_head_dim**-0.5
self.max_position_embeddings = max_position_embeddings
......@@ -853,6 +943,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj",
disable_tp=envs.VLLM_MLA_CP,
)
else:
self.q_proj = ColumnParallelLinear(
......@@ -861,6 +952,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
disable_tp=envs.VLLM_MLA_CP,
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
......@@ -869,6 +961,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj",
disable_tp=envs.VLLM_MLA_CP,
)
self.o_proj = RowParallelLinear(
self.num_heads * self.v_head_dim,
......@@ -876,6 +969,7 @@ class DeepseekV2MLAAttention(nn.Module):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
disable_tp=envs.VLLM_MLA_CP,
)
if config.rope_parameters["rope_type"] != "default":
......@@ -1217,6 +1311,9 @@ class DeepseekV2Model(nn.Module):
self.config = config
self.device = current_platform.device_type
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.vocab_size = config.vocab_size
#添加判断,默认开启DSA
force_disable_dsa = os.environ.get("VLLM_DISABLE_DSA", "0") == "1"
......@@ -1279,6 +1376,19 @@ class DeepseekV2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
enable_mla_cp = envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if enable_mla_cp:
hidden_states_per_rank = torch.chunk(hidden_states, chunks=self.tp_size, dim=0)
hidden_states = hidden_states_per_rank[self.tp_rank].contiguous()
if residual is not None:
residual_per_rank = torch.chunk(residual, chunks=self.tp_size, dim=0)
residual = residual_per_rank[self.tp_rank].contiguous()
if positions is not None:
positions_per_rank = torch.chunk(positions, chunks=self.tp_size, dim=0)
positions = positions_per_rank[self.tp_rank].contiguous()
# Compute llama 4 scaling once per forward pass if enabled
llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None)
llama_4_scaling: torch.Tensor | None
......@@ -1304,6 +1414,10 @@ class DeepseekV2Model(nn.Module):
)
hidden_states, _ = self.norm(hidden_states, residual)
if enable_mla_cp:
hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
return hidden_states
......
......@@ -285,6 +285,18 @@ class AttentionMetadata:
T = TypeVar("T", bound=AttentionMetadata)
@dataclass
class CpCommonAttentionMetadata:
# sp related metadata
query_start_loc: torch.Tensor
query_start_loc_cpu: torch.Tensor
seq_lens: torch.Tensor
_seq_lens_cpu: torch.Tensor
num_actual_tokens: int
max_query_len: int
num_reqs: int
req_ids: list[str]
@dataclass
class CommonAttentionMetadata:
......@@ -306,6 +318,7 @@ class CommonAttentionMetadata:
"""Number of requests"""
# TODO(lucas): rename to num_tokens since it may be padded and this is misleading
num_actual_tokens: int
"""Total number of tokens in batch"""
max_query_len: int
"""Longest query in batch"""
......@@ -315,6 +328,14 @@ class CommonAttentionMetadata:
block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor
num_kv_actual_tokens: int
seq_indexes_list: list[int] | None = None
scatter_indexes_tensor: torch.Tensor | None = None
gather_indexes_tensor: torch.Tensor | None = None
cp_common_metadata: CpCommonAttentionMetadata | None = None
enable_mla_cp: bool = False
causal: bool = True
# Needed by FastPrefillAttentionBuilder
......
......@@ -138,6 +138,7 @@ class FlashMLASparseMetadata(AttentionMetadata):
max_seq_len: int
num_actual_tokens: int # Number of tokens excluding padding.
num_kv_actual_tokens: int
query_start_loc: torch.Tensor
slot_mapping: torch.Tensor
......@@ -693,6 +694,7 @@ class FlashMLASparseMetadataBuilder(AttentionMetadataBuilder[FlashMLASparseMetad
max_query_len=cm.max_query_len,
max_seq_len=cm.max_seq_len,
num_actual_tokens=cm.num_actual_tokens,
num_kv_actual_tokens=cm.num_kv_actual_tokens,
query_start_loc=cm.query_start_loc,
slot_mapping=cm.slot_mapping,
block_table=cm.block_table_tensor,
......@@ -1024,12 +1026,13 @@ class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]):
return output.fill_(0)
num_actual_toks = attn_metadata.num_actual_tokens
num_kv_actual_toks = attn_metadata.num_kv_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
q = q[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_kv_actual_toks, ...]
k_pe = k_pe[:num_kv_actual_toks, ...]
assert self.topk_indices_buffer is not None
topk_indices = self.topk_indices_buffer[:num_actual_toks]
......
......@@ -105,6 +105,7 @@ class DeepseekV32IndexerMetadata:
max_seq_len: int
num_actual_tokens: int # Number of tokens excluding padding.
num_kv_actual_tokens: int
query_start_loc: torch.Tensor
slot_mapping: torch.Tensor
# The dimension of the attention heads
......@@ -437,6 +438,7 @@ class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder):
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
num_kv_actual_tokens=common_attn_metadata.num_kv_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping,
head_dim=128,
......
......@@ -802,6 +802,7 @@ class SpecDecodeBaseProposer:
_num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
num_kv_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(),
max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
block_table_tensor=common_attn_metadata.block_table_tensor,
......
......@@ -234,6 +234,10 @@ class BlockTable:
"""Returns the device tensor of the block table."""
return self.block_table.gpu[:num_reqs]
def get_device_tensor_range(self, start_req: int, end_req: int) -> torch.Tensor:
"""Returns the device tensor of the block table."""
return self.block_table.gpu[start_req:end_req]
def get_cpu_tensor(self) -> torch.Tensor:
"""Returns the CPU tensor of the block table."""
return self.block_table.cpu
......
......@@ -42,8 +42,13 @@ from vllm.distributed.parallel_state import (
get_tp_group,
graph_capture,
is_global_first_rank,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
prepare_communication_buffer_for_model,
)
from vllm.distributed import (
tensor_model_parallel_all_gather
)
from vllm.forward_context import (
BatchDescriptor,
set_forward_context,
......@@ -104,6 +109,7 @@ from vllm.v1.attention.backend import (
AttentionMetadataBuilder,
AttentionType,
CommonAttentionMetadata,
CpCommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
......@@ -371,10 +377,16 @@ class GPUModelRunner(
# Always set to false after the first forward pass
self.calculate_kv_scales = self.cache_config.calculate_kv_scales
self.tp_size = self.parallel_config.tensor_parallel_size
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
self.dcp_rank = 0 if self.dcp_world_size <= 1 else get_dcp_group().rank_in_group
self.max_num_tokens = scheduler_config.max_num_batched_tokens
self.max_num_reqs = scheduler_config.max_num_seqs
#self.max_num_reqs = scheduler_config.max_num_seqs
self.max_num_reqs = (
scheduler_config.max_num_seqs
if not envs.VLLM_MLA_CPLB
else scheduler_config.max_num_seqs * 2
)
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
......@@ -1485,6 +1497,236 @@ class GPUModelRunner(
return encoder_seq_lens, encoder_seq_lens_cpu
def _distribute_tokens_to_cp_ranks(
self,
total_q_len: int,
q_lens_cpu: np.ndarray,
kv_lens_cpu: np.ndarray,
tp_rank: int,
tp_size: int,
req_ids: list[str],
):
tokens_per_rank = (total_q_len + tp_size - 1) // tp_size
start_token = tp_rank * tokens_per_rank
end_token = min((tp_rank + 1) * tokens_per_rank, total_q_len)
q_lens = []
seq_count = 0
seq_indexes = []
kv_lens = []
local_req_ids = []
local_scatter_indexes_tensor = None
gather_indexes_tensor = None
if envs.VLLM_MLA_CPLB:
rank_tokens = 0
rank_pad_tokens = 0
accu_q_start = 0
scatter_indexes: list[int] = []
num_requests = len(q_lens_cpu)
for i in range(num_requests):
req_q_len = q_lens_cpu[i]
req_pad_q_len = round_up(q_lens_cpu[i], 2 * tp_size)
kv_len = kv_lens_cpu[i]
chunk_q_len = req_pad_q_len // (2 * tp_size)
q_1_start = tp_rank * chunk_q_len
q_1_end = (tp_rank + 1) * chunk_q_len
q_2_start = req_pad_q_len - (tp_rank + 1) * chunk_q_len
q_2_end = req_pad_q_len - tp_rank * chunk_q_len
q_len_1 = (
chunk_q_len
if q_1_end <= req_q_len
else max(0, req_q_len - q_1_start)
)
q_len_2 = (
chunk_q_len
if q_2_end <= req_q_len
else max(0, req_q_len - q_2_start)
)
kv_len_1 = kv_len - req_q_len + min(req_q_len, q_1_end)
kv_len_2 = kv_len - req_q_len + min(req_q_len, q_2_end)
scatter_index1 = range(
accu_q_start + q_1_start, accu_q_start + q_1_start + q_len_1
)
scatter_index2 = range(
accu_q_start + q_2_start, accu_q_start + q_2_start + q_len_2
)
accu_q_start += req_q_len
if q_len_1 > 0:
q_lens.append(q_len_1)
kv_lens.append(kv_len_1)
seq_indexes.append(i)
local_req_ids.append(req_ids[i])
scatter_indexes.extend(scatter_index1)
seq_count += 1
rank_tokens += q_len_1
if q_len_2 > 0:
q_lens.append(q_len_2)
kv_lens.append(kv_len_2)
seq_indexes.append(i)
local_req_ids.append(req_ids[i])
scatter_indexes.extend(scatter_index2)
seq_count += 1
rank_tokens += q_len_2
rank_pad_tokens += chunk_q_len * 2
if len(scatter_indexes) < rank_pad_tokens:
scatter_indexes.extend([-1] * (rank_pad_tokens - len(scatter_indexes)))
local_scatter_indexes_tensor = torch.tensor(
scatter_indexes, dtype=torch.int64, device=self.device
)
global_scatter_indexes_tensor = tensor_model_parallel_all_gather(
local_scatter_indexes_tensor.contiguous(), dim=0
)
non_neg_mask = global_scatter_indexes_tensor != -1
non_neg_values = global_scatter_indexes_tensor[non_neg_mask]
non_neg_positions = torch.where(non_neg_mask)[0]
sorted_indices = torch.argsort(non_neg_values)
gather_indexes_tensor = non_neg_positions[sorted_indices]
if isinstance(rank_tokens, torch.Tensor):
rank_tokens = rank_tokens.item()
else:
current_seq = 0
current_pos = 0
rank_tokens = min(tokens_per_rank, end_token - start_token)
while start_token < end_token and current_seq < len(q_lens_cpu):
q_len = q_lens_cpu[current_seq]
q_start = current_pos
q_end = current_pos + q_len
kv_len = kv_lens_cpu[current_seq]
# Find overlap between this sequence and rank's token range
overlap_start = max(start_token, q_start)
overlap_end = min(end_token, q_end)
if overlap_start < overlap_end:
# This sequence contributes tokens to this rank
token_count = overlap_end - overlap_start
q_lens.append(token_count)
start_token = overlap_end
seq_count += 1
seq_indexes.append(current_seq)
local_req_ids.append(req_ids[current_seq])
if q_end <= end_token:
kv_lens.append(kv_len)
else:
kv_lens.append(kv_len - (q_end - end_token))
current_pos = q_end
current_seq += 1
return (
rank_tokens,
np.array(q_lens, dtype=np.int32),
seq_count,
np.array(kv_lens, dtype=np.int32),
np.array(local_req_ids, dtype=str),
local_scatter_indexes_tensor,
gather_indexes_tensor,
seq_indexes,
)
def _prepare_cp_metadata(
self,
num_reqs_padded,
max_query_len,
num_tokens,
block_table_gid_0,
slot_mapping_gid_0,
):
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
tp_rank = get_tensor_model_parallel_rank()
cp_common_metadata = CpCommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1].clone(),
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1].clone(),
seq_lens=self.seq_lens.gpu[:num_reqs_padded].clone(),
_seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded].clone(),
max_query_len=max_query_len,
num_reqs=num_reqs_padded,
req_ids=self.input_batch.req_ids,
num_actual_tokens=num_tokens,
)
query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs_padded + 1]
q_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
kv_lens_cpu = self.seq_lens.cpu[:num_reqs_padded]
total_q_len = num_tokens
total_kv_len = num_tokens
(
total_q_len,
q_lens_cpu,
seq_count,
kv_lens_cpu,
local_req_ids,
scatter_indexes_tensor,
gather_indexes_tensor,
seq_indexes_list,
) = self._distribute_tokens_to_cp_ranks(
total_q_len,
q_lens_cpu,
kv_lens_cpu,
tp_rank,
tp_size,
self.input_batch.req_ids,
)
num_reqs = seq_count
cu_num_tokens = np.cumsum(q_lens_cpu)
self.query_start_loc.np[0] = 0
self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens
self.query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1])
self.query_start_loc.copy_to_gpu()
q_acc_lens = self.query_start_loc.gpu[: num_reqs + 1]
q_acc_lens_cpu = self.query_start_loc.cpu[: num_reqs + 1]
max_q_len = max(q_acc_lens_cpu)
self.seq_lens.np[:num_reqs] = kv_lens_cpu
self.seq_lens.np[num_reqs:].fill(0)
self.seq_lens.copy_to_gpu()
kv_lens = self.seq_lens.gpu[:num_reqs]
kv_lens_cpu = self.seq_lens.cpu[:num_reqs]
max_kv_len = max(kv_lens_cpu)
num_computed_tokens_cpu = kv_lens_cpu - q_acc_lens_cpu[1:]
blk_table_tensor = block_table_gid_0[seq_indexes_list]
cm_base = CommonAttentionMetadata(
query_start_loc=q_acc_lens,
query_start_loc_cpu=q_acc_lens_cpu,
seq_lens=kv_lens,
_seq_lens_cpu=kv_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=total_q_len,
max_query_len=max_q_len,
max_seq_len=max_kv_len,
block_table_tensor=blk_table_tensor,
slot_mapping=slot_mapping_gid_0,
causal=True,
num_kv_actual_tokens=total_kv_len,
seq_indexes_list=seq_indexes_list,
cp_common_metadata=cp_common_metadata,
scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor,
)
return cm_base
def _prepare_inputs(
self,
scheduler_output: "SchedulerOutput",
......@@ -1718,13 +1960,20 @@ class GPUModelRunner(
num_scheduled_tokens: dict[str, int] | None = None,
cascade_attn_prefix_lens: list[list[int]] | None = None,
slot_mappings: dict[int, torch.Tensor] | None = None,
) -> tuple[PerLayerAttnMetadata, CommonAttentionMetadata | None]:
) -> tuple[
PerLayerAttnMetadata,
CommonAttentionMetadata | None,
torch.Tensor | None,
torch.Tensor | None,
]:
"""
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
"""
# Attention metadata is not needed for attention free models
if len(self.kv_cache_config.kv_cache_groups) == 0:
return {}, None
return {}, None, None, None
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
num_tokens_padded = num_tokens_padded or num_tokens
num_reqs_padded = num_reqs_padded or num_reqs
......@@ -1772,9 +2021,13 @@ class GPUModelRunner(
assert slot_mappings is not None
block_table_gid_0 = _get_block_table(0)
slot_mapping_gid_0 = slot_mappings[0]
scatter_indexes_tensor = None
gather_indexes_tensor = None
if self.model_config.enable_return_routed_experts:
self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy()
if not envs.VLLM_MLA_CP or num_tokens <= tp_size * tp_size:
cm_base = CommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1],
......@@ -1785,12 +2038,23 @@ class GPUModelRunner(
],
num_reqs=num_reqs_padded,
num_actual_tokens=num_tokens_padded,
num_kv_actual_tokens=num_tokens_padded,
max_query_len=max_query_len,
max_seq_len=max_seq_len,
block_table_tensor=block_table_gid_0,
slot_mapping=slot_mapping_gid_0,
causal=True,
)
else:
cm_base = self._prepare_cp_metadata(
num_reqs_padded,
max_query_len,
num_tokens,
block_table_gid_0,
slot_mapping_gid_0,
)
scatter_indexes_tensor = cm_base.scatter_indexes_tensor
gather_indexes_tensor = cm_base.gather_indexes_tensor
if self.dcp_world_size > 1:
self.dcp_local_seq_lens.cpu[:num_reqs] = get_dcp_local_seq_lens(
......@@ -1901,6 +2165,9 @@ class GPUModelRunner(
cm.block_table_tensor = _get_block_table(kv_cache_gid)
cm.slot_mapping = slot_mappings[kv_cache_gid]
if cm.seq_indexes_list is not None:
cm.block_table_tensor = cm.block_table_tensor[cm.seq_indexes_list]
if self.speculative_config and spec_decode_common_attn_metadata is None and hasattr(self, "drafter"):
if isinstance(self.drafter, EagleProposer):
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
......@@ -1936,8 +2203,10 @@ class GPUModelRunner(
for _metadata in attn_metadata.values():
_metadata.mm_prefix_range = req_doc_ranges # type: ignore[attr-defined]
if spec_decode_common_attn_metadata is not None and (
num_reqs != num_reqs_padded or num_tokens != num_tokens_padded
if (
(not envs.VLLM_MLA_CP)
and spec_decode_common_attn_metadata is not None
and (num_reqs != num_reqs_padded or num_tokens != num_tokens_padded)
):
# Currently the drafter still only uses piecewise cudagraphs (and modifies
# the attention metadata in directly), and therefore does not want to use
......@@ -1946,7 +2215,12 @@ class GPUModelRunner(
spec_decode_common_attn_metadata.unpadded(num_tokens, num_reqs)
)
return attn_metadata, spec_decode_common_attn_metadata
return (
attn_metadata,
spec_decode_common_attn_metadata,
scatter_indexes_tensor,
gather_indexes_tensor
)
def _compute_cascade_attn_prefix_lens(
self,
......@@ -2798,9 +3072,19 @@ class GPUModelRunner(
return model_runner_output
def _pad_for_mla_cp(self, num_scheduled_tokens: int) -> int:
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if num_scheduled_tokens <= tp_size * tp_size:
return num_scheduled_tokens * tp_size
else:
return round_up(num_scheduled_tokens, tp_size)
def _pad_for_sequence_parallelism(self, num_scheduled_tokens: int) -> int:
# Pad tokens to multiple of tensor_parallel_size when
# enabled collective fusion for SP
if envs.VLLM_MLA_CP:
return self._pad_for_mla_cp(num_scheduled_tokens)
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
if self.compilation_config.pass_config.enable_sp and tp_size > 1:
return round_up(num_scheduled_tokens, tp_size)
......@@ -3497,6 +3781,8 @@ class GPUModelRunner(
)
num_tokens_padded = batch_desc.num_tokens
if envs.VLLM_MLA_CP:
num_tokens_padded = self._pad_for_mla_cp(num_tokens_unpadded)
num_reqs_padded = (
batch_desc.num_reqs if batch_desc.num_reqs is not None else num_reqs
)
......@@ -3553,8 +3839,12 @@ class GPUModelRunner(
ubatch_slices=ubatch_slices_padded,
)
attn_metadata, spec_decode_common_attn_metadata = (
self._build_attention_metadata(
(
attn_metadata,
spec_decode_common_attn_metadata,
scatter_indexes_tensor,
gather_indexes_tensor,
) = self._build_attention_metadata(
num_tokens=num_tokens_unpadded,
num_tokens_padded=num_tokens_padded if pad_attn else None,
num_reqs=num_reqs,
......@@ -3567,7 +3857,6 @@ class GPUModelRunner(
cascade_attn_prefix_lens=cascade_attn_prefix_lens,
slot_mappings=slot_mappings_by_group,
)
)
(
input_ids,
......@@ -3608,6 +3897,8 @@ class GPUModelRunner(
ubatch_slices=ubatch_slices_padded,
slot_mapping=slot_mappings,
skip_compiled=has_encoder_input,
scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor,
),
record_function_or_nullcontext("gpu_model_runner: forward"),
self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output,
......@@ -4094,6 +4385,15 @@ class GPUModelRunner(
spec_decode_metadata,
valid_sampled_tokens_count,
)
#total_num_tokens = common_attn_metadata.num_actual_tokens
if (
envs.VLLM_MLA_CP
and common_attn_metadata.cp_common_metadata is not None
):
total_num_tokens = (
common_attn_metadata.cp_common_metadata.num_actual_tokens
)
else:
total_num_tokens = common_attn_metadata.num_actual_tokens
# When padding the batch, token_indices is just a range
target_token_ids = self.input_ids.gpu[:total_num_tokens]
......@@ -4618,6 +4918,9 @@ class GPUModelRunner(
or cudagraph_runtime_mode.valid_runtime_modes()
)
if envs.VLLM_MLA_CP:
num_tokens = max(self.tp_size, num_tokens)
# If cudagraph_mode.decode_mode() == FULL and
# cudagraph_mode.separate_routine(). This means that we are using
# different graphs and/or modes for mixed prefill-decode batches vs.
......@@ -4748,7 +5051,7 @@ class GPUModelRunner(
self.query_start_loc.copy_to_gpu()
pad_attn = cudagraph_runtime_mode == CUDAGraphMode.FULL
attn_metadata, _ = self._build_attention_metadata(
attn_metadata, _, _, _ = self._build_attention_metadata(
num_tokens=num_tokens_unpadded,
num_reqs=num_reqs_padded,
max_query_len=max_query_len,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment