Commit 0beafe40 authored by 王敏's avatar 王敏
Browse files

[Feat]支持pcp+mtp

parent 09f318c1
......@@ -11,6 +11,7 @@ import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.forward_context import get_forward_context
from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
......@@ -36,6 +37,9 @@ from .deepseek_v2 import (
DeepseekV2MoE,
get_spec_layer_idx_from_weight_name,
)
from vllm.distributed import (tensor_model_parallel_all_gather,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from .utils import maybe_prefix
from .interfaces import SupportsPP
from vllm import _custom_ops as ops
......@@ -177,6 +181,9 @@ class DeepSeekMultiTokenPredictor(nn.Module):
)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
......@@ -191,7 +198,19 @@ class DeepSeekMultiTokenPredictor(nn.Module):
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = spec_step_idx % self.num_mtp_layers
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
enable_mla_cp = get_forward_context().enable_mla_cp#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if enable_mla_cp:
inputs_embeds_per_rank = torch.chunk(inputs_embeds, chunks=self.tp_size, dim=0)
inputs_embeds = inputs_embeds_per_rank[self.tp_rank].contiguous()
previous_hidden_states_per_rank = torch.chunk(previous_hidden_states, chunks=self.tp_size, dim=0)
previous_hidden_states = previous_hidden_states_per_rank[self.tp_rank].contiguous()
if positions is not None:
positions_per_rank = torch.chunk(positions, chunks=self.tp_size, dim=0)
positions = positions_per_rank[self.tp_rank].contiguous()
hidden_states = self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids,
positions,
previous_hidden_states,
......@@ -199,6 +218,11 @@ class DeepSeekMultiTokenPredictor(nn.Module):
current_step_idx,
)
if enable_mla_cp:
hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
......
......@@ -293,9 +293,26 @@ class CpCommonAttentionMetadata:
seq_lens: torch.Tensor
_seq_lens_cpu: torch.Tensor
num_actual_tokens: int
num_kv_actual_tokens: int
max_query_len: int
max_seq_len: int
num_reqs: int
req_ids: list[str]
block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor
_num_computed_tokens_cpu: torch.Tensor
dcp_local_seq_lens: torch.Tensor | None = None
dcp_local_seq_lens_cpu: torch.Tensor | None = None
def batch_size(self) -> int:
return self.seq_lens.shape[0]
@property
def seq_lens_cpu(self) -> torch.Tensor:
if self._seq_lens_cpu is None:
self._seq_lens_cpu = self.seq_lens.to("cpu")
return self._seq_lens_cpu
@dataclass
......
......@@ -14,7 +14,7 @@ from vllm.config import (
VllmConfig,
get_layers_from_vllm_config,
)
from vllm.distributed.parallel_state import get_pp_group
from vllm.distributed.parallel_state import get_pp_group, get_tensor_model_parallel_rank
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
......@@ -29,6 +29,7 @@ from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backend import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
CpCommonAttentionMetadata,
)
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.tree_attn import (
......@@ -48,6 +49,7 @@ from vllm.v1.spec_decode.utils import (
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.utils.math_utils import cdiv, round_up
logger = init_logger(__name__)
......@@ -76,7 +78,8 @@ class SpecDecodeBaseProposer:
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
# The drafter can get longer sequences than the target model.
max_batch_size = vllm_config.scheduler_config.max_num_seqs
max_batch_size = vllm_config.scheduler_config.max_num_seqs if not envs.VLLM_MLA_CPLB \
else vllm_config.scheduler_config.max_num_seqs * 2
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
)
......@@ -219,6 +222,25 @@ class SpecDecodeBaseProposer:
1, len(self.tree_choices) + 1, device=device, dtype=torch.int32
).repeat(max_batch_size, 1)
if envs.VLLM_MLA_CP:
self.scatter_indexes_tensor = None
self.gather_indexes_tensor = None
self.query_start_loc = CpuGpuBuffer(
max_batch_size + 1,
dtype=torch.int32,
pin_memory=is_pin_memory_available(),
device=device,
with_numpy=True,
)
self.seq_lens = CpuGpuBuffer(
max_batch_size,
dtype=torch.int32,
pin_memory=is_pin_memory_available(),
device=device,
with_numpy=True,
)
def _get_positions(self, num_tokens: int):
if self.uses_mrope:
return self.mrope_positions[:, :num_tokens]
......@@ -270,6 +292,10 @@ class SpecDecodeBaseProposer:
self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)
def _pad_for_mla_cp(self, num_scheduled_tokens: int) -> int:
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
return round_up(num_scheduled_tokens, tp_size)
def propose(
self,
# [num_tokens]
......@@ -309,12 +335,37 @@ class SpecDecodeBaseProposer:
)
)
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
)
enable_mla_cp = envs.VLLM_MLA_CP and num_tokens > self.runner.mla_cp_threshould
if enable_mla_cp:
num_tokens_dp_padded = self._pad_for_mla_cp(num_tokens_dp_padded)
common_attn_metadata = self._prepare_cp_metadata(
num_reqs_padded=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
num_tokens=num_tokens,
block_table_gid_0=common_attn_metadata.block_table_tensor,
slot_mapping_gid_0=common_attn_metadata.slot_mapping,
query_start_loc=common_attn_metadata.query_start_loc,
query_start_loc_cpu=common_attn_metadata.query_start_loc_cpu,
seq_lens=common_attn_metadata.seq_lens,
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
)
self.scatter_indexes_tensor = common_attn_metadata.scatter_indexes_tensor
self.gather_indexes_tensor = common_attn_metadata.gather_indexes_tensor
assert self.runner is not None
if self.attn_metadata_builder is None:
attn_metadata_builder = self._get_attention_metadata_builder()
else:
attn_metadata_builder = self.attn_metadata_builder
attn_metadata_builder = self.attn_metadata_builder
attn_metadata = attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0
......@@ -339,10 +390,6 @@ class SpecDecodeBaseProposer:
assert draft_indexer_metadata is not None
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
)
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens_dp_padded
)
......@@ -387,6 +434,9 @@ class SpecDecodeBaseProposer:
slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping
),
scatter_indexes_tensor=self.scatter_indexes_tensor,
gather_indexes_tensor=self.gather_indexes_tensor,
enable_mla_cp=envs.VLLM_MLA_CP and num_tokens > self.runner.mla_cp_threshould,
):
ret_hidden_states = self.model(**model_kwargs)
if not self.model_returns_tuple():
......@@ -463,6 +513,9 @@ class SpecDecodeBaseProposer:
if batch_size_across_dp is not None:
batch_size_across_dp[self.dp_rank] = input_batch_size
if enable_mla_cp:
common_attn_metadata = common_attn_metadata.cp_common_metadata
common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1
common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
......@@ -989,6 +1042,104 @@ class SpecDecodeBaseProposer:
level_num_drafts = self.cu_drafts_per_level[level + 1] - total_num_drafts
total_num_drafts = self.cu_drafts_per_level[level + 1]
return draft_token_ids_list
def _prepare_cp_metadata(
self,
num_reqs_padded,
max_query_len,
max_seq_len,
num_tokens,
block_table_gid_0,
slot_mapping_gid_0,
query_start_loc,
query_start_loc_cpu,
seq_lens,
seq_lens_cpu,
num_computed_tokens_cpu,
):
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
tp_rank = get_tensor_model_parallel_rank()
cp_common_metadata = CpCommonAttentionMetadata(
query_start_loc=query_start_loc.clone(),
query_start_loc_cpu=query_start_loc_cpu.clone(),
seq_lens=seq_lens.clone(),
_seq_lens_cpu=seq_lens_cpu.clone(),
max_query_len=max_query_len,
max_seq_len=max_seq_len,
num_reqs=num_reqs_padded,
req_ids=self.runner.input_batch.req_ids,
num_actual_tokens=num_tokens,
num_kv_actual_tokens=num_tokens,
block_table_tensor=block_table_gid_0,
slot_mapping=slot_mapping_gid_0,
_num_computed_tokens_cpu=num_computed_tokens_cpu
)
q_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
kv_lens_cpu = seq_lens_cpu
total_q_len = num_tokens
total_kv_len = num_tokens
(
total_q_len,
q_lens_cpu,
seq_count,
kv_lens_cpu,
local_req_ids,
scatter_indexes_tensor,
gather_indexes_tensor,
seq_indexes_list,
) = self.runner._distribute_tokens_to_cp_ranks(
total_q_len,
q_lens_cpu,
kv_lens_cpu,
tp_rank,
tp_size,
self.runner.input_batch.req_ids,
)
num_reqs = seq_count
cu_num_tokens = np.cumsum(q_lens_cpu)
self.query_start_loc.np[0] = 0
self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens
self.query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1])
self.query_start_loc.copy_to_gpu()
q_acc_lens = self.query_start_loc.gpu[: num_reqs + 1]
q_acc_lens_cpu = self.query_start_loc.cpu[: num_reqs + 1]
max_q_len = max(q_acc_lens_cpu)
self.seq_lens.np[:num_reqs] = kv_lens_cpu
self.seq_lens.np[num_reqs:].fill(0)
self.seq_lens.copy_to_gpu()
kv_lens = self.seq_lens.gpu[:num_reqs]
kv_lens_cpu = self.seq_lens.cpu[:num_reqs]
max_kv_len = max(kv_lens_cpu)
num_computed_tokens_cpu = kv_lens_cpu - q_acc_lens_cpu[1:]
blk_table_tensor = block_table_gid_0[seq_indexes_list]
cm_base = CommonAttentionMetadata(
query_start_loc=q_acc_lens,
query_start_loc_cpu=q_acc_lens_cpu,
seq_lens=kv_lens,
_seq_lens_cpu=kv_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=total_q_len,
max_query_len=max_q_len,
max_seq_len=max_kv_len,
block_table_tensor=blk_table_tensor,
slot_mapping=slot_mapping_gid_0,
causal=True,
num_kv_actual_tokens=total_kv_len,
seq_indexes_list=seq_indexes_list,
cp_common_metadata=cp_common_metadata,
scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor,
)
return cm_base
def prepare_inputs(
self,
......
......@@ -1644,9 +1644,11 @@ class GPUModelRunner(
self,
num_reqs_padded,
max_query_len,
max_seq_len,
num_tokens,
block_table_gid_0,
slot_mapping_gid_0,
num_computed_tokens_cpu
):
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
tp_rank = get_tensor_model_parallel_rank()
......@@ -1657,9 +1659,14 @@ class GPUModelRunner(
seq_lens=self.seq_lens.gpu[:num_reqs_padded].clone(),
_seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded].clone(),
max_query_len=max_query_len,
max_seq_len=max_seq_len,
num_reqs=num_reqs_padded,
req_ids=self.input_batch.req_ids,
num_actual_tokens=num_tokens,
num_kv_actual_tokens=num_tokens,
block_table_tensor=block_table_gid_0,
slot_mapping=slot_mapping_gid_0,
_num_computed_tokens_cpu=num_computed_tokens_cpu
)
query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs_padded + 1]
......@@ -1725,6 +1732,7 @@ class GPUModelRunner(
cp_common_metadata=cp_common_metadata,
scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor,
enable_mla_cp=True
)
return cm_base
......@@ -2028,7 +2036,8 @@ class GPUModelRunner(
if self.model_config.enable_return_routed_experts:
self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy()
if not envs.VLLM_MLA_CP or num_tokens <= self.mla_cp_threshould:
mla_cp_enable = envs.VLLM_MLA_CP and num_tokens > self.mla_cp_threshould
if not mla_cp_enable:
cm_base = CommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1],
......@@ -2050,9 +2059,13 @@ class GPUModelRunner(
cm_base = self._prepare_cp_metadata(
num_reqs_padded,
max_query_len,
max_seq_len,
num_tokens,
block_table_gid_0,
slot_mapping_gid_0,
self.input_batch.num_computed_tokens_cpu_tensor[
:num_reqs_padded
],
)
scatter_indexes_tensor = cm_base.scatter_indexes_tensor
gather_indexes_tensor = cm_base.gather_indexes_tensor
......@@ -2172,9 +2185,17 @@ class GPUModelRunner(
if self.speculative_config and spec_decode_common_attn_metadata is None and hasattr(self, "drafter"):
if isinstance(self.drafter, EagleProposer):
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
spec_decode_common_attn_metadata = cm
if mla_cp_enable:
spec_decode_common_attn_metadata = cm.cp_common_metadata
else:
spec_decode_common_attn_metadata = cm
#spec_decode_common_attn_metadata = cm
else:
spec_decode_common_attn_metadata = cm
if mla_cp_enable:
spec_decode_common_attn_metadata = cm.cp_common_metadata
else:
spec_decode_common_attn_metadata = cm
#spec_decode_common_attn_metadata = cm
for attn_gid in range(len(self.attn_groups[kv_cache_gid])):
if ubatch_slices is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment