Unverified Commit e879d8b7 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[Feature] Comprehensive Hybrid Parallelism Support (#6389)

parent 09988080
...@@ -71,6 +71,8 @@ from sglang.srt.utils import ( ...@@ -71,6 +71,8 @@ from sglang.srt.utils import (
configure_logger, configure_logger,
get_bool_env_var, get_bool_env_var,
kill_process_tree, kill_process_tree,
require_mlp_sync,
require_mlp_tp_gather,
set_gpu_proc_affinity, set_gpu_proc_affinity,
suppress_other_loggers, suppress_other_loggers,
) )
...@@ -243,7 +245,7 @@ def extend(reqs, model_runner): ...@@ -243,7 +245,7 @@ def extend(reqs, model_runner):
enable_custom_logit_processor=False, enable_custom_logit_processor=False,
) )
batch.prepare_for_extend() batch.prepare_for_extend()
_maybe_prepare_dp_attn_batch(batch, model_runner) _maybe_prepare_mlp_sync_batch(batch, model_runner)
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output, _ = model_runner.forward(forward_batch) logits_output, _ = model_runner.forward(forward_batch)
...@@ -255,7 +257,7 @@ def extend(reqs, model_runner): ...@@ -255,7 +257,7 @@ def extend(reqs, model_runner):
def decode(input_token_ids, batch, model_runner): def decode(input_token_ids, batch, model_runner):
batch.output_ids = input_token_ids batch.output_ids = input_token_ids
batch.prepare_for_decode() batch.prepare_for_decode()
_maybe_prepare_dp_attn_batch(batch, model_runner) _maybe_prepare_mlp_sync_batch(batch, model_runner)
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output, _ = model_runner.forward(forward_batch) logits_output, _ = model_runner.forward(forward_batch)
...@@ -263,18 +265,18 @@ def decode(input_token_ids, batch, model_runner): ...@@ -263,18 +265,18 @@ def decode(input_token_ids, batch, model_runner):
return next_token_ids, logits_output.next_token_logits return next_token_ids, logits_output.next_token_logits
def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner): def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
if model_runner.server_args.enable_dp_attention: if require_mlp_sync(model_runner.server_args):
Scheduler.prepare_dp_attn_batch_raw( Scheduler.prepare_mlp_sync_batch_raw(
batch, batch,
dp_size=model_runner.server_args.dp_size, dp_size=model_runner.server_args.dp_size,
attn_tp_size=1, attn_tp_size=1,
moe_dense_tp_size=model_runner.server_args.moe_dense_tp_size,
tp_cpu_group=model_runner.tp_group.cpu_group, tp_cpu_group=model_runner.tp_group.cpu_group,
get_idle_batch=None, get_idle_batch=None,
disable_cuda_graph=model_runner.server_args.disable_cuda_graph, disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
spec_algorithm=SpeculativeAlgorithm.NONE, spec_algorithm=SpeculativeAlgorithm.NONE,
speculative_num_draft_tokens=None, speculative_num_draft_tokens=None,
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
) )
......
...@@ -55,6 +55,7 @@ from sglang.srt.mem_cache.memory_pool import ( ...@@ -55,6 +55,7 @@ from sglang.srt.mem_cache.memory_pool import (
) )
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import require_mlp_sync
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -649,10 +650,7 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -649,10 +650,7 @@ class SchedulerDisaggregationDecodeMixin:
batch = self.get_next_disagg_decode_batch_to_run() batch = self.get_next_disagg_decode_batch_to_run()
self.cur_batch = batch self.cur_batch = batch
prepare_dp_attn_flag = ( prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
self.server_args.enable_dp_attention
or self.server_args.enable_sp_layernorm
)
if batch: if batch:
# Generate fake extend output. # Generate fake extend output.
...@@ -661,14 +659,14 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -661,14 +659,14 @@ class SchedulerDisaggregationDecodeMixin:
self.stream_output( self.stream_output(
batch.reqs, any(req.return_logprob for req in batch.reqs) batch.reqs, any(req.return_logprob for req in batch.reqs)
) )
if prepare_dp_attn_flag: if prepare_mlp_sync_flag:
self._prepare_idle_batch_and_run(None) self._prepare_idle_batch_and_run(None)
else: else:
if prepare_dp_attn_flag: if prepare_mlp_sync_flag:
self.prepare_dp_attn_batch(batch) self.prepare_mlp_sync_batch(batch)
result = self.run_batch(batch) result = self.run_batch(batch)
self.process_batch_result(batch, result) self.process_batch_result(batch, result)
elif prepare_dp_attn_flag: elif prepare_mlp_sync_flag:
batch, _ = self._prepare_idle_batch_and_run(None) batch, _ = self._prepare_idle_batch_and_run(None)
if batch is None and ( if batch is None and (
...@@ -699,10 +697,7 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -699,10 +697,7 @@ class SchedulerDisaggregationDecodeMixin:
self.cur_batch = batch self.cur_batch = batch
last_batch_in_queue = False last_batch_in_queue = False
prepare_dp_attn_flag = ( prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
self.server_args.enable_dp_attention
or self.server_args.enable_sp_layernorm
)
if batch: if batch:
# Generate fake extend output. # Generate fake extend output.
...@@ -711,7 +706,7 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -711,7 +706,7 @@ class SchedulerDisaggregationDecodeMixin:
self.stream_output( self.stream_output(
batch.reqs, any(req.return_logprob for req in batch.reqs) batch.reqs, any(req.return_logprob for req in batch.reqs)
) )
if prepare_dp_attn_flag: if prepare_mlp_sync_flag:
batch_, result = self._prepare_idle_batch_and_run( batch_, result = self._prepare_idle_batch_and_run(
None, delay_process=True None, delay_process=True
) )
...@@ -719,8 +714,8 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -719,8 +714,8 @@ class SchedulerDisaggregationDecodeMixin:
result_queue.append((batch_.copy(), result)) result_queue.append((batch_.copy(), result))
last_batch_in_queue = True last_batch_in_queue = True
else: else:
if prepare_dp_attn_flag: if prepare_mlp_sync_flag:
self.prepare_dp_attn_batch(batch) self.prepare_mlp_sync_batch(batch)
result = self.run_batch(batch) result = self.run_batch(batch)
result_queue.append((batch.copy(), result)) result_queue.append((batch.copy(), result))
...@@ -735,7 +730,7 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -735,7 +730,7 @@ class SchedulerDisaggregationDecodeMixin:
self.set_next_batch_sampling_info_done(tmp_batch) self.set_next_batch_sampling_info_done(tmp_batch)
last_batch_in_queue = True last_batch_in_queue = True
elif prepare_dp_attn_flag: elif prepare_mlp_sync_flag:
batch, result = self._prepare_idle_batch_and_run( batch, result = self._prepare_idle_batch_and_run(
None, delay_process=True None, delay_process=True
) )
...@@ -765,13 +760,13 @@ class SchedulerDisaggregationDecodeMixin: ...@@ -765,13 +760,13 @@ class SchedulerDisaggregationDecodeMixin:
self.last_batch = batch self.last_batch = batch
self.last_batch_in_queue = last_batch_in_queue self.last_batch_in_queue = last_batch_in_queue
def _prepare_idle_batch_and_run(self, batch, delay_process=False): def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False):
batch, _ = self.prepare_dp_attn_batch(batch) batch, _ = self.prepare_mlp_sync_batch(batch)
result = None result = None
if batch: if batch:
result = self.run_batch(batch) result = self.run_batch(batch)
if not delay_process: if not delay_process:
self.process_batch_result(batch, result) self.prepare_mlp_sync_batch(batch, result)
return batch, result return batch, result
def get_next_disagg_decode_batch_to_run( def get_next_disagg_decode_batch_to_run(
......
...@@ -45,6 +45,7 @@ from sglang.srt.disaggregation.utils import ( ...@@ -45,6 +45,7 @@ from sglang.srt.disaggregation.utils import (
) )
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import require_mlp_sync
if TYPE_CHECKING: if TYPE_CHECKING:
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
...@@ -274,12 +275,8 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -274,12 +275,8 @@ class SchedulerDisaggregationPrefillMixin:
self.process_prefill_chunk() self.process_prefill_chunk()
batch = self.get_new_batch_prefill() batch = self.get_new_batch_prefill()
# Handle DP attention if require_mlp_sync(self.server_args):
if ( batch, _ = self.prepare_mlp_sync_batch(batch)
self.server_args.enable_dp_attention
or self.server_args.enable_sp_layernorm
):
batch, _ = self.prepare_dp_attn_batch(batch)
self.cur_batch = batch self.cur_batch = batch
if batch: if batch:
...@@ -312,12 +309,8 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -312,12 +309,8 @@ class SchedulerDisaggregationPrefillMixin:
self.process_prefill_chunk() self.process_prefill_chunk()
batch = self.get_new_batch_prefill() batch = self.get_new_batch_prefill()
# Handle DP attention if require_mlp_sync(self.server_args):
if ( batch, _ = self.prepare_mlp_sync_batch(batch)
self.server_args.enable_dp_attention
or self.server_args.enable_sp_layernorm
):
batch, _ = self.prepare_dp_attn_batch(batch)
self.cur_batch = batch self.cur_batch = batch
if batch: if batch:
result = self.run_batch(batch) result = self.run_batch(batch)
......
...@@ -28,9 +28,9 @@ from sglang.srt.layers.dp_attention import ( ...@@ -28,9 +28,9 @@ from sglang.srt.layers.dp_attention import (
attn_tp_reduce_scatter, attn_tp_reduce_scatter,
dp_gather_partial, dp_gather_partial,
dp_scatter, dp_scatter,
get_attention_dp_size,
get_attention_tp_rank, get_attention_tp_rank,
get_attention_tp_size, get_attention_tp_size,
get_local_attention_dp_size,
) )
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...@@ -229,7 +229,7 @@ class CommunicateContext: ...@@ -229,7 +229,7 @@ class CommunicateContext:
process_group_sizes: Dict[ScatterMode, int] process_group_sizes: Dict[ScatterMode, int]
attn_tp_rank: int attn_tp_rank: int
attn_tp_size: int attn_tp_size: int
local_attn_dp_size: int attn_dp_size: int
tp_size: int tp_size: int
def is_same_group_size(self, a: ScatterMode, b: ScatterMode): def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
...@@ -239,7 +239,7 @@ class CommunicateContext: ...@@ -239,7 +239,7 @@ class CommunicateContext:
def init_new(cls): def init_new(cls):
attn_tp_rank = get_attention_tp_rank() attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size() attn_tp_size = get_attention_tp_size()
local_attn_dp_size = get_local_attention_dp_size() attn_dp_size = get_attention_dp_size()
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
process_group_sizes = { process_group_sizes = {
ScatterMode.SCATTERED: 1, ScatterMode.SCATTERED: 1,
...@@ -251,7 +251,7 @@ class CommunicateContext: ...@@ -251,7 +251,7 @@ class CommunicateContext:
process_group_sizes=process_group_sizes, process_group_sizes=process_group_sizes,
attn_tp_rank=attn_tp_rank, attn_tp_rank=attn_tp_rank,
attn_tp_size=attn_tp_size, attn_tp_size=attn_tp_size,
local_attn_dp_size=local_attn_dp_size, attn_dp_size=attn_dp_size,
tp_size=tp_size, tp_size=tp_size,
) )
...@@ -385,7 +385,7 @@ class CommunicateWithAllReduceAndLayerNormFn: ...@@ -385,7 +385,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
attn_tp_all_gather( attn_tp_all_gather(
list(residual.tensor_split(context.attn_tp_size)), local_residual list(residual.tensor_split(context.attn_tp_size)), local_residual
) )
if context.local_attn_dp_size != 1: if context.attn_dp_size != 1:
if context.attn_tp_rank == 0: if context.attn_tp_rank == 0:
hidden_states += residual hidden_states += residual
hidden_states, local_hidden_states = ( hidden_states, local_hidden_states = (
......
...@@ -165,7 +165,8 @@ def disable_dp_size(): ...@@ -165,7 +165,8 @@ def disable_dp_size():
def get_dp_local_info(forward_batch: ForwardBatch): def get_dp_local_info(forward_batch: ForwardBatch):
dp_rank = get_local_attention_dp_rank() # `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
dp_rank = get_attention_dp_rank()
if forward_batch.dp_local_start_pos is None: if forward_batch.dp_local_start_pos is None:
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0) cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
......
...@@ -30,9 +30,9 @@ from sglang.srt.layers.dp_attention import ( ...@@ -30,9 +30,9 @@ from sglang.srt.layers.dp_attention import (
attn_tp_all_gather, attn_tp_all_gather,
dp_gather_replicate, dp_gather_replicate,
dp_scatter, dp_scatter,
get_attention_dp_rank,
get_attention_dp_size, get_attention_dp_size,
get_attention_tp_size, get_attention_tp_size,
get_local_attention_dp_rank,
get_local_attention_dp_size, get_local_attention_dp_size,
) )
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
...@@ -171,7 +171,7 @@ class LogitsMetadata: ...@@ -171,7 +171,7 @@ class LogitsMetadata:
return return
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0) cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
dp_rank = get_local_attention_dp_rank() dp_rank = get_attention_dp_rank()
if dp_rank == 0: if dp_rank == 0:
dp_local_start_pos = torch.zeros_like( dp_local_start_pos = torch.zeros_like(
self.global_num_tokens_for_logprob_gpu[0] self.global_num_tokens_for_logprob_gpu[0]
......
...@@ -149,6 +149,8 @@ from sglang.srt.utils import ( ...@@ -149,6 +149,8 @@ from sglang.srt.utils import (
kill_itself_when_parent_died, kill_itself_when_parent_died,
point_to_point_pyobj, point_to_point_pyobj,
pyspy_dump_schedulers, pyspy_dump_schedulers,
require_mlp_sync,
require_mlp_tp_gather,
set_gpu_proc_affinity, set_gpu_proc_affinity,
set_random_seed, set_random_seed,
suppress_other_loggers, suppress_other_loggers,
...@@ -1471,9 +1473,8 @@ class Scheduler( ...@@ -1471,9 +1473,8 @@ class Scheduler(
else: else:
ret = None ret = None
# Handle DP attention if require_mlp_sync(self.server_args):
if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm: ret, _ = self.prepare_mlp_sync_batch(ret)
ret, _ = self.prepare_dp_attn_batch(ret)
return ret return ret
...@@ -1775,12 +1776,11 @@ class Scheduler( ...@@ -1775,12 +1776,11 @@ class Scheduler(
self.return_health_check_ct -= 1 self.return_health_check_ct -= 1
self.send_to_tokenizer.send_pyobj(HealthCheckOutput()) self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch): def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
return self.prepare_dp_attn_batch_raw( return self.prepare_mlp_sync_batch_raw(
local_batch, local_batch,
dp_size=self.server_args.dp_size, dp_size=self.server_args.dp_size,
attn_tp_size=self.attn_tp_size, attn_tp_size=self.attn_tp_size,
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
tp_cpu_group=self.tp_cpu_group, tp_cpu_group=self.tp_cpu_group,
get_idle_batch=self.get_idle_batch, get_idle_batch=self.get_idle_batch,
disable_cuda_graph=self.server_args.disable_cuda_graph, disable_cuda_graph=self.server_args.disable_cuda_graph,
...@@ -1789,14 +1789,14 @@ class Scheduler( ...@@ -1789,14 +1789,14 @@ class Scheduler(
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap, enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
enable_deepep_moe=self.server_args.enable_deepep_moe, enable_deepep_moe=self.server_args.enable_deepep_moe,
deepep_mode=DeepEPMode[self.server_args.deepep_mode], deepep_mode=DeepEPMode[self.server_args.deepep_mode],
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
) )
@staticmethod @staticmethod
def prepare_dp_attn_batch_raw( def prepare_mlp_sync_batch_raw(
local_batch: ScheduleBatch, local_batch: ScheduleBatch,
dp_size, dp_size,
attn_tp_size: int, attn_tp_size: int,
moe_dense_tp_size: Optional[int],
tp_cpu_group, tp_cpu_group,
get_idle_batch, get_idle_batch,
disable_cuda_graph: bool, disable_cuda_graph: bool,
...@@ -1805,6 +1805,7 @@ class Scheduler( ...@@ -1805,6 +1805,7 @@ class Scheduler(
enable_two_batch_overlap: bool, enable_two_batch_overlap: bool,
enable_deepep_moe: bool, enable_deepep_moe: bool,
deepep_mode: DeepEPMode, deepep_mode: DeepEPMode,
require_mlp_tp_gather: bool,
): ):
# Check if other DP workers have running batches # Check if other DP workers have running batches
if local_batch is None: if local_batch is None:
...@@ -1879,7 +1880,7 @@ class Scheduler( ...@@ -1879,7 +1880,7 @@ class Scheduler(
if local_batch is not None: if local_batch is not None:
# TODO: handle the case when moe_dense_tp_size != 1 # TODO: handle the case when moe_dense_tp_size != 1
if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]: if not require_mlp_tp_gather:
local_batch.global_num_tokens = [num_tokens] local_batch.global_num_tokens = [num_tokens]
local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob] local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
else: else:
......
...@@ -46,6 +46,9 @@ from sglang.srt.utils import ( ...@@ -46,6 +46,9 @@ from sglang.srt.utils import (
get_available_gpu_memory, get_available_gpu_memory,
get_device_memory_capacity, get_device_memory_capacity,
rank0_log, rank0_log,
require_attn_tp_gather,
require_gathered_buffer,
require_mlp_tp_gather,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -207,8 +210,9 @@ class CudaGraphRunner: ...@@ -207,8 +210,9 @@ class CudaGraphRunner:
self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = model_runner.server_args.enable_dp_attention self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
self.enable_two_batch_overlap = ( self.enable_two_batch_overlap = (
model_runner.server_args.enable_two_batch_overlap model_runner.server_args.enable_two_batch_overlap
) )
...@@ -299,18 +303,28 @@ class CudaGraphRunner: ...@@ -299,18 +303,28 @@ class CudaGraphRunner:
else: else:
self.encoder_lens = None self.encoder_lens = None
if self.enable_dp_attention or self.enable_sp_layernorm: if self.require_gathered_buffer:
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer if self.require_mlp_tp_gather:
self.gathered_buffer = torch.zeros( self.gathered_buffer = torch.zeros(
( (
self.max_bs * self.dp_size * self.num_tokens_per_bs, self.max_bs * self.dp_size * self.num_tokens_per_bs,
self.model_runner.model_config.hidden_size, self.model_runner.model_config.hidden_size,
), ),
dtype=self.model_runner.dtype, dtype=self.model_runner.dtype,
) )
self.global_num_tokens_gpu = torch.zeros( self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32 (self.dp_size,), dtype=torch.int32
) )
else:
assert self.require_attn_tp_gather
self.gathered_buffer = torch.zeros(
(
self.max_bs * self.num_tokens_per_bs,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
# Capture # Capture
try: try:
...@@ -322,7 +336,7 @@ class CudaGraphRunner: ...@@ -322,7 +336,7 @@ class CudaGraphRunner:
) )
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention or self.enable_sp_layernorm: if self.require_mlp_tp_gather:
total_batch_size = ( total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle() if self.model_runner.spec_algorithm.is_eagle()
...@@ -459,7 +473,7 @@ class CudaGraphRunner: ...@@ -459,7 +473,7 @@ class CudaGraphRunner:
{k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()} {k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
) )
if self.enable_dp_attention or self.enable_sp_layernorm: if self.require_mlp_tp_gather:
self.global_num_tokens_gpu.copy_( self.global_num_tokens_gpu.copy_(
torch.tensor( torch.tensor(
[ [
...@@ -472,6 +486,16 @@ class CudaGraphRunner: ...@@ -472,6 +486,16 @@ class CudaGraphRunner:
) )
global_num_tokens = self.global_num_tokens_gpu global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens] gathered_buffer = self.gathered_buffer[:num_tokens]
elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[num_tokens],
dtype=torch.int32,
device=input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
else: else:
global_num_tokens = None global_num_tokens = None
gathered_buffer = None gathered_buffer = None
...@@ -607,7 +631,7 @@ class CudaGraphRunner: ...@@ -607,7 +631,7 @@ class CudaGraphRunner:
raw_num_token = raw_bs * self.num_tokens_per_bs raw_num_token = raw_bs * self.num_tokens_per_bs
# Pad # Pad
if self.enable_dp_attention or self.enable_sp_layernorm: if self.require_mlp_tp_gather:
total_batch_size = ( total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle() if self.model_runner.spec_algorithm.is_eagle()
...@@ -642,7 +666,7 @@ class CudaGraphRunner: ...@@ -642,7 +666,7 @@ class CudaGraphRunner:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None: if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if self.enable_dp_attention or self.enable_sp_layernorm: if self.require_gathered_buffer:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu) self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
if enable_num_token_non_padded(self.model_runner.server_args): if enable_num_token_non_padded(self.model_runner.server_args):
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded) self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
......
...@@ -1621,8 +1621,6 @@ class DeepseekV2Model(nn.Module): ...@@ -1621,8 +1621,6 @@ class DeepseekV2Model(nn.Module):
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.dp_size = get_local_attention_dp_size()
def get_input_embeddings(self) -> torch.Tensor: def get_input_embeddings(self) -> torch.Tensor:
return self.embed_tokens return self.embed_tokens
...@@ -1706,7 +1704,6 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1706,7 +1704,6 @@ class DeepseekV2ForCausalLM(nn.Module):
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.dp_size = get_local_attention_dp_size()
self._routed_experts_weights_of_layer = LazyValue( self._routed_experts_weights_of_layer = LazyValue(
lambda: { lambda: {
......
...@@ -387,7 +387,6 @@ class ServerArgs: ...@@ -387,7 +387,6 @@ class ServerArgs:
), "Please enable dp attention when setting enable_dp_attention. " ), "Please enable dp attention when setting enable_dp_attention. "
# DeepEP MoE # DeepEP MoE
self.enable_sp_layernorm = False
if self.enable_deepep_moe: if self.enable_deepep_moe:
if self.deepep_mode == "auto": if self.deepep_mode == "auto":
assert ( assert (
...@@ -397,9 +396,6 @@ class ServerArgs: ...@@ -397,9 +396,6 @@ class ServerArgs:
logger.warning("Cuda graph is disabled because deepep_mode=`normal`") logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
self.disable_cuda_graph = True self.disable_cuda_graph = True
self.ep_size = self.tp_size self.ep_size = self.tp_size
self.enable_sp_layernorm = (
self.dp_size < self.tp_size if self.enable_dp_attention else True
)
logger.warning( logger.warning(
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]." f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
) )
......
...@@ -20,6 +20,11 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -20,6 +20,11 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardMode, ForwardMode,
) )
from sglang.srt.speculative.eagle_utils import EagleDraftInput from sglang.srt.speculative.eagle_utils import EagleDraftInput
from sglang.srt.utils import (
require_attn_tp_gather,
require_gathered_buffer,
require_mlp_tp_gather,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.speculative.eagle_worker import EAGLEWorker from sglang.srt.speculative.eagle_worker import EAGLEWorker
...@@ -39,8 +44,9 @@ class EAGLEDraftCudaGraphRunner: ...@@ -39,8 +44,9 @@ class EAGLEDraftCudaGraphRunner:
self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = model_runner.server_args.enable_dp_attention self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
self.dp_size = self.model_runner.dp_size self.dp_size = self.model_runner.dp_size
self.tp_size = self.model_runner.tp_size self.tp_size = self.model_runner.tp_size
self.topk = model_runner.server_args.speculative_eagle_topk self.topk = model_runner.server_args.speculative_eagle_topk
...@@ -88,8 +94,7 @@ class EAGLEDraftCudaGraphRunner: ...@@ -88,8 +94,7 @@ class EAGLEDraftCudaGraphRunner:
dtype=self.model_runner.dtype, dtype=self.model_runner.dtype,
) )
if self.enable_dp_attention or self.enable_sp_layernorm: if self.require_gathered_buffer:
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
self.gathered_buffer = torch.zeros( self.gathered_buffer = torch.zeros(
( (
self.max_num_token, self.max_num_token,
...@@ -97,12 +102,19 @@ class EAGLEDraftCudaGraphRunner: ...@@ -97,12 +102,19 @@ class EAGLEDraftCudaGraphRunner:
), ),
dtype=self.model_runner.dtype, dtype=self.model_runner.dtype,
) )
self.global_num_tokens_gpu = torch.zeros( if self.require_mlp_tp_gather:
(self.dp_size,), dtype=torch.int32 self.global_num_tokens_gpu = torch.zeros(
) (self.dp_size,), dtype=torch.int32
self.global_num_tokens_for_logprob_gpu = torch.zeros( )
(self.dp_size,), dtype=torch.int32 self.global_num_tokens_for_logprob_gpu = torch.zeros(
) (self.dp_size,), dtype=torch.int32
)
else:
assert self.require_attn_tp_gather
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(1,), dtype=torch.int32
)
# Capture # Capture
try: try:
...@@ -114,8 +126,7 @@ class EAGLEDraftCudaGraphRunner: ...@@ -114,8 +126,7 @@ class EAGLEDraftCudaGraphRunner:
) )
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention: if self.require_mlp_tp_gather:
# TODO(ch-wan): check --moe-dense-tp-size and --enable-dp-lm-head
if not forward_batch.can_run_dp_cuda_graph: if not forward_batch.can_run_dp_cuda_graph:
return False return False
total_batch_size = ( total_batch_size = (
...@@ -153,7 +164,7 @@ class EAGLEDraftCudaGraphRunner: ...@@ -153,7 +164,7 @@ class EAGLEDraftCudaGraphRunner:
topk_index = self.topk_index[:num_seqs] topk_index = self.topk_index[:num_seqs]
hidden_states = self.hidden_states[:num_seqs] hidden_states = self.hidden_states[:num_seqs]
if self.enable_dp_attention or self.enable_sp_layernorm: if self.require_mlp_tp_gather:
self.global_num_tokens_gpu.copy_( self.global_num_tokens_gpu.copy_(
torch.tensor( torch.tensor(
[ [
...@@ -177,6 +188,24 @@ class EAGLEDraftCudaGraphRunner: ...@@ -177,6 +188,24 @@ class EAGLEDraftCudaGraphRunner:
global_num_tokens = self.global_num_tokens_gpu global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens] gathered_buffer = self.gathered_buffer[:num_tokens]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[num_tokens],
dtype=torch.int32,
device=self.input_ids.device,
)
)
self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor(
[num_tokens],
dtype=torch.int32,
device=self.input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
else: else:
global_num_tokens = None global_num_tokens = None
gathered_buffer = None gathered_buffer = None
...@@ -259,7 +288,7 @@ class EAGLEDraftCudaGraphRunner: ...@@ -259,7 +288,7 @@ class EAGLEDraftCudaGraphRunner:
raw_num_token = raw_bs * self.num_tokens_per_bs raw_num_token = raw_bs * self.num_tokens_per_bs
# Pad # Pad
if self.enable_dp_attention or self.enable_sp_layernorm: if self.require_mlp_tp_gather:
total_batch_size = ( total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle() if self.model_runner.spec_algorithm.is_eagle()
...@@ -286,7 +315,7 @@ class EAGLEDraftCudaGraphRunner: ...@@ -286,7 +315,7 @@ class EAGLEDraftCudaGraphRunner:
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index) self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states) self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
if self.enable_dp_attention or self.enable_sp_layernorm: if self.require_gathered_buffer:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu) self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
self.global_num_tokens_for_logprob_gpu.copy_( self.global_num_tokens_for_logprob_gpu.copy_(
forward_batch.global_num_tokens_for_logprob_gpu forward_batch.global_num_tokens_for_logprob_gpu
......
...@@ -21,6 +21,11 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -21,6 +21,11 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardMode, ForwardMode,
) )
from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
from sglang.srt.utils import (
require_attn_tp_gather,
require_gathered_buffer,
require_mlp_tp_gather,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.speculative.eagle_worker import EAGLEWorker from sglang.srt.speculative.eagle_worker import EAGLEWorker
...@@ -35,8 +40,9 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -35,8 +40,9 @@ class EAGLEDraftExtendCudaGraphRunner:
self.output_buffers = {} self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.enable_dp_attention = model_runner.server_args.enable_dp_attention self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
self.tp_size = self.model_runner.tp_size self.tp_size = self.model_runner.tp_size
self.dp_size = model_runner.server_args.dp_size self.dp_size = model_runner.server_args.dp_size
self.speculative_num_steps = model_runner.server_args.speculative_num_steps self.speculative_num_steps = model_runner.server_args.speculative_num_steps
...@@ -92,7 +98,7 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -92,7 +98,7 @@ class EAGLEDraftExtendCudaGraphRunner:
(self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32 (self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32
) )
if self.enable_dp_attention or self.enable_sp_layernorm: if self.require_gathered_buffer:
self.gathered_buffer = torch.zeros( self.gathered_buffer = torch.zeros(
( (
self.max_num_token, self.max_num_token,
...@@ -100,13 +106,19 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -100,13 +106,19 @@ class EAGLEDraftExtendCudaGraphRunner:
), ),
dtype=self.model_runner.dtype, dtype=self.model_runner.dtype,
) )
self.global_num_tokens_gpu = torch.zeros( if self.require_mlp_tp_gather:
(self.dp_size,), dtype=torch.int32 self.global_num_tokens_gpu = torch.zeros(
) (self.dp_size,), dtype=torch.int32
self.global_num_tokens_for_logprob_gpu = torch.zeros( )
(self.dp_size,), dtype=torch.int32 self.global_num_tokens_for_logprob_gpu = torch.zeros(
) (self.dp_size,), dtype=torch.int32
)
else:
assert self.require_attn_tp_gather
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(1,), dtype=torch.int32
)
# Capture # Capture
try: try:
with model_capture_mode(): with model_capture_mode():
...@@ -117,7 +129,7 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -117,7 +129,7 @@ class EAGLEDraftExtendCudaGraphRunner:
) )
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention or self.enable_sp_layernorm: if self.require_mlp_tp_gather:
if not forward_batch.can_run_dp_cuda_graph: if not forward_batch.can_run_dp_cuda_graph:
return False return False
total_batch_size = ( total_batch_size = (
...@@ -160,7 +172,7 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -160,7 +172,7 @@ class EAGLEDraftExtendCudaGraphRunner:
positions = self.positions[:num_tokens] positions = self.positions[:num_tokens]
hidden_states = self.hidden_states[:num_tokens] hidden_states = self.hidden_states[:num_tokens]
if self.enable_dp_attention or self.enable_sp_layernorm: if self.require_mlp_tp_gather:
self.global_num_tokens_gpu.copy_( self.global_num_tokens_gpu.copy_(
torch.tensor( torch.tensor(
[ [
...@@ -184,6 +196,24 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -184,6 +196,24 @@ class EAGLEDraftExtendCudaGraphRunner:
global_num_tokens = self.global_num_tokens_gpu global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens] gathered_buffer = self.gathered_buffer[:num_tokens]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[num_tokens],
dtype=torch.int32,
device=self.input_ids.device,
)
)
self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor(
[num_tokens],
dtype=torch.int32,
device=self.input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
else: else:
global_num_tokens = None global_num_tokens = None
gathered_buffer = None gathered_buffer = None
...@@ -270,7 +300,7 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -270,7 +300,7 @@ class EAGLEDraftExtendCudaGraphRunner:
# in the batch, which will not be counted as num_seqs # in the batch, which will not be counted as num_seqs
raw_bs = forward_batch.batch_size raw_bs = forward_batch.batch_size
num_tokens = forward_batch.input_ids.shape[0] num_tokens = forward_batch.input_ids.shape[0]
if self.enable_dp_attention or self.enable_sp_layernorm: if self.require_mlp_tp_gather:
total_batch_size = ( total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle() if self.model_runner.spec_algorithm.is_eagle()
...@@ -299,7 +329,7 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -299,7 +329,7 @@ class EAGLEDraftExtendCudaGraphRunner:
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length) self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
if self.enable_dp_attention or self.enable_sp_layernorm: if self.require_gathered_buffer:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu) self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
self.global_num_tokens_for_logprob_gpu.copy_( self.global_num_tokens_for_logprob_gpu.copy_(
forward_batch.global_num_tokens_for_logprob_gpu forward_batch.global_num_tokens_for_logprob_gpu
......
...@@ -2303,6 +2303,51 @@ class Withable(Generic[T]): ...@@ -2303,6 +2303,51 @@ class Withable(Generic[T]):
self._value = None self._value = None
def require_mlp_tp_gather(server_args):
"""
Check if the input of MLP is obtained by all-gather rather than all-reduce. This only happens when each MLP TP group contains multiple attention DP groups.
"""
if server_args.enable_dp_attention:
assert server_args.dp_size > 1, "dp_size must be greater than 1"
if (
server_args.moe_dense_tp_size is None
): # TODO(ch-wan): some MoE models do not have dense layers
return True
elif not server_args.enable_dp_lm_head:
return True
elif not server_args.enable_deepep_moe:
return True
else:
return (
server_args.moe_dense_tp_size
> server_args.tp_size // server_args.dp_size
)
else:
return False
def require_attn_tp_gather(server_args):
"""
Check if the input of attention is scattered.
"""
assert server_args.moe_dense_tp_size in [1, None]
if server_args.enable_deepep_moe or server_args.moe_dense_tp_size == 1:
if server_args.enable_dp_attention:
return server_args.dp_size < server_args.tp_size
else:
return True
else:
return False
def require_gathered_buffer(server_args):
return require_mlp_tp_gather(server_args) or require_attn_tp_gather(server_args)
def require_mlp_sync(server_args):
return server_args.enable_dp_attention or require_gathered_buffer(server_args)
def merge_bias_tensor( def merge_bias_tensor(
lhs: Optional[torch.Tensor], lhs: Optional[torch.Tensor],
rhs: Optional[torch.Tensor], rhs: Optional[torch.Tensor],
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment