Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
0d477880
"next_docs/en/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "3cd51d494144ea0a940b5a1d838ab3fc876c3002"
Unverified
Commit
0d477880
authored
May 25, 2025
by
fzyzcjy
Committed by
GitHub
May 24, 2025
Browse files
Support overlapping two batches (#4068)
parent
f4560373
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
1146 additions
and
130 deletions
+1146
-130
python/sglang/srt/layers/attention/tbo_backend.py
python/sglang/srt/layers/attention/tbo_backend.py
+241
-0
python/sglang/srt/layers/quantization/deep_gemm.py
python/sglang/srt/layers/quantization/deep_gemm.py
+13
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+8
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+25
-1
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+21
-1
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+18
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+18
-9
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+118
-92
python/sglang/srt/operations.py
python/sglang/srt/operations.py
+37
-2
python/sglang/srt/operations_strategy.py
python/sglang/srt/operations_strategy.py
+107
-24
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+462
-0
test/srt/test_two_batch_overlap.py
test/srt/test_two_batch_overlap.py
+72
-0
No files found.
python/sglang/srt/layers/attention/tbo_backend.py
0 → 100644
View file @
0d477880
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
,
Union
import
torch
from
sglang.srt
import
two_batch_overlap
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
class
TboAttnBackend
(
AttentionBackend
):
def
__init__
(
self
,
primary
:
AttentionBackend
,
children
:
List
[
AttentionBackend
]):
super
().
__init__
()
self
.
primary
=
primary
self
.
children
=
children
@
classmethod
def
init_new
(
cls
,
creator
:
Callable
[[],
AttentionBackend
]):
return
cls
(
primary
=
creator
(),
children
=
[
creator
()
for
_
in
range
(
2
)],
)
def
init_forward_metadata
(
self
,
forward_batch
:
"ForwardBatch"
):
self
.
primary
.
init_forward_metadata
(
forward_batch
=
forward_batch
)
if
forward_batch
.
tbo_children
is
not
None
:
for
child
,
forward_batch_child
in
zip
(
self
.
children
,
forward_batch
.
tbo_children
,
strict
=
True
):
if
forward_batch_child
.
batch_size
>
0
:
child
.
init_forward_metadata
(
forward_batch
=
forward_batch_child
)
def
init_cuda_graph_state
(
self
,
max_bs
:
int
):
self
.
primary
.
init_cuda_graph_state
(
max_bs
=
max_bs
)
for
item
in
self
.
children
:
# TODO for children, maybe can provide *smaller* max_bs to optimize
item
.
init_cuda_graph_state
(
max_bs
=
max_bs
)
def
init_forward_metadata_capture_cuda_graph
(
self
,
bs
:
int
,
num_tokens
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
"ForwardMode"
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
):
self
.
primary
.
init_forward_metadata_capture_cuda_graph
(
bs
=
bs
,
num_tokens
=
num_tokens
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
encoder_lens
=
encoder_lens
,
forward_mode
=
forward_mode
,
spec_info
=
spec_info
,
)
self
.
_init_forward_metadata_cuda_graph_children
(
fn_name
=
"init_forward_metadata_capture_cuda_graph"
,
bs
=
bs
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
encoder_lens
=
encoder_lens
,
forward_mode
=
forward_mode
,
spec_info
=
spec_info
,
capture_num_tokens
=
num_tokens
,
)
def
init_forward_metadata_replay_cuda_graph
(
self
,
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
seq_lens_sum
:
int
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
"ForwardMode"
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
seq_lens_cpu
:
Optional
[
torch
.
Tensor
],
):
self
.
primary
.
init_forward_metadata_replay_cuda_graph
(
bs
=
bs
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
seq_lens_sum
=
seq_lens_sum
,
encoder_lens
=
encoder_lens
,
forward_mode
=
forward_mode
,
spec_info
=
spec_info
,
seq_lens_cpu
=
seq_lens_cpu
,
)
self
.
_init_forward_metadata_cuda_graph_children
(
fn_name
=
"init_forward_metadata_replay_cuda_graph"
,
bs
=
bs
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
encoder_lens
=
encoder_lens
,
forward_mode
=
forward_mode
,
spec_info
=
spec_info
,
replay_seq_lens_sum
=
seq_lens_sum
,
replay_seq_lens_cpu
=
seq_lens_cpu
,
)
def
_init_forward_metadata_cuda_graph_children
(
self
,
fn_name
:
str
,
# common args
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
"ForwardMode"
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
# capture args
capture_num_tokens
:
int
=
None
,
# replay args
replay_seq_lens_sum
:
int
=
None
,
replay_seq_lens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
,
):
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
if
fn_name
==
"init_forward_metadata_capture_cuda_graph"
:
assert
capture_num_tokens
==
bs
,
"Only support num_tokens==bs currently"
num_tokens
=
bs
forward_mode_for_tbo_split
=
(
forward_mode
if
forward_mode
!=
ForwardMode
.
IDLE
else
ForwardMode
.
DECODE
)
tbo_split_seq_index
=
two_batch_overlap
.
compute_split_seq_index
(
forward_mode
=
forward_mode_for_tbo_split
,
num_tokens
=
num_tokens
,
extend_lens
=
None
,
)
tbo_split_token_index
=
two_batch_overlap
.
compute_split_token_index
(
split_seq_index
=
tbo_split_seq_index
,
forward_mode
=
forward_mode_for_tbo_split
,
extend_seq_lens
=
None
,
)
num_tokens_child_left
=
tbo_split_token_index
num_tokens_child_right
=
num_tokens
-
tbo_split_token_index
bs_child_left
=
num_tokens_child_left
bs_child_right
=
num_tokens_child_right
assert
(
num_tokens_child_left
>
0
and
num_tokens_child_right
>
0
),
f
"
{
num_tokens_child_left
=
}
{
num_tokens_child_right
=
}
{
forward_mode
=
}
{
num_tokens
=
}
"
common_pre_split_args
=
dict
(
fn_name
=
fn_name
,
bs
=
bs
,
req_pool_indices
=
req_pool_indices
,
seq_lens
=
seq_lens
,
encoder_lens
=
encoder_lens
,
forward_mode
=
forward_mode
,
spec_info
=
spec_info
,
capture_num_tokens
=
capture_num_tokens
,
replay_seq_lens_sum
=
replay_seq_lens_sum
,
replay_seq_lens_cpu
=
replay_seq_lens_cpu
,
)
args_left
=
_init_forward_metadata_cuda_graph_split
(
output_bs
=
bs_child_left
,
seq_slice
=
slice
(
None
,
tbo_split_seq_index
),
**
common_pre_split_args
,
)
args_right
=
_init_forward_metadata_cuda_graph_split
(
output_bs
=
bs_child_right
,
seq_slice
=
slice
(
tbo_split_seq_index
,
None
),
**
common_pre_split_args
,
)
child_left
,
child_right
=
self
.
children
getattr
(
child_left
,
fn_name
)(
**
args_left
)
getattr
(
child_right
,
fn_name
)(
**
args_right
)
def
get_cuda_graph_seq_len_fill_value
(
self
):
ans
=
self
.
primary
.
get_cuda_graph_seq_len_fill_value
()
for
child
in
self
.
children
:
assert
ans
==
child
.
get_cuda_graph_seq_len_fill_value
()
return
ans
def
forward_extend
(
self
,
*
args
,
**
kwargs
):
return
self
.
primary
.
forward_extend
(
*
args
,
**
kwargs
)
def
forward_decode
(
self
,
*
args
,
**
kwargs
):
return
self
.
primary
.
forward_decode
(
*
args
,
**
kwargs
)
def
_init_forward_metadata_cuda_graph_split
(
fn_name
:
str
,
seq_slice
:
slice
,
output_bs
:
int
,
# common args
bs
:
int
,
req_pool_indices
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
encoder_lens
:
Optional
[
torch
.
Tensor
],
forward_mode
:
"ForwardMode"
,
spec_info
:
Optional
[
Union
[
EagleDraftInput
,
EagleVerifyInput
]],
# capture args
capture_num_tokens
:
int
=
None
,
# replay args
replay_seq_lens_sum
:
int
=
None
,
replay_seq_lens_cpu
:
Optional
[
torch
.
Tensor
]
=
None
,
):
assert
encoder_lens
is
None
,
"encoder_lens is not supported yet"
assert
spec_info
is
None
,
"spec_info is not supported yet"
ans
=
dict
(
bs
=
output_bs
,
req_pool_indices
=
req_pool_indices
[
seq_slice
],
seq_lens
=
seq_lens
[
seq_slice
],
# directly forward
forward_mode
=
forward_mode
,
# ignore
encoder_lens
=
None
,
spec_info
=
None
,
)
if
fn_name
==
"init_forward_metadata_capture_cuda_graph"
:
assert
capture_num_tokens
==
bs
,
"Only support num_tokens==bs currently"
ans
.
update
(
dict
(
num_tokens
=
output_bs
,
)
)
elif
fn_name
==
"init_forward_metadata_replay_cuda_graph"
:
output_seq_lens_cpu
=
replay_seq_lens_cpu
[
seq_slice
]
ans
.
update
(
dict
(
seq_lens_sum
=
output_seq_lens_cpu
.
sum
().
item
(),
seq_lens_cpu
=
output_seq_lens_cpu
,
)
)
else
:
raise
NotImplementedError
return
ans
python/sglang/srt/layers/quantization/deep_gemm.py
View file @
0d477880
...
@@ -391,3 +391,16 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
...
@@ -391,3 +391,16 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
RuntimeCache
.
get
=
__patched_func
RuntimeCache
.
get
=
__patched_func
yield
yield
RuntimeCache
.
get
=
origin_func
RuntimeCache
.
get
=
origin_func
@
contextmanager
def
configure_deep_gemm_num_sms
(
num_sms
):
if
num_sms
is
None
:
yield
else
:
original_num_sms
=
deep_gemm
.
get_num_sms
()
deep_gemm
.
set_num_sms
(
num_sms
)
try
:
yield
finally
:
deep_gemm
.
set_num_sms
(
original_num_sms
)
python/sglang/srt/managers/schedule_batch.py
View file @
0d477880
...
@@ -78,6 +78,7 @@ global_server_args_dict = {
...
@@ -78,6 +78,7 @@ global_server_args_dict = {
"disable_radix_cache"
:
ServerArgs
.
disable_radix_cache
,
"disable_radix_cache"
:
ServerArgs
.
disable_radix_cache
,
"enable_deepep_moe"
:
ServerArgs
.
enable_deepep_moe
,
"enable_deepep_moe"
:
ServerArgs
.
enable_deepep_moe
,
"enable_dp_attention"
:
ServerArgs
.
enable_dp_attention
,
"enable_dp_attention"
:
ServerArgs
.
enable_dp_attention
,
"enable_two_batch_overlap"
:
ServerArgs
.
enable_two_batch_overlap
,
"enable_dp_lm_head"
:
ServerArgs
.
enable_dp_lm_head
,
"enable_dp_lm_head"
:
ServerArgs
.
enable_dp_lm_head
,
"enable_ep_moe"
:
ServerArgs
.
enable_ep_moe
,
"enable_ep_moe"
:
ServerArgs
.
enable_ep_moe
,
"deepep_config"
:
ServerArgs
.
deepep_config
,
"deepep_config"
:
ServerArgs
.
deepep_config
,
...
@@ -831,6 +832,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -831,6 +832,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
global_num_tokens_for_logprob
:
Optional
[
List
[
int
]]
=
None
global_num_tokens_for_logprob
:
Optional
[
List
[
int
]]
=
None
can_run_dp_cuda_graph
:
bool
=
False
can_run_dp_cuda_graph
:
bool
=
False
tbo_split_seq_index
:
Optional
[
int
]
=
None
global_forward_mode
:
Optional
[
ForwardMode
]
=
None
# For processing logprobs
# For processing logprobs
return_logprob
:
bool
=
False
return_logprob
:
bool
=
False
...
@@ -1624,6 +1627,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1624,6 +1627,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
or
global_server_args_dict
[
"attention_backend"
]
==
"flashmla"
or
global_server_args_dict
[
"attention_backend"
]
==
"flashmla"
or
global_server_args_dict
[
"attention_backend"
]
==
"fa3"
or
global_server_args_dict
[
"attention_backend"
]
==
"fa3"
or
global_server_args_dict
[
"attention_backend"
]
==
"cutlass_mla"
or
global_server_args_dict
[
"attention_backend"
]
==
"cutlass_mla"
or
global_server_args_dict
[
"enable_two_batch_overlap"
]
):
):
seq_lens_cpu
=
self
.
seq_lens
.
cpu
()
seq_lens_cpu
=
self
.
seq_lens
.
cpu
()
else
:
else
:
...
@@ -1651,6 +1655,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -1651,6 +1655,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
global_num_tokens
=
self
.
global_num_tokens
,
global_num_tokens
=
self
.
global_num_tokens
,
global_num_tokens_for_logprob
=
self
.
global_num_tokens_for_logprob
,
global_num_tokens_for_logprob
=
self
.
global_num_tokens_for_logprob
,
can_run_dp_cuda_graph
=
self
.
can_run_dp_cuda_graph
,
can_run_dp_cuda_graph
=
self
.
can_run_dp_cuda_graph
,
tbo_split_seq_index
=
self
.
tbo_split_seq_index
,
global_forward_mode
=
self
.
global_forward_mode
,
seq_lens_cpu
=
seq_lens_cpu
,
seq_lens_cpu
=
seq_lens_cpu
,
extend_num_tokens
=
self
.
extend_num_tokens
,
extend_num_tokens
=
self
.
extend_num_tokens
,
extend_seq_lens
=
extend_seq_lens
,
extend_seq_lens
=
extend_seq_lens
,
...
@@ -1729,6 +1735,8 @@ class ModelWorkerBatch:
...
@@ -1729,6 +1735,8 @@ class ModelWorkerBatch:
global_num_tokens
:
Optional
[
List
[
int
]]
global_num_tokens
:
Optional
[
List
[
int
]]
global_num_tokens_for_logprob
:
Optional
[
List
[
int
]]
global_num_tokens_for_logprob
:
Optional
[
List
[
int
]]
can_run_dp_cuda_graph
:
bool
can_run_dp_cuda_graph
:
bool
tbo_split_seq_index
:
Optional
[
int
]
global_forward_mode
:
Optional
[
ForwardMode
]
# For extend
# For extend
extend_num_tokens
:
Optional
[
int
]
extend_num_tokens
:
Optional
[
int
]
...
...
python/sglang/srt/managers/scheduler.py
View file @
0d477880
...
@@ -34,6 +34,7 @@ import zmq
...
@@ -34,6 +34,7 @@ import zmq
from
torch.distributed
import
barrier
from
torch.distributed
import
barrier
from
sglang.global_config
import
global_config
from
sglang.global_config
import
global_config
from
sglang.srt
import
two_batch_overlap
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.configs.model_config
import
ModelConfig
from
sglang.srt.constrained.base_grammar_backend
import
create_grammar_backend
from
sglang.srt.constrained.base_grammar_backend
import
create_grammar_backend
from
sglang.srt.disaggregation.decode
import
(
from
sglang.srt.disaggregation.decode
import
(
...
@@ -132,7 +133,9 @@ from sglang.srt.reasoning_parser import ReasoningParser
...
@@ -132,7 +133,9 @@ from sglang.srt.reasoning_parser import ReasoningParser
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
from
sglang.srt.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
from
sglang.srt.two_batch_overlap
import
TboDPAttentionPreparer
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
DeepEPMode
,
DynamicGradMode
,
DynamicGradMode
,
broadcast_pyobj
,
broadcast_pyobj
,
configure_logger
,
configure_logger
,
...
@@ -1648,6 +1651,9 @@ class Scheduler(
...
@@ -1648,6 +1651,9 @@ class Scheduler(
disable_cuda_graph
=
self
.
server_args
.
disable_cuda_graph
,
disable_cuda_graph
=
self
.
server_args
.
disable_cuda_graph
,
spec_algorithm
=
self
.
spec_algorithm
,
spec_algorithm
=
self
.
spec_algorithm
,
speculative_num_draft_tokens
=
self
.
server_args
.
speculative_num_draft_tokens
,
speculative_num_draft_tokens
=
self
.
server_args
.
speculative_num_draft_tokens
,
enable_two_batch_overlap
=
self
.
server_args
.
enable_two_batch_overlap
,
enable_deepep_moe
=
self
.
server_args
.
enable_deepep_moe
,
deepep_mode
=
DeepEPMode
[
self
.
server_args
.
deepep_mode
],
)
)
@
staticmethod
@
staticmethod
...
@@ -1661,6 +1667,9 @@ class Scheduler(
...
@@ -1661,6 +1667,9 @@ class Scheduler(
disable_cuda_graph
:
bool
,
disable_cuda_graph
:
bool
,
spec_algorithm
,
spec_algorithm
,
speculative_num_draft_tokens
,
speculative_num_draft_tokens
,
enable_two_batch_overlap
:
bool
,
enable_deepep_moe
:
bool
,
deepep_mode
:
DeepEPMode
,
):
):
# Check if other DP workers have running batches
# Check if other DP workers have running batches
if
local_batch
is
None
:
if
local_batch
is
None
:
...
@@ -1696,17 +1705,26 @@ class Scheduler(
...
@@ -1696,17 +1705,26 @@ class Scheduler(
is_extend_in_batch
=
(
is_extend_in_batch
=
(
local_batch
.
forward_mode
.
is_extend
()
if
local_batch
else
False
local_batch
.
forward_mode
.
is_extend
()
if
local_batch
else
False
)
)
tbo_preparer
=
TboDPAttentionPreparer
()
local_info
=
torch
.
tensor
(
local_info
=
torch
.
tensor
(
[
[
num_tokens
,
num_tokens
,
can_cuda_graph
,
can_cuda_graph
,
num_tokens_for_logprob
,
num_tokens_for_logprob
,
is_extend_in_batch
,
is_extend_in_batch
,
*
tbo_preparer
.
prepare_all_gather
(
local_batch
,
deepep_mode
,
enable_deepep_moe
,
enable_two_batch_overlap
,
),
],
],
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
)
)
global_info
=
torch
.
empty
(
global_info
=
torch
.
empty
(
(
dp_size
,
attn_tp_size
,
4
),
(
dp_size
,
attn_tp_size
,
6
),
dtype
=
torch
.
int64
,
dtype
=
torch
.
int64
,
)
)
torch
.
distributed
.
all_gather_into_tensor
(
torch
.
distributed
.
all_gather_into_tensor
(
...
@@ -1719,6 +1737,10 @@ class Scheduler(
...
@@ -1719,6 +1737,10 @@ class Scheduler(
global_num_tokens_for_logprob
=
global_info
[:,
0
,
2
].
tolist
()
global_num_tokens_for_logprob
=
global_info
[:,
0
,
2
].
tolist
()
is_extend_in_batch
=
global_info
[:,
0
,
3
].
tolist
()
is_extend_in_batch
=
global_info
[:,
0
,
3
].
tolist
()
tbo_split_seq_index
,
global_forward_mode
=
tbo_preparer
.
compute_output
(
global_info
[:,
:,
4
:
6
]
)
if
local_batch
is
None
and
max
(
global_num_tokens
)
>
0
:
if
local_batch
is
None
and
max
(
global_num_tokens
)
>
0
:
local_batch
=
get_idle_batch
()
local_batch
=
get_idle_batch
()
...
@@ -1732,6 +1754,8 @@ class Scheduler(
...
@@ -1732,6 +1754,8 @@ class Scheduler(
local_batch
.
global_num_tokens_for_logprob
=
(
local_batch
.
global_num_tokens_for_logprob
=
(
global_num_tokens_for_logprob
global_num_tokens_for_logprob
)
)
local_batch
.
tbo_split_seq_index
=
tbo_split_seq_index
local_batch
.
global_forward_mode
=
global_forward_mode
# Check forward mode for cuda graph
# Check forward mode for cuda graph
if
not
disable_cuda_graph
:
if
not
disable_cuda_graph
:
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
0d477880
...
@@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Union
...
@@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Callable, Optional, Union
import
torch
import
torch
import
tqdm
import
tqdm
from
sglang.srt
import
two_batch_overlap
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.custom_op
import
CustomOp
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed
import
get_tensor_model_parallel_rank
from
sglang.srt.distributed.parallel_state
import
GroupCoordinator
,
graph_capture
from
sglang.srt.distributed.parallel_state
import
GroupCoordinator
,
graph_capture
...
@@ -38,6 +39,10 @@ from sglang.srt.model_executor.forward_batch_info import (
...
@@ -38,6 +39,10 @@ from sglang.srt.model_executor.forward_batch_info import (
PPProxyTensors
,
PPProxyTensors
,
)
)
from
sglang.srt.patch_torch
import
monkey_patch_torch_compile
from
sglang.srt.patch_torch
import
monkey_patch_torch_compile
from
sglang.srt.two_batch_overlap
import
(
TboCudaGraphRunnerUtils
,
TboForwardBatchPreparer
,
)
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_available_gpu_memory
,
get_available_gpu_memory
,
get_device_memory_capacity
,
get_device_memory_capacity
,
...
@@ -152,6 +157,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
...
@@ -152,6 +157,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
model_runner
.
req_to_token_pool
.
size
model_runner
.
req_to_token_pool
.
size
]
]
if
server_args
.
enable_two_batch_overlap
:
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
>=
2
]
if
server_args
.
cuda_graph_max_bs
:
if
server_args
.
cuda_graph_max_bs
:
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
<=
server_args
.
cuda_graph_max_bs
]
capture_bs
=
[
bs
for
bs
in
capture_bs
if
bs
<=
server_args
.
cuda_graph_max_bs
]
if
max
(
capture_bs
)
<
server_args
.
cuda_graph_max_bs
:
if
max
(
capture_bs
)
<
server_args
.
cuda_graph_max_bs
:
...
@@ -349,7 +357,14 @@ class CudaGraphRunner:
...
@@ -349,7 +357,14 @@ class CudaGraphRunner:
if
self
.
is_encoder_decoder
if
self
.
is_encoder_decoder
else
True
else
True
)
)
return
is_bs_supported
and
is_encoder_lens_supported
is_tbo_supported
=
(
forward_batch
.
can_run_tbo
if
self
.
model_runner
.
server_args
.
enable_two_batch_overlap
else
True
)
return
is_bs_supported
and
is_encoder_lens_supported
and
is_tbo_supported
def
capture
(
self
):
def
capture
(
self
):
with
graph_capture
()
as
graph_capture_context
:
with
graph_capture
()
as
graph_capture_context
:
...
@@ -466,7 +481,12 @@ class CudaGraphRunner:
...
@@ -466,7 +481,12 @@ class CudaGraphRunner:
capture_hidden_mode
=
self
.
capture_hidden_mode
,
capture_hidden_mode
=
self
.
capture_hidden_mode
,
lora_paths
=
lora_paths
,
lora_paths
=
lora_paths
,
num_token_non_padded
=
self
.
num_token_non_padded
,
num_token_non_padded
=
self
.
num_token_non_padded
,
tbo_split_seq_index
=
TboCudaGraphRunnerUtils
.
compute_tbo_split_seq_index
(
self
,
num_tokens
),
global_forward_mode
=
self
.
capture_forward_mode
,
)
)
TboForwardBatchPreparer
.
prepare
(
forward_batch
)
if
lora_paths
is
not
None
:
if
lora_paths
is
not
None
:
self
.
model_runner
.
lora_manager
.
prepare_lora_batch
(
forward_batch
)
self
.
model_runner
.
lora_manager
.
prepare_lora_batch
(
forward_batch
)
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
0d477880
...
@@ -29,9 +29,10 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
...
@@ -29,9 +29,10 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
from
__future__
import
annotations
from
__future__
import
annotations
import
dataclasses
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
IntEnum
,
auto
from
enum
import
IntEnum
,
auto
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
triton
import
triton
...
@@ -239,6 +240,7 @@ class ForwardBatch:
...
@@ -239,6 +240,7 @@ class ForwardBatch:
dp_local_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
# cached info at runtime
dp_local_num_tokens
:
Optional
[
torch
.
Tensor
]
=
None
# cached info at runtime
gathered_buffer
:
Optional
[
torch
.
Tensor
]
=
None
gathered_buffer
:
Optional
[
torch
.
Tensor
]
=
None
can_run_dp_cuda_graph
:
bool
=
False
can_run_dp_cuda_graph
:
bool
=
False
global_forward_mode
:
Optional
[
ForwardMode
]
=
None
# Speculative decoding
# Speculative decoding
spec_info
:
Optional
[
Union
[
EagleVerifyInput
,
EagleDraftInput
]]
=
None
spec_info
:
Optional
[
Union
[
EagleVerifyInput
,
EagleDraftInput
]]
=
None
...
@@ -252,12 +254,18 @@ class ForwardBatch:
...
@@ -252,12 +254,18 @@ class ForwardBatch:
# For Qwen2-VL
# For Qwen2-VL
mrope_positions
:
torch
.
Tensor
=
None
mrope_positions
:
torch
.
Tensor
=
None
tbo_split_seq_index
:
Optional
[
int
]
=
None
tbo_parent_token_range
:
Optional
[
Tuple
[
int
,
int
]]
=
None
tbo_children
:
Optional
[
List
[
"ForwardBatch"
]]
=
None
@
classmethod
@
classmethod
def
init_new
(
def
init_new
(
cls
,
cls
,
batch
:
ModelWorkerBatch
,
batch
:
ModelWorkerBatch
,
model_runner
:
ModelRunner
,
model_runner
:
ModelRunner
,
):
):
from
sglang.srt.two_batch_overlap
import
TboForwardBatchPreparer
device
=
model_runner
.
device
device
=
model_runner
.
device
extend_input_logprob_token_ids_gpu
=
None
extend_input_logprob_token_ids_gpu
=
None
if
batch
.
extend_input_logprob_token_ids
is
not
None
:
if
batch
.
extend_input_logprob_token_ids
is
not
None
:
...
@@ -281,6 +289,7 @@ class ForwardBatch:
...
@@ -281,6 +289,7 @@ class ForwardBatch:
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
top_logprobs_nums
=
batch
.
top_logprobs_nums
,
token_ids_logprobs
=
batch
.
token_ids_logprobs
,
token_ids_logprobs
=
batch
.
token_ids_logprobs
,
can_run_dp_cuda_graph
=
batch
.
can_run_dp_cuda_graph
,
can_run_dp_cuda_graph
=
batch
.
can_run_dp_cuda_graph
,
global_forward_mode
=
batch
.
global_forward_mode
,
lora_paths
=
batch
.
lora_paths
,
lora_paths
=
batch
.
lora_paths
,
sampling_info
=
batch
.
sampling_info
,
sampling_info
=
batch
.
sampling_info
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
req_to_token_pool
=
model_runner
.
req_to_token_pool
,
...
@@ -294,6 +303,7 @@ class ForwardBatch:
...
@@ -294,6 +303,7 @@ class ForwardBatch:
num_token_non_padded
=
torch
.
tensor
(
num_token_non_padded
=
torch
.
tensor
(
len
(
batch
.
input_ids
),
dtype
=
torch
.
int32
len
(
batch
.
input_ids
),
dtype
=
torch
.
int32
).
to
(
device
,
non_blocking
=
True
),
).
to
(
device
,
non_blocking
=
True
),
tbo_split_seq_index
=
batch
.
tbo_split_seq_index
,
)
)
# For DP attention
# For DP attention
...
@@ -316,6 +326,7 @@ class ForwardBatch:
...
@@ -316,6 +326,7 @@ class ForwardBatch:
)
)
if
ret
.
forward_mode
.
is_idle
():
if
ret
.
forward_mode
.
is_idle
():
ret
.
positions
=
torch
.
empty
((
0
,),
device
=
device
)
ret
.
positions
=
torch
.
empty
((
0
,),
device
=
device
)
TboForwardBatchPreparer
.
prepare
(
ret
)
return
ret
return
ret
# Override the positions with spec_info
# Override the positions with spec_info
...
@@ -364,6 +375,8 @@ class ForwardBatch:
...
@@ -364,6 +375,8 @@ class ForwardBatch:
if
model_runner
.
server_args
.
lora_paths
is
not
None
:
if
model_runner
.
server_args
.
lora_paths
is
not
None
:
model_runner
.
lora_manager
.
prepare_lora_batch
(
ret
)
model_runner
.
lora_manager
.
prepare_lora_batch
(
ret
)
TboForwardBatchPreparer
.
prepare
(
ret
)
return
ret
return
ret
def
merge_mm_inputs
(
self
)
->
Optional
[
MultimodalInputs
]:
def
merge_mm_inputs
(
self
)
->
Optional
[
MultimodalInputs
]:
...
@@ -588,6 +601,10 @@ class ForwardBatch:
...
@@ -588,6 +601,10 @@ class ForwardBatch:
# Precompute the kv indices for each chunk
# Precompute the kv indices for each chunk
self
.
prepare_chunked_kv_indices
(
device
)
self
.
prepare_chunked_kv_indices
(
device
)
@
property
def
can_run_tbo
(
self
):
return
self
.
tbo_split_seq_index
is
not
None
class
PPProxyTensors
:
class
PPProxyTensors
:
# adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
# adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
0d477880
...
@@ -37,6 +37,7 @@ from sglang.srt.distributed import (
...
@@ -37,6 +37,7 @@ from sglang.srt.distributed import (
set_custom_all_reduce
,
set_custom_all_reduce
,
)
)
from
sglang.srt.distributed.parallel_state
import
monkey_patch_vllm_parallel_state
from
sglang.srt.distributed.parallel_state
import
monkey_patch_vllm_parallel_state
from
sglang.srt.layers.attention.tbo_backend
import
TboAttnBackend
from
sglang.srt.layers.dp_attention
import
(
from
sglang.srt.layers.dp_attention
import
(
get_attention_tp_group
,
get_attention_tp_group
,
get_attention_tp_size
,
get_attention_tp_size
,
...
@@ -198,6 +199,7 @@ class ModelRunner:
...
@@ -198,6 +199,7 @@ class ModelRunner:
"disable_radix_cache"
:
server_args
.
disable_radix_cache
,
"disable_radix_cache"
:
server_args
.
disable_radix_cache
,
"enable_nan_detection"
:
server_args
.
enable_nan_detection
,
"enable_nan_detection"
:
server_args
.
enable_nan_detection
,
"enable_dp_attention"
:
server_args
.
enable_dp_attention
,
"enable_dp_attention"
:
server_args
.
enable_dp_attention
,
"enable_two_batch_overlap"
:
server_args
.
enable_two_batch_overlap
,
"enable_dp_lm_head"
:
server_args
.
enable_dp_lm_head
,
"enable_dp_lm_head"
:
server_args
.
enable_dp_lm_head
,
"enable_ep_moe"
:
server_args
.
enable_ep_moe
,
"enable_ep_moe"
:
server_args
.
enable_ep_moe
,
"enable_deepep_moe"
:
server_args
.
enable_deepep_moe
,
"enable_deepep_moe"
:
server_args
.
enable_deepep_moe
,
...
@@ -994,6 +996,13 @@ class ModelRunner:
...
@@ -994,6 +996,13 @@ class ModelRunner:
def
init_attention_backend
(
self
):
def
init_attention_backend
(
self
):
"""Init attention kernel backend."""
"""Init attention kernel backend."""
if
self
.
server_args
.
enable_two_batch_overlap
:
self
.
attn_backend
=
TboAttnBackend
.
init_new
(
self
.
_get_attention_backend
)
else
:
self
.
attn_backend
=
self
.
_get_attention_backend
()
# TODO unify with 6338
def
_get_attention_backend
(
self
):
if
self
.
server_args
.
attention_backend
==
"flashinfer"
:
if
self
.
server_args
.
attention_backend
==
"flashinfer"
:
if
not
self
.
use_mla_backend
:
if
not
self
.
use_mla_backend
:
from
sglang.srt.layers.attention.flashinfer_backend
import
(
from
sglang.srt.layers.attention.flashinfer_backend
import
(
...
@@ -1003,17 +1012,17 @@ class ModelRunner:
...
@@ -1003,17 +1012,17 @@ class ModelRunner:
# Init streams
# Init streams
if
self
.
server_args
.
speculative_algorithm
==
"EAGLE"
:
if
self
.
server_args
.
speculative_algorithm
==
"EAGLE"
:
self
.
plan_stream_for_flashinfer
=
torch
.
cuda
.
Stream
()
self
.
plan_stream_for_flashinfer
=
torch
.
cuda
.
Stream
()
self
.
attn_backend
=
FlashInferAttnBackend
(
self
)
return
FlashInferAttnBackend
(
self
)
else
:
else
:
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
from
sglang.srt.layers.attention.flashinfer_mla_backend
import
(
FlashInferMLAAttnBackend
,
FlashInferMLAAttnBackend
,
)
)
self
.
attn_backend
=
FlashInferMLAAttnBackend
(
self
)
return
FlashInferMLAAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"aiter"
:
elif
self
.
server_args
.
attention_backend
==
"aiter"
:
from
sglang.srt.layers.attention.aiter_backend
import
AiterAttnBackend
from
sglang.srt.layers.attention.aiter_backend
import
AiterAttnBackend
self
.
attn_backend
=
AiterAttnBackend
(
self
)
return
AiterAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"triton"
:
elif
self
.
server_args
.
attention_backend
==
"triton"
:
assert
self
.
sliding_window_size
is
None
,
(
assert
self
.
sliding_window_size
is
None
,
(
"Window attention is not supported in the triton attention backend. "
"Window attention is not supported in the triton attention backend. "
...
@@ -1028,21 +1037,21 @@ class ModelRunner:
...
@@ -1028,21 +1037,21 @@ class ModelRunner:
DoubleSparseAttnBackend
,
DoubleSparseAttnBackend
,
)
)
self
.
attn_backend
=
DoubleSparseAttnBackend
(
self
)
return
DoubleSparseAttnBackend
(
self
)
else
:
else
:
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
from
sglang.srt.layers.attention.triton_backend
import
TritonAttnBackend
self
.
attn_backend
=
TritonAttnBackend
(
self
)
return
TritonAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"torch_native"
:
elif
self
.
server_args
.
attention_backend
==
"torch_native"
:
from
sglang.srt.layers.attention.torch_native_backend
import
(
from
sglang.srt.layers.attention.torch_native_backend
import
(
TorchNativeAttnBackend
,
TorchNativeAttnBackend
,
)
)
self
.
attn_backend
=
TorchNativeAttnBackend
(
self
)
return
TorchNativeAttnBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"flashmla"
:
elif
self
.
server_args
.
attention_backend
==
"flashmla"
:
from
sglang.srt.layers.attention.flashmla_backend
import
FlashMLABackend
from
sglang.srt.layers.attention.flashmla_backend
import
FlashMLABackend
self
.
attn_backend
=
FlashMLABackend
(
self
)
return
FlashMLABackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"fa3"
:
elif
self
.
server_args
.
attention_backend
==
"fa3"
:
assert
(
assert
(
torch
.
cuda
.
get_device_capability
()[
0
]
==
8
and
not
self
.
use_mla_backend
torch
.
cuda
.
get_device_capability
()[
0
]
==
8
and
not
self
.
use_mla_backend
...
@@ -1054,13 +1063,13 @@ class ModelRunner:
...
@@ -1054,13 +1063,13 @@ class ModelRunner:
FlashAttentionBackend
,
FlashAttentionBackend
,
)
)
self
.
attn_backend
=
FlashAttentionBackend
(
self
)
return
FlashAttentionBackend
(
self
)
elif
self
.
server_args
.
attention_backend
==
"cutlass_mla"
:
elif
self
.
server_args
.
attention_backend
==
"cutlass_mla"
:
from
sglang.srt.layers.attention.cutlass_mla_backend
import
(
from
sglang.srt.layers.attention.cutlass_mla_backend
import
(
CutlassMLABackend
,
CutlassMLABackend
,
)
)
self
.
attn_backend
=
CutlassMLABackend
(
self
)
return
CutlassMLABackend
(
self
)
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Invalid attention backend:
{
self
.
server_args
.
attention_backend
}
"
f
"Invalid attention backend:
{
self
.
server_args
.
attention_backend
}
"
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
0d477880
...
@@ -83,8 +83,10 @@ from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchI
...
@@ -83,8 +83,10 @@ from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchI
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.operations
import
execute_operations
from
sglang.srt.two_batch_overlap
import
(
from
sglang.srt.operations_strategy
import
compute_layer_operations
MaybeTboDeepEPDispatcher
,
model_forward_maybe_tbo
,
)
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
BumpAllocator
,
BumpAllocator
,
DeepEPMode
,
DeepEPMode
,
...
@@ -226,6 +228,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -226,6 +228,7 @@ class DeepseekV2MoE(nn.Module):
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
self
.
routed_scaling_factor
=
config
.
routed_scaling_factor
self
.
n_shared_experts
=
config
.
n_shared_experts
self
.
n_shared_experts
=
config
.
n_shared_experts
self
.
n_share_experts_fusion
=
global_server_args_dict
[
"n_share_experts_fusion"
]
self
.
n_share_experts_fusion
=
global_server_args_dict
[
"n_share_experts_fusion"
]
self
.
config
=
config
self
.
layer_id
=
layer_id
self
.
layer_id
=
layer_id
if
self
.
tp_size
>
config
.
n_routed_experts
:
if
self
.
tp_size
>
config
.
n_routed_experts
:
...
@@ -300,7 +303,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -300,7 +303,7 @@ class DeepseekV2MoE(nn.Module):
else
None
else
None
)
)
self
.
deepep_dispatcher
=
DeepEPDispatcher
(
self
.
deepep_dispatcher
=
MaybeTbo
DeepEPDispatcher
(
group
=
parallel_state
.
get_tp_group
().
device_group
,
group
=
parallel_state
.
get_tp_group
().
device_group
,
router_topk
=
self
.
top_k
,
router_topk
=
self
.
top_k
,
permute_fusion
=
True
,
permute_fusion
=
True
,
...
@@ -309,13 +312,11 @@ class DeepseekV2MoE(nn.Module):
...
@@ -309,13 +312,11 @@ class DeepseekV2MoE(nn.Module):
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
params_dtype
=
config
.
torch_dtype
,
params_dtype
=
config
.
torch_dtype
,
deepep_mode
=
DeepEPMode
[
global_server_args_dict
[
"deepep_mode"
]],
deepep_mode
=
DeepEPMode
[
global_server_args_dict
[
"deepep_mode"
]],
async_finish
=
True
,
# TODO
async_finish
=
True
,
return_recv_hook
=
True
,
return_recv_hook
=
True
,
)
)
@
property
self
.
_enable_deepep_moe
=
global_server_args_dict
[
"enable_deepep_moe"
]
def
_enable_deepep_moe
(
self
):
return
global_server_args_dict
[
"enable_deepep_moe"
]
def
get_moe_weights
(
self
):
def
get_moe_weights
(
self
):
return
[
return
[
...
@@ -423,7 +424,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -423,7 +424,7 @@ class DeepseekV2MoE(nn.Module):
return
None
return
None
def
op_gate
(
self
,
state
):
def
op_gate
(
self
,
state
):
if
(
not
self
.
_enable_deepep_moe
)
or
is_non_idle_and_non_empty
(
if
is_non_idle_and_non_empty
(
state
.
forward_batch
.
forward_mode
,
state
.
hidden_states_mlp_input
state
.
forward_batch
.
forward_mode
,
state
.
hidden_states_mlp_input
):
):
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
...
@@ -432,115 +433,105 @@ class DeepseekV2MoE(nn.Module):
...
@@ -432,115 +433,105 @@ class DeepseekV2MoE(nn.Module):
state
.
router_logits
=
None
state
.
router_logits
=
None
def
op_shared_experts
(
self
,
state
):
def
op_shared_experts
(
self
,
state
):
if
(
self
.
n_share_experts_fusion
==
0
)
and
(
hidden_states_mlp_input
=
state
.
pop
(
"hidden_states_mlp_input"
)
(
not
self
.
_enable_deepep_moe
)
if
(
self
.
n_share_experts_fusion
==
0
)
and
is_non_idle_and_non_empty
(
or
is_non_idle_and_non_empty
(
state
.
forward_batch
.
forward_mode
,
hidden_states_mlp_input
state
.
forward_batch
.
forward_mode
,
state
.
hidden_states_mlp_input
)
):
):
state
.
shared_output
=
self
.
shared_experts
(
state
.
hidden_states_mlp_input
)
state
.
shared_output
=
self
.
shared_experts
(
hidden_states_mlp_input
)
else
:
else
:
state
.
shared_output
=
None
state
.
shared_output
=
None
def
op_select_experts
(
self
,
state
):
def
op_select_experts
(
self
,
state
):
router_logits
=
state
.
router_logits
router_logits
=
state
.
pop
(
"
router_logits
"
)
hidden_states
=
state
.
hidden_states_mlp_input
hidden_states
=
state
.
hidden_states_mlp_input
if
self
.
_enable_deepep_moe
:
if
router_logits
is
not
None
:
if
router_logits
is
not
None
:
state
.
topk_weights_local
,
state
.
topk_idx_local
=
select_experts
(
state
.
topk_weights_local
,
state
.
topk_idx_local
=
select_experts
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
router_logits
=
router_logits
,
top_k
=
self
.
top_k
,
top_k
=
self
.
top_k
,
use_grouped_topk
=
True
,
use_grouped_topk
=
True
,
renormalize
=
self
.
renormalize
,
renormalize
=
self
.
renormalize
,
topk_group
=
self
.
topk_group
,
topk_group
=
self
.
topk_group
,
num_expert_group
=
self
.
num_expert_group
,
num_expert_group
=
self
.
num_expert_group
,
correction_bias
=
self
.
correction_bias
,
correction_bias
=
self
.
correction_bias
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
expert_location_dispatch_info
=
ExpertLocationDispatchInfo
.
init_new
(
expert_location_dispatch_info
=
ExpertLocationDispatchInfo
.
init_new
(
layer_id
=
self
.
layer_id
,
layer_id
=
self
.
layer_id
,
),
),
)
)
else
:
else
:
state
.
topk_idx_local
=
torch
.
full
(
state
.
topk_idx_local
=
torch
.
full
(
(
0
,
self
.
top_k
),
-
1
,
dtype
=
torch
.
int
,
device
=
hidden_states
.
device
(
0
,
self
.
top_k
),
-
1
,
dtype
=
torch
.
int
,
device
=
hidden_states
.
device
)
)
state
.
topk_weights_local
=
torch
.
empty
(
state
.
topk_weights_local
=
torch
.
empty
(
(
0
,
self
.
top_k
),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
(
0
,
self
.
top_k
),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
)
def
op_dispatch_a
(
self
,
state
):
def
op_dispatch_a
(
self
,
state
):
if
self
.
_enable_deepep_moe
and
(
self
.
ep_size
>
1
)
:
if
self
.
ep_size
>
1
:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
self
.
deepep_dispatcher
.
dispatch_a
(
self
.
deepep_dispatcher
.
dispatch_a
(
hidden_states
=
state
.
pop
(
"
hidden_states_mlp_input
"
)
,
hidden_states
=
state
.
hidden_states_mlp_input
,
topk_idx
=
state
.
pop
(
"topk_idx_local"
),
topk_idx
=
state
.
pop
(
"topk_idx_local"
),
topk_weights
=
state
.
pop
(
"topk_weights_local"
),
topk_weights
=
state
.
pop
(
"topk_weights_local"
),
forward_mode
=
state
.
forward_batch
.
forward_mode
,
forward_mode
=
state
.
forward_batch
.
forward_mode
,
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
)
def
op_dispatch_b
(
self
,
state
):
def
op_dispatch_b
(
self
,
state
):
if
self
.
_enable_deepep_moe
and
(
self
.
ep_size
>
1
):
if
self
.
ep_size
>
1
:
(
with
get_global_expert_distribution_recorder
().
with_current_layer
(
state
.
hidden_states_experts_input
,
self
.
layer_id
state
.
topk_idx_dispatched
,
):
state
.
topk_weights_dispatched
,
(
state
.
reorder_topk_ids
,
state
.
hidden_states_experts_input
,
state
.
num_recv_tokens_per_expert
,
state
.
topk_idx_dispatched
,
state
.
seg_indptr
,
state
.
topk_weights_dispatched
,
state
.
masked_m
,
state
.
reorder_topk_ids
,
state
.
expected_m
,
state
.
num_recv_tokens_per_expert
,
)
=
self
.
deepep_dispatcher
.
dispatch_b
()
state
.
seg_indptr
,
state
.
masked_m
,
state
.
expected_m
,
)
=
self
.
deepep_dispatcher
.
dispatch_b
(
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
def
op_experts
(
self
,
state
):
def
op_experts
(
self
,
state
):
if
self
.
_enable_deepep_moe
:
state
.
hidden_states_experts_output
=
self
.
experts
(
state
.
pop
(
"router_logits"
)
hidden_states
=
state
.
pop
(
"hidden_states_experts_input"
),
state
.
hidden_states_experts_output
=
self
.
experts
(
topk_idx
=
state
.
topk_idx_dispatched
,
hidden_states
=
state
.
pop
(
"hidden_states_experts_input"
),
topk_weights
=
state
.
topk_weights_dispatched
,
topk_idx
=
state
.
topk_idx_dispatched
,
reorder_topk_ids
=
state
.
pop
(
"reorder_topk_ids"
),
topk_weights
=
state
.
topk_weights_dispatched
,
seg_indptr
=
state
.
pop
(
"seg_indptr"
),
reorder_topk_ids
=
state
.
pop
(
"reorder_topk_ids"
),
masked_m
=
state
.
pop
(
"masked_m"
),
seg_indptr
=
state
.
pop
(
"seg_indptr"
),
expected_m
=
state
.
pop
(
"expected_m"
),
masked_m
=
state
.
pop
(
"masked_m"
),
num_recv_tokens_per_expert
=
state
.
pop
(
"num_recv_tokens_per_expert"
),
expected_m
=
state
.
pop
(
"expected_m"
),
forward_mode
=
state
.
forward_batch
.
forward_mode
,
num_recv_tokens_per_expert
=
state
.
pop
(
"num_recv_tokens_per_expert"
),
)
forward_mode
=
state
.
forward_batch
.
forward_mode
,
)
else
:
state
.
hidden_states_experts_output
=
self
.
experts
(
hidden_states
=
state
.
pop
(
"hidden_states_mlp_input"
),
router_logits
=
state
.
pop
(
"router_logits"
),
)
def
op_combine_a
(
self
,
state
):
def
op_combine_a
(
self
,
state
):
if
self
.
_enable_deepep_moe
and
(
self
.
ep_size
>
1
)
:
if
self
.
ep_size
>
1
:
self
.
deepep_dispatcher
.
combine_a
(
self
.
deepep_dispatcher
.
combine_a
(
state
.
pop
(
"hidden_states_experts_output"
),
hidden_states
=
state
.
pop
(
"hidden_states_experts_output"
),
topk_idx
=
state
.
pop
(
"topk_idx_dispatched"
),
topk_idx
=
state
.
pop
(
"topk_idx_dispatched"
),
topk_weights
=
state
.
pop
(
"topk_weights_dispatched"
),
topk_weights
=
state
.
pop
(
"topk_weights_dispatched"
),
forward_mode
=
state
.
forward_batch
.
forward_mode
,
forward_mode
=
state
.
forward_batch
.
forward_mode
,
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
)
def
op_combine_b
(
self
,
state
):
def
op_combine_b
(
self
,
state
):
if
self
.
_enable_deepep_moe
and
(
self
.
ep_size
>
1
):
if
self
.
ep_size
>
1
:
state
.
hidden_states_after_combine
=
self
.
deepep_dispatcher
.
combine_b
()
state
.
hidden_states_after_combine
=
self
.
deepep_dispatcher
.
combine_b
(
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
def
op_output
(
self
,
state
):
def
op_output
(
self
,
state
):
final_hidden_states
=
(
final_hidden_states
=
state
.
pop
(
"hidden_states_after_combine"
)
state
.
pop
(
"hidden_states_after_combine"
)
if
self
.
_enable_deepep_moe
else
state
.
pop
(
"hidden_states_experts_output"
)
)
final_hidden_states
*=
self
.
routed_scaling_factor
final_hidden_states
*=
self
.
routed_scaling_factor
if
(
s
:
=
state
.
pop
(
"shared_output"
))
is
not
None
:
if
(
s
:
=
state
.
pop
(
"shared_output"
))
is
not
None
:
final_hidden_states
=
final_hidden_states
+
s
final_hidden_states
=
final_hidden_states
+
s
if
(
not
self
.
_enable_deepep_moe
)
and
(
self
.
tp_size
>
1
):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
state
.
hidden_states_mlp_output
=
final_hidden_states
state
.
hidden_states_mlp_output
=
final_hidden_states
...
@@ -1482,6 +1473,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1482,6 +1473,7 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
zero_allocator
:
BumpAllocator
,
zero_allocator
:
BumpAllocator
,
tbo_subbatch_index
:
Optional
[
int
]
=
None
,
):
):
state
.
hidden_states_after_comm_pre_attn
,
state
.
residual_after_input_ln
=
(
state
.
hidden_states_after_comm_pre_attn
,
state
.
residual_after_input_ln
=
(
self
.
layer_communicator
.
prepare_attn
(
hidden_states
,
residual
,
forward_batch
)
self
.
layer_communicator
.
prepare_attn
(
hidden_states
,
residual
,
forward_batch
)
...
@@ -1491,6 +1483,7 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1491,6 +1483,7 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
positions
=
positions
,
positions
=
positions
,
zero_allocator
=
zero_allocator
,
zero_allocator
=
zero_allocator
,
tbo_subbatch_index
=
tbo_subbatch_index
,
)
)
)
)
...
@@ -1523,8 +1516,24 @@ class DeepseekV2DecoderLayer(nn.Module):
...
@@ -1523,8 +1516,24 @@ class DeepseekV2DecoderLayer(nn.Module):
state
.
forward_batch
,
state
.
forward_batch
,
)
)
state
.
clear
(
expect_keys
=
{
"positions"
,
"forward_batch"
,
"zero_allocator"
})
output
=
dict
(
return
hidden_states
,
residual
positions
=
state
.
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
,
forward_batch
=
state
.
forward_batch
,
zero_allocator
=
state
.
zero_allocator
,
tbo_subbatch_index
=
state
.
tbo_subbatch_index
,
)
state
.
clear
(
expect_keys
=
{
"positions"
,
"forward_batch"
,
"zero_allocator"
,
"tbo_subbatch_index"
,
}
)
return
output
class
DeepseekV2Model
(
nn
.
Module
):
class
DeepseekV2Model
(
nn
.
Module
):
...
@@ -1539,6 +1548,7 @@ class DeepseekV2Model(nn.Module):
...
@@ -1539,6 +1548,7 @@ class DeepseekV2Model(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
padding_id
=
config
.
pad_token_id
self
.
padding_id
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
first_k_dense_replace
=
config
.
first_k_dense_replace
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
vocab_size
,
...
@@ -1572,13 +1582,12 @@ class DeepseekV2Model(nn.Module):
...
@@ -1572,13 +1582,12 @@ class DeepseekV2Model(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
total_num_layers
=
len
(
self
.
layers
)
device
=
input_embeds
.
device
if
input_embeds
is
not
None
else
input_ids
.
device
zero_allocator
=
BumpAllocator
(
zero_allocator
=
BumpAllocator
(
# TODO for two-batch-overlap, we need a larger buffer size
buffer_size
=
total_num_layers
*
2
*
(
2
if
forward_batch
.
can_run_tbo
else
1
),
buffer_size
=
len
(
self
.
layers
)
*
2
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
device
=
(
device
=
device
,
input_embeds
.
device
if
input_embeds
is
not
None
else
input_ids
.
device
),
)
)
if
input_embeds
is
None
:
if
input_embeds
is
None
:
...
@@ -1587,12 +1596,30 @@ class DeepseekV2Model(nn.Module):
...
@@ -1587,12 +1596,30 @@ class DeepseekV2Model(nn.Module):
hidden_states
=
input_embeds
hidden_states
=
input_embeds
residual
=
None
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
normal_num_layers
=
(
self
.
first_k_dense_replace
if
forward_batch
.
can_run_tbo
else
total_num_layers
)
for
i
in
range
(
normal_num_layers
):
with
get_global_expert_distribution_recorder
().
with_current_layer
(
i
):
with
get_global_expert_distribution_recorder
().
with_current_layer
(
i
):
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
)
)
if
normal_num_layers
!=
total_num_layers
:
hidden_states
,
residual
=
model_forward_maybe_tbo
(
layers
=
self
.
layers
[
normal_num_layers
:],
enable_tbo
=
True
,
positions
=
positions
,
forward_batch
=
forward_batch
,
hidden_states
=
hidden_states
,
residual
=
residual
,
zero_allocator
=
zero_allocator
,
)
if
not
forward_batch
.
forward_mode
.
is_idle
():
if
not
forward_batch
.
forward_mode
.
is_idle
():
if
residual
is
None
:
if
residual
is
None
:
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
=
self
.
norm
(
hidden_states
)
...
@@ -1674,7 +1701,6 @@ class DeepseekV2ForCausalLM(nn.Module):
...
@@ -1674,7 +1701,6 @@ class DeepseekV2ForCausalLM(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
return
self
.
logits_processor
(
return
self
.
logits_processor
(
...
...
python/sglang/srt/operations.py
View file @
0d477880
...
@@ -12,7 +12,7 @@ if _ENABLE_PROFILE:
...
@@ -12,7 +12,7 @@ if _ENABLE_PROFILE:
def
execute_operations
(
inputs
,
operations
):
def
execute_operations
(
inputs
,
operations
):
stages
=
_convert_operations_to_stages
(
decorate_
operations
(
operations
)
)
stages
=
_convert_operations_to_stages
(
operations
)
executor
=
_StageExecutor
(
"primary"
,
stages
,
inputs
=
inputs
)
executor
=
_StageExecutor
(
"primary"
,
stages
,
inputs
=
inputs
)
for
_
in
range
(
executor
.
num_stages
):
for
_
in
range
(
executor
.
num_stages
):
executor
.
next
()
executor
.
next
()
...
@@ -20,6 +20,37 @@ def execute_operations(inputs, operations):
...
@@ -20,6 +20,37 @@ def execute_operations(inputs, operations):
return
executor
.
output
return
executor
.
output
def
execute_overlapped_operations
(
inputs_arr
:
Sequence
,
operations_arr
:
Sequence
,
delta_stages
:
Sequence
[
int
],
)
->
Sequence
:
# Make it explicit for clarity; if we need multi-batch overlap, this can be generalized
inputs_a
,
inputs_b
=
inputs_arr
operations_a
,
operations_b
=
operations_arr
delta_stage_a
,
delta_stage_b
=
delta_stages
assert
delta_stage_a
==
0
delta_stage
=
delta_stage_b
stages_a
=
_convert_operations_to_stages
(
operations_a
)
stages_b
=
_convert_operations_to_stages
(
operations_b
)
executor_a
=
_StageExecutor
(
"a"
,
stages_a
,
inputs
=
inputs_a
)
executor_b
=
_StageExecutor
(
"b"
,
stages_b
,
inputs
=
inputs_b
)
for
_
in
range
(
delta_stage
):
executor_a
.
next
()
for
_
in
range
(
executor_a
.
num_stages
-
delta_stage
):
executor_a
.
next
()
executor_b
.
next
()
for
_
in
range
(
delta_stage
):
executor_b
.
next
()
assert
executor_a
.
done
and
executor_b
.
done
return
[
executor_a
.
output
,
executor_b
.
output
]
class
YieldOperation
:
class
YieldOperation
:
pass
pass
...
@@ -109,6 +140,9 @@ class _StateDict:
...
@@ -109,6 +140,9 @@ class _StateDict:
for
k
,
v
in
values
.
items
():
for
k
,
v
in
values
.
items
():
setattr
(
self
,
k
,
v
)
setattr
(
self
,
k
,
v
)
def
get
(
self
,
item
):
return
self
.
_data
.
get
(
item
)
def
clear
(
self
,
expect_keys
:
Sequence
[
str
]):
def
clear
(
self
,
expect_keys
:
Sequence
[
str
]):
if
set
(
self
.
_data
.
keys
())
!=
set
(
expect_keys
):
if
set
(
self
.
_data
.
keys
())
!=
set
(
expect_keys
):
raise
Exception
(
raise
Exception
(
...
@@ -119,6 +153,7 @@ class _StateDict:
...
@@ -119,6 +153,7 @@ class _StateDict:
def
_convert_operations_to_stages
(
operations
:
List
[
Operation
])
->
List
[
Stage
]:
def
_convert_operations_to_stages
(
operations
:
List
[
Operation
])
->
List
[
Stage
]:
operations
=
_decorate_operations
(
operations
)
operation_chunks
=
list
(
operation_chunks
=
list
(
_chunk_by_separator
(
operations
,
lambda
op
:
isinstance
(
op
,
YieldOperation
))
_chunk_by_separator
(
operations
,
lambda
op
:
isinstance
(
op
,
YieldOperation
))
)
)
...
@@ -140,7 +175,7 @@ def _chunk_by_separator(
...
@@ -140,7 +175,7 @@ def _chunk_by_separator(
yield
pending_items
yield
pending_items
def
decorate_operations
(
operations
:
List
[
Operation
],
debug_name_prefix
:
str
=
""
):
def
_
decorate_operations
(
operations
:
List
[
Operation
],
debug_name_prefix
:
str
=
""
):
return
[
_decorate_operation
(
op
,
debug_name_prefix
)
for
op
in
operations
]
return
[
_decorate_operation
(
op
,
debug_name_prefix
)
for
op
in
operations
]
...
...
python/sglang/srt/operations_strategy.py
View file @
0d477880
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
import
torch
import
torch
from
sglang.srt
import
operations
from
sglang.srt.layers.moe.ep_moe.token_dispatcher
import
DeepEPConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.operations
import
Operation
@
dataclass
class
OperationsStrategy
:
operations
:
List
[
Operation
]
deep_gemm_num_sms
:
Optional
[
int
]
=
None
tbo_delta_stages
:
Optional
[
int
]
=
None
@
classmethod
def
concat
(
cls
,
items
:
List
[
"OperationsStrategy"
])
->
"OperationsStrategy"
:
return
OperationsStrategy
(
operations
=
[
x
for
item
in
items
for
x
in
item
.
operations
],
deep_gemm_num_sms
=
_assert_all_same
(
[
item
.
deep_gemm_num_sms
for
item
in
items
]
),
tbo_delta_stages
=
_assert_all_same
(
[
item
.
tbo_delta_stages
for
item
in
items
]
),
)
@
staticmethod
def
init_new_tbo
(
layers
:
torch
.
nn
.
ModuleList
,
forward_mode
:
ForwardMode
,
)
->
"OperationsStrategy"
:
return
OperationsStrategy
.
concat
(
[
_compute_layer_operations_strategy_tbo
(
layer
,
forward_mode
)
for
layer
in
layers
]
)
def
compute_layer_operations
(
def
_assert_all_same
(
items
:
List
):
assert
all
(
item
==
items
[
0
]
for
item
in
items
)
return
items
[
0
]
# TODO can refactor to make it more fancy if we have more complex strategies
def
_compute_layer_operations_strategy_tbo
(
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
):
forward_mode
:
ForwardMode
,
if
not
layer
.
is_layer_sparse
:
)
->
OperationsStrategy
:
return
[
assert
layer
.
is_layer_sparse
,
"dense layer TBO not yet implemented"
if
forward_mode
==
ForwardMode
.
EXTEND
:
return
_compute_moe_deepseek_blog_prefill
(
layer
)
elif
forward_mode
==
ForwardMode
.
DECODE
:
return
_compute_moe_deepseek_blog_decode
(
layer
)
else
:
raise
NotImplementedError
(
f
"Unsupported
{
forward_mode
=
}
"
)
def
_compute_moe_deepseek_blog_prefill
(
layer
):
device_properties
=
torch
.
cuda
.
get_device_properties
(
device
=
"cuda"
)
total_num_sms
=
device_properties
.
multi_processor_count
deep_gemm_num_sms
=
total_num_sms
-
DeepEPConfig
.
get_instance
().
num_sms
return
OperationsStrategy
(
deep_gemm_num_sms
=
deep_gemm_num_sms
,
tbo_delta_stages
=
0
,
operations
=
[
layer
.
op_comm_prepare_attn
,
layer
.
self_attn
.
op_prepare
,
layer
.
self_attn
.
op_core
,
layer
.
op_comm_prepare_mlp
,
layer
.
mlp
.
op_gate
,
layer
.
mlp
.
op_select_experts
,
layer
.
mlp
.
op_dispatch_a
,
operations
.
YieldOperation
(),
layer
.
mlp
.
op_dispatch_b
,
layer
.
mlp
.
op_experts
,
layer
.
mlp
.
op_combine_a
,
operations
.
YieldOperation
(),
layer
.
mlp
.
op_shared_experts
,
layer
.
mlp
.
op_combine_b
,
layer
.
mlp
.
op_output
,
layer
.
op_comm_postprocess_layer
,
],
)
def
_compute_moe_deepseek_blog_decode
(
layer
):
return
OperationsStrategy
(
deep_gemm_num_sms
=
None
,
tbo_delta_stages
=
2
,
operations
=
[
layer
.
op_comm_prepare_attn
,
layer
.
op_comm_prepare_attn
,
layer
.
self_attn
.
op_prepare
,
layer
.
self_attn
.
op_prepare
,
operations
.
YieldOperation
(),
layer
.
self_attn
.
op_core
,
layer
.
self_attn
.
op_core
,
layer
.
op_comm_prepare_mlp
,
layer
.
op_comm_prepare_mlp
,
layer
.
op_mlp
,
layer
.
mlp
.
op_gate
,
layer
.
mlp
.
op_select_experts
,
operations
.
YieldOperation
(),
layer
.
mlp
.
op_dispatch_a
,
layer
.
mlp
.
op_shared_experts
,
operations
.
YieldOperation
(),
layer
.
mlp
.
op_dispatch_b
,
layer
.
mlp
.
op_experts
,
layer
.
mlp
.
op_combine_a
,
operations
.
YieldOperation
(),
layer
.
mlp
.
op_combine_b
,
layer
.
mlp
.
op_output
,
layer
.
op_comm_postprocess_layer
,
layer
.
op_comm_postprocess_layer
,
]
operations
.
YieldOperation
(),
],
# Will add TBO operation orders here
)
return
[
layer
.
op_comm_prepare_attn
,
layer
.
self_attn
.
op_prepare
,
layer
.
self_attn
.
op_core
,
layer
.
op_comm_prepare_mlp
,
layer
.
mlp
.
op_gate
,
layer
.
mlp
.
op_shared_experts
,
layer
.
mlp
.
op_select_experts
,
layer
.
mlp
.
op_dispatch_a
,
layer
.
mlp
.
op_dispatch_b
,
layer
.
mlp
.
op_experts
,
layer
.
mlp
.
op_combine_a
,
layer
.
mlp
.
op_combine_b
,
layer
.
mlp
.
op_output
,
layer
.
op_comm_postprocess_layer
,
]
python/sglang/srt/server_args.py
View file @
0d477880
...
@@ -167,6 +167,7 @@ class ServerArgs:
...
@@ -167,6 +167,7 @@ class ServerArgs:
enable_mixed_chunk
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_dp_attention
:
bool
=
False
enable_dp_attention
:
bool
=
False
enable_dp_lm_head
:
bool
=
False
enable_dp_lm_head
:
bool
=
False
enable_two_batch_overlap
:
bool
=
False
enable_ep_moe
:
bool
=
False
enable_ep_moe
:
bool
=
False
enable_deepep_moe
:
bool
=
False
enable_deepep_moe
:
bool
=
False
deepep_mode
:
Optional
[
Literal
[
"auto"
,
"normal"
,
"low_latency"
]]
=
"auto"
deepep_mode
:
Optional
[
Literal
[
"auto"
,
"normal"
,
"low_latency"
]]
=
"auto"
...
@@ -1144,6 +1145,11 @@ class ServerArgs:
...
@@ -1144,6 +1145,11 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Enabling expert parallelism for moe. The ep size is equal to the tp size."
,
help
=
"Enabling expert parallelism for moe. The ep size is equal to the tp size."
,
)
)
parser
.
add_argument
(
"--enable-two-batch-overlap"
,
action
=
"store_true"
,
help
=
"Enabling two micro batches to overlap."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-torch-compile"
,
"--enable-torch-compile"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
...
python/sglang/srt/two_batch_overlap.py
0 → 100644
View file @
0d477880
import
dataclasses
from
typing
import
TYPE_CHECKING
,
Dict
,
List
,
Optional
,
Sequence
import
torch
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.moe.ep_moe.token_dispatcher
import
DeepEPDispatcher
from
sglang.srt.layers.quantization.deep_gemm
import
configure_deep_gemm_num_sms
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.operations
import
execute_operations
,
execute_overlapped_operations
from
sglang.srt.operations_strategy
import
OperationsStrategy
from
sglang.srt.utils
import
BumpAllocator
,
DeepEPMode
if
TYPE_CHECKING
:
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
# -------------------------------- Compute Basic Info ---------------------------------------
# TODO: may smartly disable TBO when batch size is too small b/c it will slow down
def
compute_split_seq_index
(
forward_mode
:
"ForwardMode"
,
num_tokens
:
int
,
extend_lens
:
Optional
[
Sequence
[
int
]],
)
->
Optional
[
int
]:
if
forward_mode
.
is_extend
():
assert
extend_lens
is
not
None
return
_split_array_by_half_sum
(
extend_lens
)
elif
forward_mode
.
is_decode
():
return
num_tokens
//
2
elif
forward_mode
.
is_idle
():
assert
num_tokens
==
0
return
0
else
:
raise
NotImplementedError
def
_split_array_by_half_sum
(
arr
:
Sequence
[
int
])
->
int
:
overall_sum
=
sum
(
arr
)
accumulator
,
split_index
=
0
,
0
for
value
in
arr
[:
-
1
]:
accumulator
+=
value
split_index
+=
1
if
accumulator
>=
overall_sum
//
2
:
break
return
split_index
def
compute_split_token_index
(
split_seq_index
:
int
,
forward_mode
:
"ForwardMode"
,
extend_seq_lens
:
Optional
[
Sequence
[
int
]],
)
->
int
:
if
forward_mode
.
is_extend
():
assert
extend_seq_lens
is
not
None
return
sum
(
extend_seq_lens
[:
split_seq_index
])
elif
forward_mode
.
is_decode
():
return
split_seq_index
elif
forward_mode
.
is_idle
():
assert
split_seq_index
==
0
return
0
else
:
raise
NotImplementedError
# -------------------------------- Preparation ---------------------------------------
class
TboCudaGraphRunnerUtils
:
@
staticmethod
def
compute_tbo_split_seq_index
(
that
:
"CudaGraphRunner"
,
num_tokens
:
int
):
if
that
.
model_runner
.
server_args
.
enable_two_batch_overlap
:
tbo_split_seq_index
=
compute_split_seq_index
(
forward_mode
=
that
.
capture_forward_mode
,
num_tokens
=
num_tokens
,
extend_lens
=
None
,
)
# For simplicity, when two_batch_overlap is enabled, we only capture CUDA Graph for tbo=true
assert
(
tbo_split_seq_index
is
not
None
),
f
"
{
that
.
capture_forward_mode
=
}
{
num_tokens
=
}
"
else
:
tbo_split_seq_index
=
None
return
tbo_split_seq_index
class
TboDPAttentionPreparer
:
def
prepare_all_gather
(
self
,
local_batch
,
deepep_mode
,
enable_deepep_moe
,
enable_two_batch_overlap
):
self
.
enable_two_batch_overlap
=
enable_two_batch_overlap
if
local_batch
is
not
None
:
self
.
local_tbo_split_seq_index
=
compute_split_seq_index
(
forward_mode
=
local_batch
.
forward_mode
,
num_tokens
=
local_batch
.
input_ids
.
shape
[
0
],
extend_lens
=
local_batch
.
extend_lens
,
)
resolved_deepep_mode
=
deepep_mode
.
resolve
(
local_batch
.
forward_mode
)
local_can_run_tbo
=
(
self
.
local_tbo_split_seq_index
is
not
None
)
and
not
(
local_batch
.
forward_mode
.
is_extend
()
and
enable_deepep_moe
and
(
resolved_deepep_mode
==
DeepEPMode
.
low_latency
)
)
else
:
self
.
local_tbo_split_seq_index
=
0
local_can_run_tbo
=
True
local_forward_mode
=
self
.
_compute_local_forward_mode
(
local_batch
)
return
local_can_run_tbo
,
local_forward_mode
def
compute_output
(
self
,
partial_global_info
):
local_can_run_tbo_aggregated
=
min
(
partial_global_info
[:,
0
,
0
].
tolist
())
forward_modes
=
partial_global_info
[:,
0
,
1
].
tolist
()
global_forward_mode
,
forward_mode_agree
=
self
.
_compute_global_forward_mode
(
forward_modes
)
can_run_tbo
=
(
self
.
enable_two_batch_overlap
and
local_can_run_tbo_aggregated
and
forward_mode_agree
)
tbo_split_seq_index
=
self
.
local_tbo_split_seq_index
if
can_run_tbo
else
None
global_forward_mode
=
global_forward_mode
if
can_run_tbo
else
None
return
tbo_split_seq_index
,
global_forward_mode
@
staticmethod
def
_compute_local_forward_mode
(
local_batch
):
return
(
local_batch
.
forward_mode
if
local_batch
is
not
None
else
ForwardMode
.
IDLE
).
value
@
staticmethod
def
_compute_global_forward_mode
(
forward_modes
):
converted_forward_modes
=
[
ForwardMode
.
DECODE
.
value
if
x
==
ForwardMode
.
IDLE
.
value
else
x
for
x
in
forward_modes
]
forward_mode_agree
=
TboDPAttentionPreparer
.
_is_all_same
(
converted_forward_modes
)
global_forward_mode
=
(
ForwardMode
(
converted_forward_modes
[
0
])
if
forward_mode_agree
else
None
)
return
global_forward_mode
,
forward_mode_agree
@
staticmethod
def
_is_all_same
(
x
):
return
all
(
value
==
x
[
0
]
for
value
in
x
)
class
TboForwardBatchPreparer
:
@
classmethod
def
prepare
(
cls
,
batch
:
ForwardBatch
):
from
sglang.srt.layers.attention.tbo_backend
import
TboAttnBackend
if
batch
.
tbo_split_seq_index
is
None
:
return
tbo_split_token_index
=
compute_split_token_index
(
split_seq_index
=
batch
.
tbo_split_seq_index
,
forward_mode
=
batch
.
forward_mode
,
extend_seq_lens
=
batch
.
extend_seq_lens_cpu
,
)
assert
isinstance
(
batch
.
attn_backend
,
TboAttnBackend
)
attn_backend_child_a
,
attn_backend_child_b
=
batch
.
attn_backend
.
children
child_a
=
cls
.
filter_batch
(
batch
,
start_token_index
=
0
,
end_token_index
=
tbo_split_token_index
,
start_seq_index
=
0
,
end_seq_index
=
batch
.
tbo_split_seq_index
,
output_attn_backend
=
attn_backend_child_a
,
)
child_b
=
cls
.
filter_batch
(
batch
,
start_token_index
=
tbo_split_token_index
,
end_token_index
=
batch
.
input_ids
.
shape
[
0
],
start_seq_index
=
batch
.
tbo_split_seq_index
,
end_seq_index
=
batch
.
batch_size
,
output_attn_backend
=
attn_backend_child_b
,
)
assert
batch
.
tbo_children
is
None
batch
.
tbo_children
=
[
child_a
,
child_b
]
@
classmethod
def
filter_batch
(
cls
,
batch
:
ForwardBatch
,
*
,
start_token_index
:
int
,
end_token_index
:
int
,
start_seq_index
:
int
,
end_seq_index
:
int
,
output_attn_backend
:
AttentionBackend
,
):
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
num_tokens
=
batch
.
input_ids
.
shape
[
0
]
num_seqs
=
batch
.
batch_size
output_dict
=
dict
()
for
key
in
[
"input_ids"
,
"positions"
,
"out_cache_loc"
,
]:
old_value
=
getattr
(
batch
,
key
)
assert
(
old_value
.
shape
[
0
]
==
num_tokens
),
f
"
{
key
=
}
{
old_value
=
}
{
num_tokens
=
}
{
batch
=
}
"
output_dict
[
key
]
=
old_value
[
start_token_index
:
end_token_index
]
for
key
in
[
"req_pool_indices"
,
"seq_lens"
,
"seq_lens_cpu"
,
"extend_seq_lens"
,
"extend_prefix_lens"
,
"extend_start_loc"
,
"extend_prefix_lens_cpu"
,
"extend_seq_lens_cpu"
,
"extend_logprob_start_lens_cpu"
,
"lora_paths"
,
]:
old_value
=
getattr
(
batch
,
key
)
if
old_value
is
None
:
continue
assert
(
len
(
old_value
)
==
num_seqs
),
f
"
{
key
=
}
{
old_value
=
}
{
num_seqs
=
}
{
batch
=
}
"
output_dict
[
key
]
=
old_value
[
start_seq_index
:
end_seq_index
]
for
key
in
[
"forward_mode"
,
"return_logprob"
,
"req_to_token_pool"
,
"token_to_kv_pool"
,
"can_run_dp_cuda_graph"
,
"global_forward_mode"
,
"spec_info"
,
"spec_algorithm"
,
"capture_hidden_mode"
,
"padded_static_len"
,
"mrope_positions"
,
# only used by qwen2-vl, thus not care
]:
output_dict
[
key
]
=
getattr
(
batch
,
key
)
assert
(
_compute_extend_num_tokens
(
batch
.
input_ids
,
batch
.
forward_mode
)
==
batch
.
extend_num_tokens
),
f
"
{
batch
=
}
"
extend_num_tokens
=
_compute_extend_num_tokens
(
output_dict
[
"input_ids"
],
output_dict
[
"forward_mode"
]
)
# TODO improve, e.g. unify w/ `init_raw`
if
global_server_args_dict
[
"moe_dense_tp_size"
]
==
1
:
sum_len
=
end_token_index
-
start_token_index
gathered_buffer
=
torch
.
zeros
(
(
sum_len
,
batch
.
gathered_buffer
.
shape
[
1
]),
dtype
=
batch
.
gathered_buffer
.
dtype
,
device
=
batch
.
gathered_buffer
.
device
,
)
else
:
gathered_buffer
=
None
output_dict
.
update
(
dict
(
batch_size
=
end_seq_index
-
start_seq_index
,
seq_lens_sum
=
(
output_dict
[
"seq_lens_cpu"
].
sum
()
if
"seq_lens_cpu"
in
output_dict
else
None
),
extend_num_tokens
=
extend_num_tokens
,
attn_backend
=
output_attn_backend
,
tbo_split_seq_index
=
None
,
tbo_parent_token_range
=
(
start_token_index
,
end_token_index
),
tbo_children
=
None
,
global_num_tokens_gpu
=
None
,
global_num_tokens_cpu
=
None
,
gathered_buffer
=
gathered_buffer
,
global_num_tokens_for_logprob_gpu
=
None
,
global_num_tokens_for_logprob_cpu
=
None
,
sampling_info
=
None
,
# For logits and logprobs post processing, thus we do not care
temp_scaled_logprobs
=
False
,
temperature
=
None
,
top_p_normalized_logprobs
=
False
,
top_p
=
None
,
mm_inputs
=
None
,
num_token_non_padded
=
None
,
)
)
errors
=
[]
for
field
in
dataclasses
.
fields
(
ForwardBatch
):
if
getattr
(
batch
,
field
.
name
)
is
not
None
and
field
.
name
not
in
output_dict
:
errors
.
append
(
f
"Field
{
field
.
name
}
has value, but is not yet supported (value=
{
getattr
(
batch
,
field
.
name
)
}
batch=
{
batch
}
)"
)
if
len
(
errors
)
>
0
:
raise
Exception
(
f
"
{
len
(
errors
)
}
errors happen:
\n
"
+
"
\n\n
"
.
join
(
errors
))
return
ForwardBatch
(
**
output_dict
)
def
_compute_extend_num_tokens
(
input_ids
,
forward_mode
:
ForwardMode
):
if
forward_mode
.
is_extend
():
return
input_ids
.
shape
[
0
]
elif
forward_mode
.
is_decode
()
or
forward_mode
.
is_idle
():
return
None
raise
NotImplementedError
# -------------------------------- Execution ---------------------------------------
def
model_forward_maybe_tbo
(
layers
,
enable_tbo
:
bool
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
hidden_states
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
],
zero_allocator
:
BumpAllocator
,
):
inputs
=
dict
(
positions
=
positions
,
hidden_states
=
hidden_states
,
forward_batch
=
forward_batch
,
residual
=
residual
,
zero_allocator
=
zero_allocator
,
)
operations_strategy
=
OperationsStrategy
.
init_new_tbo
(
layers
,
forward_batch
.
global_forward_mode
)
if
enable_tbo
:
return
_model_forward_tbo
(
inputs
,
operations_strategy
)
else
:
return
_model_forward_non_tbo
(
inputs
,
operations_strategy
)
def
_model_forward_tbo
(
inputs
,
operations_strategy
:
OperationsStrategy
):
# The attn_tp_size!=1 case is not yet extracted to master
assert
get_attention_tp_size
()
==
1
inputs_arr
=
_model_forward_tbo_split_inputs
(
**
inputs
)
del
inputs
with
configure_deep_gemm_num_sms
(
operations_strategy
.
deep_gemm_num_sms
):
outputs_arr
=
execute_overlapped_operations
(
inputs_arr
=
inputs_arr
,
operations_arr
=
[
operations_strategy
.
operations
]
*
2
,
delta_stages
=
[
0
,
operations_strategy
.
tbo_delta_stages
],
)
return
_model_forward_tbo_merge_outputs
(
*
outputs_arr
)
def
_model_forward_non_tbo
(
inputs
,
operations_strategy
:
OperationsStrategy
):
outputs
=
execute_operations
(
inputs
,
operations_strategy
.
operations
)
return
outputs
[
"hidden_states"
],
outputs
[
"residual"
]
def
_model_forward_tbo_split_inputs
(
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
BumpAllocator
,
)
->
List
[
Dict
]:
return
[
dict
(
**
_model_forward_filter_inputs
(
hidden_states
=
hidden_states
,
residual
=
residual
,
positions
=
positions
,
output_forward_batch
=
output_forward_batch
,
tbo_subbatch_index
=
tbo_subbatch_index
,
),
zero_allocator
=
zero_allocator
,
)
for
tbo_subbatch_index
,
output_forward_batch
in
enumerate
(
forward_batch
.
tbo_children
)
]
def
_model_forward_filter_inputs
(
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
output_forward_batch
:
ForwardBatch
,
tbo_subbatch_index
:
int
,
)
->
Dict
:
token_slice
=
slice
(
*
output_forward_batch
.
tbo_parent_token_range
)
return
dict
(
hidden_states
=
hidden_states
[
token_slice
],
residual
=
None
if
residual
is
None
else
residual
[
token_slice
],
positions
=
positions
[
token_slice
],
forward_batch
=
output_forward_batch
,
tbo_subbatch_index
=
tbo_subbatch_index
,
)
def
_model_forward_tbo_merge_outputs
(
output_a
,
output_b
):
def
_handle_key
(
name
):
value_a
=
output_a
[
name
]
value_b
=
output_b
[
name
]
assert
(
value_a
is
None
)
==
(
value_b
is
None
)
if
value_a
is
None
:
return
None
return
torch
.
concat
([
value_a
,
value_b
],
dim
=
0
)
return
_handle_key
(
"hidden_states"
),
_handle_key
(
"residual"
)
# -------------------------------- Utilities and wrappers ---------------------------------------
class
MaybeTboDeepEPDispatcher
:
def
__init__
(
self
,
**
kwargs
):
num_inner_dispatchers
=
(
2
if
global_server_args_dict
[
"enable_two_batch_overlap"
]
else
1
)
self
.
_inners
=
[
DeepEPDispatcher
(
**
kwargs
)
for
_
in
range
(
num_inner_dispatchers
)
]
def
_execute
(
self
,
name
,
tbo_subbatch_index
:
Optional
[
int
]
=
None
,
**
kwargs
):
return
getattr
(
self
.
_inners
[
tbo_subbatch_index
or
0
],
name
)(
**
kwargs
)
def
dispatch
(
self
,
**
kwargs
):
return
self
.
_execute
(
"dispatch"
,
**
kwargs
)
def
dispatch_a
(
self
,
**
kwargs
):
return
self
.
_execute
(
"dispatch_a"
,
**
kwargs
)
def
dispatch_b
(
self
,
**
kwargs
):
return
self
.
_execute
(
"dispatch_b"
,
**
kwargs
)
def
combine
(
self
,
**
kwargs
):
return
self
.
_execute
(
"combine"
,
**
kwargs
)
def
combine_a
(
self
,
**
kwargs
):
return
self
.
_execute
(
"combine_a"
,
**
kwargs
)
def
combine_b
(
self
,
**
kwargs
):
return
self
.
_execute
(
"combine_b"
,
**
kwargs
)
test/srt/test_two_batch_overlap.py
0 → 100644
View file @
0d477880
import
os
import
unittest
from
types
import
SimpleNamespace
import
requests
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
DEFAULT_MLA_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
class
TestTwoBatchOverlap
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--trust-remote-code"
,
"--tp"
,
"2"
,
"--dp"
,
"2"
,
"--enable-dp-attention"
,
"--enable-deepep-moe"
,
"--deepep-mode"
,
"normal"
,
"--disable-cuda-graph"
,
# DeepEP normal does not support CUDA Graph
"--enable-two-batch-overlap"
,
],
env
=
{
"SGL_ENABLE_JIT_DEEPGEMM"
:
"0"
,
**
os
.
environ
},
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_generate_single_prompt
(
self
):
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
# we use an uncommon start to minimise the chance that the cache is hit by chance
json
=
{
"text"
:
"_ 1+1=2, 1+2=3, 1+3=4, 1+4="
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
8
},
},
)
print
(
f
"
{
response
.
json
()
=
}
"
)
self
.
assertEquals
(
response
.
json
()[
"text"
],
"5, 1+5=6"
)
def
test_mmlu
(
self
):
args
=
SimpleNamespace
(
base_url
=
self
.
base_url
,
model
=
self
.
model
,
eval_name
=
"mmlu"
,
num_examples
=
64
,
num_threads
=
32
,
)
metrics
=
run_eval
(
args
)
self
.
assertGreater
(
metrics
[
"score"
],
0.5
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment