Unverified Commit 02bf31ef authored by Atream's avatar Atream Committed by GitHub
Browse files

[fix] PD disaggregation when enable mtp and tp!=dp (#7420)

parent 5ea5d221
......@@ -766,7 +766,7 @@ class SchedulerDisaggregationDecodeMixin:
if batch:
result = self.run_batch(batch)
if not delay_process:
self.prepare_mlp_sync_batch(batch, result)
self.process_batch_result(batch, result)
return batch, result
def get_next_disagg_decode_batch_to_run(
......
......@@ -310,4 +310,4 @@ def attn_tp_reduce_scatter(
def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
return get_attention_tp_group().all_gather(input_, tensor_list=output_list)
return get_attention_tp_group().all_gather(input_, output_tensor_list=output_list)
......@@ -1435,7 +1435,7 @@ class DeepseekV2DecoderLayer(nn.Module):
self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id,
num_layers=config.num_hidden_layers,
num_layers=1 if is_nextn else config.num_hidden_layers,
is_layer_sparse=self.is_layer_sparse,
is_previous_layer_sparse=is_previous_layer_sparse,
)
......@@ -1488,6 +1488,7 @@ class DeepseekV2DecoderLayer(nn.Module):
residual: Optional[torch.Tensor],
zero_allocator: BumpAllocator,
) -> torch.Tensor:
hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment