You need to sign in or sign up before continuing.
Unverified Commit ac5010e0 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Fix CUDA Graph Check under Deepep with DP FFN (#7451)

parent 3cee035e
......@@ -48,6 +48,7 @@ from sglang.srt.utils import (
rank0_log,
require_attn_tp_gather,
require_gathered_buffer,
require_mlp_sync,
require_mlp_tp_gather,
)
......@@ -212,6 +213,7 @@ class CudaGraphRunner:
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
self.enable_two_batch_overlap = (
model_runner.server_args.enable_two_batch_overlap
......@@ -337,23 +339,23 @@ class CudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch):
if self.require_mlp_tp_gather:
total_batch_size = (
cuda_graph_bs = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
total_batch_size in self.graphs
if self.disable_padding
else total_batch_size <= self.max_bs
)
else:
cuda_graph_bs = forward_batch.batch_size
is_bs_supported = (
forward_batch.batch_size in self.graphs
cuda_graph_bs in self.graphs
if self.disable_padding
else forward_batch.batch_size <= self.max_bs
else cuda_graph_bs <= self.max_bs
)
if self.require_mlp_sync:
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
# because the full_text_row_masked_out_mask tensor will always be ones
......
......@@ -23,6 +23,7 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput
from sglang.srt.utils import (
require_attn_tp_gather,
require_gathered_buffer,
require_mlp_sync,
require_mlp_tp_gather,
)
......@@ -46,6 +47,7 @@ class EAGLEDraftCudaGraphRunner:
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
self.dp_size = self.model_runner.dp_size
self.tp_size = self.model_runner.tp_size
......@@ -127,24 +129,23 @@ class EAGLEDraftCudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch):
if self.require_mlp_tp_gather:
if not forward_batch.can_run_dp_cuda_graph:
return False
total_batch_size = (
cuda_graph_bs = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
is_bs_supported = (
total_batch_size in self.graphs
if self.disable_padding
else total_batch_size <= self.max_bs
)
else:
cuda_graph_bs = forward_batch.batch_size
is_bs_supported = (
forward_batch.batch_size in self.graphs
cuda_graph_bs in self.graphs
if self.disable_padding
else forward_batch.batch_size <= self.max_bs
else cuda_graph_bs <= self.max_bs
)
if self.require_mlp_sync:
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
return is_bs_supported
def capture(self):
......
......@@ -24,6 +24,7 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
from sglang.srt.utils import (
require_attn_tp_gather,
require_gathered_buffer,
require_mlp_sync,
require_mlp_tp_gather,
)
......@@ -42,6 +43,7 @@ class EAGLEDraftExtendCudaGraphRunner:
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
self.tp_size = self.model_runner.tp_size
self.dp_size = model_runner.server_args.dp_size
......@@ -130,28 +132,23 @@ class EAGLEDraftExtendCudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch):
if self.require_mlp_tp_gather:
if not forward_batch.can_run_dp_cuda_graph:
return False
total_batch_size = (
cuda_graph_bs = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
is_bs_supported = (
total_batch_size in self.graphs
if self.disable_padding
else total_batch_size <= self.max_bs
)
return is_bs_supported
else:
batch_size = forward_batch.seq_lens.numel()
cuda_graph_bs = forward_batch.seq_lens.numel()
is_bs_supported = (
batch_size in self.graphs
cuda_graph_bs in self.graphs
if self.disable_padding
else batch_size <= self.max_bs
else cuda_graph_bs <= self.max_bs
)
if self.require_mlp_sync:
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
return is_bs_supported
def capture(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment