Commit 0beafe40 authored by 王敏's avatar 王敏
Browse files

[Feat]支持pcp+mtp

parent 09f318c1
...@@ -11,6 +11,7 @@ import torch ...@@ -11,6 +11,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.forward_context import get_forward_context
from vllm._aiter_ops import rocm_aiter_ops from vllm._aiter_ops import rocm_aiter_ops
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
...@@ -36,6 +37,9 @@ from .deepseek_v2 import ( ...@@ -36,6 +37,9 @@ from .deepseek_v2 import (
DeepseekV2MoE, DeepseekV2MoE,
get_spec_layer_idx_from_weight_name, get_spec_layer_idx_from_weight_name,
) )
from vllm.distributed import (tensor_model_parallel_all_gather,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from .utils import maybe_prefix from .utils import maybe_prefix
from .interfaces import SupportsPP from .interfaces import SupportsPP
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -177,6 +181,9 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -177,6 +181,9 @@ class DeepSeekMultiTokenPredictor(nn.Module):
) )
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
...@@ -191,7 +198,19 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -191,7 +198,19 @@ class DeepSeekMultiTokenPredictor(nn.Module):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = spec_step_idx % self.num_mtp_layers current_step_idx = spec_step_idx % self.num_mtp_layers
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( enable_mla_cp = get_forward_context().enable_mla_cp#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if enable_mla_cp:
inputs_embeds_per_rank = torch.chunk(inputs_embeds, chunks=self.tp_size, dim=0)
inputs_embeds = inputs_embeds_per_rank[self.tp_rank].contiguous()
previous_hidden_states_per_rank = torch.chunk(previous_hidden_states, chunks=self.tp_size, dim=0)
previous_hidden_states = previous_hidden_states_per_rank[self.tp_rank].contiguous()
if positions is not None:
positions_per_rank = torch.chunk(positions, chunks=self.tp_size, dim=0)
positions = positions_per_rank[self.tp_rank].contiguous()
hidden_states = self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids, input_ids,
positions, positions,
previous_hidden_states, previous_hidden_states,
...@@ -199,6 +218,11 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -199,6 +218,11 @@ class DeepSeekMultiTokenPredictor(nn.Module):
current_step_idx, current_step_idx,
) )
if enable_mla_cp:
hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
return hidden_states
def compute_logits( def compute_logits(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
......
...@@ -293,9 +293,26 @@ class CpCommonAttentionMetadata: ...@@ -293,9 +293,26 @@ class CpCommonAttentionMetadata:
seq_lens: torch.Tensor seq_lens: torch.Tensor
_seq_lens_cpu: torch.Tensor _seq_lens_cpu: torch.Tensor
num_actual_tokens: int num_actual_tokens: int
num_kv_actual_tokens: int
max_query_len: int max_query_len: int
max_seq_len: int
num_reqs: int num_reqs: int
req_ids: list[str] req_ids: list[str]
block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor
_num_computed_tokens_cpu: torch.Tensor
dcp_local_seq_lens: torch.Tensor | None = None
dcp_local_seq_lens_cpu: torch.Tensor | None = None
def batch_size(self) -> int:
return self.seq_lens.shape[0]
@property
def seq_lens_cpu(self) -> torch.Tensor:
if self._seq_lens_cpu is None:
self._seq_lens_cpu = self.seq_lens.to("cpu")
return self._seq_lens_cpu
@dataclass @dataclass
......
...@@ -14,7 +14,7 @@ from vllm.config import ( ...@@ -14,7 +14,7 @@ from vllm.config import (
VllmConfig, VllmConfig,
get_layers_from_vllm_config, get_layers_from_vllm_config,
) )
from vllm.distributed.parallel_state import get_pp_group from vllm.distributed.parallel_state import get_pp_group, get_tensor_model_parallel_rank
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
...@@ -29,6 +29,7 @@ from vllm.utils.platform_utils import is_pin_memory_available ...@@ -29,6 +29,7 @@ from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionMetadataBuilder, AttentionMetadataBuilder,
CommonAttentionMetadata, CommonAttentionMetadata,
CpCommonAttentionMetadata,
) )
from vllm.v1.attention.backends.registry import AttentionBackendEnum from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.backends.tree_attn import ( from vllm.v1.attention.backends.tree_attn import (
...@@ -48,6 +49,7 @@ from vllm.v1.spec_decode.utils import ( ...@@ -48,6 +49,7 @@ from vllm.v1.spec_decode.utils import (
from vllm.v1.utils import CpuGpuBuffer from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.dp_utils import coordinate_batch_across_dp from vllm.v1.worker.dp_utils import coordinate_batch_across_dp
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.utils.math_utils import cdiv, round_up
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -76,7 +78,8 @@ class SpecDecodeBaseProposer: ...@@ -76,7 +78,8 @@ class SpecDecodeBaseProposer:
self.dp_rank = vllm_config.parallel_config.data_parallel_rank self.dp_rank = vllm_config.parallel_config.data_parallel_rank
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
# The drafter can get longer sequences than the target model. # The drafter can get longer sequences than the target model.
max_batch_size = vllm_config.scheduler_config.max_num_seqs max_batch_size = vllm_config.scheduler_config.max_num_seqs if not envs.VLLM_MLA_CPLB \
else vllm_config.scheduler_config.max_num_seqs * 2
self.max_num_tokens = ( self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size vllm_config.scheduler_config.max_num_batched_tokens + max_batch_size
) )
...@@ -219,6 +222,25 @@ class SpecDecodeBaseProposer: ...@@ -219,6 +222,25 @@ class SpecDecodeBaseProposer:
1, len(self.tree_choices) + 1, device=device, dtype=torch.int32 1, len(self.tree_choices) + 1, device=device, dtype=torch.int32
).repeat(max_batch_size, 1) ).repeat(max_batch_size, 1)
if envs.VLLM_MLA_CP:
self.scatter_indexes_tensor = None
self.gather_indexes_tensor = None
self.query_start_loc = CpuGpuBuffer(
max_batch_size + 1,
dtype=torch.int32,
pin_memory=is_pin_memory_available(),
device=device,
with_numpy=True,
)
self.seq_lens = CpuGpuBuffer(
max_batch_size,
dtype=torch.int32,
pin_memory=is_pin_memory_available(),
device=device,
with_numpy=True,
)
def _get_positions(self, num_tokens: int): def _get_positions(self, num_tokens: int):
if self.uses_mrope: if self.uses_mrope:
return self.mrope_positions[:, :num_tokens] return self.mrope_positions[:, :num_tokens]
...@@ -270,6 +292,10 @@ class SpecDecodeBaseProposer: ...@@ -270,6 +292,10 @@ class SpecDecodeBaseProposer:
self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode) self.cudagraph_dispatcher.initialize_cudagraph_keys(eagle_cudagraph_mode)
def _pad_for_mla_cp(self, num_scheduled_tokens: int) -> int:
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
return round_up(num_scheduled_tokens, tp_size)
def propose( def propose(
self, self,
# [num_tokens] # [num_tokens]
...@@ -309,6 +335,31 @@ class SpecDecodeBaseProposer: ...@@ -309,6 +335,31 @@ class SpecDecodeBaseProposer:
) )
) )
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
)
enable_mla_cp = envs.VLLM_MLA_CP and num_tokens > self.runner.mla_cp_threshould
if enable_mla_cp:
num_tokens_dp_padded = self._pad_for_mla_cp(num_tokens_dp_padded)
common_attn_metadata = self._prepare_cp_metadata(
num_reqs_padded=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.seq_lens_cpu.max().item(),
num_tokens=num_tokens,
block_table_gid_0=common_attn_metadata.block_table_tensor,
slot_mapping_gid_0=common_attn_metadata.slot_mapping,
query_start_loc=common_attn_metadata.query_start_loc,
query_start_loc_cpu=common_attn_metadata.query_start_loc_cpu,
seq_lens=common_attn_metadata.seq_lens,
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
num_computed_tokens_cpu=common_attn_metadata._num_computed_tokens_cpu,
)
self.scatter_indexes_tensor = common_attn_metadata.scatter_indexes_tensor
self.gather_indexes_tensor = common_attn_metadata.gather_indexes_tensor
assert self.runner is not None assert self.runner is not None
if self.attn_metadata_builder is None: if self.attn_metadata_builder is None:
...@@ -339,10 +390,6 @@ class SpecDecodeBaseProposer: ...@@ -339,10 +390,6 @@ class SpecDecodeBaseProposer:
assert draft_indexer_metadata is not None assert draft_indexer_metadata is not None
per_layer_attn_metadata[layer_name] = draft_indexer_metadata per_layer_attn_metadata[layer_name] = draft_indexer_metadata
num_tokens_dp_padded, num_tokens_across_dp = self._pad_batch_across_dp(
num_tokens_unpadded=num_tokens, num_tokens_padded=num_tokens
)
cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch( cudagraph_runtime_mode, batch_desc = self.cudagraph_dispatcher.dispatch(
num_tokens_dp_padded num_tokens_dp_padded
) )
...@@ -387,6 +434,9 @@ class SpecDecodeBaseProposer: ...@@ -387,6 +434,9 @@ class SpecDecodeBaseProposer:
slot_mapping=self._get_slot_mapping( slot_mapping=self._get_slot_mapping(
num_input_tokens, common_attn_metadata.slot_mapping num_input_tokens, common_attn_metadata.slot_mapping
), ),
scatter_indexes_tensor=self.scatter_indexes_tensor,
gather_indexes_tensor=self.gather_indexes_tensor,
enable_mla_cp=envs.VLLM_MLA_CP and num_tokens > self.runner.mla_cp_threshould,
): ):
ret_hidden_states = self.model(**model_kwargs) ret_hidden_states = self.model(**model_kwargs)
if not self.model_returns_tuple(): if not self.model_returns_tuple():
...@@ -463,6 +513,9 @@ class SpecDecodeBaseProposer: ...@@ -463,6 +513,9 @@ class SpecDecodeBaseProposer:
if batch_size_across_dp is not None: if batch_size_across_dp is not None:
batch_size_across_dp[self.dp_rank] = input_batch_size batch_size_across_dp[self.dp_rank] = input_batch_size
if enable_mla_cp:
common_attn_metadata = common_attn_metadata.cp_common_metadata
common_attn_metadata.num_actual_tokens = batch_size common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1 common_attn_metadata.max_query_len = 1
common_attn_metadata.query_start_loc = self.arange[: batch_size + 1] common_attn_metadata.query_start_loc = self.arange[: batch_size + 1]
...@@ -990,6 +1043,104 @@ class SpecDecodeBaseProposer: ...@@ -990,6 +1043,104 @@ class SpecDecodeBaseProposer:
total_num_drafts = self.cu_drafts_per_level[level + 1] total_num_drafts = self.cu_drafts_per_level[level + 1]
return draft_token_ids_list return draft_token_ids_list
def _prepare_cp_metadata(
self,
num_reqs_padded,
max_query_len,
max_seq_len,
num_tokens,
block_table_gid_0,
slot_mapping_gid_0,
query_start_loc,
query_start_loc_cpu,
seq_lens,
seq_lens_cpu,
num_computed_tokens_cpu,
):
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
tp_rank = get_tensor_model_parallel_rank()
cp_common_metadata = CpCommonAttentionMetadata(
query_start_loc=query_start_loc.clone(),
query_start_loc_cpu=query_start_loc_cpu.clone(),
seq_lens=seq_lens.clone(),
_seq_lens_cpu=seq_lens_cpu.clone(),
max_query_len=max_query_len,
max_seq_len=max_seq_len,
num_reqs=num_reqs_padded,
req_ids=self.runner.input_batch.req_ids,
num_actual_tokens=num_tokens,
num_kv_actual_tokens=num_tokens,
block_table_tensor=block_table_gid_0,
slot_mapping=slot_mapping_gid_0,
_num_computed_tokens_cpu=num_computed_tokens_cpu
)
q_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
kv_lens_cpu = seq_lens_cpu
total_q_len = num_tokens
total_kv_len = num_tokens
(
total_q_len,
q_lens_cpu,
seq_count,
kv_lens_cpu,
local_req_ids,
scatter_indexes_tensor,
gather_indexes_tensor,
seq_indexes_list,
) = self.runner._distribute_tokens_to_cp_ranks(
total_q_len,
q_lens_cpu,
kv_lens_cpu,
tp_rank,
tp_size,
self.runner.input_batch.req_ids,
)
num_reqs = seq_count
cu_num_tokens = np.cumsum(q_lens_cpu)
self.query_start_loc.np[0] = 0
self.query_start_loc.np[1 : num_reqs + 1] = cu_num_tokens
self.query_start_loc.np[num_reqs + 1 :].fill(cu_num_tokens[-1])
self.query_start_loc.copy_to_gpu()
q_acc_lens = self.query_start_loc.gpu[: num_reqs + 1]
q_acc_lens_cpu = self.query_start_loc.cpu[: num_reqs + 1]
max_q_len = max(q_acc_lens_cpu)
self.seq_lens.np[:num_reqs] = kv_lens_cpu
self.seq_lens.np[num_reqs:].fill(0)
self.seq_lens.copy_to_gpu()
kv_lens = self.seq_lens.gpu[:num_reqs]
kv_lens_cpu = self.seq_lens.cpu[:num_reqs]
max_kv_len = max(kv_lens_cpu)
num_computed_tokens_cpu = kv_lens_cpu - q_acc_lens_cpu[1:]
blk_table_tensor = block_table_gid_0[seq_indexes_list]
cm_base = CommonAttentionMetadata(
query_start_loc=q_acc_lens,
query_start_loc_cpu=q_acc_lens_cpu,
seq_lens=kv_lens,
_seq_lens_cpu=kv_lens_cpu,
_num_computed_tokens_cpu=num_computed_tokens_cpu,
num_reqs=num_reqs,
num_actual_tokens=total_q_len,
max_query_len=max_q_len,
max_seq_len=max_kv_len,
block_table_tensor=blk_table_tensor,
slot_mapping=slot_mapping_gid_0,
causal=True,
num_kv_actual_tokens=total_kv_len,
seq_indexes_list=seq_indexes_list,
cp_common_metadata=cp_common_metadata,
scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor,
)
return cm_base
def prepare_inputs( def prepare_inputs(
self, self,
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
......
...@@ -1644,9 +1644,11 @@ class GPUModelRunner( ...@@ -1644,9 +1644,11 @@ class GPUModelRunner(
self, self,
num_reqs_padded, num_reqs_padded,
max_query_len, max_query_len,
max_seq_len,
num_tokens, num_tokens,
block_table_gid_0, block_table_gid_0,
slot_mapping_gid_0, slot_mapping_gid_0,
num_computed_tokens_cpu
): ):
tp_size = self.vllm_config.parallel_config.tensor_parallel_size tp_size = self.vllm_config.parallel_config.tensor_parallel_size
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
...@@ -1657,9 +1659,14 @@ class GPUModelRunner( ...@@ -1657,9 +1659,14 @@ class GPUModelRunner(
seq_lens=self.seq_lens.gpu[:num_reqs_padded].clone(), seq_lens=self.seq_lens.gpu[:num_reqs_padded].clone(),
_seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded].clone(), _seq_lens_cpu=self.seq_lens.cpu[:num_reqs_padded].clone(),
max_query_len=max_query_len, max_query_len=max_query_len,
max_seq_len=max_seq_len,
num_reqs=num_reqs_padded, num_reqs=num_reqs_padded,
req_ids=self.input_batch.req_ids, req_ids=self.input_batch.req_ids,
num_actual_tokens=num_tokens, num_actual_tokens=num_tokens,
num_kv_actual_tokens=num_tokens,
block_table_tensor=block_table_gid_0,
slot_mapping=slot_mapping_gid_0,
_num_computed_tokens_cpu=num_computed_tokens_cpu
) )
query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs_padded + 1] query_start_loc_cpu = self.query_start_loc.cpu[: num_reqs_padded + 1]
...@@ -1725,6 +1732,7 @@ class GPUModelRunner( ...@@ -1725,6 +1732,7 @@ class GPUModelRunner(
cp_common_metadata=cp_common_metadata, cp_common_metadata=cp_common_metadata,
scatter_indexes_tensor=scatter_indexes_tensor, scatter_indexes_tensor=scatter_indexes_tensor,
gather_indexes_tensor=gather_indexes_tensor, gather_indexes_tensor=gather_indexes_tensor,
enable_mla_cp=True
) )
return cm_base return cm_base
...@@ -2028,7 +2036,8 @@ class GPUModelRunner( ...@@ -2028,7 +2036,8 @@ class GPUModelRunner(
if self.model_config.enable_return_routed_experts: if self.model_config.enable_return_routed_experts:
self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy() self.slot_mapping = slot_mapping_gid_0[:num_tokens].cpu().numpy()
if not envs.VLLM_MLA_CP or num_tokens <= self.mla_cp_threshould: mla_cp_enable = envs.VLLM_MLA_CP and num_tokens > self.mla_cp_threshould
if not mla_cp_enable:
cm_base = CommonAttentionMetadata( cm_base = CommonAttentionMetadata(
query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1], query_start_loc=self.query_start_loc.gpu[: num_reqs_padded + 1],
query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1], query_start_loc_cpu=self.query_start_loc.cpu[: num_reqs_padded + 1],
...@@ -2050,9 +2059,13 @@ class GPUModelRunner( ...@@ -2050,9 +2059,13 @@ class GPUModelRunner(
cm_base = self._prepare_cp_metadata( cm_base = self._prepare_cp_metadata(
num_reqs_padded, num_reqs_padded,
max_query_len, max_query_len,
max_seq_len,
num_tokens, num_tokens,
block_table_gid_0, block_table_gid_0,
slot_mapping_gid_0, slot_mapping_gid_0,
self.input_batch.num_computed_tokens_cpu_tensor[
:num_reqs_padded
],
) )
scatter_indexes_tensor = cm_base.scatter_indexes_tensor scatter_indexes_tensor = cm_base.scatter_indexes_tensor
gather_indexes_tensor = cm_base.gather_indexes_tensor gather_indexes_tensor = cm_base.gather_indexes_tensor
...@@ -2172,9 +2185,17 @@ class GPUModelRunner( ...@@ -2172,9 +2185,17 @@ class GPUModelRunner(
if self.speculative_config and spec_decode_common_attn_metadata is None and hasattr(self, "drafter"): if self.speculative_config and spec_decode_common_attn_metadata is None and hasattr(self, "drafter"):
if isinstance(self.drafter, EagleProposer): if isinstance(self.drafter, EagleProposer):
if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names: if self.drafter.attn_layer_names[0] in kv_cache_group.layer_names:
if mla_cp_enable:
spec_decode_common_attn_metadata = cm.cp_common_metadata
else:
spec_decode_common_attn_metadata = cm spec_decode_common_attn_metadata = cm
#spec_decode_common_attn_metadata = cm
else:
if mla_cp_enable:
spec_decode_common_attn_metadata = cm.cp_common_metadata
else: else:
spec_decode_common_attn_metadata = cm spec_decode_common_attn_metadata = cm
#spec_decode_common_attn_metadata = cm
for attn_gid in range(len(self.attn_groups[kv_cache_gid])): for attn_gid in range(len(self.attn_groups[kv_cache_gid])):
if ubatch_slices is not None: if ubatch_slices is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment