Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
2e7ab862
"vscode:/vscode.git/clone" did not exist on "244f34c24e0a56433c2a0c996994d5271af41159"
Unverified
Commit
2e7ab862
authored
Jul 09, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Jul 08, 2025
Browse files
Fix illegal memory in trtllm allreduce fusion (#7864)
parent
51ae4030
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
6 deletions
+8
-6
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+3
-1
python/sglang/srt/layers/flashinfer_comm_fusion.py
python/sglang/srt/layers/flashinfer_comm_fusion.py
+3
-3
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+2
-2
No files found.
python/sglang/srt/layers/communicator.py
View file @
2e7ab862
...
...
@@ -402,12 +402,14 @@ class CommunicateWithAllReduceAndLayerNormFn:
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
=
layernorm
(
hidden_states
)
else
:
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
if
(
_is_sm100_supported
and
_is_flashinfer_available
and
hasattr
(
layernorm
,
"forward_with_allreduce_fusion"
)
and
global_server_args_dict
[
"enable_flashinfer_allreduce_fusion"
]
and
hidden_states
.
shape
[
0
]
<=
1
024
and
hidden_states
.
shape
[
0
]
<=
1
28
):
hidden_states
,
residual
=
layernorm
.
forward_with_allreduce_fusion
(
hidden_states
,
residual
...
...
python/sglang/srt/layers/flashinfer_comm_fusion.py
View file @
2e7ab862
...
...
@@ -92,7 +92,7 @@ _workspace_manager = FlashInferWorkspaceManager()
def
ensure_workspace_initialized
(
max_token_num
:
int
=
1
024
,
hidden_dim
:
int
=
4096
,
use_fp32_lamport
:
bool
=
False
max_token_num
:
int
=
1
28
,
hidden_dim
:
int
=
4096
,
use_fp32_lamport
:
bool
=
False
):
"""Ensure workspace is initialized"""
if
not
is_flashinfer_available
()
or
_flashinfer_comm
is
None
:
...
...
@@ -119,12 +119,12 @@ def ensure_workspace_initialized(
return
_workspace_manager
.
initialized
def
flashinfer_allreduce_
add
_rmsnorm
(
def
flashinfer_allreduce_
residual
_rmsnorm
(
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
max_token_num
:
int
=
1
024
,
max_token_num
:
int
=
1
28
,
use_oneshot
:
bool
=
True
,
trigger_completion_at_end
:
bool
=
False
,
fp32_acc
:
bool
=
False
,
...
...
python/sglang/srt/layers/layernorm.py
View file @
2e7ab862
...
...
@@ -174,11 +174,11 @@ class RMSNorm(CustomOp):
if
residual
is
not
None
:
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.layers.flashinfer_comm_fusion
import
(
flashinfer_allreduce_
add
_rmsnorm
,
flashinfer_allreduce_
residual
_rmsnorm
,
)
if
get_tensor_model_parallel_world_size
()
>
1
:
fused_result
=
flashinfer_allreduce_
add
_rmsnorm
(
fused_result
=
flashinfer_allreduce_
residual
_rmsnorm
(
input_tensor
=
x
,
residual
=
residual
,
weight
=
self
.
weight
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment