Commit f3d1f95b authored by 王敏's avatar 王敏
Browse files

[Feat]添加CPLB功能,支持PCP模式下负载均衡

parent 20254503
...@@ -201,6 +201,12 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -201,6 +201,12 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
k_pe.contiguous(), 0 k_pe.contiguous(), 0
) )
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if envs.VLLM_MLA_CPLB and gather_indexes_tensor is not None:
# Reorder kv after pcp allgather.
kv_c_normed = torch.index_select(kv_c_normed, 0, gather_indexes_tensor)
k_pe = torch.index_select(k_pe, 0, gather_indexes_tensor)
attn_out = self.mla_attn( attn_out = self.mla_attn(
q, q,
kv_c_normed, kv_c_normed,
...@@ -244,6 +250,11 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer): ...@@ -244,6 +250,11 @@ class MultiHeadLatentAttentionWrapper(PluggableLayer):
k_pe = tensor_model_parallel_all_gather( k_pe = tensor_model_parallel_all_gather(
k_pe.contiguous(), 0 k_pe.contiguous(), 0
) )
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if envs.VLLM_MLA_CPLB and gather_indexes_tensor is not None:
# Reorder kv after pcp allgather.
kv_c = torch.index_select(kv_c, 0, gather_indexes_tensor)
k_pe = torch.index_select(k_pe, 0, gather_indexes_tensor)
attn_out = self.mla_attn( attn_out = self.mla_attn(
q[..., self.qk_nope_head_dim:], q[..., self.qk_nope_head_dim:],
......
...@@ -200,15 +200,25 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -200,15 +200,25 @@ class DeepSeekMultiTokenPredictor(nn.Module):
current_step_idx = spec_step_idx % self.num_mtp_layers current_step_idx = spec_step_idx % self.num_mtp_layers
enable_mla_cp = get_forward_context().enable_mla_cp#envs.VLLM_MLA_CP # and not get_forward_context().draft_model enable_mla_cp = get_forward_context().enable_mla_cp#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if enable_mla_cp: if enable_mla_cp:
inputs_embeds_per_rank = torch.chunk(inputs_embeds, chunks=self.tp_size, dim=0) scatter_indexes_tensor = get_forward_context().scatter_indexes_tensor
inputs_embeds = inputs_embeds_per_rank[self.tp_rank].contiguous() if scatter_indexes_tensor is None:
inputs_embeds_per_rank = torch.chunk(inputs_embeds, chunks=self.tp_size, dim=0)
inputs_embeds = inputs_embeds_per_rank[self.tp_rank].contiguous()
previous_hidden_states_per_rank = torch.chunk(previous_hidden_states, chunks=self.tp_size, dim=0) previous_hidden_states_per_rank = torch.chunk(previous_hidden_states, chunks=self.tp_size, dim=0)
previous_hidden_states = previous_hidden_states_per_rank[self.tp_rank].contiguous() previous_hidden_states = previous_hidden_states_per_rank[self.tp_rank].contiguous()
if positions is not None: if positions is not None:
positions_per_rank = torch.chunk(positions, chunks=self.tp_size, dim=0) positions_per_rank = torch.chunk(positions, chunks=self.tp_size, dim=0)
positions = positions_per_rank[self.tp_rank].contiguous() positions = positions_per_rank[self.tp_rank].contiguous()
else:
#scatter_indexes_tensor = scatter_indexes_tensor[scatter_indexes_tensor != -1]
scatter_indexes_tensor = torch.where(scatter_indexes_tensor == -1, 0, scatter_indexes_tensor)
inputs_embeds = torch.index_select(inputs_embeds, 0, scatter_indexes_tensor)
previous_hidden_states = torch.index_select(previous_hidden_states, 0, scatter_indexes_tensor)
if positions is not None:
positions = torch.index_select(positions, 0, scatter_indexes_tensor)
hidden_states = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]( hidden_states = self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids, input_ids,
...@@ -220,6 +230,9 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -220,6 +230,9 @@ class DeepSeekMultiTokenPredictor(nn.Module):
if enable_mla_cp: if enable_mla_cp:
hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0) hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if gather_indexes_tensor is not None:
hidden_states = torch.index_select(hidden_states, 0, gather_indexes_tensor)
return hidden_states return hidden_states
......
...@@ -855,6 +855,9 @@ class Indexer(nn.Module): ...@@ -855,6 +855,9 @@ class Indexer(nn.Module):
k = tensor_model_parallel_all_gather( k = tensor_model_parallel_all_gather(
k.contiguous(), 0 k.contiguous(), 0
) )
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if envs.VLLM_MLA_CPLB and gather_indexes_tensor is not None:
k = torch.index_select(k, 0, gather_indexes_tensor)
# we only quant q here since k quant is fused with cache insertion # we only quant q here since k quant is fused with cache insertion
if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938": if not current_platform.is_rocm() or torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
...@@ -1397,16 +1400,27 @@ class DeepseekV2Model(nn.Module): ...@@ -1397,16 +1400,27 @@ class DeepseekV2Model(nn.Module):
enable_mla_cp = get_forward_context().enable_mla_cp#envs.VLLM_MLA_CP # and not get_forward_context().draft_model enable_mla_cp = get_forward_context().enable_mla_cp#envs.VLLM_MLA_CP # and not get_forward_context().draft_model
if enable_mla_cp: if enable_mla_cp:
hidden_states_per_rank = torch.chunk(hidden_states, chunks=self.tp_size, dim=0) scatter_indexes_tensor = get_forward_context().scatter_indexes_tensor
hidden_states = hidden_states_per_rank[self.tp_rank].contiguous() if scatter_indexes_tensor is None:
hidden_states_per_rank = torch.chunk(hidden_states, chunks=self.tp_size, dim=0)
hidden_states = hidden_states_per_rank[self.tp_rank].contiguous()
if residual is not None:
residual_per_rank = torch.chunk(residual, chunks=self.tp_size, dim=0)
residual = residual_per_rank[self.tp_rank].contiguous()
if positions is not None:
positions_per_rank = torch.chunk(positions, chunks=self.tp_size, dim=0)
positions = positions_per_rank[self.tp_rank].contiguous()
else:
scatter_indexes_tensor = torch.where(scatter_indexes_tensor == -1, 0, scatter_indexes_tensor)
hidden_states = torch.index_select(hidden_states, 0, scatter_indexes_tensor)
if residual is not None: if residual is not None:
residual_per_rank = torch.chunk(residual, chunks=self.tp_size, dim=0) residual = torch.index_select(residual, 0, scatter_indexes_tensor)
residual = residual_per_rank[self.tp_rank].contiguous()
if positions is not None: if positions is not None:
positions_per_rank = torch.chunk(positions, chunks=self.tp_size, dim=0) positions = torch.index_select(positions, 0, scatter_indexes_tensor)
positions = positions_per_rank[self.tp_rank].contiguous()
# Compute llama 4 scaling once per forward pass if enabled # Compute llama 4 scaling once per forward pass if enabled
llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None) llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None)
...@@ -1439,6 +1453,9 @@ class DeepseekV2Model(nn.Module): ...@@ -1439,6 +1453,9 @@ class DeepseekV2Model(nn.Module):
if enable_mla_cp: if enable_mla_cp:
hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0) hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
gather_indexes_tensor = get_forward_context().gather_indexes_tensor
if gather_indexes_tensor is not None:
hidden_states = torch.index_select(hidden_states, 0, gather_indexes_tensor)
return hidden_states return hidden_states
......
...@@ -5159,7 +5159,7 @@ class GPUModelRunner( ...@@ -5159,7 +5159,7 @@ class GPUModelRunner(
batch_descriptor=batch_desc, batch_descriptor=batch_desc,
ubatch_slices=ubatch_slices_padded, ubatch_slices=ubatch_slices_padded,
slot_mapping=slot_mappings, slot_mapping=slot_mappings,
enable_mla_cp=envs.VLLM_MLA_CP and num_tokens_unpadded > self.mla_cp_threshould enable_mla_cp=envs.VLLM_MLA_CP and num_tokens_unpadded > self.mla_cp_threshould,
), ),
): ):
outputs = self.model( outputs = self.model(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment