Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ae6a5b29
Unverified
Commit
ae6a5b29
authored
May 29, 2025
by
fzyzcjy
Committed by
GitHub
May 28, 2025
Browse files
Minor refactor two-batch overlap (#6682)
parent
4839999b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
60 additions
and
36 deletions
+60
-36
python/sglang/srt/layers/attention/tbo_backend.py
python/sglang/srt/layers/attention/tbo_backend.py
+5
-14
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+10
-6
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+45
-16
No files found.
python/sglang/srt/layers/attention/tbo_backend.py
View file @
ae6a5b29
...
...
@@ -119,24 +119,15 @@ class TboAttnBackend(AttentionBackend):
replay_seq_lens_sum
:
int
=
None
,
replay_seq_lens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
,
):
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
if
fn_name
==
"init_forward_metadata_capture_cuda_graph"
:
assert
capture_num_tokens
==
bs
,
"Only support num_tokens==bs currently"
num_tokens
=
bs
forward_mode_for_tbo_split
=
(
forward_mode
if
forward_mode
!=
ForwardMode
.
IDLE
else
ForwardMode
.
DECODE
)
tbo_split_seq_index
=
two_batch_overlap
.
compute_split_seq_index
(
forward_mode
=
forward_mode_for_tbo_split
,
num_tokens
=
num_tokens
,
extend_lens
=
None
,
)
tbo_split_token_index
=
two_batch_overlap
.
compute_split_token_index
(
split_seq_index
=
tbo_split_seq_index
,
forward_mode
=
forward_mode_for_tbo_split
,
extend_seq_lens
=
None
,
tbo_split_seq_index
,
tbo_split_token_index
=
(
two_batch_overlap
.
compute_split_indices_for_cuda_graph_replay
(
forward_mode
=
forward_mode
,
cuda_graph_num_tokens
=
num_tokens
,
)
)
num_tokens_child_left
=
tbo_split_token_index
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
ae6a5b29
...
...
@@ -40,7 +40,7 @@ from sglang.srt.model_executor.forward_batch_info import (
)
from
sglang.srt.patch_torch
import
monkey_patch_torch_compile
from
sglang.srt.two_batch_overlap
import
(
TboCudaGraphRunner
Utils
,
TboCudaGraphRunner
Plugin
,
TboForwardBatchPreparer
,
)
from
sglang.srt.utils
import
(
...
...
@@ -256,6 +256,7 @@ class CudaGraphRunner:
self
.
positions
=
torch
.
zeros
((
self
.
max_num_token
,),
dtype
=
torch
.
int64
)
self
.
mrope_positions
=
torch
.
zeros
((
3
,
self
.
max_bs
),
dtype
=
torch
.
int64
)
self
.
num_token_non_padded
=
torch
.
zeros
((
1
,),
dtype
=
torch
.
int32
)
self
.
tbo_plugin
=
TboCudaGraphRunnerPlugin
()
# pipeline parallelism
if
self
.
pp_size
>
1
:
...
...
@@ -481,12 +482,9 @@ class CudaGraphRunner:
capture_hidden_mode
=
self
.
capture_hidden_mode
,
lora_paths
=
lora_paths
,
num_token_non_padded
=
self
.
num_token_non_padded
,
tbo_split_seq_index
=
TboCudaGraphRunnerUtils
.
compute_tbo_split_seq_index
(
self
,
num_tokens
),
global_forward_mode
=
self
.
capture_forward_mode
,
)
TboForwardBatchPreparer
.
prepare
(
forward_batch
)
self
.
tbo_plugin
.
capture_one_batch_size
(
forward_batch
,
num_tokens
=
num_tokens
)
if
lora_paths
is
not
None
:
self
.
model_runner
.
lora_manager
.
prepare_lora_batch
(
forward_batch
)
...
...
@@ -581,7 +579,13 @@ class CudaGraphRunner:
self
.
seq_lens
[:
raw_bs
].
copy_
(
forward_batch
.
seq_lens
)
self
.
out_cache_loc
[:
raw_num_token
].
copy_
(
forward_batch
.
out_cache_loc
)
self
.
positions
[:
raw_num_token
].
copy_
(
forward_batch
.
positions
)
self
.
num_token_non_padded
[...]
=
len
(
forward_batch
.
input_ids
)
num_token_non_padded
=
len
(
forward_batch
.
input_ids
)
self
.
num_token_non_padded
[...]
=
num_token_non_padded
self
.
tbo_plugin
.
replay_prepare
(
forward_mode
=
forward_batch
.
forward_mode
,
bs
=
bs
,
num_token_non_padded
=
num_token_non_padded
,
)
if
forward_batch
.
seq_lens_cpu
is
not
None
:
if
bs
!=
raw_bs
:
self
.
seq_lens_cpu
.
fill_
(
1
)
...
...
python/sglang/srt/two_batch_overlap.py
View file @
ae6a5b29
...
...
@@ -85,25 +85,54 @@ def compute_split_token_index(
raise
NotImplementedError
def
compute_split_indices_for_cuda_graph_replay
(
forward_mode
:
ForwardMode
,
cuda_graph_num_tokens
:
int
,
):
forward_mode_for_tbo_split
=
(
forward_mode
if
forward_mode
!=
ForwardMode
.
IDLE
else
ForwardMode
.
DECODE
)
tbo_split_seq_index
=
compute_split_seq_index
(
forward_mode
=
forward_mode_for_tbo_split
,
num_tokens
=
cuda_graph_num_tokens
,
extend_lens
=
None
,
)
tbo_split_token_index
=
compute_split_token_index
(
split_seq_index
=
tbo_split_seq_index
,
forward_mode
=
forward_mode_for_tbo_split
,
extend_seq_lens
=
None
,
)
return
tbo_split_seq_index
,
tbo_split_token_index
# -------------------------------- Preparation ---------------------------------------
class
TboCudaGraphRunnerUtils
:
@
staticmethod
def
compute_tbo_split_seq_index
(
that
:
"CudaGraphRunner"
,
num_tokens
:
int
):
if
that
.
model_runner
.
server_args
.
enable_two_batch_overlap
:
tbo_split_seq_index
=
compute_split_seq_index
(
forward_mode
=
that
.
capture_forward_mode
,
num_tokens
=
num_tokens
,
extend_lens
=
None
,
)
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
assert
(
tbo_split_seq_index
is
not
None
),
f
"
{
that
.
capture_forward_mode
=
}
{
num_tokens
=
}
"
else
:
tbo_split_seq_index
=
None
return
tbo_split_seq_index
class
TboCudaGraphRunnerPlugin
:
def
__init__
(
self
):
pass
# TODO add logic here
def
capture_one_batch_size
(
self
,
batch
:
ForwardBatch
,
num_tokens
:
int
):
if
not
global_server_args_dict
[
"enable_two_batch_overlap"
]:
return
batch
.
tbo_split_seq_index
=
compute_split_seq_index
(
forward_mode
=
batch
.
forward_mode
,
num_tokens
=
num_tokens
,
extend_lens
=
None
,
)
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
assert
batch
.
tbo_split_seq_index
is
not
None
,
f
"
{
num_tokens
=
}
"
TboForwardBatchPreparer
.
prepare
(
batch
)
def
replay_prepare
(
self
,
forward_mode
:
ForwardMode
,
bs
:
int
,
num_token_non_padded
:
int
):
if
not
global_server_args_dict
[
"enable_two_batch_overlap"
]:
return
pass
# TODO add logic here
class
TboDPAttentionPreparer
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment