Unverified Commit e879d8b7 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[Feature] Comprehensive Hybrid Parallelism Support (#6389)

parent 09988080
......@@ -71,6 +71,8 @@ from sglang.srt.utils import (
configure_logger,
get_bool_env_var,
kill_process_tree,
require_mlp_sync,
require_mlp_tp_gather,
set_gpu_proc_affinity,
suppress_other_loggers,
)
......@@ -243,7 +245,7 @@ def extend(reqs, model_runner):
enable_custom_logit_processor=False,
)
batch.prepare_for_extend()
_maybe_prepare_dp_attn_batch(batch, model_runner)
_maybe_prepare_mlp_sync_batch(batch, model_runner)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output, _ = model_runner.forward(forward_batch)
......@@ -255,7 +257,7 @@ def extend(reqs, model_runner):
def decode(input_token_ids, batch, model_runner):
batch.output_ids = input_token_ids
batch.prepare_for_decode()
_maybe_prepare_dp_attn_batch(batch, model_runner)
_maybe_prepare_mlp_sync_batch(batch, model_runner)
model_worker_batch = batch.get_model_worker_batch()
forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
logits_output, _ = model_runner.forward(forward_batch)
......@@ -263,18 +265,18 @@ def decode(input_token_ids, batch, model_runner):
return next_token_ids, logits_output.next_token_logits
def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner):
if model_runner.server_args.enable_dp_attention:
Scheduler.prepare_dp_attn_batch_raw(
def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
if require_mlp_sync(model_runner.server_args):
Scheduler.prepare_mlp_sync_batch_raw(
batch,
dp_size=model_runner.server_args.dp_size,
attn_tp_size=1,
moe_dense_tp_size=model_runner.server_args.moe_dense_tp_size,
tp_cpu_group=model_runner.tp_group.cpu_group,
get_idle_batch=None,
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
spec_algorithm=SpeculativeAlgorithm.NONE,
speculative_num_draft_tokens=None,
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
)
......
......@@ -55,6 +55,7 @@ from sglang.srt.mem_cache.memory_pool import (
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import require_mlp_sync
logger = logging.getLogger(__name__)
......@@ -649,10 +650,7 @@ class SchedulerDisaggregationDecodeMixin:
batch = self.get_next_disagg_decode_batch_to_run()
self.cur_batch = batch
prepare_dp_attn_flag = (
self.server_args.enable_dp_attention
or self.server_args.enable_sp_layernorm
)
prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
if batch:
# Generate fake extend output.
......@@ -661,14 +659,14 @@ class SchedulerDisaggregationDecodeMixin:
self.stream_output(
batch.reqs, any(req.return_logprob for req in batch.reqs)
)
if prepare_dp_attn_flag:
if prepare_mlp_sync_flag:
self._prepare_idle_batch_and_run(None)
else:
if prepare_dp_attn_flag:
self.prepare_dp_attn_batch(batch)
if prepare_mlp_sync_flag:
self.prepare_mlp_sync_batch(batch)
result = self.run_batch(batch)
self.process_batch_result(batch, result)
elif prepare_dp_attn_flag:
elif prepare_mlp_sync_flag:
batch, _ = self._prepare_idle_batch_and_run(None)
if batch is None and (
......@@ -699,10 +697,7 @@ class SchedulerDisaggregationDecodeMixin:
self.cur_batch = batch
last_batch_in_queue = False
prepare_dp_attn_flag = (
self.server_args.enable_dp_attention
or self.server_args.enable_sp_layernorm
)
prepare_mlp_sync_flag = require_mlp_sync(self.server_args)
if batch:
# Generate fake extend output.
......@@ -711,7 +706,7 @@ class SchedulerDisaggregationDecodeMixin:
self.stream_output(
batch.reqs, any(req.return_logprob for req in batch.reqs)
)
if prepare_dp_attn_flag:
if prepare_mlp_sync_flag:
batch_, result = self._prepare_idle_batch_and_run(
None, delay_process=True
)
......@@ -719,8 +714,8 @@ class SchedulerDisaggregationDecodeMixin:
result_queue.append((batch_.copy(), result))
last_batch_in_queue = True
else:
if prepare_dp_attn_flag:
self.prepare_dp_attn_batch(batch)
if prepare_mlp_sync_flag:
self.prepare_mlp_sync_batch(batch)
result = self.run_batch(batch)
result_queue.append((batch.copy(), result))
......@@ -735,7 +730,7 @@ class SchedulerDisaggregationDecodeMixin:
self.set_next_batch_sampling_info_done(tmp_batch)
last_batch_in_queue = True
elif prepare_dp_attn_flag:
elif prepare_mlp_sync_flag:
batch, result = self._prepare_idle_batch_and_run(
None, delay_process=True
)
......@@ -765,13 +760,13 @@ class SchedulerDisaggregationDecodeMixin:
self.last_batch = batch
self.last_batch_in_queue = last_batch_in_queue
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
batch, _ = self.prepare_dp_attn_batch(batch)
def _prepare_idle_batch_and_run(self: Scheduler, batch, delay_process=False):
batch, _ = self.prepare_mlp_sync_batch(batch)
result = None
if batch:
result = self.run_batch(batch)
if not delay_process:
self.process_batch_result(batch, result)
self.prepare_mlp_sync_batch(batch, result)
return batch, result
def get_next_disagg_decode_batch_to_run(
......
......@@ -45,6 +45,7 @@ from sglang.srt.disaggregation.utils import (
)
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import require_mlp_sync
if TYPE_CHECKING:
from torch.distributed import ProcessGroup
......@@ -274,12 +275,8 @@ class SchedulerDisaggregationPrefillMixin:
self.process_prefill_chunk()
batch = self.get_new_batch_prefill()
# Handle DP attention
if (
self.server_args.enable_dp_attention
or self.server_args.enable_sp_layernorm
):
batch, _ = self.prepare_dp_attn_batch(batch)
if require_mlp_sync(self.server_args):
batch, _ = self.prepare_mlp_sync_batch(batch)
self.cur_batch = batch
if batch:
......@@ -312,12 +309,8 @@ class SchedulerDisaggregationPrefillMixin:
self.process_prefill_chunk()
batch = self.get_new_batch_prefill()
# Handle DP attention
if (
self.server_args.enable_dp_attention
or self.server_args.enable_sp_layernorm
):
batch, _ = self.prepare_dp_attn_batch(batch)
if require_mlp_sync(self.server_args):
batch, _ = self.prepare_mlp_sync_batch(batch)
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
......
......@@ -28,9 +28,9 @@ from sglang.srt.layers.dp_attention import (
attn_tp_reduce_scatter,
dp_gather_partial,
dp_scatter,
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
......@@ -229,7 +229,7 @@ class CommunicateContext:
process_group_sizes: Dict[ScatterMode, int]
attn_tp_rank: int
attn_tp_size: int
local_attn_dp_size: int
attn_dp_size: int
tp_size: int
def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
......@@ -239,7 +239,7 @@ class CommunicateContext:
def init_new(cls):
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
local_attn_dp_size = get_local_attention_dp_size()
attn_dp_size = get_attention_dp_size()
tp_size = get_tensor_model_parallel_world_size()
process_group_sizes = {
ScatterMode.SCATTERED: 1,
......@@ -251,7 +251,7 @@ class CommunicateContext:
process_group_sizes=process_group_sizes,
attn_tp_rank=attn_tp_rank,
attn_tp_size=attn_tp_size,
local_attn_dp_size=local_attn_dp_size,
attn_dp_size=attn_dp_size,
tp_size=tp_size,
)
......@@ -385,7 +385,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
attn_tp_all_gather(
list(residual.tensor_split(context.attn_tp_size)), local_residual
)
if context.local_attn_dp_size != 1:
if context.attn_dp_size != 1:
if context.attn_tp_rank == 0:
hidden_states += residual
hidden_states, local_hidden_states = (
......
......@@ -165,7 +165,8 @@ def disable_dp_size():
def get_dp_local_info(forward_batch: ForwardBatch):
dp_rank = get_local_attention_dp_rank()
# `get_dp_local_info` is only called in global DP gather and scatter. We use global DP rank here.
dp_rank = get_attention_dp_rank()
if forward_batch.dp_local_start_pos is None:
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
......
......@@ -30,9 +30,9 @@ from sglang.srt.layers.dp_attention import (
attn_tp_all_gather,
dp_gather_replicate,
dp_scatter,
get_attention_dp_rank,
get_attention_dp_size,
get_attention_tp_size,
get_local_attention_dp_rank,
get_local_attention_dp_size,
)
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
......@@ -171,7 +171,7 @@ class LogitsMetadata:
return
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
dp_rank = get_local_attention_dp_rank()
dp_rank = get_attention_dp_rank()
if dp_rank == 0:
dp_local_start_pos = torch.zeros_like(
self.global_num_tokens_for_logprob_gpu[0]
......
......@@ -149,6 +149,8 @@ from sglang.srt.utils import (
kill_itself_when_parent_died,
point_to_point_pyobj,
pyspy_dump_schedulers,
require_mlp_sync,
require_mlp_tp_gather,
set_gpu_proc_affinity,
set_random_seed,
suppress_other_loggers,
......@@ -1471,9 +1473,8 @@ class Scheduler(
else:
ret = None
# Handle DP attention
if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
ret, _ = self.prepare_dp_attn_batch(ret)
if require_mlp_sync(self.server_args):
ret, _ = self.prepare_mlp_sync_batch(ret)
return ret
......@@ -1775,12 +1776,11 @@ class Scheduler(
self.return_health_check_ct -= 1
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
return self.prepare_dp_attn_batch_raw(
def prepare_mlp_sync_batch(self, local_batch: ScheduleBatch):
return self.prepare_mlp_sync_batch_raw(
local_batch,
dp_size=self.server_args.dp_size,
attn_tp_size=self.attn_tp_size,
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
tp_cpu_group=self.tp_cpu_group,
get_idle_batch=self.get_idle_batch,
disable_cuda_graph=self.server_args.disable_cuda_graph,
......@@ -1789,14 +1789,14 @@ class Scheduler(
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
enable_deepep_moe=self.server_args.enable_deepep_moe,
deepep_mode=DeepEPMode[self.server_args.deepep_mode],
require_mlp_tp_gather=require_mlp_tp_gather(self.server_args),
)
@staticmethod
def prepare_dp_attn_batch_raw(
def prepare_mlp_sync_batch_raw(
local_batch: ScheduleBatch,
dp_size,
attn_tp_size: int,
moe_dense_tp_size: Optional[int],
tp_cpu_group,
get_idle_batch,
disable_cuda_graph: bool,
......@@ -1805,6 +1805,7 @@ class Scheduler(
enable_two_batch_overlap: bool,
enable_deepep_moe: bool,
deepep_mode: DeepEPMode,
require_mlp_tp_gather: bool,
):
# Check if other DP workers have running batches
if local_batch is None:
......@@ -1879,7 +1880,7 @@ class Scheduler(
if local_batch is not None:
# TODO: handle the case when moe_dense_tp_size != 1
if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
if not require_mlp_tp_gather:
local_batch.global_num_tokens = [num_tokens]
local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
else:
......
......@@ -46,6 +46,9 @@ from sglang.srt.utils import (
get_available_gpu_memory,
get_device_memory_capacity,
rank0_log,
require_attn_tp_gather,
require_gathered_buffer,
require_mlp_tp_gather,
)
logger = logging.getLogger(__name__)
......@@ -207,8 +210,9 @@ class CudaGraphRunner:
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
self.enable_two_batch_overlap = (
model_runner.server_args.enable_two_batch_overlap
)
......@@ -299,18 +303,28 @@ class CudaGraphRunner:
else:
self.encoder_lens = None
if self.enable_dp_attention or self.enable_sp_layernorm:
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
self.gathered_buffer = torch.zeros(
(
self.max_bs * self.dp_size * self.num_tokens_per_bs,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
if self.require_gathered_buffer:
if self.require_mlp_tp_gather:
self.gathered_buffer = torch.zeros(
(
self.max_bs * self.dp_size * self.num_tokens_per_bs,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
else:
assert self.require_attn_tp_gather
self.gathered_buffer = torch.zeros(
(
self.max_bs * self.num_tokens_per_bs,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
)
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
# Capture
try:
......@@ -322,7 +336,7 @@ class CudaGraphRunner:
)
def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.require_mlp_tp_gather:
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
......@@ -459,7 +473,7 @@ class CudaGraphRunner:
{k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
)
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.require_mlp_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[
......@@ -472,6 +486,16 @@ class CudaGraphRunner:
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[num_tokens],
dtype=torch.int32,
device=input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
else:
global_num_tokens = None
gathered_buffer = None
......@@ -607,7 +631,7 @@ class CudaGraphRunner:
raw_num_token = raw_bs * self.num_tokens_per_bs
# Pad
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.require_mlp_tp_gather:
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) / self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
......@@ -642,7 +666,7 @@ class CudaGraphRunner:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.require_gathered_buffer:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
if enable_num_token_non_padded(self.model_runner.server_args):
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
......
......@@ -1621,8 +1621,6 @@ class DeepseekV2Model(nn.Module):
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.dp_size = get_local_attention_dp_size()
def get_input_embeddings(self) -> torch.Tensor:
return self.embed_tokens
......@@ -1706,7 +1704,6 @@ class DeepseekV2ForCausalLM(nn.Module):
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
self.dp_size = get_local_attention_dp_size()
self._routed_experts_weights_of_layer = LazyValue(
lambda: {
......
......@@ -387,7 +387,6 @@ class ServerArgs:
), "Please enable dp attention when setting enable_dp_attention. "
# DeepEP MoE
self.enable_sp_layernorm = False
if self.enable_deepep_moe:
if self.deepep_mode == "auto":
assert (
......@@ -397,9 +396,6 @@ class ServerArgs:
logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
self.disable_cuda_graph = True
self.ep_size = self.tp_size
self.enable_sp_layernorm = (
self.dp_size < self.tp_size if self.enable_dp_attention else True
)
logger.warning(
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the tensor parallel size[{self.tp_size}]."
)
......
......@@ -20,6 +20,11 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardMode,
)
from sglang.srt.speculative.eagle_utils import EagleDraftInput
from sglang.srt.utils import (
require_attn_tp_gather,
require_gathered_buffer,
require_mlp_tp_gather,
)
if TYPE_CHECKING:
from sglang.srt.speculative.eagle_worker import EAGLEWorker
......@@ -39,8 +44,9 @@ class EAGLEDraftCudaGraphRunner:
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
self.dp_size = self.model_runner.dp_size
self.tp_size = self.model_runner.tp_size
self.topk = model_runner.server_args.speculative_eagle_topk
......@@ -88,8 +94,7 @@ class EAGLEDraftCudaGraphRunner:
dtype=self.model_runner.dtype,
)
if self.enable_dp_attention or self.enable_sp_layernorm:
# TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
if self.require_gathered_buffer:
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
......@@ -97,12 +102,19 @@ class EAGLEDraftCudaGraphRunner:
),
dtype=self.model_runner.dtype,
)
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
if self.require_mlp_tp_gather:
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
else:
assert self.require_attn_tp_gather
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(1,), dtype=torch.int32
)
# Capture
try:
......@@ -114,8 +126,7 @@ class EAGLEDraftCudaGraphRunner:
)
def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention:
# TODO(ch-wan): check --moe-dense-tp-size and --enable-dp-lm-head
if self.require_mlp_tp_gather:
if not forward_batch.can_run_dp_cuda_graph:
return False
total_batch_size = (
......@@ -153,7 +164,7 @@ class EAGLEDraftCudaGraphRunner:
topk_index = self.topk_index[:num_seqs]
hidden_states = self.hidden_states[:num_seqs]
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.require_mlp_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[
......@@ -177,6 +188,24 @@ class EAGLEDraftCudaGraphRunner:
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[num_tokens],
dtype=torch.int32,
device=self.input_ids.device,
)
)
self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor(
[num_tokens],
dtype=torch.int32,
device=self.input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
else:
global_num_tokens = None
gathered_buffer = None
......@@ -259,7 +288,7 @@ class EAGLEDraftCudaGraphRunner:
raw_num_token = raw_bs * self.num_tokens_per_bs
# Pad
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.require_mlp_tp_gather:
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
......@@ -286,7 +315,7 @@ class EAGLEDraftCudaGraphRunner:
self.topk_index[:raw_bs].copy_(forward_batch.spec_info.topk_index)
self.hidden_states[:raw_bs].copy_(forward_batch.spec_info.hidden_states)
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.require_gathered_buffer:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
self.global_num_tokens_for_logprob_gpu.copy_(
forward_batch.global_num_tokens_for_logprob_gpu
......
......@@ -21,6 +21,11 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardMode,
)
from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
from sglang.srt.utils import (
require_attn_tp_gather,
require_gathered_buffer,
require_mlp_tp_gather,
)
if TYPE_CHECKING:
from sglang.srt.speculative.eagle_worker import EAGLEWorker
......@@ -35,8 +40,9 @@ class EAGLEDraftExtendCudaGraphRunner:
self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
self.tp_size = self.model_runner.tp_size
self.dp_size = model_runner.server_args.dp_size
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
......@@ -92,7 +98,7 @@ class EAGLEDraftExtendCudaGraphRunner:
(self.max_bs,), self.num_tokens_per_bs, dtype=torch.int32
)
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.require_gathered_buffer:
self.gathered_buffer = torch.zeros(
(
self.max_num_token,
......@@ -100,13 +106,19 @@ class EAGLEDraftExtendCudaGraphRunner:
),
dtype=self.model_runner.dtype,
)
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
if self.require_mlp_tp_gather:
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
else:
assert self.require_attn_tp_gather
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(1,), dtype=torch.int32
)
# Capture
try:
with model_capture_mode():
......@@ -117,7 +129,7 @@ class EAGLEDraftExtendCudaGraphRunner:
)
def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.require_mlp_tp_gather:
if not forward_batch.can_run_dp_cuda_graph:
return False
total_batch_size = (
......@@ -160,7 +172,7 @@ class EAGLEDraftExtendCudaGraphRunner:
positions = self.positions[:num_tokens]
hidden_states = self.hidden_states[:num_tokens]
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.require_mlp_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[
......@@ -184,6 +196,24 @@ class EAGLEDraftExtendCudaGraphRunner:
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[num_tokens],
dtype=torch.int32,
device=self.input_ids.device,
)
)
self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor(
[num_tokens],
dtype=torch.int32,
device=self.input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
global_num_tokens_for_logprob = self.global_num_tokens_for_logprob_gpu
else:
global_num_tokens = None
gathered_buffer = None
......@@ -270,7 +300,7 @@ class EAGLEDraftExtendCudaGraphRunner:
# in the batch, which will not be counted as num_seqs
raw_bs = forward_batch.batch_size
num_tokens = forward_batch.input_ids.shape[0]
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.require_mlp_tp_gather:
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
......@@ -299,7 +329,7 @@ class EAGLEDraftExtendCudaGraphRunner:
self.accept_length[:raw_bs].copy_(forward_batch.spec_info.accept_length)
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
if self.enable_dp_attention or self.enable_sp_layernorm:
if self.require_gathered_buffer:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
self.global_num_tokens_for_logprob_gpu.copy_(
forward_batch.global_num_tokens_for_logprob_gpu
......
......@@ -2303,6 +2303,51 @@ class Withable(Generic[T]):
self._value = None
def require_mlp_tp_gather(server_args):
"""
Check if the input of MLP is obtained by all-gather rather than all-reduce. This only happens when each MLP TP group contains multiple attention DP groups.
"""
if server_args.enable_dp_attention:
assert server_args.dp_size > 1, "dp_size must be greater than 1"
if (
server_args.moe_dense_tp_size is None
): # TODO(ch-wan): some MoE models do not have dense layers
return True
elif not server_args.enable_dp_lm_head:
return True
elif not server_args.enable_deepep_moe:
return True
else:
return (
server_args.moe_dense_tp_size
> server_args.tp_size // server_args.dp_size
)
else:
return False
def require_attn_tp_gather(server_args):
"""
Check if the input of attention is scattered.
"""
assert server_args.moe_dense_tp_size in [1, None]
if server_args.enable_deepep_moe or server_args.moe_dense_tp_size == 1:
if server_args.enable_dp_attention:
return server_args.dp_size < server_args.tp_size
else:
return True
else:
return False
def require_gathered_buffer(server_args):
return require_mlp_tp_gather(server_args) or require_attn_tp_gather(server_args)
def require_mlp_sync(server_args):
return server_args.enable_dp_attention or require_gathered_buffer(server_args)
def merge_bias_tensor(
lhs: Optional[torch.Tensor],
rhs: Optional[torch.Tensor],
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment