Commit f3d1f95b authored by 王敏's avatar 王敏
Browse files

[Feat]添加CPLB功能,支持PCP模式下负载均衡

parent 20254503
......@@ -201,6 +201,12 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
k_pe.contiguous(), 0
)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if envs.VLLM_MLA_CPLB and gather_indexes_tensor is not None:
# Reorder kv after pcp allgather.
kv_c_normed = torch.index_select(kv_c_normed, 0, gather_indexes_tensor)
k_pe = torch.index_select(k_pe, 0, gather_indexes_tensor)
attn_out = self.mla_attn(
q,
kv_c_normed,
......@@ -244,6 +250,11 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
k_pe = tensor_model_parallel_all_gather(
k_pe.contiguous(), 0
)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if envs.VLLM_MLA_CPLB and gather_indexes_tensor is not None:
# Reorder kv after pcp allgather.
kv_c = torch.index_select(kv_c, 0, gather_indexes_tensor)
k_pe = torch.index_select(k_pe, 0, gather_indexes_tensor)
attn_out = self.mla_attn(
q[..., self.qk_nope_head_dim:],
......
......@@ -200,6 +200,8 @@ class DeepSeekMultiTokenPredictor(nn.Module):
current_step_idx = spec_step_idx % self.num_mtp_layers
enable_mla_cp = get_forward_context().enable_mla_cp#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if enable_mla_cp:
scatter_indexes_tensor = get_forward_context().scatter_indexes_tensor
if scatter_indexes_tensor is None:
inputs_embeds_per_rank = torch.chunk(inputs_embeds, chunks=self.tp_size, dim=0)
inputs_embeds = inputs_embeds_per_rank[self.tp_rank].contiguous()
......@@ -209,6 +211,14 @@ class DeepSeekMultiTokenPredictor(nn.Module):
if positions is not None:
positions_per_rank = torch.chunk(positions, chunks=self.tp_size, dim=0)
positions = positions_per_rank[self.tp_rank].contiguous()
else:
#scatter_indexes_tensor = scatter_indexes_tensor[scatter_indexes_tensor != -1]
scatter_indexes_tensor = torch.where(scatter_indexes_tensor == -1, 0, scatter_indexes_tensor)
inputs_embeds = torch.index_select(inputs_embeds, 0, scatter_indexes_tensor)
previous_hidden_states = torch.index_select(previous_hidden_states, 0, scatter_indexes_tensor)
if positions is not None:
positions = torch.index_select(positions, 0, scatter_indexes_tensor)
hidden_states = self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids,
......@@ -220,6 +230,9 @@ class DeepSeekMultiTokenPredictor(nn.Module):
if enable_mla_cp:
hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if gather_indexes_tensor is not None:
hidden_states = torch.index_select(hidden_states, 0, gather_indexes_tensor)
return hidden_states
......
......@@ -855,6 +855,9 @@ class Indexer(nn.Module):
k = tensor_model_parallel_all_gather(
k.contiguous(), 0
)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if envs.VLLM_MLA_CPLB and gather_indexes_tensor is not None:
k = torch.index_select(k, 0, gather_indexes_tensor)
# we only quant q here since k quant is fused with cache insertion
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
......@@ -1397,6 +1400,8 @@ class DeepseekV2Model(nn.Module):
enable_mla_cp = get_forward_context().enable_mla_cp#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if enable_mla_cp:
scatter_indexes_tensor = get_forward_context().scatter_indexes_tensor
if scatter_indexes_tensor is None:
hidden_states_per_rank = torch.chunk(hidden_states, chunks=self.tp_size, dim=0)
hidden_states = hidden_states_per_rank[self.tp_rank].contiguous()
......@@ -1407,6 +1412,15 @@ class DeepseekV2Model(nn.Module):
if positions is not None:
positions_per_rank = torch.chunk(positions, chunks=self.tp_size, dim=0)
positions = positions_per_rank[self.tp_rank].contiguous()
else:
scatter_indexes_tensor = torch.where(scatter_indexes_tensor == -1, 0, scatter_indexes_tensor)
hidden_states = torch.index_select(hidden_states, 0, scatter_indexes_tensor)
if residual is not None:
residual = torch.index_select(residual, 0, scatter_indexes_tensor)
if positions is not None:
positions = torch.index_select(positions, 0, scatter_indexes_tensor)
# Compute llama 4 scaling once per forward pass if enabled
llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None)
......@@ -1439,6 +1453,9 @@ class DeepseekV2Model(nn.Module):
if enable_mla_cp:
hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if gather_indexes_tensor is not None:
hidden_states = torch.index_select(hidden_states, 0, gather_indexes_tensor)
return hidden_states
......
......@@ -5159,7 +5159,7 @@ class GPUModelRunner(
batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices_padded,
slot_mapping=slot_mappings,
enable_mla_cp=envs.VLLM_MLA_CP and num_tokens_unpadded > self.mla_cp_threshould
enable_mla_cp=envs.VLLM_MLA_CP and num_tokens_unpadded > self.mla_cp_threshould,
),
):
outputs = self.model(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment