Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
32cd7070
Unverified
Commit
32cd7070
authored
May 27, 2025
by
fzyzcjy
Committed by
GitHub
May 26, 2025
Browse files
Support TP in attention for two batch overlap (#6634)
parent
ebd1ed49
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
104 additions
and
8 deletions
+104
-8
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+19
-0
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+3
-0
python/sglang/srt/models/qwen2_moe.py
python/sglang/srt/models/qwen2_moe.py
+6
-1
python/sglang/srt/two_batch_overlap.py
python/sglang/srt/two_batch_overlap.py
+76
-7
No files found.
python/sglang/srt/layers/communicator.py
View file @
32cd7070
...
...
@@ -448,6 +448,13 @@ class CommunicateSummableTensorPairFn:
):
return
CommunicateSummableTensorPairFn
.
_gather
if
(
(
hidden_states_input_mode
==
ScatterMode
.
TP_ATTN_FULL
)
and
(
residual_input_mode
==
ScatterMode
.
TP_ATTN_FULL
)
and
(
output_mode
==
ScatterMode
.
SCATTERED
)
):
return
CommunicateSummableTensorPairFn
.
_scatter
raise
NotImplementedError
(
f
"
{
hidden_states_input_mode
=
}
{
residual_input_mode
=
}
{
output_mode
=
}
"
)
...
...
@@ -496,3 +503,15 @@ class CommunicateSummableTensorPairFn:
local_hidden_states
,
)
return
hidden_states
,
residual
@
staticmethod
def
_scatter
(
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
context
:
CommunicateContext
,
):
assert
residual
is
None
,
"not yet handled residual!=None"
tensor_list
=
list
(
hidden_states
.
tensor_split
(
context
.
attn_tp_size
))
hidden_states
=
tensor_list
[
context
.
attn_tp_rank
]
return
hidden_states
,
residual
python/sglang/srt/models/deepseek_v2.py
View file @
32cd7070
...
...
@@ -1613,6 +1613,9 @@ class DeepseekV2Model(nn.Module):
forward_batch
=
forward_batch
,
hidden_states
=
hidden_states
,
residual
=
residual
,
input_data_scatter_mode
=
self
.
layers
[
normal_num_layers
-
1
].
layer_scatter_modes
.
layer_output_mode
,
zero_allocator
=
zero_allocator
,
)
...
...
python/sglang/srt/models/qwen2_moe.py
View file @
32cd7070
...
...
@@ -32,7 +32,11 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce
,
)
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.communicator
import
LayerCommunicator
,
LayerScatterModes
from
sglang.srt.layers.communicator
import
(
LayerCommunicator
,
LayerScatterModes
,
ScatterMode
,
)
from
sglang.srt.layers.dp_attention
import
(
attn_tp_all_gather
,
attn_tp_reduce_scatter
,
...
...
@@ -447,6 +451,7 @@ class Qwen2MoeModel(nn.Module):
hidden_states
,
residual
=
model_forward_maybe_tbo
(
layers
=
self
.
layers
,
enable_tbo
=
True
,
input_data_scatter_mode
=
ScatterMode
.
model_input_output
(),
positions
=
positions
,
forward_batch
=
forward_batch
,
hidden_states
=
hidden_states
,
...
...
python/sglang/srt/two_batch_overlap.py
View file @
32cd7070
...
...
@@ -5,6 +5,12 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
import
torch
from
sglang.srt.layers.attention.base_attn_backend
import
AttentionBackend
from
sglang.srt.layers.communicator
import
(
CommunicateContext
,
CommunicateSimpleFn
,
CommunicateSummableTensorPairFn
,
ScatterMode
,
)
from
sglang.srt.layers.dp_attention
import
get_attention_tp_size
from
sglang.srt.layers.moe.ep_moe.token_dispatcher
import
DeepEPDispatcher
from
sglang.srt.layers.quantization.deep_gemm
import
configure_deep_gemm_num_sms
...
...
@@ -355,6 +361,7 @@ def model_forward_maybe_tbo(
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
hidden_states
:
torch
.
Tensor
,
input_data_scatter_mode
:
ScatterMode
,
residual
:
Optional
[
torch
.
Tensor
],
zero_allocator
:
Optional
[
BumpAllocator
]
=
None
,
):
...
...
@@ -365,20 +372,32 @@ def model_forward_maybe_tbo(
residual
=
residual
,
**
(
dict
(
zero_allocator
=
zero_allocator
)
if
zero_allocator
is
not
None
else
{}),
)
layer_input_scatter_mode
=
layers
[
0
].
layer_scatter_modes
.
layer_input_mode
operations_strategy
=
OperationsStrategy
.
init_new_tbo
(
layers
,
forward_batch
.
global_forward_mode
)
if
enable_tbo
:
return
_model_forward_tbo
(
inputs
,
operations_strategy
)
return
_model_forward_tbo
(
inputs
=
inputs
,
operations_strategy
=
operations_strategy
,
input_data_scatter_mode
=
input_data_scatter_mode
,
layer_input_scatter_mode
=
layer_input_scatter_mode
,
)
else
:
return
_model_forward_non_tbo
(
inputs
,
operations_strategy
)
def
_model_forward_tbo
(
inputs
,
operations_strategy
:
OperationsStrategy
):
# The attn_tp_size!=1 case is not yet extracted to master
assert
get_attention_tp_size
()
==
1
inputs_arr
=
_model_forward_tbo_split_inputs
(
**
inputs
)
def
_model_forward_tbo
(
inputs
,
operations_strategy
:
OperationsStrategy
,
input_data_scatter_mode
:
ScatterMode
,
layer_input_scatter_mode
:
ScatterMode
,
):
inputs_arr
=
_model_forward_tbo_split_inputs
(
**
inputs
,
input_data_scatter_mode
=
input_data_scatter_mode
,
layer_input_scatter_mode
=
layer_input_scatter_mode
,
)
del
inputs
with
configure_deep_gemm_num_sms
(
operations_strategy
.
deep_gemm_num_sms
):
...
...
@@ -401,7 +420,57 @@ def _model_forward_tbo_split_inputs(
residual
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
Optional
[
BumpAllocator
]
=
None
,
zero_allocator
:
Optional
[
BumpAllocator
],
input_data_scatter_mode
:
ScatterMode
,
layer_input_scatter_mode
:
ScatterMode
,
)
->
List
[
Dict
]:
tbo_splitter_scatter_mode
=
ScatterMode
.
TP_ATTN_FULL
context
=
CommunicateContext
.
init_new
()
hidden_states
,
residual
=
CommunicateSummableTensorPairFn
.
execute
(
hidden_states_input_mode
=
input_data_scatter_mode
,
residual_input_mode
=
input_data_scatter_mode
,
output_mode
=
tbo_splitter_scatter_mode
,
hidden_states
=
hidden_states
,
residual
=
residual
,
forward_batch
=
forward_batch
,
context
=
context
,
)
inputs_arr
=
_model_forward_tbo_split_inputs_raw
(
hidden_states
=
hidden_states
,
residual
=
residual
,
positions
=
positions
,
forward_batch
=
forward_batch
,
zero_allocator
=
zero_allocator
,
)
def
_post_transform
(
hidden_states
,
residual
,
forward_batch
,
**
kwargs
):
hidden_states
,
residual
=
CommunicateSummableTensorPairFn
.
execute
(
hidden_states_input_mode
=
tbo_splitter_scatter_mode
,
residual_input_mode
=
tbo_splitter_scatter_mode
,
output_mode
=
layer_input_scatter_mode
,
hidden_states
=
hidden_states
,
residual
=
residual
,
forward_batch
=
forward_batch
,
context
=
context
,
)
return
dict
(
hidden_states
=
hidden_states
,
residual
=
residual
,
forward_batch
=
forward_batch
,
**
kwargs
,
)
return
[
_post_transform
(
**
inputs
)
for
inputs
in
inputs_arr
]
def
_model_forward_tbo_split_inputs_raw
(
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
zero_allocator
:
Optional
[
BumpAllocator
],
)
->
List
[
Dict
]:
return
[
dict
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment