Unverified Commit 69fe3c97 authored by Zilin Zhu's avatar Zilin Zhu Committed by GitHub
Browse files

Manually flip deepep_mode for cuda_graph (#11666)

parent 8af84912
......@@ -235,6 +235,15 @@ class DeepEPBuffer:
cls.clean_buffer()
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
@classmethod
def set_dispatch_mode(cls, mode: DeepEPMode):
if mode.is_low_latency():
cls.set_dispatch_mode_as_low_latency()
elif mode.is_normal():
cls.set_dispatch_mode_as_normal()
else:
raise Exception("unsupported mode")
class DeepEPConfig(BaseDispatcherConfig):
_instance = None
......
......@@ -40,6 +40,8 @@ from sglang.srt.layers.dp_attention import (
set_dp_buffer_len,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPBuffer
from sglang.srt.layers.moe.utils import get_deepep_mode, get_moe_a2a_backend
from sglang.srt.layers.torchao_utils import save_gemlite_cache
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
......@@ -240,6 +242,8 @@ class CudaGraphRunner:
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
self.deepep_adapter = DeepEPCudaGraphRunnerAdapter()
# Batch sizes to capture
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
log_info_on_rank0(logger, f"Capture cuda graph bs {self.capture_bs}")
......@@ -653,6 +657,8 @@ class CudaGraphRunner:
)
return logits_output_or_pp_proxy_tensors
self.deepep_adapter.capture(is_extend_in_batch=False)
for _ in range(2):
self.device_module.synchronize()
self.model_runner.tp_group.barrier()
......@@ -796,6 +802,8 @@ class CudaGraphRunner:
skip_attn_backend_init: bool = False,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
self.deepep_adapter.replay()
if not skip_attn_backend_init:
self.replay_prepare(forward_batch, pp_proxy_tensors)
else:
......@@ -872,3 +880,23 @@ CUDA_GRAPH_CAPTURE_FAILED_MSG = (
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
)
class DeepEPCudaGraphRunnerAdapter:
def __init__(self):
# Record DeepEP mode used during capture to ensure replay consistency
self._captured_deepep_mode = None
def capture(self, is_extend_in_batch: bool):
if not get_moe_a2a_backend().is_deepep():
return
self._captured_deepep_mode = get_deepep_mode().resolve(
is_extend_in_batch=is_extend_in_batch
)
DeepEPBuffer.set_dispatch_mode(self._captured_deepep_mode)
def replay(self):
if not get_moe_a2a_backend().is_deepep():
return
assert self._captured_deepep_mode is not None
DeepEPBuffer.set_dispatch_mode(self._captured_deepep_mode)
......@@ -9,6 +9,7 @@ from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
from sglang.srt.model_executor.cuda_graph_runner import (
CUDA_GRAPH_CAPTURE_FAILED_MSG,
CudaGraphRunner,
DeepEPCudaGraphRunnerAdapter,
get_batch_sizes_to_capture,
get_global_graph_memory_pool,
model_capture_mode,
......@@ -61,6 +62,7 @@ class EAGLEDraftCudaGraphRunner:
self.enable_profile_cuda_graph = (
model_runner.server_args.enable_profile_cuda_graph
)
self.deepep_adapter = DeepEPCudaGraphRunnerAdapter()
server_args = model_runner.server_args
# Batch sizes to capture
......@@ -264,6 +266,8 @@ class EAGLEDraftCudaGraphRunner:
forward_batch.spec_info.hidden_states = hidden_states_backup
return ret
self.deepep_adapter.capture(is_extend_in_batch=False)
for _ in range(2):
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
......@@ -285,6 +289,8 @@ class EAGLEDraftCudaGraphRunner:
def replay(self, forward_batch: ForwardBatch):
assert forward_batch.out_cache_loc is not None
self.deepep_adapter.replay()
raw_bs = forward_batch.batch_size
raw_num_token = raw_bs * self.num_tokens_per_bs
......
......@@ -9,6 +9,7 @@ from sglang.srt.layers.dp_attention import DpPaddingMode, set_dp_buffer_len
from sglang.srt.model_executor.cuda_graph_runner import (
CUDA_GRAPH_CAPTURE_FAILED_MSG,
CudaGraphRunner,
DeepEPCudaGraphRunnerAdapter,
LogitsProcessorOutput,
get_batch_sizes_to_capture,
get_global_graph_memory_pool,
......@@ -61,6 +62,7 @@ class EAGLEDraftExtendCudaGraphRunner:
)
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
self.padded_static_len = -1
self.deepep_adapter = DeepEPCudaGraphRunnerAdapter()
# Attention backend
self.num_tokens_per_bs = self.speculative_num_steps + 1
......@@ -243,6 +245,8 @@ class EAGLEDraftExtendCudaGraphRunner:
)
spec_info.positions = None
self.deepep_adapter.capture(is_extend_in_batch=True)
# Forward batch
forward_batch = ForwardBatch(
forward_mode=ForwardMode.DRAFT_EXTEND,
......@@ -318,6 +322,8 @@ class EAGLEDraftExtendCudaGraphRunner:
def replay(self, forward_batch: ForwardBatch):
assert forward_batch.out_cache_loc is not None
self.deepep_adapter.replay()
# batch_size and num_seqs can be different in case there are finished examples
# in the batch, which will not be counted as num_seqs
raw_bs = forward_batch.batch_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment