Unverified Commit 69f453e5 authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

Use device_group for all_gather when disabling overlap scheduling (#8001)

parent 3bc43c68
...@@ -271,12 +271,13 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner): ...@@ -271,12 +271,13 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
batch, batch,
dp_size=model_runner.server_args.dp_size, dp_size=model_runner.server_args.dp_size,
attn_tp_size=1, attn_tp_size=1,
tp_cpu_group=model_runner.tp_group.cpu_group, tp_group=model_runner.tp_group,
get_idle_batch=None, get_idle_batch=None,
disable_cuda_graph=model_runner.server_args.disable_cuda_graph, disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
spec_algorithm=SpeculativeAlgorithm.NONE, spec_algorithm=SpeculativeAlgorithm.NONE,
speculative_num_draft_tokens=None, speculative_num_draft_tokens=None,
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args), require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
) )
......
...@@ -1945,7 +1945,7 @@ class Scheduler( ...@@ -1945,7 +1945,7 @@ class Scheduler(
local_batch, local_batch,
dp_size=self.server_args.dp_size, dp_size=self.server_args.dp_size,
attn_tp_size=self.attn_tp_size, attn_tp_size=self.attn_tp_size,
tp_cpu_group=self.tp_cpu_group, tp_group=self.tp_group,
get_idle_batch=self.get_idle_batch, get_idle_batch=self.get_idle_batch,
disable_cuda_graph=self.server_args.disable_cuda_graph, disable_cuda_graph=self.server_args.disable_cuda_graph,
spec_algorithm=self.spec_algorithm, spec_algorithm=self.spec_algorithm,
...@@ -1954,6 +1954,7 @@ class Scheduler( ...@@ -1954,6 +1954,7 @@ class Scheduler(
enable_deepep_moe=self.server_args.enable_deepep_moe, enable_deepep_moe=self.server_args.enable_deepep_moe,
deepep_mode=DeepEPMode[self.server_args.deepep_mode], deepep_mode=DeepEPMode[self.server_args.deepep_mode],
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args), require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
) )
@staticmethod @staticmethod
...@@ -1961,7 +1962,7 @@ class Scheduler( ...@@ -1961,7 +1962,7 @@ class Scheduler(
local_batch: ScheduleBatch, local_batch: ScheduleBatch,
dp_size, dp_size,
attn_tp_size: int, attn_tp_size: int,
tp_cpu_group, tp_group,
get_idle_batch, get_idle_batch,
disable_cuda_graph: bool, disable_cuda_graph: bool,
spec_algorithm, spec_algorithm,
...@@ -1970,6 +1971,7 @@ class Scheduler( ...@@ -1970,6 +1971,7 @@ class Scheduler(
enable_deepep_moe: bool, enable_deepep_moe: bool,
deepep_mode: DeepEPMode, deepep_mode: DeepEPMode,
require_mlp_tp_gather: bool, require_mlp_tp_gather: bool,
disable_overlap_schedule: bool,
): ):
# Check if other DP workers have running batches # Check if other DP workers have running batches
if local_batch is None: if local_batch is None:
...@@ -2000,6 +2002,12 @@ class Scheduler( ...@@ -2000,6 +2002,12 @@ class Scheduler(
) )
tbo_preparer = TboDPAttentionPreparer() tbo_preparer = TboDPAttentionPreparer()
if disable_overlap_schedule:
group = tp_group.device_group
device = tp_group.device
else:
group = tp_group.cpu_group
device = "cpu"
local_info = torch.tensor( local_info = torch.tensor(
[ [
...@@ -2015,15 +2023,17 @@ class Scheduler( ...@@ -2015,15 +2023,17 @@ class Scheduler(
), ),
], ],
dtype=torch.int64, dtype=torch.int64,
device=device,
) )
global_info = torch.empty( global_info = torch.empty(
(dp_size, attn_tp_size, 6), (dp_size, attn_tp_size, 6),
dtype=torch.int64, dtype=torch.int64,
device=device,
) )
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
global_info.flatten(), global_info.flatten(),
local_info, local_info,
group=tp_cpu_group, group=group,
) )
global_num_tokens = global_info[:, 0, 0].tolist() global_num_tokens = global_info[:, 0, 0].tolist()
can_cuda_graph = min(global_info[:, 0, 1].tolist()) can_cuda_graph = min(global_info[:, 0, 1].tolist())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment