Unverified Commit ac5010e0 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Fix CUDA Graph Check under Deepep with DP FFN (#7451)

parent 3cee035e
......@@ -48,6 +48,7 @@ from sglang.srt.utils import (
rank0_log,
require_attn_tp_gather,
require_gathered_buffer,
require_mlp_sync,
require_mlp_tp_gather,
)
......@@ -212,6 +213,7 @@ class CudaGraphRunner:
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
self.enable_two_batch_overlap = (
model_runner.server_args.enable_two_batch_overlap
......@@ -337,22 +339,22 @@ class CudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch):
if self.require_mlp_tp_gather:
total_batch_size = (
cuda_graph_bs = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
total_batch_size in self.graphs
if self.disable_padding
else total_batch_size <= self.max_bs
)
else:
is_bs_supported = (
forward_batch.batch_size in self.graphs
if self.disable_padding
else forward_batch.batch_size <= self.max_bs
)
cuda_graph_bs = forward_batch.batch_size
is_bs_supported = (
cuda_graph_bs in self.graphs
if self.disable_padding
else cuda_graph_bs <= self.max_bs
)
if self.require_mlp_sync:
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
......
......@@ -23,6 +23,7 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput
from sglang.srt.utils import (
require_attn_tp_gather,
require_gathered_buffer,
require_mlp_sync,
require_mlp_tp_gather,
)
......@@ -46,6 +47,7 @@ class EAGLEDraftCudaGraphRunner:
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
self.dp_size = self.model_runner.dp_size
self.tp_size = self.model_runner.tp_size
......@@ -127,24 +129,23 @@ class EAGLEDraftCudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch):
if self.require_mlp_tp_gather:
if not forward_batch.can_run_dp_cuda_graph:
return False
total_batch_size = (
cuda_graph_bs = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
is_bs_supported = (
total_batch_size in self.graphs
if self.disable_padding
else total_batch_size <= self.max_bs
)
else:
is_bs_supported = (
forward_batch.batch_size in self.graphs
if self.disable_padding
else forward_batch.batch_size <= self.max_bs
)
cuda_graph_bs = forward_batch.batch_size
is_bs_supported = (
cuda_graph_bs in self.graphs
if self.disable_padding
else cuda_graph_bs <= self.max_bs
)
if self.require_mlp_sync:
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
return is_bs_supported
def capture(self):
......
......@@ -24,6 +24,7 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
from sglang.srt.utils import (
require_attn_tp_gather,
require_gathered_buffer,
require_mlp_sync,
require_mlp_tp_gather,
)
......@@ -42,6 +43,7 @@ class EAGLEDraftExtendCudaGraphRunner:
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
self.tp_size = self.model_runner.tp_size
self.dp_size = model_runner.server_args.dp_size
......@@ -130,29 +132,24 @@ class EAGLEDraftExtendCudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch):
if self.require_mlp_tp_gather:
if not forward_batch.can_run_dp_cuda_graph:
return False
total_batch_size = (
cuda_graph_bs = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
is_bs_supported = (
total_batch_size in self.graphs
if self.disable_padding
else total_batch_size <= self.max_bs
)
return is_bs_supported
else:
batch_size = forward_batch.seq_lens.numel()
cuda_graph_bs = forward_batch.seq_lens.numel()
is_bs_supported = (
batch_size in self.graphs
if self.disable_padding
else batch_size <= self.max_bs
)
is_bs_supported = (
cuda_graph_bs in self.graphs
if self.disable_padding
else cuda_graph_bs <= self.max_bs
)
if self.require_mlp_sync:
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
return is_bs_supported
return is_bs_supported
def capture(self):
CudaGraphRunner.capture(self)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment