Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8e64140e
Unverified
Commit
8e64140e
authored
Jul 03, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Jul 02, 2025
Browse files
[b200] support trt-llm allreduce fuse rms_norm_add kernel (#7621)
parent
82f021e2
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
253 additions
and
2 deletions
+253
-2
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+18
-2
python/sglang/srt/layers/flashinfer_comm_fusion.py
python/sglang/srt/layers/flashinfer_comm_fusion.py
+202
-0
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+26
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
No files found.
python/sglang/srt/layers/communicator.py
View file @
8e64140e
...
@@ -32,8 +32,13 @@ from sglang.srt.layers.dp_attention import (
...
@@ -32,8 +32,13 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_rank
,
get_attention_tp_rank
,
get_attention_tp_size
,
get_attention_tp_size
,
)
)
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.utils
import
is_cuda
,
is_flashinfer_available
_is_flashinfer_available
=
is_flashinfer_available
()
_is_sm100_supported
=
is_cuda
()
and
is_sm100_supported
()
class
ScatterMode
(
Enum
):
class
ScatterMode
(
Enum
):
...
@@ -396,6 +401,17 @@ class CommunicateWithAllReduceAndLayerNormFn:
...
@@ -396,6 +401,17 @@ class CommunicateWithAllReduceAndLayerNormFn:
dp_scatter
(
residual
,
hidden_states
,
forward_batch
)
dp_scatter
(
residual
,
hidden_states
,
forward_batch
)
if
hidden_states
.
shape
[
0
]
!=
0
:
if
hidden_states
.
shape
[
0
]
!=
0
:
hidden_states
=
layernorm
(
hidden_states
)
hidden_states
=
layernorm
(
hidden_states
)
else
:
if
(
_is_sm100_supported
and
_is_flashinfer_available
and
hasattr
(
layernorm
,
"forward_with_allreduce_fusion"
)
and
global_server_args_dict
[
"enable_flashinfer_allreduce_fusion"
]
and
hidden_states
.
shape
[
0
]
<=
1024
):
hidden_states
,
residual
=
layernorm
.
forward_with_allreduce_fusion
(
hidden_states
,
residual
)
else
:
else
:
hidden_states
=
tensor_model_parallel_all_reduce
(
hidden_states
)
hidden_states
=
tensor_model_parallel_all_reduce
(
hidden_states
)
hidden_states
,
residual
=
layernorm
(
hidden_states
,
residual
)
hidden_states
,
residual
=
layernorm
(
hidden_states
,
residual
)
...
...
python/sglang/srt/layers/flashinfer_comm_fusion.py
0 → 100644
View file @
8e64140e
import
logging
from
typing
import
Tuple
import
torch
import
torch.distributed
as
dist
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.utils
import
is_flashinfer_available
logger
=
logging
.
getLogger
(
__name__
)
_flashinfer_comm
=
None
_workspace_manager
=
None
if
is_flashinfer_available
():
try
:
import
flashinfer.comm
as
comm
_flashinfer_comm
=
comm
except
ImportError
:
logger
.
warning
(
"flashinfer.comm is not available, falling back to standard "
"implementation"
)
class
FlashInferWorkspaceManager
:
def
__init__
(
self
):
self
.
workspace_tensor
=
None
self
.
ipc_handles
=
None
self
.
world_size
=
None
self
.
rank
=
None
self
.
initialized
=
False
def
initialize
(
self
,
world_size
:
int
,
rank
:
int
,
max_token_num
:
int
,
hidden_dim
:
int
,
group
=
None
,
use_fp32_lamport
:
bool
=
False
,
):
"""Initialize workspace"""
if
self
.
initialized
and
self
.
world_size
==
world_size
:
return
if
_flashinfer_comm
is
None
:
logger
.
warning
(
"FlashInfer comm not available, skipping workspace "
"initialization"
)
return
self
.
cleanup
()
self
.
ipc_handles
,
self
.
workspace_tensor
=
(
comm
.
trtllm_create_ipc_workspace_for_all_reduce_fusion
(
rank
,
world_size
,
max_token_num
,
hidden_dim
,
group
=
group
,
use_fp32_lamport
=
use_fp32_lamport
,
)
)
self
.
world_size
=
world_size
self
.
rank
=
rank
self
.
initialized
=
True
logger
.
info
(
f
"FlashInfer workspace initialized for rank
{
rank
}
, "
f
"world_size
{
world_size
}
"
)
def
cleanup
(
self
):
"""Clean up workspace"""
if
self
.
initialized
and
self
.
ipc_handles
is
not
None
:
try
:
_flashinfer_comm
.
trtllm_destroy_ipc_workspace_for_all_reduce
(
self
.
ipc_handles
,
group
=
dist
.
group
.
WORLD
)
except
Exception
as
e
:
logger
.
warning
(
f
"Failed to cleanup FlashInfer workspace:
{
e
}
"
)
finally
:
self
.
workspace_tensor
=
None
self
.
ipc_handles
=
None
self
.
initialized
=
False
_workspace_manager
=
FlashInferWorkspaceManager
()
def
ensure_workspace_initialized
(
max_token_num
:
int
=
1024
,
hidden_dim
:
int
=
4096
,
use_fp32_lamport
:
bool
=
False
):
"""Ensure workspace is initialized"""
if
not
is_flashinfer_available
()
or
_flashinfer_comm
is
None
:
return
False
world_size
=
get_tensor_model_parallel_world_size
()
if
world_size
<=
1
:
return
False
rank
=
dist
.
get_rank
()
if
(
not
_workspace_manager
.
initialized
or
_workspace_manager
.
world_size
!=
world_size
):
_workspace_manager
.
initialize
(
world_size
=
world_size
,
rank
=
rank
,
max_token_num
=
max_token_num
,
hidden_dim
=
hidden_dim
,
use_fp32_lamport
=
use_fp32_lamport
,
)
return
_workspace_manager
.
initialized
def
flashinfer_allreduce_add_rmsnorm
(
input_tensor
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
max_token_num
:
int
=
1024
,
use_oneshot
:
bool
=
True
,
trigger_completion_at_end
:
bool
=
False
,
fp32_acc
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Use FlashInfer's fused allreduce + residual + RMS norm operation
Args:
input_tensor: Input tensor that needs allreduce
residual: Residual tensor
weight: RMS norm weight
eps: RMS norm epsilon
max_token_num: Maximum token number
use_oneshot: Whether to use oneshot mode
trigger_completion_at_end: Whether to trigger completion at end
fp32_acc: Whether to use fp32 precision
Returns:
Tuple[torch.Tensor, torch.Tensor]: (norm_output, residual_output)
"""
if
not
is_flashinfer_available
()
or
_flashinfer_comm
is
None
:
logger
.
debug
(
"FlashInfer not available, falling back to standard "
"implementation"
)
return
None
,
None
world_size
=
get_tensor_model_parallel_world_size
()
if
world_size
<=
1
:
logger
.
debug
(
"Single GPU, no need for allreduce fusion"
)
return
None
,
None
if
not
ensure_workspace_initialized
(
max_token_num
=
max_token_num
,
hidden_dim
=
input_tensor
.
shape
[
-
1
],
use_fp32_lamport
=
(
input_tensor
.
dtype
==
torch
.
float32
),
):
logger
.
debug
(
"FlashInfer workspace not available"
)
return
None
,
None
token_num
,
hidden_dim
=
input_tensor
.
shape
residual_out
=
torch
.
empty_like
(
residual
)
norm_out
=
torch
.
empty_like
(
input_tensor
)
_flashinfer_comm
.
trtllm_allreduce_fusion
(
allreduce_in
=
input_tensor
,
world_size
=
world_size
,
world_rank
=
dist
.
get_rank
(),
token_num
=
token_num
,
hidden_dim
=
hidden_dim
,
workspace_ptrs
=
_workspace_manager
.
workspace_tensor
,
launch_with_pdl
=
True
,
use_oneshot
=
use_oneshot
,
trigger_completion_at_end
=
trigger_completion_at_end
,
fp32_acc
=
fp32_acc
,
pattern_code
=
(
_flashinfer_comm
.
AllReduceFusionPattern
.
kARResidualRMSNorm
),
allreduce_out
=
None
,
residual_in
=
residual
,
residual_out
=
residual_out
,
norm_out
=
norm_out
,
quant_out
=
None
,
scale_out
=
None
,
rms_gamma
=
weight
,
rms_eps
=
eps
,
scale_factor
=
None
,
layout_code
=
None
,
)
return
norm_out
,
residual_out
def
cleanup_flashinfer_workspace
():
global
_workspace_manager
if
_workspace_manager
is
not
None
:
_workspace_manager
.
cleanup
()
python/sglang/srt/layers/layernorm.py
View file @
8e64140e
...
@@ -163,6 +163,32 @@ class RMSNorm(CustomOp):
...
@@ -163,6 +163,32 @@ class RMSNorm(CustomOp):
else
:
else
:
return
self
.
forward_native
(
x
,
residual
)
return
self
.
forward_native
(
x
,
residual
)
def
forward_with_allreduce_fusion
(
self
,
x
:
torch
.
Tensor
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
"""
Forward method with allreduce fusion, prioritizing flashinfer fused operations
"""
if
residual
is
not
None
:
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.layers.flashinfer_comm_fusion
import
(
flashinfer_allreduce_add_rmsnorm
,
)
if
get_tensor_model_parallel_world_size
()
>
1
:
fused_result
=
flashinfer_allreduce_add_rmsnorm
(
input_tensor
=
x
,
residual
=
residual
,
weight
=
self
.
weight
,
eps
=
self
.
variance_epsilon
,
)
if
fused_result
[
0
]
is
not
None
:
return
fused_result
return
self
.
forward
(
x
,
residual
)
class
GemmaRMSNorm
(
CustomOp
):
class
GemmaRMSNorm
(
CustomOp
):
def
__init__
(
def
__init__
(
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
8e64140e
...
@@ -85,6 +85,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
...
@@ -85,6 +85,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"deepep_mode"
,
"deepep_mode"
,
"enable_ep_moe"
,
"enable_ep_moe"
,
"enable_flashinfer_moe"
,
"enable_flashinfer_moe"
,
"enable_flashinfer_allreduce_fusion"
,
"moe_dense_tp_size"
,
"moe_dense_tp_size"
,
"ep_dispatch_algorithm"
,
"ep_dispatch_algorithm"
,
"deepep_config"
,
"deepep_config"
,
...
...
python/sglang/srt/server_args.py
View file @
8e64140e
...
@@ -157,6 +157,7 @@ class ServerArgs:
...
@@ -157,6 +157,7 @@ class ServerArgs:
enable_ep_moe
:
bool
=
False
enable_ep_moe
:
bool
=
False
enable_deepep_moe
:
bool
=
False
enable_deepep_moe
:
bool
=
False
enable_flashinfer_moe
:
bool
=
False
enable_flashinfer_moe
:
bool
=
False
enable_flashinfer_allreduce_fusion
:
bool
=
False
deepep_mode
:
Optional
[
Literal
[
"auto"
,
"normal"
,
"low_latency"
]]
=
"auto"
deepep_mode
:
Optional
[
Literal
[
"auto"
,
"normal"
,
"low_latency"
]]
=
"auto"
ep_num_redundant_experts
:
int
=
0
ep_num_redundant_experts
:
int
=
0
ep_dispatch_algorithm
:
Optional
[
Literal
[
"static"
,
"dynamic"
,
"fake"
]]
=
None
ep_dispatch_algorithm
:
Optional
[
Literal
[
"static"
,
"dynamic"
,
"fake"
]]
=
None
...
@@ -1206,6 +1207,11 @@ class ServerArgs:
...
@@ -1206,6 +1207,11 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe"
,
help
=
"Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe"
,
)
)
parser
.
add_argument
(
"--enable-flashinfer-allreduce-fusion"
,
action
=
"store_true"
,
help
=
"Enable FlashInfer allreduce fusion for Add_RMSNorm."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-deepep-moe"
,
"--enable-deepep-moe"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment