Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
44e86480
Unverified
Commit
44e86480
authored
Aug 12, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Aug 11, 2025
Browse files
fuse allreduce and residual_rmsnorm (#8731)
parent
8c07fabd
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
134 additions
and
58 deletions
+134
-58
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+1
-1
python/sglang/srt/layers/flashinfer_comm_fusion.py
python/sglang/srt/layers/flashinfer_comm_fusion.py
+3
-3
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+1
-0
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+5
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+48
-33
python/sglang/srt/models/glm4_moe.py
python/sglang/srt/models/glm4_moe.py
+9
-9
python/sglang/srt/models/gpt_oss.py
python/sglang/srt/models/gpt_oss.py
+66
-10
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+1
-1
No files found.
python/sglang/srt/layers/communicator.py
View file @
44e86480
...
...
@@ -441,7 +441,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
and
_is_flashinfer_available
and
hasattr
(
layernorm
,
"forward_with_allreduce_fusion"
)
and
global_server_args_dict
[
"enable_flashinfer_allreduce_fusion"
]
and
hidden_states
.
shape
[
0
]
<=
1
28
and
hidden_states
.
shape
[
0
]
<=
2
04
8
):
hidden_states
,
residual
=
layernorm
.
forward_with_allreduce_fusion
(
hidden_states
,
residual
...
...
python/sglang/srt/layers/flashinfer_comm_fusion.py
View file @
44e86480
...
...
@@ -92,7 +92,7 @@ _workspace_manager = FlashInferWorkspaceManager()
def
ensure_workspace_initialized
(
max_token_num
:
int
=
1
28
,
hidden_dim
:
int
=
4096
,
use_fp32_lamport
:
bool
=
False
max_token_num
:
int
=
2
04
8
,
hidden_dim
:
int
=
4096
,
use_fp32_lamport
:
bool
=
False
):
"""Ensure workspace is initialized"""
if
not
is_flashinfer_available
()
or
_flashinfer_comm
is
None
:
...
...
@@ -124,8 +124,8 @@ def flashinfer_allreduce_residual_rmsnorm(
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
eps
:
float
=
1e-6
,
max_token_num
:
int
=
1
28
,
use_oneshot
:
bool
=
Tru
e
,
max_token_num
:
int
=
2
04
8
,
use_oneshot
:
Optional
[
bool
]
=
Non
e
,
trigger_completion_at_end
:
bool
=
False
,
fp32_acc
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
python/sglang/srt/layers/linear.py
View file @
44e86480
...
...
@@ -1294,6 +1294,7 @@ class RowParallelLinear(LinearBase):
with
use_symmetric_memory
(
parallel_state
.
get_tp_group
())
as
sm
:
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
)
sm
.
tag
(
output_parallel
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
and
not
skip_all_reduce
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
44e86480
...
...
@@ -847,10 +847,14 @@ class FusedMoE(torch.nn.Module):
)
sm
.
tag
(
final_hidden_states
)
final_hidden_states
=
final_hidden_states
[
...,
:
origin_hidden_states_dim
].
contiguous
()
if
self
.
reduce_results
and
(
self
.
moe_tp_size
>
1
or
self
.
moe_ep_size
>
1
):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
[...,
:
origin_hidden_states_dim
].
contiguous
()
return
final_hidden_states
@
classmethod
def
make_expert_params_mapping
(
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
44e86480
...
...
@@ -212,7 +212,7 @@ class DeepseekV2MLP(nn.Module):
self
,
x
,
forward_batch
=
None
,
can_fuse_mlp
_allreduce
:
bool
=
False
,
should
_allreduce
_fusion
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
):
if
(
self
.
tp_size
==
1
)
and
x
.
shape
[
0
]
==
0
:
...
...
@@ -221,7 +221,7 @@ class DeepseekV2MLP(nn.Module):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
,
skip_all_reduce
=
can_fuse_mlp
_allreduce
or
use_reduce_scatter
x
,
skip_all_reduce
=
should
_allreduce
_fusion
or
use_reduce_scatter
)
return
x
...
...
@@ -448,7 +448,7 @@ class DeepseekV2MoE(nn.Module):
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
Optional
[
ForwardBatch
]
=
None
,
can_fuse_mlp
_allreduce
:
bool
=
False
,
should
_allreduce
_fusion
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
)
->
torch
.
Tensor
:
if
not
self
.
_enable_deepep_moe
:
...
...
@@ -459,11 +459,11 @@ class DeepseekV2MoE(nn.Module):
and
hidden_states
.
shape
[
0
]
<=
DUAL_STREAM_TOKEN_THRESHOLD
):
return
self
.
forward_normal_dual_stream
(
hidden_states
,
can_fuse_mlp
_allreduce
,
use_reduce_scatter
hidden_states
,
should
_allreduce
_fusion
,
use_reduce_scatter
)
else
:
return
self
.
forward_normal
(
hidden_states
,
can_fuse_mlp
_allreduce
,
use_reduce_scatter
hidden_states
,
should
_allreduce
_fusion
,
use_reduce_scatter
)
else
:
return
self
.
forward_deepep
(
hidden_states
,
forward_batch
)
...
...
@@ -471,7 +471,7 @@ class DeepseekV2MoE(nn.Module):
def
forward_normal_dual_stream
(
self
,
hidden_states
:
torch
.
Tensor
,
can_fuse_mlp
_allreduce
:
bool
=
False
,
should
_allreduce
_fusion
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
)
->
torch
.
Tensor
:
...
...
@@ -500,20 +500,20 @@ class DeepseekV2MoE(nn.Module):
torch
.
add
(
final_hidden_states
,
shared_output
,
out
=
final_hidden_states_out
)
final_hidden_states
=
final_hidden_states_out
sm
.
tag
(
final_hidden_states
)
if
self
.
tp_size
>
1
and
not
can_fuse_mlp
_allreduce
and
not
use_reduce_scatter
:
if
self
.
tp_size
>
1
and
not
should
_allreduce
_fusion
and
not
use_reduce_scatter
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
def
forward_normal
(
self
,
hidden_states
:
torch
.
Tensor
,
can_fuse_mlp
_allreduce
:
bool
=
False
,
should
_allreduce
_fusion
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
)
->
torch
.
Tensor
:
if
hasattr
(
self
,
"shared_experts"
)
and
use_intel_amx_backend
(
self
.
shared_experts
.
gate_up_proj
):
return
self
.
forward_cpu
(
hidden_states
,
can_fuse_mlp
_allreduce
)
return
self
.
forward_cpu
(
hidden_states
,
should
_allreduce
_fusion
)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
...
...
@@ -537,12 +537,14 @@ class DeepseekV2MoE(nn.Module):
torch
.
add
(
final_hidden_states
,
shared_output
,
out
=
final_hidden_states_out
)
final_hidden_states
=
final_hidden_states_out
sm
.
tag
(
final_hidden_states
)
if
self
.
tp_size
>
1
and
not
can_fuse_mlp
_allreduce
and
not
use_reduce_scatter
:
if
self
.
tp_size
>
1
and
not
should
_allreduce
_fusion
and
not
use_reduce_scatter
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
def
forward_cpu
(
self
,
hidden_states
:
torch
.
Tensor
,
can_fuse_mlp_allreduce
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
should_allreduce_fusion
:
bool
=
False
,
)
->
torch
.
Tensor
:
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
...
...
@@ -593,7 +595,7 @@ class DeepseekV2MoE(nn.Module):
None
,
# a2_scale
True
,
# is_vnni
)
if
self
.
tp_size
>
1
and
not
can_fuse_mlp
_allreduce
:
if
self
.
tp_size
>
1
and
not
should
_allreduce
_fusion
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
...
...
@@ -1842,6 +1844,8 @@ class DeepseekV2DecoderLayer(nn.Module):
allow_reduce_scatter
=
True
,
)
self
.
_fuse_allreduce_lookup_table
=
self
.
_build_fuse_allreduce_lookup_table
()
def
_is_layer_sparse
(
self
,
layer_id
:
int
,
is_nextn
:
bool
)
->
bool
:
return
is_nextn
or
(
self
.
config
.
n_routed_experts
is
not
None
...
...
@@ -1850,27 +1854,18 @@ class DeepseekV2DecoderLayer(nn.Module):
)
def
_should_fuse_mlp_allreduce_with_next_layer
(
self
,
forward_batch
)
->
bool
:
"""Check if MLP allreduce can be fused with next layer's add_rmsnorm"""
if
(
self
.
layer_id
==
self
.
config
.
num_hidden_layers
-
1
or
get_tensor_model_parallel_world_size
()
<=
1
):
return
False
if
not
global_server_args_dict
.
get
(
"enable_flashinfer_allreduce_fusion"
,
False
):
return
False
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
if
not
_is_sm100_supported
or
not
_is_flashinfer_available
:
return
False
batch_size
=
(
forward_batch
.
input_ids
.
shape
[
0
]
if
hasattr
(
forward_batch
,
"input_ids"
)
else
0
)
if
hasattr
(
forward_batch
,
"input_ids"
)
and
(
forward_batch
.
input_ids
.
shape
[
0
]
==
0
or
forward_batch
.
input_ids
.
shape
[
0
]
>
128
):
if
batch_size
>
128
:
return
False
return
True
return
self
.
_fuse_allreduce_lookup_table
.
get
(
batch_size
,
False
)
def
forward
(
self
,
...
...
@@ -1896,7 +1891,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states
,
residual
,
forward_batch
)
can_fuse_mlp
_allreduce
=
(
should
_allreduce
_fusion
=
(
self
.
_should_fuse_mlp_allreduce_with_next_layer
(
forward_batch
)
and
not
(
self
.
enable_dp_attention
and
self
.
speculative_algorithm
.
is_eagle
())
and
not
self
.
is_nextn
...
...
@@ -1907,13 +1902,13 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch
)
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
,
can_fuse_mlp
_allreduce
,
use_reduce_scatter
hidden_states
,
forward_batch
,
should
_allreduce
_fusion
,
use_reduce_scatter
)
if
can_fuse_mlp
_allreduce
:
if
should
_allreduce
_fusion
:
hidden_states
.
_sglang_needs_allreduce_fusion
=
True
if
not
can_fuse_mlp
_allreduce
:
if
not
should
_allreduce
_fusion
:
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
hidden_states
,
residual
,
forward_batch
)
...
...
@@ -1990,6 +1985,26 @@ class DeepseekV2DecoderLayer(nn.Module):
)
return
output
def
_build_fuse_allreduce_lookup_table
(
self
):
static_conditions_met
=
(
self
.
layer_id
!=
self
.
config
.
num_hidden_layers
-
1
and
get_tensor_model_parallel_world_size
()
>
1
and
global_server_args_dict
.
get
(
"enable_flashinfer_allreduce_fusion"
,
False
)
and
_is_sm100_supported
and
_is_flashinfer_available
)
if
not
static_conditions_met
:
return
{}
lookup_table
=
{}
for
batch_size
in
range
(
129
):
# 0 to 128
is_last_layer
=
self
.
layer_id
==
self
.
config
.
num_hidden_layers
-
1
should_fuse
=
batch_size
>
0
and
batch_size
<=
128
and
not
is_last_layer
lookup_table
[
batch_size
]
=
should_fuse
return
lookup_table
class
DeepseekV2Model
(
nn
.
Module
):
fall_back_to_pt_during_load
=
False
...
...
python/sglang/srt/models/glm4_moe.py
View file @
44e86480
...
...
@@ -154,13 +154,13 @@ class Glm4MoeMLP(nn.Module):
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
,
forward_batch
=
None
,
can_fuse_mlp
_allreduce
=
False
):
def
forward
(
self
,
x
,
forward_batch
=
None
,
should
_allreduce
_fusion
=
False
):
if
(
self
.
tp_size
==
1
)
and
x
.
shape
[
0
]
==
0
:
return
x
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
,
skip_all_reduce
=
can_fuse_mlp
_allreduce
)
x
,
_
=
self
.
down_proj
(
x
,
skip_all_reduce
=
should
_allreduce
_fusion
)
return
x
...
...
@@ -529,7 +529,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
def
forward_normal_dual_stream
(
self
,
hidden_states
:
torch
.
Tensor
,
can_fuse_mlp
_allreduce
:
bool
=
False
,
should
_allreduce
_fusion
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
)
->
torch
.
Tensor
:
...
...
@@ -553,7 +553,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
if
self
.
ep_size
>
1
:
if
(
self
.
tp_size
>
1
and
not
can_fuse_mlp
_allreduce
and
not
should
_allreduce
_fusion
and
not
use_reduce_scatter
):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
...
...
@@ -564,7 +564,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
final_hidden_states
+=
shared_output
if
(
self
.
tp_size
>
1
and
not
can_fuse_mlp
_allreduce
and
not
should
_allreduce
_fusion
and
not
use_reduce_scatter
):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
...
...
@@ -575,13 +575,13 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
def
forward_normal
(
self
,
hidden_states
:
torch
.
Tensor
,
can_fuse_mlp
_allreduce
:
bool
=
False
,
should
_allreduce
_fusion
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
)
->
torch
.
Tensor
:
if
hasattr
(
self
,
"shared_experts"
)
and
use_intel_amx_backend
(
self
.
shared_experts
.
gate_up_proj
):
return
self
.
forward_cpu
(
hidden_states
,
can_fuse_mlp
_allreduce
)
return
self
.
forward_cpu
(
hidden_states
,
should
_allreduce
_fusion
)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
...
...
@@ -596,7 +596,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
# fused in biased_grouped_topk so we can skip here
final_hidden_states
*=
self
.
routed_scaling_factor
if
self
.
ep_size
>
1
:
if
self
.
tp_size
>
1
and
not
can_fuse_mlp
_allreduce
:
if
self
.
tp_size
>
1
and
not
should
_allreduce
_fusion
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
...
...
@@ -605,7 +605,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
else
:
if
shared_output
is
not
None
:
final_hidden_states
+=
shared_output
if
self
.
tp_size
>
1
and
not
can_fuse_mlp
_allreduce
:
if
self
.
tp_size
>
1
and
not
should
_allreduce
_fusion
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
...
...
python/sglang/srt/models/gpt_oss.py
View file @
44e86480
...
...
@@ -56,7 +56,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from
sglang.srt.layers.quantization.fp8_utils
import
dequant_mxfp4
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.utils
import
PPMissingLayer
,
get_layer_id
from
sglang.srt.layers.utils
import
PPMissingLayer
,
get_layer_id
,
is_sm100_supported
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
,
...
...
@@ -64,7 +64,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
PPProxyTensors
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
add_prefix
,
make_layers
from
sglang.srt.utils
import
add_prefix
,
is_cuda
,
is_flashinfer_available
,
make_layers
_is_flashinfer_available
=
is_flashinfer_available
()
_is_sm100_supported
=
is_cuda
()
and
is_sm100_supported
()
class
GptOssConfig
(
PretrainedConfig
):
...
...
@@ -151,10 +154,13 @@ class GptOssSparseMoeBlock(nn.Module):
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
Optional
[
ForwardBatch
]
=
None
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
Optional
[
ForwardBatch
]
=
None
,
should_allreduce_fusion
:
bool
=
False
,
)
->
torch
.
Tensor
:
if
not
global_server_args_dict
[
"moe_a2a_backend"
].
is_deepep
():
return
self
.
forward_normal
(
hidden_states
)
return
self
.
forward_normal
(
hidden_states
,
should_allreduce_fusion
)
else
:
raise
Exception
(
"forward_deepep branch not implemented yet"
)
...
...
@@ -165,7 +171,11 @@ class GptOssSparseMoeBlock(nn.Module):
if
name
not
in
[
"correction_bias"
]
]
def
forward_normal
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward_normal
(
self
,
hidden_states
:
torch
.
Tensor
,
should_allreduce_fusion
:
bool
=
False
,
)
->
torch
.
Tensor
:
num_tokens
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
...
...
@@ -179,7 +189,7 @@ class GptOssSparseMoeBlock(nn.Module):
kwargs
[
"topk_output"
]
=
(
self
.
top_k
,
router_logits
)
final_hidden_states
=
self
.
experts
(
**
kwargs
)
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
and
not
should_allreduce_fusion
:
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
ans
=
final_hidden_states
.
view
(
num_tokens
,
hidden_dim
)
...
...
@@ -370,6 +380,7 @@ class GptOssDecoderLayer(nn.Module):
# GptOss all layers are sparse and have no nextn now
self
.
is_layer_sparse
=
True
self
.
is_nextn
=
False
is_previous_layer_sparse
=
True
self
.
layer_scatter_modes
=
LayerScatterModes
.
init_new
(
...
...
@@ -402,6 +413,42 @@ class GptOssDecoderLayer(nn.Module):
post_attention_layernorm
=
self
.
post_attention_layernorm
,
)
self
.
_fuse_allreduce_lookup_table
=
self
.
_build_fuse_allreduce_lookup_table
()
def
_should_fuse_mlp_allreduce_with_next_layer
(
self
,
forward_batch
)
->
bool
:
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
batch_size
=
(
forward_batch
.
input_ids
.
shape
[
0
]
if
hasattr
(
forward_batch
,
"input_ids"
)
else
0
)
if
batch_size
>
128
:
return
False
return
self
.
_fuse_allreduce_lookup_table
.
get
(
batch_size
,
False
)
def
_build_fuse_allreduce_lookup_table
(
self
):
static_conditions_met
=
(
self
.
layer_id
!=
self
.
config
.
num_hidden_layers
-
1
and
get_tensor_model_parallel_world_size
()
>
1
and
global_server_args_dict
.
get
(
"enable_flashinfer_allreduce_fusion"
,
False
)
and
_is_sm100_supported
and
_is_flashinfer_available
)
if
not
static_conditions_met
:
return
{}
lookup_table
=
{}
for
batch_size
in
range
(
129
):
# 0 to 128
is_last_layer
=
self
.
layer_id
==
self
.
config
.
num_hidden_layers
-
1
should_fuse
=
batch_size
>
0
and
batch_size
<=
128
and
not
is_last_layer
lookup_table
[
batch_size
]
=
should_fuse
return
lookup_table
def
forward
(
self
,
positions
:
torch
.
Tensor
,
...
...
@@ -424,12 +471,21 @@ class GptOssDecoderLayer(nn.Module):
hidden_states
,
residual
,
forward_batch
)
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
)
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
hidden_states
,
residual
,
forward_batch
should_allreduce_fusion
=
(
self
.
_should_fuse_mlp_allreduce_with_next_layer
(
forward_batch
)
and
not
self
.
is_nextn
)
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
,
should_allreduce_fusion
)
if
should_allreduce_fusion
:
hidden_states
.
_sglang_needs_allreduce_fusion
=
True
if
not
should_allreduce_fusion
:
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
hidden_states
,
residual
,
forward_batch
)
return
hidden_states
,
residual
...
...
python/sglang/srt/server_args.py
View file @
44e86480
...
...
@@ -1435,7 +1435,7 @@ class ServerArgs:
parser
.
add_argument
(
"--enable-flashinfer-allreduce-fusion"
,
action
=
"store_true"
,
help
=
"Enable FlashInfer allreduce fusion
for Add_
RMSNorm."
,
help
=
"Enable FlashInfer allreduce fusion
with Residual
RMSNorm."
,
)
parser
.
add_argument
(
"--deepep-mode"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment