Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0d477880
Unverified
Commit
0d477880
authored
May 25, 2025
by
fzyzcjy
Committed by
GitHub
May 24, 2025
Browse files
Support overlapping two batches (#4068)
parent
f4560373
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1146 additions
and
130 deletions
+1146
-130
python/sglang/srt/layers/attention/tbo_backend.py
python/sglang/srt/layers/attention/tbo_backend.py
+241
-0
python/sglang/srt/layers/quantization/deep_gemm.py
python/sglang/srt/layers/quantization/deep_gemm.py
+13
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+8
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+25
-1
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+21
-1
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+18
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+18
-9
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+118
-92
python/sglang/srt/operations.py
python/sglang/srt/operations.py
+37
-2
python/sglang/srt/operations_strategy.py
python/sglang/srt/operations_strategy.py
+107
-24
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+462
-0
test/srt/test_two_batch_overlap.py
test/srt/test_two_batch_overlap.py
+72
-0
No files found.
python/sglang/srt/layers/attention/tbo_backend.py
0 → 100644
View file @
0d477880
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
,
Union
import
torch
from
sglang.srt
import
two_batch_overlap
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
class
TboAttnBackend
(
AttentionBackend
):
def
__init__
(
self
,
primary
:
AttentionBackend
,
children
:
List
[
AttentionBackend
]):
super
().
__init__
()
self
.
primary
=
primary
self
.
children
=
children
@
classmethod
def
init_new
(
cls
,
creator
:
Callable
[[],
AttentionBackend
]):
return
cls
(
primary
=
creator
(),
children
=
[
creator
()
for
_
in
range
(
2
)],
)
def
init_forward_metadata
(
self
,
forward_batch
:
"ForwardBatch"
):
self
.
primary
.
init_forward_metadata
(
forward_batch
=
forward_batch
)
if
forward_batch
.
tbo_children
is
not
None
:
for
child
,
forward_batch_child
in
zip
(
self
.
children
,
forward_batch
.
tbo_children
,
strict
=
True
):
if
forward_batch_child
.
batch_size
>
0
:
child
.
init_forward_metadata
(
forward_batch
=
forward_batch_child
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
self
.
primary
.
init_cuda_graph_state
(
max_bs
=
max_bs
)
for
item
in
self
.
children
:
# TODO for children, maybe can provide *smaller* max_bs to optimize
item
.
init_cuda_graph_state
(
max_bs
=
max_bs
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
num_tokens
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
"ForwardMode"
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
):
self
.
primary
.
init_forward_metadata_capture_cuda_graph
(
bs
=
bs
,
num_tokens
=
num_tokens
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
encoder_lens
=
encoder_lens
,
forward_mode
=
forward_mode
,
spec_info
=
spec_info
,
)
self
.
_init_forward_metadata_cuda_graph_children
(
fn_name
=
"init_forward_metadata_capture_cuda_graph"
,
bs
=
bs
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
encoder_lens
=
encoder_lens
,
forward_mode
=
forward_mode
,
spec_info
=
spec_info
,
capture_num_tokens
=
num_tokens
,
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
"ForwardMode"
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
self
.
primary
.
init_forward_metadata_replay_cuda_graph
(
bs
=
bs
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
seq_lens_sum
=
seq_lens_sum
,
encoder_lens
=
encoder_lens
,
forward_mode
=
forward_mode
,
spec_info
=
spec_info
,
seq_lens_cpu
=
seq_lens_cpu
,
)
self
.
_init_forward_metadata_cuda_graph_children
(
fn_name
=
"init_forward_metadata_replay_cuda_graph"
,
bs
=
bs
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
encoder_lens
=
encoder_lens
,
forward_mode
=
forward_mode
,
spec_info
=
spec_info
,
replay_seq_lens_sum
=
seq_lens_sum
,
replay_seq_lens_cpu
=
seq_lens_cpu
,
)
def
_init_forward_metadata_cuda_graph_children
(
self
,
fn_name
:
str
,
# common args
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
"ForwardMode"
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
# capture args
capture_num_tokens
:
int
=
None
,
# replay args
replay_seq_lens_sum
:
int
=
None
,
replay_seq_lens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
,
):
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
if
fn_name
==
"init_forward_metadata_capture_cuda_graph"
:
assert
capture_num_tokens
==
bs
,
"Only support num_tokens==bs currently"
num_tokens
=
bs
forward_mode_for_tbo_split
=
(
forward_mode
if
forward_mode
!=
ForwardMode
.
IDLE
else
ForwardMode
.
DECODE
)
tbo_split_seq_index
=
two_batch_overlap
.
compute_split_seq_index
(
forward_mode
=
forward_mode_for_tbo_split
,
num_tokens
=
num_tokens
,
extend_lens
=
None
,
)
tbo_split_token_index
=
two_batch_overlap
.
compute_split_token_index
(
split_seq_index
=
tbo_split_seq_index
,
forward_mode
=
forward_mode_for_tbo_split
,
extend_seq_lens
=
None
,
)
num_tokens_child_left
=
tbo_split_token_index
num_tokens_child_right
=
num_tokens
-
tbo_split_token_index
bs_child_left
=
num_tokens_child_left
bs_child_right
=
num_tokens_child_right
assert
(
num_tokens_child_left
>
0
and
num_tokens_child_right
>
0
),
f
"
{
num_tokens_child_left
=
}
{
num_tokens_child_right
=
}
{
forward_mode
=
}
{
num_tokens
=
}
"
common_pre_split_args
=
dict
(
fn_name
=
fn_name
,
bs
=
bs
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
encoder_lens
=
encoder_lens
,
forward_mode
=
forward_mode
,
spec_info
=
spec_info
,
capture_num_tokens
=
capture_num_tokens
,
replay_seq_lens_sum
=
replay_seq_lens_sum
,
replay_seq_lens_cpu
=
replay_seq_lens_cpu
,
)
args_left
=
_init_forward_metadata_cuda_graph_split
(
output_bs
=
bs_child_left
,
seq_slice
=
slice
(
None
,
tbo_split_seq_index
),
**
common_pre_split_args
,
)
args_right
=
_init_forward_metadata_cuda_graph_split
(
output_bs
=
bs_child_right
,
seq_slice
=
slice
(
tbo_split_seq_index
,
None
),
**
common_pre_split_args
,
)
child_left
,
child_right
=
self
.
children
getattr
(
child_left
,
fn_name
)(
**
args_left
)
getattr
(
child_right
,
fn_name
)(
**
args_right
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
ans
=
self
.
primary
.
get_cuda_graph_seq_len_fill_value
()
for
child
in
self
.
children
:
assert
ans
==
child
.
get_cuda_graph_seq_len_fill_value
()
return
ans
def
forward_extend
(
self
,
*
args
,
**
kwargs
):
return
self
.
primary
.
forward_extend
(
*
args
,
**
kwargs
)
def
forward_decode
(
self
,
*
args
,
**
kwargs
):
return
self
.
primary
.
forward_decode
(
*
args
,
**
kwargs
)
def
_init_forward_metadata_cuda_graph_split
(
fn_name
:
str
,
seq_slice
:
slice
,
output_bs
:
int
,
# common args
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
"ForwardMode"
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
# capture args
capture_num_tokens
:
int
=
None
,
# replay args
replay_seq_lens_sum
:
int
=
None
,
replay_seq_lens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
,
):
assert
encoder_lens
is
None
,
"encoder_lens is not supported yet"
assert
spec_info
is
None
,
"spec_info is not supported yet"
ans
=
dict
(
bs
=
output_bs
,
req_pool_indices
=
req_pool_indices
[
seq_slice
],
seq_lens
=
seq_lens
[
seq_slice
],
# directly forward
forward_mode
=
forward_mode
,
# ignore
encoder_lens
=
None
,
spec_info
=
None
,
)
if
fn_name
==
"init_forward_metadata_capture_cuda_graph"
:
assert
capture_num_tokens
==
bs
,
"Only support num_tokens==bs currently"
ans
.
update
(
dict
(
num_tokens
=
output_bs
,
)
)
elif
fn_name
==
"init_forward_metadata_replay_cuda_graph"
:
output_seq_lens_cpu
=
replay_seq_lens_cpu
[
seq_slice
]
ans
.
update
(
dict
(
seq_lens_sum
=
output_seq_lens_cpu
.
sum
().
item
(),
seq_lens_cpu
=
output_seq_lens_cpu
,
)
)
else
:
raise
NotImplementedError
return
ans
python/sglang/srt/layers/quantization/deep_gemm.py
View file @
0d477880
...
@@ -391,3 +391,16 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
...
@@ -391,3 +391,16 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
RuntimeCache
.
get
=
__patched_func
RuntimeCache
.
get
=
__patched_func
yield
yield
RuntimeCache
.
get
=
origin_func
RuntimeCache
.
get
=
origin_func
@
contextmanager
def
configure_deep_gemm_num_sms
(
num_sms
):
if
num_sms
is
None
:
yield
else
:
original_num_sms
=
deep_gemm
.
get_num_sms
()
deep_gemm
.
set_num_sms
(
num_sms
)
try
:
yield
finally
:
deep_gemm
.
set_num_sms
(
original_num_sms
)
python/sglang/srt/managers/schedule_batch.py
View file @
0d477880
...
@@ -78,6 +78,7 @@ global_server_args_dict = {
...
@@ -78,6 +78,7 @@ global_server_args_dict = {
"disable_radix_cache"
:
ServerArgs
.
disable_radix_cache
,
"disable_radix_cache"
:
ServerArgs
.
disable_radix_cache
,
"enable_deepep_moe"
:
ServerArgs
.
enable_deepep_moe
,
"enable_deepep_moe"
:
ServerArgs
.
enable_deepep_moe
,
"enable_dp_attention"
:
ServerArgs
.
enable_dp_attention
,
"enable_dp_attention"
:
ServerArgs
.
enable_dp_attention
,
"enable_two_batch_overlap"
:
ServerArgs
.
enable_two_batch_overlap
,
"enable_dp_lm_head"
:
ServerArgs
.
enable_dp_lm_head
,
"enable_dp_lm_head"
:
ServerArgs
.
enable_dp_lm_head
,
"enable_ep_moe"
:
ServerArgs
.
enable_ep_moe
,
"enable_ep_moe"
:
ServerArgs
.
enable_ep_moe
,
"deepep_config"
:
ServerArgs
.
deepep_config
,
"deepep_config"
:
ServerArgs
.
deepep_config
,
...
@@ -831,6 +832,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -831,6 +832,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
global_num_tokens_for_logprob
:
Optional
[
List
[
int
]]
=
None
global_num_tokens_for_logprob
:
Optional
[
List
[
int
]]
=
None
can_run_dp_cuda_graph
:
bool
=
False
can_run_dp_cuda_graph
:
bool
=
False
tbo_split_seq_index
:
Optional
[
int
]
=
None
global_forward_mode
:
Optional
[
ForwardMode
]
=
None
# For processing logprobs
# For processing logprobs
return_logprob
:
bool
=
False
return_logprob
:
bool
=
False
...
@@ -1624,6 +1627,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1624,6 +1627,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
or
global_server_args_dict
[
"attention_backend"
]
==
"flashmla"
or
global_server_args_dict
[
"attention_backend"
]
==
"flashmla"
or
global_server_args_dict
[
"attention_backend"
]
==
"fa3"
or
global_server_args_dict
[
"attention_backend"
]
==
"fa3"
or
global_server_args_dict
[
"attention_backend"
]
==
"cutlass_mla"
or
global_server_args_dict
[
"attention_backend"
]
==
"cutlass_mla"
or
global_server_args_dict
[
"enable_two_batch_overlap"
]
):
):
seq_lens_cpu
=
self
.
seq_lens
.
cpu
()
seq_lens_cpu
=
self
.
seq_lens
.
cpu
()
else
:
else
:
...
@@ -1651,6 +1655,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1651,6 +1655,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
global_num_tokens
=
self
.
global_num_tokens
,
global_num_tokens
=
self
.
global_num_tokens
,
global_num_tokens_for_logprob
=
self
.
global_num_tokens_for_logprob
,
global_num_tokens_for_logprob
=
self
.
global_num_tokens_for_logprob
,
can_run_dp_cuda_graph
=
self
.
can_run_dp_cuda_graph
,
can_run_dp_cuda_graph
=
self
.
can_run_dp_cuda_graph
,
tbo_split_seq_index
=
self
.
tbo_split_seq_index
,
global_forward_mode
=
self
.
global_forward_mode
,
seq_lens_cpu
=
seq_lens_cpu
,
seq_lens_cpu
=
seq_lens_cpu
,
extend_num_tokens
=
self
.
extend_num_tokens
,
extend_num_tokens
=
self
.
extend_num_tokens
,
extend_seq_lens
=
extend_seq_lens
,
extend_seq_lens
=
extend_seq_lens
,
...
@@ -1729,6 +1735,8 @@ class ModelWorkerBatch:
...
@@ -1729,6 +1735,8 @@ class ModelWorkerBatch:
global_num_tokens
:
Optional
[
List
[
int
]]
global_num_tokens
:
Optional
[
List
[
int
]]
global_num_tokens_for_logprob
:
Optional
[
List
[
int
]]
global_num_tokens_for_logprob
:
Optional
[
List
[
int
]]
can_run_dp_cuda_graph
:
bool
can_run_dp_cuda_graph
:
bool
tbo_split_seq_index
:
Optional
[
int
]
global_forward_mode
:
Optional
[
ForwardMode
]
# For extend
# For extend
extend_num_tokens
:
Optional
[
int
]
extend_num_tokens
:
Optional
[
int
]
...
...
python/sglang/srt/managers/scheduler.py
View file @
0d477880
...
@@ -34,6 +34,7 @@ import zmq
...
@@ -34,6 +34,7 @@ import zmq
from
torch.distributed
import
barrier
from
torch.distributed
import
barrier
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt
import
two_batch_overlap
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.constrained.base_grammar_backend
import
create_grammar_backend
from
sglang.srt.constrained.base_grammar_backend
import
create_grammar_backend
from
sglang.srt.disaggregation.decode
import
(
from
sglang.srt.disaggregation.decode
import
(
...
@@ -132,7 +133,9 @@ from sglang.srt.reasoning_parser import ReasoningParser
...
@@ -132,7 +133,9 @@ from sglang.srt.reasoning_parser import ReasoningParser
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
from
sglang.srt.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
from
sglang.srt.two_batch_overlap
import
TboDPAttentionPreparer
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
DeepEPMode
,
DynamicGradMode
,
DynamicGradMode
,
broadcast_pyobj
,
broadcast_pyobj
,
configure_logger
,
configure_logger
,
...
@@ -1648,6 +1651,9 @@ class Scheduler(
...
@@ -1648,6 +1651,9 @@ class Scheduler(
disable_cuda_graph
=
self
.
server_args
.
disable_cuda_graph
,
disable_cuda_graph
=
self
.
server_args
.
disable_cuda_graph
,
spec_algorithm
=
self
.
spec_algorithm
,
spec_algorithm
=
self
.
spec_algorithm
,
speculative_num_draft_tokens
=
self
.
server_args
.
speculative_num_draft_tokens
,
speculative_num_draft_tokens
=
self
.
server_args
.
speculative_num_draft_tokens
,
enable_two_batch_overlap
=
self
.
server_args
.
enable_two_batch_overlap
,
enable_deepep_moe
=
self
.
server_args
.
enable_deepep_moe
,
deepep_mode
=
DeepEPMode
[
self
.
server_args
.
deepep_mode
],
)
)
@
staticmethod
@
staticmethod
...
@@ -1661,6 +1667,9 @@ class Scheduler(
...
@@ -1661,6 +1667,9 @@ class Scheduler(
disable_cuda_graph
:
bool
,
disable_cuda_graph
:
bool
,
spec_algorithm
,
spec_algorithm
,
speculative_num_draft_tokens
,
speculative_num_draft_tokens
,
enable_two_batch_overlap
:
bool
,
enable_deepep_moe
:
bool
,
deepep_mode
:
DeepEPMode
,
):
):
# Check if other DP workers have running batches
# Check if other DP workers have running batches
if
local_batch
is
None
:
if
local_batch
is
None
:
...
@@ -1696,17 +1705,26 @@ class Scheduler(
...
@@ -1696,17 +1705,26 @@ class Scheduler(
is_extend_in_batch
=
(
is_extend_in_batch
=
(
local_batch
.
forward_mode
.
is_extend
()
if
local_batch
else
False
local_batch
.
forward_mode
.
is_extend
()
if
local_batch
else
False
)
)
tbo_preparer
=
TboDPAttentionPreparer
()
local_info
=
torch
.
tensor
(
local_info
=
torch
.
tensor
(
[
[
num_tokens
,
num_tokens
,
can_cuda_graph
,
can_cuda_graph
,
num_tokens_for_logprob
,
num_tokens_for_logprob
,
is_extend_in_batch
,
is_extend_in_batch
,
*
tbo_preparer
.
prepare_all_gather
(
local_batch
,
deepep_mode
,
enable_deepep_moe
,
enable_two_batch_overlap
,
),
],
],
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
)
)
global_info
=
torch
.
empty
(
global_info
=
torch
.
empty
(
(
dp_size
,
attn_tp_size
,
4
),
(
dp_size
,
attn_tp_size
,
6
),
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
)
)
torch
.
distributed
.
all_gather_into_tensor
(
torch
.
distributed
.
all_gather_into_tensor
(
...
@@ -1719,6 +1737,10 @@ class Scheduler(
...
@@ -1719,6 +1737,10 @@ class Scheduler(
global_num_tokens_for_logprob
=
global_info
[:,
0
,
2
].
tolist
()
global_num_tokens_for_logprob
=
global_info
[:,
0
,
2
].
tolist
()
is_extend_in_batch
=
global_info
[:,
0
,
3
].
tolist
()
is_extend_in_batch
=
global_info
[:,
0
,
3
].
tolist
()
tbo_split_seq_index
,
global_forward_mode
=
tbo_preparer
.
compute_output
(
global_info
[:,
:,
4
:
6
]
)
if
local_batch
is
None
and
max
(
global_num_tokens
)
>
0
:
if
local_batch
is
None
and
max
(
global_num_tokens
)
>
0
:
local_batch
=
get_idle_batch
()
local_batch
=
get_idle_batch
()
...
@@ -1732,6 +1754,8 @@ class Scheduler(
...
@@ -1732,6 +1754,8 @@ class Scheduler(
local_batch
.
global_num_tokens_for_logprob
=
(
local_batch
.
global_num_tokens_for_logprob
=
(
global_num_tokens_for_logprob
global_num_tokens_for_logprob
)
)
local_batch
.
tbo_split_seq_index
=
tbo_split_seq_index
local_batch
.
global_forward_mode
=
global_forward_mode
# Check forward mode for cuda graph
# Check forward mode for cuda graph
if
not
disable_cuda_graph
:
if
not
disable_cuda_graph
:
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
0d477880
...
@@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Union
...
@@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Union
import
torch
import
torch
import
tqdm
import
tqdm
from
sglang.srt
import
two_batch_overlap
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed.parallel_state
import
GroupCoordinator
,
graph_capture
from
sglang.srt.distributed.parallel_state
import
GroupCoordinator
,
graph_capture
...
@@ -38,6 +39,10 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -38,6 +39,10 @@ from sglang.srt.model_executor.forward_batch_info import (
PPProxyTensors
,
PPProxyTensors
,
)
)
from
sglang.srt.patch_torch
import
monkey_patch_torch_compile
from
sglang.srt.patch_torch
import
monkey_patch_torch_compile
from
sglang.srt.two_batch_overlap
import
(
TboCudaGraphRunnerUtils
,
TboForwardBatchPreparer
,
)
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_available_gpu_memory
,
get_available_gpu_memory
,
get_device_memory_capacity
,
get_device_memory_capacity
,
...
@@ -152,6 +157,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
...
@@ -152,6 +157,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
model_runner
.
req_to_token_pool
.
size
model_runner
.
req_to_token_pool
.
size
]
]
if
server_args
.
enable_two_batch_overlap
:
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
>=
2
]
if
server_args
.
cuda_graph_max_bs
:
if
server_args
.
cuda_graph_max_bs
:
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
<=
server_args
.
cuda_graph_max_bs
]
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
<=
server_args
.
cuda_graph_max_bs
]
if
max
(
capture_bs
)
<
server_args
.
cuda_graph_max_bs
:
if
max
(
capture_bs
)
<
server_args
.
cuda_graph_max_bs
:
...
@@ -349,7 +357,14 @@ class CudaGraphRunner:
...
@@ -349,7 +357,14 @@ class CudaGraphRunner:
if
self
.
is_encoder_decoder
if
self
.
is_encoder_decoder
else
True
else
True
)
)
return
is_bs_supported
and
is_encoder_lens_supported
is_tbo_supported
=
(
forward_batch
.
can_run_tbo
if
self
.
model_runner
.
server_args
.
enable_two_batch_overlap
else
True
)
return
is_bs_supported
and
is_encoder_lens_supported
and
is_tbo_supported
def
capture
(
self
):
def
capture
(
self
):
with
graph_capture
()
as
graph_capture_context
:
with
graph_capture
()
as
graph_capture_context
:
...
@@ -466,7 +481,12 @@ class CudaGraphRunner:
...
@@ -466,7 +481,12 @@ class CudaGraphRunner:
capture_hidden_mode
=
self
.
capture_hidden_mode
,
capture_hidden_mode
=
self
.
capture_hidden_mode
,
lora_paths
=
lora_paths
,
lora_paths
=
lora_paths
,
num_token_non_padded
=
self
.
num_token_non_padded
,
num_token_non_padded
=
self
.
num_token_non_padded
,
tbo_split_seq_index
=
TboCudaGraphRunnerUtils
.
compute_tbo_split_seq_index
(
self
,
num_tokens
),
global_forward_mode
=
self
.
capture_forward_mode
,
)
)
TboForwardBatchPreparer
.
prepare
(
forward_batch
)
if
lora_paths
is
not
None
:
if
lora_paths
is
not
None
:
self
.
model_runner
.
lora_manager
.
prepare_lora_batch
(
forward_batch
)
self
.
model_runner
.
lora_manager
.
prepare_lora_batch
(
forward_batch
)
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
0d477880
...
@@ -29,9 +29,10 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
...
@@ -29,9 +29,10 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
from
__future__
import
annotations
from
__future__
import
annotations
import
dataclasses
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
IntEnum
,
auto
from
enum
import
IntEnum
,
auto
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
triton
import
triton
...
@@ -239,6 +240,7 @@ class ForwardBatch:
...
@@ -239,6 +240,7 @@ class ForwardBatch:
dp_local_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
# cached info at runtime
dp_local_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
# cached info at runtime
gathered_buffer
:
Optional
[
torch
.
Tensor
]
=
None
gathered_buffer
:
Optional
[
torch
.
Tensor
]
=
None
can_run_dp_cuda_graph
:
bool
=
False
can_run_dp_cuda_graph
:
bool
=
False
global_forward_mode
:
Optional
[
ForwardMode
]
=
None
# Speculative decoding
# Speculative decoding
spec_info
:
Optional
[
Union
[
EagleVerifyInput
,
EagleDraftInput
]]
=
None
spec_info
:
Optional
[
Union
[
EagleVerifyInput
,
EagleDraftInput
]]
=
None
...
@@ -252,12 +254,18 @@ class ForwardBatch:
...
@@ -252,12 +254,18 @@ class ForwardBatch:
# For Qwen2-VL
# For Qwen2-VL
mrope_positions
:
torch
.
Tensor
=
None
mrope_positions
:
torch
.
Tensor
=
None
tbo_split_seq_index
:
Optional
[
int
]
=
None
tbo_parent_token_range
:
Optional
[
Tuple
[
int
,
int
]]
=
None
tbo_children
:
Optional
[
List
[
"ForwardBatch"
]]
=
None
@
classmethod
@
classmethod
def
init_new
(
def
init_new
(
cls
,
cls
,
batch
:
ModelWorkerBatch
,
batch
:
ModelWorkerBatch
,
model_runner
:
ModelRunner
,
model_runner
:
ModelRunner
,
):
):
from
sglang.srt.two_batch_overlap
import
TboForwardBatchPreparer
device
=
model_runner
.
device
device
=
model_runner
.
device
extend_input_logprob_token_ids_gpu
=
None
extend_input_logprob_token_ids_gpu
=
None
if
batch
.
extend_input_logprob_token_ids
is
not
None
:
if
batch
.
extend_input_logprob_token_ids
is
not
None
:
...
@@ -281,6 +289,7 @@ class ForwardBatch:
...
@@ -281,6 +289,7 @@ class ForwardBatch:
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
token_ids_logprobs
=
batch
.
token_ids_logprobs
,
token_ids_logprobs
=
batch
.
token_ids_logprobs
,
can_run_dp_cuda_graph
=
batch
.
can_run_dp_cuda_graph
,
can_run_dp_cuda_graph
=
batch
.
can_run_dp_cuda_graph
,
global_forward_mode
=
batch
.
global_forward_mode
,
lora_paths
=
batch
.
lora_paths
,
lora_paths
=
batch
.
lora_paths
,
sampling_info
=
batch
.
sampling_info
,
sampling_info
=
batch
.
sampling_info
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
...
@@ -294,6 +303,7 @@ class ForwardBatch:
...
@@ -294,6 +303,7 @@ class ForwardBatch:
num_token_non_padded
=
torch
.
tensor
(
num_token_non_padded
=
torch
.
tensor
(
len
(
batch
.
input_ids
),
dtype
=
torch
.
int32
len
(
batch
.
input_ids
),
dtype
=
torch
.
int32
).
to
(
device
,
non_blocking
=
True
),
).
to
(
device
,
non_blocking
=
True
),
tbo_split_seq_index
=
batch
.
tbo_split_seq_index
,
)
)
# For DP attention
# For DP attention
...
@@ -316,6 +326,7 @@ class ForwardBatch:
...
@@ -316,6 +326,7 @@ class ForwardBatch:
)
)
if
ret
.
forward_mode
.
is_idle
():
if
ret
.
forward_mode
.
is_idle
():
ret
.
positions
=
torch
.
empty
((
0
,),
device
=
device
)
ret
.
positions
=
torch
.
empty
((
0
,),
device
=
device
)
TboForwardBatchPreparer
.
prepare
(
ret
)
return
ret
return
ret
# Override the positions with spec_info
# Override the positions with spec_info
...
@@ -364,6 +375,8 @@ class ForwardBatch:
...
@@ -364,6 +375,8 @@ class ForwardBatch:
if
model_runner
.
server_args
.
lora_paths
is
not
None
:
if
model_runner
.
server_args
.
lora_paths
is
not
None
:
model_runner
.
lora_manager
.
prepare_lora_batch
(
ret
)
model_runner
.
lora_manager
.
prepare_lora_batch
(
ret
)
TboForwardBatchPreparer
.
prepare
(
ret
)
return
ret
return
ret
def
merge_mm_inputs
(
self
)
->
Optional
[
MultimodalInputs
]:
def
merge_mm_inputs
(
self
)
->
Optional
[
MultimodalInputs
]:
...
@@ -588,6 +601,10 @@ class ForwardBatch:
...
@@ -588,6 +601,10 @@ class ForwardBatch:
# Precompute the kv indices for each chunk
# Precompute the kv indices for each chunk
self
.
prepare_chunked_kv_indices
(
device
)
self
.
prepare_chunked_kv_indices
(
device
)
@
property
def
can_run_tbo
(
self
):
return
self
.
tbo_split_seq_index
is
not
None
class
PPProxyTensors
:
class
PPProxyTensors
:
# adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
# adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
0d477880
...
@@ -37,6 +37,7 @@ from sglang.srt.distributed import (
...
@@ -37,6 +37,7 @@ from sglang.srt.distributed import (
set_custom_all_reduce
,
set_custom_all_reduce
,
)
)
from
sglang.srt.distributed.parallel_state
import
monkey_patch_vllm_parallel_state
from
sglang.srt.distributed.parallel_state
import
monkey_patch_vllm_parallel_state
from
sglang.srt.layers.attention.tbo_backend
import
TboAttnBackend
from
sglang.srt.layers.dp_attention
import
(
from
sglang.srt.layers.dp_attention
import
(
get_attention_tp_group
,
get_attention_tp_group
,
get_attention_tp_size
,
get_attention_tp_size
,
...
@@ -198,6 +199,7 @@ class ModelRunner:
...
@@ -198,6 +199,7 @@ class ModelRunner:
"disable_radix_cache"
:
server_args
.
disable_radix_cache
,
"disable_radix_cache"
:
server_args
.
disable_radix_cache
,
"enable_nan_detection"
:
server_args
.
enable_nan_detection
,
"enable_nan_detection"
:
server_args
.
enable_nan_detection
,
"enable_dp_attention"
:
server_args
.
enable_dp_attention
,
"enable_dp_attention"
:
server_args
.
enable_dp_attention
,
"enable_two_batch_overlap"
:
server_args
.
enable_two_batch_overlap
,
"enable_dp_lm_head"
:
server_args
.
enable_dp_lm_head
,
"enable_dp_lm_head"
:
server_args
.
enable_dp_lm_head
,
"enable_ep_moe"
:
server_args
.
enable_ep_moe
,
"enable_ep_moe"
:
server_args
.
enable_ep_moe
,
"enable_deepep_moe"
:
server_args
.
enable_deepep_moe
,
"enable_deepep_moe"
:
server_args
.
enable_deepep_moe
,
...
@@ -994,6 +996,13 @@ class ModelRunner:
...
@@ -994,6 +996,13 @@ class ModelRunner:
def
init_attention_backend
(
self
):
def
init_attention_backend
(
self
):
"""Init attention kernel backend."""
"""Init attention kernel backend."""
if
self
.
server_args
.
enable_two_batch_overlap
:
self
.
attn_backend
=
TboAttnBackend
.
init_new
(
self
.
_get_attention_backend
)
else
:
self
.
attn_backend
=
self
.
_get_attention_backend
()
# TODO unify with 6338
def
_get_attention_backend
(
self
):
if
self
.
server_args
.
attention_backend
==
"flashinfer"
:
if
self
.
server_args
.
attention_backend
==
"flashinfer"
:
if
not
self
.
use_mla_backend
:
if
not
self
.
use_mla_backend
:
from
sglang.srt.layers.attention.flashinfer_backend
import
(
from
sglang.srt.layers.attention.flashinfer_backend
import
(
...
@@ -1003,17 +1012,17 @@ class ModelRunner:
...
@@ -1003,17 +1012,17 @@ class ModelRunner:
# Init streams
# Init streams
if
self
.
server_args
.
speculative_algorithm
==
"EAGLE"
:
if
self
.
server_args
.
speculative_algorithm
==
"EAGLE"
:
self
.
plan_stream_for_flashinfer
=
torch
.
cuda
.
Stream
()
self
.
plan_stream_for_flashinfer
=
torch
.
cuda
.
Stream
()
self
.
attn_backend
=
FlashInferAttnBackend
(
self
)
return
FlashInferAttnBackend
(
self
)
else
:
else
:
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
FlashInferMLAAttnBackend
,
FlashInferMLAAttnBackend
,
)
)
self
.
attn_backend
=
FlashInferMLAAttnBackend
(
self
)
return
FlashInferMLAAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"aiter"
:
elif
self
.
server_args
.
attention_backend
==
"aiter"
:
from
sglang.srt.layers.attention.aiter_backend
import
AiterAttnBackend
from
sglang.srt.layers.attention.aiter_backend
import
AiterAttnBackend
self
.
attn_backend
=
AiterAttnBackend
(
self
)
return
AiterAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"triton"
:
elif
self
.
server_args
.
attention_backend
==
"triton"
:
assert
self
.
sliding_window_size
is
None
,
(
assert
self
.
sliding_window_size
is
None
,
(
"Window attention is not supported in the triton attention backend. "
"Window attention is not supported in the triton attention backend. "
...
@@ -1028,21 +1037,21 @@ class ModelRunner:
...
@@ -1028,21 +1037,21 @@ class ModelRunner:
DoubleSparseAttnBackend
,
DoubleSparseAttnBackend
,
)
)
self
.
attn_backend
=
DoubleSparseAttnBackend
(
self
)
return
DoubleSparseAttnBackend
(
self
)
else
:
else
:
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
self
.
attn_backend
=
TritonAttnBackend
(
self
)
return
TritonAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"torch_native"
:
elif
self
.
server_args
.
attention_backend
==
"torch_native"
:
from
sglang.srt.layers.attention.torch_native_backend
import
(
from
sglang.srt.layers.attention.torch_native_backend
import
(
TorchNativeAttnBackend
,
TorchNativeAttnBackend
,
)
)
self
.
attn_backend
=
TorchNativeAttnBackend
(
self
)
return
TorchNativeAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"flashmla"
:
elif
self
.
server_args
.
attention_backend
==
"flashmla"
:
from
sglang.srt.layers.attention.flashmla_backend
import
FlashMLABackend
from
sglang.srt.layers.attention.flashmla_backend
import
FlashMLABackend
self
.
attn_backend
=
FlashMLABackend
(
self
)
return
FlashMLABackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"fa3"
:
elif
self
.
server_args
.
attention_backend
==
"fa3"
:
assert
(
assert
(
torch
.
cuda
.
get_device_capability
()[
0
]
==
8
and
not
self
.
use_mla_backend
torch
.
cuda
.
get_device_capability
()[
0
]
==
8
and
not
self
.
use_mla_backend
...
@@ -1054,13 +1063,13 @@ class ModelRunner:
...
@@ -1054,13 +1063,13 @@ class ModelRunner:
FlashAttentionBackend
,
FlashAttentionBackend
,
)
)
self
.
attn_backend
=
FlashAttentionBackend
(
self
)
return
FlashAttentionBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"cutlass_mla"
:
elif
self
.
server_args
.
attention_backend
==
"cutlass_mla"
:
from
sglang.srt.layers.attention.cutlass_mla_backend
import
(
from
sglang.srt.layers.attention.cutlass_mla_backend
import
(
CutlassMLABackend
,
CutlassMLABackend
,
)
)
self
.
attn_backend
=
CutlassMLABackend
(
self
)
return
CutlassMLABackend
(
self
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Invalid attention backend:
{
self
.
server_args
.
attention_backend
}
"
f
"Invalid attention backend:
{
self
.
server_args
.
attention_backend
}
"
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
0d477880
...
@@ -83,8 +83,10 @@ from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchI
...
@@ -83,8 +83,10 @@ from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchI
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.operations
import
execute_operations
from
sglang.srt.two_batch_overlap
import
(
from
sglang.srt.operations_strategy
import
compute_layer_operations
MaybeTboDeepEPDispatcher
,
model_forward_maybe_tbo
,
)
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
BumpAllocator
,
BumpAllocator
,
DeepEPMode
,
DeepEPMode
,
...
@@ -226,6 +228,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -226,6 +228,7 @@ class DeepseekV2MoE(nn.Module):
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
self
.
n_shared_experts
=
config
.
n_shared_experts
self
.
n_shared_experts
=
config
.
n_shared_experts
self
.
n_share_experts_fusion
=
global_server_args_dict
[
"n_share_experts_fusion"
]
self
.
n_share_experts_fusion
=
global_server_args_dict
[
"n_share_experts_fusion"
]
self
.
config
=
config
self
.
layer_id
=
layer_id
self
.
layer_id
=
layer_id
if
self
.
tp_size
>
config
.
n_routed_experts
:
if
self
.
tp_size
>
config
.
n_routed_experts
:
...
@@ -300,7 +303,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -300,7 +303,7 @@ class DeepseekV2MoE(nn.Module):
else
None
else
None
)
)
self
.
deepep_dispatcher
=
DeepEPDispatcher
(
self
.
deepep_dispatcher
=
MaybeTbo
DeepEPDispatcher
(
group
=
parallel_state
.
get_tp_group
().
device_group
,
group
=
parallel_state
.
get_tp_group
().
device_group
,
router_topk
=
self
.
top_k
,
router_topk
=
self
.
top_k
,
permute_fusion
=
True
,
permute_fusion
=
True
,
...
@@ -309,13 +312,11 @@ class DeepseekV2MoE(nn.Module):
...
@@ -309,13 +312,11 @@ class DeepseekV2MoE(nn.Module):
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
params_dtype
=
config
.
torch_dtype
,
params_dtype
=
config
.
torch_dtype
,
deepep_mode
=
DeepEPMode
[
global_server_args_dict
[
"deepep_mode"
]],
deepep_mode
=
DeepEPMode
[
global_server_args_dict
[
"deepep_mode"
]],
async_finish
=
True
,
# TODO
async_finish
=
True
,
return_recv_hook
=
True
,
return_recv_hook
=
True
,
)
)
@
property
self
.
_enable_deepep_moe
=
global_server_args_dict
[
"enable_deepep_moe"
]
def
_enable_deepep_moe
(
self
):
return
global_server_args_dict
[
"enable_deepep_moe"
]
def
get_moe_weights
(
self
):
def
get_moe_weights
(
self
):
return
[
return
[
...
@@ -423,7 +424,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -423,7 +424,7 @@ class DeepseekV2MoE(nn.Module):
return
None
return
None
def
op_gate
(
self
,
state
):
def
op_gate
(
self
,
state
):
if
(
not
self
.
_enable_deepep_moe
)
or
is_non_idle_and_non_empty
(
if
is_non_idle_and_non_empty
(
state
.
forward_batch
.
forward_mode
,
state
.
hidden_states_mlp_input
state
.
forward_batch
.
forward_mode
,
state
.
hidden_states_mlp_input
):
):
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
...
@@ -432,21 +433,18 @@ class DeepseekV2MoE(nn.Module):
...
@@ -432,21 +433,18 @@ class DeepseekV2MoE(nn.Module):
state
.
router_logits
=
None
state
.
router_logits
=
None
def
op_shared_experts
(
self
,
state
):
def
op_shared_experts
(
self
,
state
):
if
(
self
.
n_share_experts_fusion
==
0
)
and
(
hidden_states_mlp_input
=
state
.
pop
(
"hidden_states_mlp_input"
)
(
not
self
.
_enable_deepep_moe
)
if
(
self
.
n_share_experts_fusion
==
0
)
and
is_non_idle_and_non_empty
(
or
is_non_idle_and_non_empty
(
state
.
forward_batch
.
forward_mode
,
hidden_states_mlp_input
state
.
forward_batch
.
forward_mode
,
state
.
hidden_states_mlp_input
)
):
):
state
.
shared_output
=
self
.
shared_experts
(
state
.
hidden_states_mlp_input
)
state
.
shared_output
=
self
.
shared_experts
(
hidden_states_mlp_input
)
else
:
else
:
state
.
shared_output
=
None
state
.
shared_output
=
None
def
op_select_experts
(
self
,
state
):
def
op_select_experts
(
self
,
state
):
router_logits
=
state
.
router_logits
router_logits
=
state
.
pop
(
"
router_logits
"
)
hidden_states
=
state
.
hidden_states_mlp_input
hidden_states
=
state
.
hidden_states_mlp_input
if
self
.
_enable_deepep_moe
:
if
router_logits
is
not
None
:
if
router_logits
is
not
None
:
state
.
topk_weights_local
,
state
.
topk_idx_local
=
select_experts
(
state
.
topk_weights_local
,
state
.
topk_idx_local
=
select_experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
@@ -471,17 +469,21 @@ class DeepseekV2MoE(nn.Module):
...
@@ -471,17 +469,21 @@ class DeepseekV2MoE(nn.Module):
)
)
def
op_dispatch_a
(
self
,
state
):
def
op_dispatch_a
(
self
,
state
):
if
self
.
_enable_deepep_moe
and
(
self
.
ep_size
>
1
)
:
if
self
.
ep_size
>
1
:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
self
.
deepep_dispatcher
.
dispatch_a
(
self
.
deepep_dispatcher
.
dispatch_a
(
hidden_states
=
state
.
pop
(
"
hidden_states_mlp_input
"
)
,
hidden_states
=
state
.
hidden_states_mlp_input
,
topk_idx
=
state
.
pop
(
"topk_idx_local"
),
topk_idx
=
state
.
pop
(
"topk_idx_local"
),
topk_weights
=
state
.
pop
(
"topk_weights_local"
),
topk_weights
=
state
.
pop
(
"topk_weights_local"
),
forward_mode
=
state
.
forward_batch
.
forward_mode
,
forward_mode
=
state
.
forward_batch
.
forward_mode
,
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
)
def
op_dispatch_b
(
self
,
state
):
def
op_dispatch_b
(
self
,
state
):
if
self
.
_enable_deepep_moe
and
(
self
.
ep_size
>
1
):
if
self
.
ep_size
>
1
:
with
get_global_expert_distribution_recorder
().
with_current_layer
(
self
.
layer_id
):
(
(
state
.
hidden_states_experts_input
,
state
.
hidden_states_experts_input
,
state
.
topk_idx_dispatched
,
state
.
topk_idx_dispatched
,
...
@@ -491,11 +493,11 @@ class DeepseekV2MoE(nn.Module):
...
@@ -491,11 +493,11 @@ class DeepseekV2MoE(nn.Module):
state
.
seg_indptr
,
state
.
seg_indptr
,
state
.
masked_m
,
state
.
masked_m
,
state
.
expected_m
,
state
.
expected_m
,
)
=
self
.
deepep_dispatcher
.
dispatch_b
()
)
=
self
.
deepep_dispatcher
.
dispatch_b
(
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
def
op_experts
(
self
,
state
):
def
op_experts
(
self
,
state
):
if
self
.
_enable_deepep_moe
:
state
.
pop
(
"router_logits"
)
state
.
hidden_states_experts_output
=
self
.
experts
(
state
.
hidden_states_experts_output
=
self
.
experts
(
hidden_states
=
state
.
pop
(
"hidden_states_experts_input"
),
hidden_states
=
state
.
pop
(
"hidden_states_experts_input"
),
topk_idx
=
state
.
topk_idx_dispatched
,
topk_idx
=
state
.
topk_idx_dispatched
,
...
@@ -507,40 +509,29 @@ class DeepseekV2MoE(nn.Module):
...
@@ -507,40 +509,29 @@ class DeepseekV2MoE(nn.Module):
num_recv_tokens_per_expert
=
state
.
pop
(
"num_recv_tokens_per_expert"
),
num_recv_tokens_per_expert
=
state
.
pop
(
"num_recv_tokens_per_expert"
),
forward_mode
=
state
.
forward_batch
.
forward_mode
,
forward_mode
=
state
.
forward_batch
.
forward_mode
,
)
)
else
:
state
.
hidden_states_experts_output
=
self
.
experts
(
hidden_states
=
state
.
pop
(
"hidden_states_mlp_input"
),
router_logits
=
state
.
pop
(
"router_logits"
),
)
def
op_combine_a
(
self
,
state
):
def
op_combine_a
(
self
,
state
):
if
self
.
_enable_deepep_moe
and
(
self
.
ep_size
>
1
)
:
if
self
.
ep_size
>
1
:
self
.
deepep_dispatcher
.
combine_a
(
self
.
deepep_dispatcher
.
combine_a
(
state
.
pop
(
"hidden_states_experts_output"
),
hidden_states
=
state
.
pop
(
"hidden_states_experts_output"
),
topk_idx
=
state
.
pop
(
"topk_idx_dispatched"
),
topk_idx
=
state
.
pop
(
"topk_idx_dispatched"
),
topk_weights
=
state
.
pop
(
"topk_weights_dispatched"
),
topk_weights
=
state
.
pop
(
"topk_weights_dispatched"
),
forward_mode
=
state
.
forward_batch
.
forward_mode
,
forward_mode
=
state
.
forward_batch
.
forward_mode
,
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
)
def
op_combine_b
(
self
,
state
):
def
op_combine_b
(
self
,
state
):
if
self
.
_enable_deepep_moe
and
(
self
.
ep_size
>
1
):
if
self
.
ep_size
>
1
:
state
.
hidden_states_after_combine
=
self
.
deepep_dispatcher
.
combine_b
()
state
.
hidden_states_after_combine
=
self
.
deepep_dispatcher
.
combine_b
(
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
def
op_output
(
self
,
state
):
final_hidden_states
=
(
state
.
pop
(
"hidden_states_after_combine"
)
if
self
.
_enable_deepep_moe
else
state
.
pop
(
"hidden_states_experts_output"
)
)
)
def
op_output
(
self
,
state
):
final_hidden_states
=
state
.
pop
(
"hidden_states_after_combine"
)
final_hidden_states
*=
self
.
routed_scaling_factor
final_hidden_states
*=
self
.
routed_scaling_factor
if
(
s
:
=
state
.
pop
(
"shared_output"
))
is
not
None
:
if
(
s
:
=
state
.
pop
(
"shared_output"
))
is
not
None
:
final_hidden_states
=
final_hidden_states
+
s
final_hidden_states
=
final_hidden_states
+
s
if
(
not
self
.
_enable_deepep_moe
)
and
(
self
.
tp_size
>
1
):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
state
.
hidden_states_mlp_output
=
final_hidden_states
state
.
hidden_states_mlp_output
=
final_hidden_states
...
@@ -1482,6 +1473,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1482,6 +1473,7 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
zero_allocator
:
BumpAllocator
,
zero_allocator
:
BumpAllocator
,
tbo_subbatch_index
:
Optional
[
int
]
=
None
,
):
):
state
.
hidden_states_after_comm_pre_attn
,
state
.
residual_after_input_ln
=
(
state
.
hidden_states_after_comm_pre_attn
,
state
.
residual_after_input_ln
=
(
self
.
layer_communicator
.
prepare_attn
(
hidden_states
,
residual
,
forward_batch
)
self
.
layer_communicator
.
prepare_attn
(
hidden_states
,
residual
,
forward_batch
)
...
@@ -1491,6 +1483,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1491,6 +1483,7 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
positions
=
positions
,
positions
=
positions
,
zero_allocator
=
zero_allocator
,
zero_allocator
=
zero_allocator
,
tbo_subbatch_index
=
tbo_subbatch_index
,
)
)
)
)
...
@@ -1523,8 +1516,24 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1523,8 +1516,24 @@ class DeepseekV2DecoderLayer(nn.Module):
state
.
forward_batch
,
state
.
forward_batch
,
)
)
state
.
clear
(
expect_keys
=
{
"positions"
,
"forward_batch"
,
"zero_allocator"
})
output
=
dict
(
return
hidden_states
,
residual
positions
=
state
.
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
,
forward_batch
=
state
.
forward_batch
,
zero_allocator
=
state
.
zero_allocator
,
tbo_subbatch_index
=
state
.
tbo_subbatch_index
,
)
state
.
clear
(
expect_keys
=
{
"positions"
,
"forward_batch"
,
"zero_allocator"
,
"tbo_subbatch_index"
,
}
)
return
output
class
DeepseekV2Model
(
nn
.
Module
):
class
DeepseekV2Model
(
nn
.
Module
):
...
@@ -1539,6 +1548,7 @@ class DeepseekV2Model(nn.Module):
...
@@ -1539,6 +1548,7 @@ class DeepseekV2Model(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
padding_id
=
config
.
pad_token_id
self
.
padding_id
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
first_k_dense_replace
=
config
.
first_k_dense_replace
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
vocab_size
,
...
@@ -1572,13 +1582,12 @@ class DeepseekV2Model(nn.Module):
...
@@ -1572,13 +1582,12 @@ class DeepseekV2Model(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
total_num_layers
=
len
(
self
.
layers
)
device
=
input_embeds
.
device
if
input_embeds
is
not
None
else
input_ids
.
device
zero_allocator
=
BumpAllocator
(
zero_allocator
=
BumpAllocator
(
# TODO for two-batch-overlap, we need a larger buffer size
buffer_size
=
total_num_layers
*
2
*
(
2
if
forward_batch
.
can_run_tbo
else
1
),
buffer_size
=
len
(
self
.
layers
)
*
2
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
(
device
=
device
,
input_embeds
.
device
if
input_embeds
is
not
None
else
input_ids
.
device
),
)
)
if
input_embeds
is
None
:
if
input_embeds
is
None
:
...
@@ -1587,12 +1596,30 @@ class DeepseekV2Model(nn.Module):
...
@@ -1587,12 +1596,30 @@ class DeepseekV2Model(nn.Module):
hidden_states
=
input_embeds
hidden_states
=
input_embeds
residual
=
None
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
normal_num_layers
=
(
self
.
first_k_dense_replace
if
forward_batch
.
can_run_tbo
else
total_num_layers
)
for
i
in
range
(
normal_num_layers
):
with
get_global_expert_distribution_recorder
().
with_current_layer
(
i
):
with
get_global_expert_distribution_recorder
().
with_current_layer
(
i
):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
)
)
if
normal_num_layers
!=
total_num_layers
:
hidden_states
,
residual
=
model_forward_maybe_tbo
(
layers
=
self
.
layers
[
normal_num_layers
:],
enable_tbo
=
True
,
positions
=
positions
,
forward_batch
=
forward_batch
,
hidden_states
=
hidden_states
,
residual
=
residual
,
zero_allocator
=
zero_allocator
,
)
if
not
forward_batch
.
forward_mode
.
is_idle
():
if
not
forward_batch
.
forward_mode
.
is_idle
():
if
residual
is
None
:
if
residual
is
None
:
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
=
self
.
norm
(
hidden_states
)
...
@@ -1674,7 +1701,6 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -1674,7 +1701,6 @@ class DeepseekV2ForCausalLM(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
...
...
python/sglang/srt/operations.py
View file @
0d477880
...
@@ -12,7 +12,7 @@ if _ENABLE_PROFILE:
...
@@ -12,7 +12,7 @@ if _ENABLE_PROFILE:
def
execute_operations
(
inputs
,
operations
):
def
execute_operations
(
inputs
,
operations
):
stages
=
_convert_operations_to_stages
(
decorate_
operations
(
operations
)
)
stages
=
_convert_operations_to_stages
(
operations
)
executor
=
_StageExecutor
(
"primary"
,
stages
,
inputs
=
inputs
)
executor
=
_StageExecutor
(
"primary"
,
stages
,
inputs
=
inputs
)
for
_
in
range
(
executor
.
num_stages
):
for
_
in
range
(
executor
.
num_stages
):
executor
.
next
()
executor
.
next
()
...
@@ -20,6 +20,37 @@ def execute_operations(inputs, operations):
...
@@ -20,6 +20,37 @@ def execute_operations(inputs, operations):
return
executor
.
output
return
executor
.
output
def
execute_overlapped_operations
(
inputs_arr
:
Sequence
,
operations_arr
:
Sequence
,
delta_stages
:
Sequence
[
int
],
)
->
Sequence
:
# Make it explicit for clarity; if we need multi-batch overlap, this can be generalized
inputs_a
,
inputs_b
=
inputs_arr
operations_a
,
operations_b
=
operations_arr
delta_stage_a
,
delta_stage_b
=
delta_stages
assert
delta_stage_a
==
0
delta_stage
=
delta_stage_b
stages_a
=
_convert_operations_to_stages
(
operations_a
)
stages_b
=
_convert_operations_to_stages
(
operations_b
)
executor_a
=
_StageExecutor
(
"a"
,
stages_a
,
inputs
=
inputs_a
)
executor_b
=
_StageExecutor
(
"b"
,
stages_b
,
inputs
=
inputs_b
)
for
_
in
range
(
delta_stage
):
executor_a
.
next
()
for
_
in
range
(
executor_a
.
num_stages
-
delta_stage
):
executor_a
.
next
()
executor_b
.
next
()
for
_
in
range
(
delta_stage
):
executor_b
.
next
()
assert
executor_a
.
done
and
executor_b
.
done
return
[
executor_a
.
output
,
executor_b
.
output
]
class
YieldOperation
:
class
YieldOperation
:
pass
pass
...
@@ -109,6 +140,9 @@ class _StateDict:
...
@@ -109,6 +140,9 @@ class _StateDict:
for
k
,
v
in
values
.
items
():
for
k
,
v
in
values
.
items
():
setattr
(
self
,
k
,
v
)
setattr
(
self
,
k
,
v
)
def
get
(
self
,
item
):
return
self
.
_data
.
get
(
item
)
def
clear
(
self
,
expect_keys
:
Sequence
[
str
]):
def
clear
(
self
,
expect_keys
:
Sequence
[
str
]):
if
set
(
self
.
_data
.
keys
())
!=
set
(
expect_keys
):
if
set
(
self
.
_data
.
keys
())
!=
set
(
expect_keys
):
raise
Exception
(
raise
Exception
(
...
@@ -119,6 +153,7 @@ class _StateDict:
...
@@ -119,6 +153,7 @@ class _StateDict:
def
_convert_operations_to_stages
(
operations
:
List
[
Operation
])
->
List
[
Stage
]:
def
_convert_operations_to_stages
(
operations
:
List
[
Operation
])
->
List
[
Stage
]:
operations
=
_decorate_operations
(
operations
)
operation_chunks
=
list
(
operation_chunks
=
list
(
_chunk_by_separator
(
operations
,
lambda
op
:
isinstance
(
op
,
YieldOperation
))
_chunk_by_separator
(
operations
,
lambda
op
:
isinstance
(
op
,
YieldOperation
))
)
)
...
@@ -140,7 +175,7 @@ def _chunk_by_separator(
...
@@ -140,7 +175,7 @@ def _chunk_by_separator(
yield
pending_items
yield
pending_items
def
decorate_operations
(
operations
:
List
[
Operation
],
debug_name_prefix
:
str
=
""
):
def
_
decorate_operations
(
operations
:
List
[
Operation
],
debug_name_prefix
:
str
=
""
):
return
[
_decorate_operation
(
op
,
debug_name_prefix
)
for
op
in
operations
]
return
[
_decorate_operation
(
op
,
debug_name_prefix
)
for
op
in
operations
]
...
...
python/sglang/srt/operations_strategy.py
View file @
0d477880
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
import
torch
import
torch
from
sglang.srt
import
operations
from
sglang.srt.layers.moe.ep_moe.token_dispatcher
import
DeepEPConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.operations
import
Operation
@
dataclass
class
OperationsStrategy
:
operations
:
List
[
Operation
]
deep_gemm_num_sms
:
Optional
[
int
]
=
None
tbo_delta_stages
:
Optional
[
int
]
=
None
@
classmethod
def
concat
(
cls
,
items
:
List
[
"OperationsStrategy"
])
->
"OperationsStrategy"
:
return
OperationsStrategy
(
operations
=
[
x
for
item
in
items
for
x
in
item
.
operations
],
deep_gemm_num_sms
=
_assert_all_same
(
[
item
.
deep_gemm_num_sms
for
item
in
items
]
),
tbo_delta_stages
=
_assert_all_same
(
[
item
.
tbo_delta_stages
for
item
in
items
]
),
)
@
staticmethod
def
init_new_tbo
(
layers
:
torch
.
nn
.
ModuleList
,
forward_mode
:
ForwardMode
,
)
->
"OperationsStrategy"
:
return
OperationsStrategy
.
concat
(
[
_compute_layer_operations_strategy_tbo
(
layer
,
forward_mode
)
for
layer
in
layers
]
)
def
compute_layer_operations
(
def
_assert_all_same
(
items
:
List
):
assert
all
(
item
==
items
[
0
]
for
item
in
items
)
return
items
[
0
]
# TODO can refactor to make it more fancy if we have more complex strategies
def
_compute_layer_operations_strategy_tbo
(
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
):
forward_mode
:
ForwardMode
,
if
not
layer
.
is_layer_sparse
:
)
->
OperationsStrategy
:
return
[
assert
layer
.
is_layer_sparse
,
"dense layer TBO not yet implemented"
if
forward_mode
==
ForwardMode
.
EXTEND
:
return
_compute_moe_deepseek_blog_prefill
(
layer
)
elif
forward_mode
==
ForwardMode
.
DECODE
:
return
_compute_moe_deepseek_blog_decode
(
layer
)
else
:
raise
NotImplementedError
(
f
"Unsupported
{
forward_mode
=
}
"
)
def
_compute_moe_deepseek_blog_prefill
(
layer
):
device_properties
=
torch
.
cuda
.
get_device_properties
(
device
=
"cuda"
)
total_num_sms
=
device_properties
.
multi_processor_count
deep_gemm_num_sms
=
total_num_sms
-
DeepEPConfig
.
get_instance
().
num_sms
return
OperationsStrategy
(
deep_gemm_num_sms
=
deep_gemm_num_sms
,
tbo_delta_stages
=
0
,
operations
=
[
layer
.
op_comm_prepare_attn
,
layer
.
op_comm_prepare_attn
,
layer
.
self_attn
.
op_prepare
,
layer
.
self_attn
.
op_prepare
,
layer
.
self_attn
.
op_core
,
layer
.
self_attn
.
op_core
,
layer
.
op_comm_prepare_mlp
,
layer
.
op_comm_prepare_mlp
,
layer
.
op_mlp
,
layer
.
mlp
.
op_gate
,
layer
.
mlp
.
op_select_experts
,
layer
.
mlp
.
op_dispatch_a
,
operations
.
YieldOperation
(),
layer
.
mlp
.
op_dispatch_b
,
layer
.
mlp
.
op_experts
,
layer
.
mlp
.
op_combine_a
,
operations
.
YieldOperation
(),
layer
.
mlp
.
op_shared_experts
,
layer
.
mlp
.
op_combine_b
,
layer
.
mlp
.
op_output
,
layer
.
op_comm_postprocess_layer
,
layer
.
op_comm_postprocess_layer
,
]
],
)
# Will add TBO operation orders here
def
_compute_moe_deepseek_blog_decode
(
layer
):
return
[
return
OperationsStrategy
(
deep_gemm_num_sms
=
None
,
tbo_delta_stages
=
2
,
operations
=
[
layer
.
op_comm_prepare_attn
,
layer
.
op_comm_prepare_attn
,
layer
.
self_attn
.
op_prepare
,
layer
.
self_attn
.
op_prepare
,
operations
.
YieldOperation
(),
layer
.
self_attn
.
op_core
,
layer
.
self_attn
.
op_core
,
layer
.
op_comm_prepare_mlp
,
layer
.
op_comm_prepare_mlp
,
layer
.
mlp
.
op_gate
,
layer
.
mlp
.
op_gate
,
layer
.
mlp
.
op_shared_experts
,
layer
.
mlp
.
op_select_experts
,
layer
.
mlp
.
op_select_experts
,
operations
.
YieldOperation
(),
layer
.
mlp
.
op_dispatch_a
,
layer
.
mlp
.
op_dispatch_a
,
layer
.
mlp
.
op_shared_experts
,
operations
.
YieldOperation
(),
layer
.
mlp
.
op_dispatch_b
,
layer
.
mlp
.
op_dispatch_b
,
layer
.
mlp
.
op_experts
,
layer
.
mlp
.
op_experts
,
layer
.
mlp
.
op_combine_a
,
layer
.
mlp
.
op_combine_a
,
operations
.
YieldOperation
(),
layer
.
mlp
.
op_combine_b
,
layer
.
mlp
.
op_combine_b
,
layer
.
mlp
.
op_output
,
layer
.
mlp
.
op_output
,
layer
.
op_comm_postprocess_layer
,
layer
.
op_comm_postprocess_layer
,
]
operations
.
YieldOperation
(),
],
)
python/sglang/srt/server_args.py
View file @
0d477880
...
@@ -167,6 +167,7 @@ class ServerArgs:
...
@@ -167,6 +167,7 @@ class ServerArgs:
enable_mixed_chunk
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_dp_attention
:
bool
=
False
enable_dp_attention
:
bool
=
False
enable_dp_lm_head
:
bool
=
False
enable_dp_lm_head
:
bool
=
False
enable_two_batch_overlap
:
bool
=
False
enable_ep_moe
:
bool
=
False
enable_ep_moe
:
bool
=
False
enable_deepep_moe
:
bool
=
False
enable_deepep_moe
:
bool
=
False
deepep_mode
:
Optional
[
Literal
[
"auto"
,
"normal"
,
"low_latency"
]]
=
"auto"
deepep_mode
:
Optional
[
Literal
[
"auto"
,
"normal"
,
"low_latency"
]]
=
"auto"
...
@@ -1144,6 +1145,11 @@ class ServerArgs:
...
@@ -1144,6 +1145,11 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Enabling expert parallelism for moe. The ep size is equal to the tp size."
,
help
=
"Enabling expert parallelism for moe. The ep size is equal to the tp size."
,
)
)
parser
.
add_argument
(
"--enable-two-batch-overlap"
,
action
=
"store_true"
,
help
=
"Enabling two micro batches to overlap."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-torch-compile"
,
"--enable-torch-compile"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
...
python/sglang/srt/two_batch_overlap.py
0 → 100644
View file @
0d477880
import
dataclasses
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
import
torch
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.moe.ep_moe.token_dispatcher
import
DeepEPDispatcher
from
sglang.srt.layers.quantization.deep_gemm
import
configure_deep_gemm_num_sms
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.operations
import
execute_operations
,
execute_overlapped_operations
from
sglang.srt.operations_strategy
import
OperationsStrategy
from
sglang.srt.utils
import
BumpAllocator
,
DeepEPMode
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
# -------------------------------- Compute Basic Info ---------------------------------------
# TODO: may smartly disable TBO when batch size is too small b/c it will slow down
def
compute_split_seq_index
(
forward_mode
:
"ForwardMode"
,
num_tokens
:
int
,
extend_lens
:
Optional
[
Sequence
[
int
]],
)
->
Optional
[
int
]:
if
forward_mode
.
is_extend
():
assert
extend_lens
is
not
None
return
_split_array_by_half_sum
(
extend_lens
)
elif
forward_mode
.
is_decode
():
return
num_tokens
//
2
elif
forward_mode
.
is_idle
():
assert
num_tokens
==
0
return
0
else
:
raise
NotImplementedError
def
_split_array_by_half_sum
(
arr
:
Sequence
[
int
])
->
int
:
overall_sum
=
sum
(
arr
)
accumulator
,
split_index
=
0
,
0
for
value
in
arr
[:
-
1
]:
accumulator
+=
value
split_index
+=
1
if
accumulator
>=
overall_sum
//
2
:
break
return
split_index
def
compute_split_token_index
(
split_seq_index
:
int
,
forward_mode
:
"ForwardMode"
,
extend_seq_lens
:
Optional
[
Sequence
[
int
]],
)
->
int
:
if
forward_mode
.
is_extend
():
assert
extend_seq_lens
is
not
None
return
sum
(
extend_seq_lens
[:
split_seq_index
])
elif
forward_mode
.
is_decode
():
return
split_seq_index
elif
forward_mode
.
is_idle
():
assert
split_seq_index
==
0
return
0
else
:
raise
NotImplementedError
# -------------------------------- Preparation ---------------------------------------
class
TboCudaGraphRunnerUtils
:
@
staticmethod
def
compute_tbo_split_seq_index
(
that
:
"CudaGraphRunner"
,
num_tokens
:
int
):
if
that
.
model_runner
.
server_args
.
enable_two_batch_overlap
:
tbo_split_seq_index
=
compute_split_seq_index
(
forward_mode
=
that
.
capture_forward_mode
,
num_tokens
=
num_tokens
,
extend_lens
=
None
,
)
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
assert
(
tbo_split_seq_index
is
not
None
),
f
"
{
that
.
capture_forward_mode
=
}
{
num_tokens
=
}
"
else
:
tbo_split_seq_index
=
None
return
tbo_split_seq_index
class
TboDPAttentionPreparer
:
def
prepare_all_gather
(
self
,
local_batch
,
deepep_mode
,
enable_deepep_moe
,
enable_two_batch_overlap
):
self
.
enable_two_batch_overlap
=
enable_two_batch_overlap
if
local_batch
is
not
None
:
self
.
local_tbo_split_seq_index
=
compute_split_seq_index
(
forward_mode
=
local_batch
.
forward_mode
,
num_tokens
=
local_batch
.
input_ids
.
shape
[
0
],
extend_lens
=
local_batch
.
extend_lens
,
)
resolved_deepep_mode
=
deepep_mode
.
resolve
(
local_batch
.
forward_mode
)
local_can_run_tbo
=
(
self
.
local_tbo_split_seq_index
is
not
None
)
and
not
(
local_batch
.
forward_mode
.
is_extend
()
and
enable_deepep_moe
and
(
resolved_deepep_mode
==
DeepEPMode
.
low_latency
)
)
else
:
self
.
local_tbo_split_seq_index
=
0
local_can_run_tbo
=
True
local_forward_mode
=
self
.
_compute_local_forward_mode
(
local_batch
)
return
local_can_run_tbo
,
local_forward_mode
def
compute_output
(
self
,
partial_global_info
):
local_can_run_tbo_aggregated
=
min
(
partial_global_info
[:,
0
,
0
].
tolist
())
forward_modes
=
partial_global_info
[:,
0
,
1
].
tolist
()
global_forward_mode
,
forward_mode_agree
=
self
.
_compute_global_forward_mode
(
forward_modes
)
can_run_tbo
=
(
self
.
enable_two_batch_overlap
and
local_can_run_tbo_aggregated
and
forward_mode_agree
)
tbo_split_seq_index
=
self
.
local_tbo_split_seq_index
if
can_run_tbo
else
None
global_forward_mode
=
global_forward_mode
if
can_run_tbo
else
None
return
tbo_split_seq_index
,
global_forward_mode
@
staticmethod
def
_compute_local_forward_mode
(
local_batch
):
return
(
local_batch
.
forward_mode
if
local_batch
is
not
None
else
ForwardMode
.
IDLE
).
value
@
staticmethod
def
_compute_global_forward_mode
(
forward_modes
):
converted_forward_modes
=
[
ForwardMode
.
DECODE
.
value
if
x
==
ForwardMode
.
IDLE
.
value
else
x
for
x
in
forward_modes
]
forward_mode_agree
=
TboDPAttentionPreparer
.
_is_all_same
(
converted_forward_modes
)
global_forward_mode
=
(
ForwardMode
(
converted_forward_modes
[
0
])
if
forward_mode_agree
else
None
)
return
global_forward_mode
,
forward_mode_agree
@
staticmethod
def
_is_all_same
(
x
):
return
all
(
value
==
x
[
0
]
for
value
in
x
)
class
TboForwardBatchPreparer
:
@
classmethod
def
prepare
(
cls
,
batch
:
ForwardBatch
):
from
sglang.srt.layers.attention.tbo_backend
import
TboAttnBackend
if
batch
.
tbo_split_seq_index
is
None
:
return
tbo_split_token_index
=
compute_split_token_index
(
split_seq_index
=
batch
.
tbo_split_seq_index
,
forward_mode
=
batch
.
forward_mode
,
extend_seq_lens
=
batch
.
extend_seq_lens_cpu
,
)
assert
isinstance
(
batch
.
attn_backend
,
TboAttnBackend
)
attn_backend_child_a
,
attn_backend_child_b
=
batch
.
attn_backend
.
children
child_a
=
cls
.
filter_batch
(
batch
,
start_token_index
=
0
,
end_token_index
=
tbo_split_token_index
,
start_seq_index
=
0
,
end_seq_index
=
batch
.
tbo_split_seq_index
,
output_attn_backend
=
attn_backend_child_a
,
)
child_b
=
cls
.
filter_batch
(
batch
,
start_token_index
=
tbo_split_token_index
,
end_token_index
=
batch
.
input_ids
.
shape
[
0
],
start_seq_index
=
batch
.
tbo_split_seq_index
,
end_seq_index
=
batch
.
batch_size
,
output_attn_backend
=
attn_backend_child_b
,
)
assert
batch
.
tbo_children
is
None
batch
.
tbo_children
=
[
child_a
,
child_b
]
@
classmethod
def
filter_batch
(
cls
,
batch
:
ForwardBatch
,
*
,
start_token_index
:
int
,
end_token_index
:
int
,
start_seq_index
:
int
,
end_seq_index
:
int
,
output_attn_backend
:
AttentionBackend
,
):
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
num_tokens
=
batch
.
input_ids
.
shape
[
0
]
num_seqs
=
batch
.
batch_size
output_dict
=
dict
()
for
key
in
[
"input_ids"
,
"positions"
,
"out_cache_loc"
,
]:
old_value
=
getattr
(
batch
,
key
)
assert
(
old_value
.
shape
[
0
]
==
num_tokens
),
f
"
{
key
=
}
{
old_value
=
}
{
num_tokens
=
}
{
batch
=
}
"
output_dict
[
key
]
=
old_value
[
start_token_index
:
end_token_index
]
for
key
in
[
"req_pool_indices"
,
"seq_lens"
,
"seq_lens_cpu"
,
"extend_seq_lens"
,
"extend_prefix_lens"
,
"extend_start_loc"
,
"extend_prefix_lens_cpu"
,
"extend_seq_lens_cpu"
,
"extend_logprob_start_lens_cpu"
,
"lora_paths"
,
]:
old_value
=
getattr
(
batch
,
key
)
if
old_value
is
None
:
continue
assert
(
len
(
old_value
)
==
num_seqs
),
f
"
{
key
=
}
{
old_value
=
}
{
num_seqs
=
}
{
batch
=
}
"
output_dict
[
key
]
=
old_value
[
start_seq_index
:
end_seq_index
]
for
key
in
[
"forward_mode"
,
"return_logprob"
,
"req_to_token_pool"
,
"token_to_kv_pool"
,
"can_run_dp_cuda_graph"
,
"global_forward_mode"
,
"spec_info"
,
"spec_algorithm"
,
"capture_hidden_mode"
,
"padded_static_len"
,
"mrope_positions"
,
# only used by qwen2-vl, thus not care
]:
output_dict
[
key
]
=
getattr
(
batch
,
key
)
assert
(
_compute_extend_num_tokens
(
batch
.
input_ids
,
batch
.
forward_mode
)
==
batch
.
extend_num_tokens
),
f
"
{
batch
=
}
"
extend_num_tokens
=
_compute_extend_num_tokens
(
output_dict
[
"input_ids"
],
output_dict
[
"forward_mode"
]
)
# TODO improve, e.g. unify w/ `init_raw`
if
global_server_args_dict
[
"moe_dense_tp_size"
]
==
1
:
sum_len
=
end_token_index
-
start_token_index
gathered_buffer
=
torch
.
zeros
(
(
sum_len
,
batch
.
gathered_buffer
.
shape
[
1
]),
dtype
=
batch
.
gathered_buffer
.
dtype
,
device
=
batch
.
gathered_buffer
.
device
,
)
else
:
gathered_buffer
=
None
output_dict
.
update
(
dict
(
batch_size
=
end_seq_index
-
start_seq_index
,
seq_lens_sum
=
(
output_dict
[
"seq_lens_cpu"
].
sum
()
if
"seq_lens_cpu"
in
output_dict
else
None
),
extend_num_tokens
=
extend_num_tokens
,
attn_backend
=
output_attn_backend
,
tbo_split_seq_index
=
None
,
tbo_parent_token_range
=
(
start_token_index
,
end_token_index
),
tbo_children
=
None
,
global_num_tokens_gpu
=
None
,
global_num_tokens_cpu
=
None
,
gathered_buffer
=
gathered_buffer
,
global_num_tokens_for_logprob_gpu
=
None
,
global_num_tokens_for_logprob_cpu
=
None
,
sampling_info
=
None
,
# For logits and logprobs post processing, thus we do not care
temp_scaled_logprobs
=
False
,
temperature
=
None
,
top_p_normalized_logprobs
=
False
,
top_p
=
None
,
mm_inputs
=
None
,
num_token_non_padded
=
None
,
)
)
errors
=
[]
for
field
in
dataclasses
.
fields
(
ForwardBatch
):
if
getattr
(
batch
,
field
.
name
)
is
not
None
and
field
.
name
not
in
output_dict
:
errors
.
append
(
f
"Field
{
field
.
name
}
has value, but is not yet supported (value=
{
getattr
(
batch
,
field
.
name
)
}
batch=
{
batch
}
)"
)
if
len
(
errors
)
>
0
:
raise
Exception
(
f
"
{
len
(
errors
)
}
errors happen:
\n
"
+
"
\n\n
"
.
join
(
errors
))
return
ForwardBatch
(
**
output_dict
)
def
_compute_extend_num_tokens
(
input_ids
,
forward_mode
:
ForwardMode
):
if
forward_mode
.
is_extend
():
return
input_ids
.
shape
[
0
]
elif
forward_mode
.
is_decode
()
or
forward_mode
.
is_idle
():
return
None
raise
NotImplementedError
# -------------------------------- Execution ---------------------------------------
def
model_forward_maybe_tbo
(
layers
,
enable_tbo
:
bool
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
zero_allocator
:
BumpAllocator
,
):
inputs
=
dict
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
residual
=
residual
,
zero_allocator
=
zero_allocator
,
)
operations_strategy
=
OperationsStrategy
.
init_new_tbo
(
layers
,
forward_batch
.
global_forward_mode
)
if
enable_tbo
:
return
_model_forward_tbo
(
inputs
,
operations_strategy
)
else
:
return
_model_forward_non_tbo
(
inputs
,
operations_strategy
)
def
_model_forward_tbo
(
inputs
,
operations_strategy
:
OperationsStrategy
):
# The attn_tp_size!=1 case is not yet extracted to master
assert
get_attention_tp_size
()
==
1
inputs_arr
=
_model_forward_tbo_split_inputs
(
**
inputs
)
del
inputs
with
configure_deep_gemm_num_sms
(
operations_strategy
.
deep_gemm_num_sms
):
outputs_arr
=
execute_overlapped_operations
(
inputs_arr
=
inputs_arr
,
operations_arr
=
[
operations_strategy
.
operations
]
*
2
,
delta_stages
=
[
0
,
operations_strategy
.
tbo_delta_stages
],
)
return
_model_forward_tbo_merge_outputs
(
*
outputs_arr
)
def
_model_forward_non_tbo
(
inputs
,
operations_strategy
:
OperationsStrategy
):
outputs
=
execute_operations
(
inputs
,
operations_strategy
.
operations
)
return
outputs
[
"hidden_states"
],
outputs
[
"residual"
]
def
_model_forward_tbo_split_inputs
(
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
)
->
List
[
Dict
]:
return
[
dict
(
**
_model_forward_filter_inputs
(
hidden_states
=
hidden_states
,
residual
=
residual
,
positions
=
positions
,
output_forward_batch
=
output_forward_batch
,
tbo_subbatch_index
=
tbo_subbatch_index
,
),
zero_allocator
=
zero_allocator
,
)
for
tbo_subbatch_index
,
output_forward_batch
in
enumerate
(
forward_batch
.
tbo_children
)
]
def
_model_forward_filter_inputs
(
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
output_forward_batch
:
ForwardBatch
,
tbo_subbatch_index
:
int
,
)
->
Dict
:
token_slice
=
slice
(
*
output_forward_batch
.
tbo_parent_token_range
)
return
dict
(
hidden_states
=
hidden_states
[
token_slice
],
residual
=
None
if
residual
is
None
else
residual
[
token_slice
],
positions
=
positions
[
token_slice
],
forward_batch
=
output_forward_batch
,
tbo_subbatch_index
=
tbo_subbatch_index
,
)
def
_model_forward_tbo_merge_outputs
(
output_a
,
output_b
):
def
_handle_key
(
name
):
value_a
=
output_a
[
name
]
value_b
=
output_b
[
name
]
assert
(
value_a
is
None
)
==
(
value_b
is
None
)
if
value_a
is
None
:
return
None
return
torch
.
concat
([
value_a
,
value_b
],
dim
=
0
)
return
_handle_key
(
"hidden_states"
),
_handle_key
(
"residual"
)
# -------------------------------- Utilities and wrappers ---------------------------------------
class
MaybeTboDeepEPDispatcher
:
def
__init__
(
self
,
**
kwargs
):
num_inner_dispatchers
=
(
2
if
global_server_args_dict
[
"enable_two_batch_overlap"
]
else
1
)
self
.
_inners
=
[
DeepEPDispatcher
(
**
kwargs
)
for
_
in
range
(
num_inner_dispatchers
)
]
def
_execute
(
self
,
name
,
tbo_subbatch_index
:
Optional
[
int
]
=
None
,
**
kwargs
):
return
getattr
(
self
.
_inners
[
tbo_subbatch_index
or
0
],
name
)(
**
kwargs
)
def
dispatch
(
self
,
**
kwargs
):
return
self
.
_execute
(
"dispatch"
,
**
kwargs
)
def
dispatch_a
(
self
,
**
kwargs
):
return
self
.
_execute
(
"dispatch_a"
,
**
kwargs
)
def
dispatch_b
(
self
,
**
kwargs
):
return
self
.
_execute
(
"dispatch_b"
,
**
kwargs
)
def
combine
(
self
,
**
kwargs
):
return
self
.
_execute
(
"combine"
,
**
kwargs
)
def
combine_a
(
self
,
**
kwargs
):
return
self
.
_execute
(
"combine_a"
,
**
kwargs
)
def
combine_b
(
self
,
**
kwargs
):
return
self
.
_execute
(
"combine_b"
,
**
kwargs
)
test/srt/test_two_batch_overlap.py
0 → 100644
View file @
0d477880
import
os
import
unittest
from
types
import
SimpleNamespace
import
requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MLA_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
class
TestTwoBatchOverlap
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--trust-remote-code"
,
"--tp"
,
"2"
,
"--dp"
,
"2"
,
"--enable-dp-attention"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"normal"
,
"--disable-cuda-graph"
,
# DeepEP normal does not support CUDA Graph
"--enable-two-batch-overlap"
,
],
env
=
{
"SGL_ENABLE_JIT_DEEPGEMM"
:
"0"
,
**
os
.
environ
},
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_generate_single_prompt
(
self
):
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
# we use an uncommon start to minimise the chance that the cache is hit by chance
json
=
{
"text"
:
"_ 1+1=2, 1+2=3, 1+3=4, 1+4="
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
8
},
},
)
print
(
f
"
{
response
.
json
()
=
}
"
)
self
.
assertEquals
(
response
.
json
()[
"text"
],
"5, 1+5=6"
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
metrics
=
run_eval
(
args
)
self
.
assertGreater
(
metrics
[
"score"
],
0.5
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment