Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3c2274fb
"convert/vscode:/vscode.git/clone" did not exist on "5a28b9cf5fcb3994aa1a143118c73c7d1fbf3bf9"
Unverified
Commit
3c2274fb
authored
Jun 15, 2025
by
Cheng Wan
Committed by
GitHub
Jun 15, 2025
Browse files
Implement gather before attn (#6378)
parent
d2679f51
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
5 deletions
+23
-5
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+23
-5
No files found.
python/sglang/srt/layers/communicator.py
View file @
3c2274fb
...
...
@@ -226,13 +226,13 @@ class LayerCommunicator:
@
dataclass
class
CommunicateContext
:
process_group_sizes
:
Dict
[
"
ScatterMode
"
,
int
]
process_group_sizes
:
Dict
[
ScatterMode
,
int
]
attn_tp_rank
:
int
attn_tp_size
:
int
local_attn_dp_size
:
int
tp_size
:
int
def
is_same_group_size
(
self
,
a
:
"
ScatterMode
"
,
b
:
"
ScatterMode
"
):
def
is_same_group_size
(
self
,
a
:
ScatterMode
,
b
:
ScatterMode
):
return
self
.
process_group_sizes
[
a
]
==
self
.
process_group_sizes
[
b
]
@
classmethod
...
...
@@ -244,6 +244,7 @@ class CommunicateContext:
process_group_sizes
=
{
ScatterMode
.
SCATTERED
:
1
,
ScatterMode
.
TP_ATTN_FULL
:
attn_tp_size
,
# TODO: support --moe-dense-tp-size > 1
ScatterMode
.
FULL
:
tp_size
,
}
return
cls
(
...
...
@@ -323,11 +324,16 @@ class CommunicateWithAllReduceAndLayerNormFn:
if
(
(
hidden_states_input_mode
==
ScatterMode
.
TP_ATTN_FULL
)
and
(
residual_input_mode
==
ScatterMode
.
TP_ATTN_FULL
)
and
(
residual_input_mode
in
[
ScatterMode
.
SCATTERED
,
ScatterMode
.
TP_ATTN_FULL
]
)
and
(
hidden_states_output_mode
==
ScatterMode
.
FULL
)
and
(
residual_output_mode
==
ScatterMode
.
TP_ATTN_FULL
)
):
return
CommunicateWithAllReduceAndLayerNormFn
.
_gather_hidden_states
return
partial
(
CommunicateWithAllReduceAndLayerNormFn
.
_gather_hidden_states_and_residual
,
residual_input_mode
=
residual_input_mode
,
)
if
(
(
hidden_states_input_mode
==
ScatterMode
.
TP_ATTN_FULL
)
...
...
@@ -360,13 +366,25 @@ class CommunicateWithAllReduceAndLayerNormFn:
return
hidden_states
,
residual
@
staticmethod
def
_gather_hidden_states
(
def
_gather_hidden_states
_and_residual
(
hidden_states
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
layernorm
:
torch
.
nn
.
Module
,
context
:
CommunicateContext
,
*
,
residual_input_mode
,
):
if
residual_input_mode
==
ScatterMode
.
SCATTERED
and
context
.
attn_tp_size
>
1
:
residual
,
local_residual
=
(
forward_batch
.
gathered_buffer
[
:
forward_batch
.
input_ids
.
shape
[
0
]
].
clone
(),
residual
,
)
attn_tp_all_gather
(
list
(
residual
.
tensor_split
(
context
.
attn_tp_size
)),
local_residual
)
if
context
.
local_attn_dp_size
!=
1
:
if
context
.
attn_tp_rank
==
0
:
hidden_states
+=
residual
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment