Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ebd1ed49
Unverified
Commit
ebd1ed49
authored
May 27, 2025
by
fzyzcjy
Committed by
GitHub
May 26, 2025
Browse files
Tiny refactor communicator (#6646)
parent
f77da699
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
78 additions
and
47 deletions
+78
-47
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+78
-47
No files found.
python/sglang/srt/layers/communicator.py
View file @
ebd1ed49
...
@@ -37,10 +37,23 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...
@@ -37,10 +37,23 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class
ScatterMode
(
Enum
):
class
ScatterMode
(
Enum
):
"""
Suppose we have TP=4, DP=2, enable-dp-attention, and the system handles seq a,b,c,d
Model input/output: [ab, ab, cd, cd] for four ranks respectively
SCATTERED: [a, b, c, d]
TP_ATTN_FULL: [ab, ab, cd, cd], i.e. all ranks inside a TP attn group have full data of the group
FULL: [abcd, abcd, abcd, abcd]
"""
SCATTERED
=
auto
()
SCATTERED
=
auto
()
TP_ATTN_FULL
=
auto
()
TP_ATTN_FULL
=
auto
()
FULL
=
auto
()
FULL
=
auto
()
@
staticmethod
def
model_input_output
():
"""The scatter mode for model forward pass input and output data"""
return
ScatterMode
.
TP_ATTN_FULL
@
dataclass
@
dataclass
class
_LayerModeComputationContext
:
class
_LayerModeComputationContext
:
...
@@ -82,7 +95,7 @@ class LayerScatterModes:
...
@@ -82,7 +95,7 @@ class LayerScatterModes:
@
classmethod
@
classmethod
def
_compute_layer_input_mode
(
cls
,
context
:
_LayerModeComputationContext
):
def
_compute_layer_input_mode
(
cls
,
context
:
_LayerModeComputationContext
):
if
context
.
layer_id
==
0
:
if
context
.
layer_id
==
0
:
return
ScatterMode
.
TP_ATTN_FULL
return
ScatterMode
.
model_input_output
()
return
cls
.
_compute_layer_output_mode
(
context
.
previous_layer
())
return
cls
.
_compute_layer_output_mode
(
context
.
previous_layer
())
@
classmethod
@
classmethod
...
@@ -113,7 +126,7 @@ class LayerScatterModes:
...
@@ -113,7 +126,7 @@ class LayerScatterModes:
def
_compute_layer_output_mode
(
cls
,
context
:
_LayerModeComputationContext
):
def
_compute_layer_output_mode
(
cls
,
context
:
_LayerModeComputationContext
):
mlp_mode
=
cls
.
_compute_mlp_mode
(
context
)
mlp_mode
=
cls
.
_compute_mlp_mode
(
context
)
if
context
.
layer_id
==
context
.
num_layers
-
1
:
if
context
.
layer_id
==
context
.
num_layers
-
1
:
return
ScatterMode
.
TP_ATTN_FULL
return
ScatterMode
.
model_input_output
()
if
mlp_mode
==
ScatterMode
.
SCATTERED
:
if
mlp_mode
==
ScatterMode
.
SCATTERED
:
return
ScatterMode
.
SCATTERED
return
ScatterMode
.
SCATTERED
if
mlp_mode
==
ScatterMode
.
FULL
:
if
mlp_mode
==
ScatterMode
.
FULL
:
...
@@ -136,30 +149,14 @@ class LayerCommunicator:
...
@@ -136,30 +149,14 @@ class LayerCommunicator:
self
.
input_layernorm
=
input_layernorm
self
.
input_layernorm
=
input_layernorm
self
.
post_attention_layernorm
=
post_attention_layernorm
self
.
post_attention_layernorm
=
post_attention_layernorm
self
.
attn_tp_rank
=
get_attention_tp_rank
()
self
.
_context
=
CommunicateContext
.
init_new
()
self
.
attn_tp_size
=
get_attention_tp_size
()
self
.
_communicate_simple_fn
=
CommunicateSimpleFn
.
get_fn
(
self
.
local_attn_dp_size
=
get_local_attention_dp_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
process_group_sizes
=
{
ScatterMode
.
SCATTERED
:
1
,
ScatterMode
.
TP_ATTN_FULL
:
self
.
attn_tp_size
,
ScatterMode
.
FULL
:
self
.
tp_size
,
}
self
.
_context
=
_Context
(
process_group_sizes
=
self
.
process_group_sizes
,
attn_tp_rank
=
self
.
attn_tp_rank
,
attn_tp_size
=
self
.
attn_tp_size
,
local_attn_dp_size
=
self
.
local_attn_dp_size
,
tp_size
=
self
.
tp_size
,
)
self
.
_communicate_simple_fn
=
_CommunicateSimpleFn
.
get_fn
(
input_mode
=
self
.
layer_scatter_modes
.
layer_input_mode
,
input_mode
=
self
.
layer_scatter_modes
.
layer_input_mode
,
output_mode
=
self
.
layer_scatter_modes
.
attn_mode
,
output_mode
=
self
.
layer_scatter_modes
.
attn_mode
,
context
=
self
.
_context
,
context
=
self
.
_context
,
)
)
self
.
_communicate_with_all_reduce_and_layer_norm_fn
=
(
self
.
_communicate_with_all_reduce_and_layer_norm_fn
=
(
_
CommunicateWithAllReduceAndLayerNormFn
.
get_fn
(
CommunicateWithAllReduceAndLayerNormFn
.
get_fn
(
hidden_states_input_mode
=
self
.
layer_scatter_modes
.
attn_mode
,
hidden_states_input_mode
=
self
.
layer_scatter_modes
.
attn_mode
,
residual_input_mode
=
self
.
layer_scatter_modes
.
layer_input_mode
,
residual_input_mode
=
self
.
layer_scatter_modes
.
layer_input_mode
,
hidden_states_output_mode
=
self
.
layer_scatter_modes
.
mlp_mode
,
hidden_states_output_mode
=
self
.
layer_scatter_modes
.
mlp_mode
,
...
@@ -168,7 +165,7 @@ class LayerCommunicator:
...
@@ -168,7 +165,7 @@ class LayerCommunicator:
)
)
)
)
self
.
_communicate_summable_tensor_pair_fn
=
(
self
.
_communicate_summable_tensor_pair_fn
=
(
_
CommunicateSummableTensorPairFn
.
get_fn
(
CommunicateSummableTensorPairFn
.
get_fn
(
hidden_states_input_mode
=
self
.
layer_scatter_modes
.
mlp_mode
,
hidden_states_input_mode
=
self
.
layer_scatter_modes
.
mlp_mode
,
residual_input_mode
=
self
.
layer_scatter_modes
.
middle_residual_mode
,
residual_input_mode
=
self
.
layer_scatter_modes
.
middle_residual_mode
,
output_mode
=
self
.
layer_scatter_modes
.
layer_output_mode
,
output_mode
=
self
.
layer_scatter_modes
.
layer_output_mode
,
...
@@ -228,7 +225,7 @@ class LayerCommunicator:
...
@@ -228,7 +225,7 @@ class LayerCommunicator:
@
dataclass
@
dataclass
class
_
Context
:
class
Communicate
Context
:
process_group_sizes
:
Dict
[
"ScatterMode"
,
int
]
process_group_sizes
:
Dict
[
"ScatterMode"
,
int
]
attn_tp_rank
:
int
attn_tp_rank
:
int
attn_tp_size
:
int
attn_tp_size
:
int
...
@@ -238,21 +235,40 @@ class _Context:
...
@@ -238,21 +235,40 @@ class _Context:
def
is_same_group_size
(
self
,
a
:
"ScatterMode"
,
b
:
"ScatterMode"
):
def
is_same_group_size
(
self
,
a
:
"ScatterMode"
,
b
:
"ScatterMode"
):
return
self
.
process_group_sizes
[
a
]
==
self
.
process_group_sizes
[
b
]
return
self
.
process_group_sizes
[
a
]
==
self
.
process_group_sizes
[
b
]
@
classmethod
def
init_new
(
cls
):
attn_tp_rank
=
get_attention_tp_rank
()
attn_tp_size
=
get_attention_tp_size
()
local_attn_dp_size
=
get_local_attention_dp_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
process_group_sizes
=
{
ScatterMode
.
SCATTERED
:
1
,
ScatterMode
.
TP_ATTN_FULL
:
attn_tp_size
,
ScatterMode
.
FULL
:
tp_size
,
}
return
cls
(
process_group_sizes
=
process_group_sizes
,
attn_tp_rank
=
attn_tp_rank
,
attn_tp_size
=
attn_tp_size
,
local_attn_dp_size
=
local_attn_dp_size
,
tp_size
=
tp_size
,
)
class
_
CommunicateSimpleFn
:
class
CommunicateSimpleFn
:
@
staticmethod
@
staticmethod
def
get_fn
(
def
get_fn
(
input_mode
:
ScatterMode
,
input_mode
:
ScatterMode
,
output_mode
:
ScatterMode
,
output_mode
:
ScatterMode
,
context
:
_
Context
,
context
:
Communicate
Context
,
):
):
if
context
.
is_same_group_size
(
input_mode
,
output_mode
):
if
context
.
is_same_group_size
(
input_mode
,
output_mode
):
return
_
CommunicateSimpleFn
.
_trivial
return
CommunicateSimpleFn
.
_trivial
if
(
input_mode
==
ScatterMode
.
SCATTERED
)
and
(
if
(
input_mode
==
ScatterMode
.
SCATTERED
)
and
(
output_mode
==
ScatterMode
.
TP_ATTN_FULL
output_mode
==
ScatterMode
.
TP_ATTN_FULL
):
):
return
_
CommunicateSimpleFn
.
_scattered_to_tp_attn_full
return
CommunicateSimpleFn
.
_scattered_to_tp_attn_full
raise
NotImplementedError
(
f
"
{
input_mode
=
}
{
output_mode
=
}
"
)
raise
NotImplementedError
(
f
"
{
input_mode
=
}
{
output_mode
=
}
"
)
...
@@ -260,7 +276,7 @@ class _CommunicateSimpleFn:
...
@@ -260,7 +276,7 @@ class _CommunicateSimpleFn:
def
_trivial
(
def
_trivial
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
context
:
_
Context
,
context
:
Communicate
Context
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
hidden_states
return
hidden_states
...
@@ -268,7 +284,7 @@ class _CommunicateSimpleFn:
...
@@ -268,7 +284,7 @@ class _CommunicateSimpleFn:
def
_scattered_to_tp_attn_full
(
def
_scattered_to_tp_attn_full
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
context
:
_
Context
,
context
:
Communicate
Context
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
,
local_hidden_states
=
(
hidden_states
,
local_hidden_states
=
(
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
forward_batch
.
gathered_buffer
[:
forward_batch
.
input_ids
.
shape
[
0
]],
...
@@ -281,7 +297,7 @@ class _CommunicateSimpleFn:
...
@@ -281,7 +297,7 @@ class _CommunicateSimpleFn:
return
hidden_states
return
hidden_states
class
_
CommunicateWithAllReduceAndLayerNormFn
:
class
CommunicateWithAllReduceAndLayerNormFn
:
"""Besides communication, needs to
"""Besides communication, needs to
1. All reduce in tp_attn_group on hidden_states
1. All reduce in tp_attn_group on hidden_states
2. Apply layer norm
2. Apply layer norm
...
@@ -293,7 +309,7 @@ class _CommunicateWithAllReduceAndLayerNormFn:
...
@@ -293,7 +309,7 @@ class _CommunicateWithAllReduceAndLayerNormFn:
residual_input_mode
:
ScatterMode
,
residual_input_mode
:
ScatterMode
,
hidden_states_output_mode
:
ScatterMode
,
hidden_states_output_mode
:
ScatterMode
,
residual_output_mode
:
ScatterMode
,
residual_output_mode
:
ScatterMode
,
context
:
_
Context
,
context
:
Communicate
Context
,
):
):
if
(
if
(
...
@@ -303,7 +319,7 @@ class _CommunicateWithAllReduceAndLayerNormFn:
...
@@ -303,7 +319,7 @@ class _CommunicateWithAllReduceAndLayerNormFn:
and
context
.
is_same_group_size
(
residual_input_mode
,
residual_output_mode
)
and
context
.
is_same_group_size
(
residual_input_mode
,
residual_output_mode
)
and
context
.
attn_tp_size
==
1
and
context
.
attn_tp_size
==
1
):
):
return
_
CommunicateWithAllReduceAndLayerNormFn
.
_simple
return
CommunicateWithAllReduceAndLayerNormFn
.
_simple
if
(
if
(
(
hidden_states_input_mode
==
ScatterMode
.
TP_ATTN_FULL
)
(
hidden_states_input_mode
==
ScatterMode
.
TP_ATTN_FULL
)
...
@@ -311,7 +327,7 @@ class _CommunicateWithAllReduceAndLayerNormFn:
...
@@ -311,7 +327,7 @@ class _CommunicateWithAllReduceAndLayerNormFn:
and
(
hidden_states_output_mode
==
ScatterMode
.
FULL
)
and
(
hidden_states_output_mode
==
ScatterMode
.
FULL
)
and
(
residual_output_mode
==
ScatterMode
.
TP_ATTN_FULL
)
and
(
residual_output_mode
==
ScatterMode
.
TP_ATTN_FULL
)
):
):
return
_
CommunicateWithAllReduceAndLayerNormFn
.
_gather_hidden_states
return
CommunicateWithAllReduceAndLayerNormFn
.
_gather_hidden_states
if
(
if
(
(
hidden_states_input_mode
==
ScatterMode
.
TP_ATTN_FULL
)
(
hidden_states_input_mode
==
ScatterMode
.
TP_ATTN_FULL
)
...
@@ -322,7 +338,7 @@ class _CommunicateWithAllReduceAndLayerNormFn:
...
@@ -322,7 +338,7 @@ class _CommunicateWithAllReduceAndLayerNormFn:
and
(
residual_output_mode
==
ScatterMode
.
SCATTERED
)
and
(
residual_output_mode
==
ScatterMode
.
SCATTERED
)
):
):
return
partial
(
return
partial
(
_
CommunicateWithAllReduceAndLayerNormFn
.
_scatter_hidden_states_and_residual
,
CommunicateWithAllReduceAndLayerNormFn
.
_scatter_hidden_states_and_residual
,
residual_input_mode
=
residual_input_mode
,
residual_input_mode
=
residual_input_mode
,
)
)
...
@@ -336,7 +352,7 @@ class _CommunicateWithAllReduceAndLayerNormFn:
...
@@ -336,7 +352,7 @@ class _CommunicateWithAllReduceAndLayerNormFn:
residual
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
layernorm
:
torch
.
nn
.
Module
,
layernorm
:
torch
.
nn
.
Module
,
context
:
_
Context
,
context
:
Communicate
Context
,
):
):
# TODO move these `if shape != 0` into LayerNorm itself
# TODO move these `if shape != 0` into LayerNorm itself
if
hidden_states
.
shape
[
0
]
!=
0
:
if
hidden_states
.
shape
[
0
]
!=
0
:
...
@@ -349,7 +365,7 @@ class _CommunicateWithAllReduceAndLayerNormFn:
...
@@ -349,7 +365,7 @@ class _CommunicateWithAllReduceAndLayerNormFn:
residual
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
layernorm
:
torch
.
nn
.
Module
,
layernorm
:
torch
.
nn
.
Module
,
context
:
_
Context
,
context
:
Communicate
Context
,
):
):
if
context
.
local_attn_dp_size
!=
1
:
if
context
.
local_attn_dp_size
!=
1
:
if
context
.
attn_tp_rank
==
0
:
if
context
.
attn_tp_rank
==
0
:
...
@@ -373,7 +389,7 @@ class _CommunicateWithAllReduceAndLayerNormFn:
...
@@ -373,7 +389,7 @@ class _CommunicateWithAllReduceAndLayerNormFn:
residual
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
layernorm
:
torch
.
nn
.
Module
,
layernorm
:
torch
.
nn
.
Module
,
context
:
_
Context
,
context
:
Communicate
Context
,
*
,
*
,
residual_input_mode
,
residual_input_mode
,
):
):
...
@@ -387,35 +403,50 @@ class _CommunicateWithAllReduceAndLayerNormFn:
...
@@ -387,35 +403,50 @@ class _CommunicateWithAllReduceAndLayerNormFn:
return
hidden_states
,
residual
return
hidden_states
,
residual
class
_CommunicateSummableTensorPairFn
:
class
CommunicateSummableTensorPairFn
:
"""It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed."""
@
classmethod
def
execute
(
cls
,
hidden_states_input_mode
,
residual_input_mode
,
output_mode
,
context
,
**
kwargs
,
):
return
cls
.
get_fn
(
hidden_states_input_mode
=
hidden_states_input_mode
,
residual_input_mode
=
residual_input_mode
,
output_mode
=
output_mode
,
context
=
context
,
)(
context
=
context
,
**
kwargs
)
@
staticmethod
@
staticmethod
def
get_fn
(
def
get_fn
(
hidden_states_input_mode
:
ScatterMode
,
hidden_states_input_mode
:
ScatterMode
,
residual_input_mode
:
ScatterMode
,
residual_input_mode
:
ScatterMode
,
output_mode
:
ScatterMode
,
output_mode
:
ScatterMode
,
context
:
_
Context
,
context
:
Communicate
Context
,
):
):
"""It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed."""
if
context
.
is_same_group_size
(
if
context
.
is_same_group_size
(
hidden_states_input_mode
,
output_mode
hidden_states_input_mode
,
output_mode
)
and
context
.
is_same_group_size
(
residual_input_mode
,
output_mode
):
)
and
context
.
is_same_group_size
(
residual_input_mode
,
output_mode
):
return
_
CommunicateSummableTensorPairFn
.
_trivial
return
CommunicateSummableTensorPairFn
.
_trivial
if
(
if
(
(
hidden_states_input_mode
==
ScatterMode
.
FULL
)
(
hidden_states_input_mode
==
ScatterMode
.
FULL
)
and
(
residual_input_mode
==
ScatterMode
.
TP_ATTN_FULL
)
and
(
residual_input_mode
==
ScatterMode
.
TP_ATTN_FULL
)
and
(
output_mode
==
ScatterMode
.
TP_ATTN_FULL
)
and
(
output_mode
==
ScatterMode
.
TP_ATTN_FULL
)
):
):
return
_
CommunicateSummableTensorPairFn
.
_scatter_hidden_states
return
CommunicateSummableTensorPairFn
.
_scatter_hidden_states
if
(
if
(
(
hidden_states_input_mode
==
ScatterMode
.
SCATTERED
)
(
hidden_states_input_mode
==
ScatterMode
.
SCATTERED
)
and
(
residual_input_mode
==
ScatterMode
.
SCATTERED
)
and
(
residual_input_mode
==
ScatterMode
.
SCATTERED
)
and
(
output_mode
==
ScatterMode
.
TP_ATTN_FULL
)
and
(
output_mode
==
ScatterMode
.
TP_ATTN_FULL
)
):
):
return
_
CommunicateSummableTensorPairFn
.
_gather
return
CommunicateSummableTensorPairFn
.
_gather
raise
NotImplementedError
(
raise
NotImplementedError
(
f
"
{
hidden_states_input_mode
=
}
{
residual_input_mode
=
}
{
output_mode
=
}
"
f
"
{
hidden_states_input_mode
=
}
{
residual_input_mode
=
}
{
output_mode
=
}
"
...
@@ -426,7 +457,7 @@ class _CommunicateSummableTensorPairFn:
...
@@ -426,7 +457,7 @@ class _CommunicateSummableTensorPairFn:
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
context
:
_
Context
,
context
:
Communicate
Context
,
):
):
return
hidden_states
,
residual
return
hidden_states
,
residual
...
@@ -435,7 +466,7 @@ class _CommunicateSummableTensorPairFn:
...
@@ -435,7 +466,7 @@ class _CommunicateSummableTensorPairFn:
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
context
:
_
Context
,
context
:
Communicate
Context
,
):
):
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# important: forward batch.gathered_buffer is used both after scatter and after gather.
...
@@ -452,7 +483,7 @@ class _CommunicateSummableTensorPairFn:
...
@@ -452,7 +483,7 @@ class _CommunicateSummableTensorPairFn:
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
context
:
_
Context
,
context
:
Communicate
Context
,
):
):
hidden_states
+=
residual
hidden_states
+=
residual
residual
=
None
residual
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment