Unverified Commit 02bf31ef authored by Atream's avatar Atream Committed by GitHub
Browse files

[fix] PD disaggregation when enable mtp and tp!=dp (#7420)

parent 5ea5d221
...@@ -766,7 +766,7 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -766,7 +766,7 @@ class SchedulerDisaggregationDecodeMixin:
if batch: if batch:
result = self.run_batch(batch) result = self.run_batch(batch)
if not delay_process: if not delay_process:
self.prepare_mlp_sync_batch(batch, result) self.process_batch_result(batch, result)
return batch, result return batch, result
def get_next_disagg_decode_batch_to_run( def get_next_disagg_decode_batch_to_run(
......
...@@ -310,4 +310,4 @@ def attn_tp_reduce_scatter( ...@@ -310,4 +310,4 @@ def attn_tp_reduce_scatter(
def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor): def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
return get_attention_tp_group().all_gather(input_, tensor_list=output_list) return get_attention_tp_group().all_gather(input_, output_tensor_list=output_list)
...@@ -1435,7 +1435,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1435,7 +1435,7 @@ class DeepseekV2DecoderLayer(nn.Module):
self.layer_scatter_modes = LayerScatterModes.init_new( self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id, layer_id=layer_id,
num_layers=config.num_hidden_layers, num_layers=1 if is_nextn else config.num_hidden_layers,
is_layer_sparse=self.is_layer_sparse, is_layer_sparse=self.is_layer_sparse,
is_previous_layer_sparse=is_previous_layer_sparse, is_previous_layer_sparse=is_previous_layer_sparse,
) )
...@@ -1488,6 +1488,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -1488,6 +1488,7 @@ class DeepseekV2DecoderLayer(nn.Module):
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator, zero_allocator: BumpAllocator,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states, residual = self.layer_communicator.prepare_attn( hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch hidden_states, residual, forward_batch
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment