Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2e7ab862
Unverified
Commit
2e7ab862
authored
Jul 09, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Jul 08, 2025
Browse files
Fix illegal memory in trtllm allreduce fusion (#7864)
parent
51ae4030
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
6 deletions
+8
-6
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+3
-1
python/sglang/srt/layers/flashinfer_comm_fusion.py
python/sglang/srt/layers/flashinfer_comm_fusion.py
+3
-3
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+2
-2
No files found.
python/sglang/srt/layers/communicator.py
View file @
2e7ab862
...
...
@@ -402,12 +402,14 @@ class CommunicateWithAllReduceAndLayerNormFn:
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
=
layernorm
(
hidden_states
)
else
:
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
if
(
_is_sm100_supported
and
_is_flashinfer_available
and
hasattr
(
layernorm
,
"forward_with_allreduce_fusion"
)
and
global_server_args_dict
[
"enable_flashinfer_allreduce_fusion"
]
and
hidden_states
.
shape
[
0
]
<=
1
024
and
hidden_states
.
shape
[
0
]
<=
1
28
):
hidden_states
,
residual
=
layernorm
.
forward_with_allreduce_fusion
(
hidden_states
,
residual
...
...
python/sglang/srt/layers/flashinfer_comm_fusion.py
View file @
2e7ab862
...
...
@@ -92,7 +92,7 @@ _workspace_manager = FlashInferWorkspaceManager()
def
ensure_workspace_initialized
(
max_token_num
:
int
=
1
024
,
hidden_dim
:
int
=
4096
,
use_fp32_lamport
:
bool
=
False
max_token_num
:
int
=
1
28
,
hidden_dim
:
int
=
4096
,
use_fp32_lamport
:
bool
=
False
):
"""Ensure workspace is initialized"""
if
not
is_flashinfer_available
()
or
_flashinfer_comm
is
None
:
...
...
@@ -119,12 +119,12 @@ def ensure_workspace_initialized(
return
_workspace_manager
.
initialized
def
flashinfer_allreduce_
add
_rmsnorm
(
def
flashinfer_allreduce_
residual
_rmsnorm
(
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
max_token_num
:
int
=
1
024
,
max_token_num
:
int
=
1
28
,
use_oneshot
:
bool
=
True
,
trigger_completion_at_end
:
bool
=
False
,
fp32_acc
:
bool
=
False
,
...
...
python/sglang/srt/layers/layernorm.py
View file @
2e7ab862
...
...
@@ -174,11 +174,11 @@ class RMSNorm(CustomOp):
if
residual
is
not
None
:
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.layers.flashinfer_comm_fusion
import
(
flashinfer_allreduce_
add
_rmsnorm
,
flashinfer_allreduce_
residual
_rmsnorm
,
)
if
get_tensor_model_parallel_world_size
()
>
1
:
fused_result
=
flashinfer_allreduce_
add
_rmsnorm
(
fused_result
=
flashinfer_allreduce_
residual
_rmsnorm
(
input_tensor
=
x
,
residual
=
residual
,
weight
=
self
.
weight
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment