Unverified Commit ac5010e0 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Fix CUDA Graph Check under Deepep with DP FFN (#7451)

parent 3cee035e
...@@ -48,6 +48,7 @@ from sglang.srt.utils import ( ...@@ -48,6 +48,7 @@ from sglang.srt.utils import (
rank0_log, rank0_log,
require_attn_tp_gather, require_attn_tp_gather,
require_gathered_buffer, require_gathered_buffer,
require_mlp_sync,
require_mlp_tp_gather, require_mlp_tp_gather,
) )
...@@ -212,6 +213,7 @@ class CudaGraphRunner: ...@@ -212,6 +213,7 @@ class CudaGraphRunner:
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args) self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args) self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args) self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
self.enable_two_batch_overlap = ( self.enable_two_batch_overlap = (
model_runner.server_args.enable_two_batch_overlap model_runner.server_args.enable_two_batch_overlap
...@@ -337,23 +339,23 @@ class CudaGraphRunner: ...@@ -337,23 +339,23 @@ class CudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
if self.require_mlp_tp_gather: if self.require_mlp_tp_gather:
total_batch_size = ( cuda_graph_bs = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle() if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu) else sum(forward_batch.global_num_tokens_cpu)
) )
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
total_batch_size in self.graphs
if self.disable_padding
else total_batch_size <= self.max_bs
)
else: else:
cuda_graph_bs = forward_batch.batch_size
is_bs_supported = ( is_bs_supported = (
forward_batch.batch_size in self.graphs cuda_graph_bs in self.graphs
if self.disable_padding if self.disable_padding
else forward_batch.batch_size <= self.max_bs else cuda_graph_bs <= self.max_bs
) )
if self.require_mlp_sync:
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0) # NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph # If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
# because the full_text_row_masked_out_mask tensor will always be ones # because the full_text_row_masked_out_mask tensor will always be ones
......
...@@ -23,6 +23,7 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput ...@@ -23,6 +23,7 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput
from sglang.srt.utils import ( from sglang.srt.utils import (
require_attn_tp_gather, require_attn_tp_gather,
require_gathered_buffer, require_gathered_buffer,
require_mlp_sync,
require_mlp_tp_gather, require_mlp_tp_gather,
) )
...@@ -46,6 +47,7 @@ class EAGLEDraftCudaGraphRunner: ...@@ -46,6 +47,7 @@ class EAGLEDraftCudaGraphRunner:
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args) self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args) self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args) self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
self.dp_size = self.model_runner.dp_size self.dp_size = self.model_runner.dp_size
self.tp_size = self.model_runner.tp_size self.tp_size = self.model_runner.tp_size
...@@ -127,24 +129,23 @@ class EAGLEDraftCudaGraphRunner: ...@@ -127,24 +129,23 @@ class EAGLEDraftCudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
if self.require_mlp_tp_gather: if self.require_mlp_tp_gather:
if not forward_batch.can_run_dp_cuda_graph: cuda_graph_bs = (
return False
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle() if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu) else sum(forward_batch.global_num_tokens_cpu)
) )
is_bs_supported = (
total_batch_size in self.graphs
if self.disable_padding
else total_batch_size <= self.max_bs
)
else: else:
cuda_graph_bs = forward_batch.batch_size
is_bs_supported = ( is_bs_supported = (
forward_batch.batch_size in self.graphs cuda_graph_bs in self.graphs
if self.disable_padding if self.disable_padding
else forward_batch.batch_size <= self.max_bs else cuda_graph_bs <= self.max_bs
) )
if self.require_mlp_sync:
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
return is_bs_supported return is_bs_supported
def capture(self): def capture(self):
......
...@@ -24,6 +24,7 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk ...@@ -24,6 +24,7 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
from sglang.srt.utils import ( from sglang.srt.utils import (
require_attn_tp_gather, require_attn_tp_gather,
require_gathered_buffer, require_gathered_buffer,
require_mlp_sync,
require_mlp_tp_gather, require_mlp_tp_gather,
) )
...@@ -42,6 +43,7 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -42,6 +43,7 @@ class EAGLEDraftExtendCudaGraphRunner:
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args) self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args) self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args) self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
self.tp_size = self.model_runner.tp_size self.tp_size = self.model_runner.tp_size
self.dp_size = model_runner.server_args.dp_size self.dp_size = model_runner.server_args.dp_size
...@@ -130,28 +132,23 @@ class EAGLEDraftExtendCudaGraphRunner: ...@@ -130,28 +132,23 @@ class EAGLEDraftExtendCudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
if self.require_mlp_tp_gather: if self.require_mlp_tp_gather:
if not forward_batch.can_run_dp_cuda_graph: cuda_graph_bs = (
return False
total_batch_size = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle() if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu) else sum(forward_batch.global_num_tokens_cpu)
) )
is_bs_supported = (
total_batch_size in self.graphs
if self.disable_padding
else total_batch_size <= self.max_bs
)
return is_bs_supported
else: else:
batch_size = forward_batch.seq_lens.numel() cuda_graph_bs = forward_batch.seq_lens.numel()
is_bs_supported = ( is_bs_supported = (
batch_size in self.graphs cuda_graph_bs in self.graphs
if self.disable_padding if self.disable_padding
else batch_size <= self.max_bs else cuda_graph_bs <= self.max_bs
) )
if self.require_mlp_sync:
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
return is_bs_supported return is_bs_supported
def capture(self): def capture(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment