Unverified Commit e68a2b5b authored by Zilin Zhu's avatar Zilin Zhu Committed by GitHub
Browse files

[RL] use cpu group to prepare_mlp_sync_batch_raw when the server is offloaded (#10152)

parent 31b9f19e
...@@ -320,6 +320,7 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner): ...@@ -320,6 +320,7 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
speculative_num_draft_tokens=None, speculative_num_draft_tokens=None,
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args), require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule, disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
offload_tags=set(),
) )
......
...@@ -2339,6 +2339,7 @@ class Scheduler( ...@@ -2339,6 +2339,7 @@ class Scheduler(
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens, speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args), require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
disable_overlap_schedule=self.server_args.disable_overlap_schedule, disable_overlap_schedule=self.server_args.disable_overlap_schedule,
offload_tags=self.offload_tags,
) )
@staticmethod @staticmethod
...@@ -2353,6 +2354,7 @@ class Scheduler( ...@@ -2353,6 +2354,7 @@ class Scheduler(
speculative_num_draft_tokens, speculative_num_draft_tokens,
require_mlp_tp_gather: bool, require_mlp_tp_gather: bool,
disable_overlap_schedule: bool, disable_overlap_schedule: bool,
offload_tags: set[str],
): ):
# Check if other DP workers have running batches # Check if other DP workers have running batches
if local_batch is None: if local_batch is None:
...@@ -2383,7 +2385,7 @@ class Scheduler( ...@@ -2383,7 +2385,7 @@ class Scheduler(
) )
tbo_preparer = TboDPAttentionPreparer() tbo_preparer = TboDPAttentionPreparer()
if disable_overlap_schedule: if len(offload_tags) == 0 and disable_overlap_schedule:
group = tp_group.device_group group = tp_group.device_group
device = tp_group.device device = tp_group.device
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment