Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
eff4eb3f
"vscode:/vscode.git/clone" did not exist on "de34e15abbc068f608c8b152070e017f53b35f2f"
Unverified
Commit
eff4eb3f
authored
Aug 15, 2025
by
Trevor Morris
Committed by
GitHub
Aug 15, 2025
Browse files
Add fp4 quantize before all-gather for Flashinfer cutlass MoE DP (max throughput) (#7667)
parent
87dab548
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
360 additions
and
52 deletions
+360
-52
python/sglang/srt/distributed/device_communicators/pynccl.py
python/sglang/srt/distributed/device_communicators/pynccl.py
+68
-18
python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py
...ng/srt/distributed/device_communicators/pynccl_wrapper.py
+52
-0
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+81
-0
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+9
-2
python/sglang/srt/layers/dp_attention.py
python/sglang/srt/layers/dp_attention.py
+22
-3
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+5
-1
python/sglang/srt/layers/moe/__init__.py
python/sglang/srt/layers/moe/__init__.py
+2
-0
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+2
-3
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+7
-0
python/sglang/srt/layers/moe/utils.py
python/sglang/srt/layers/moe/utils.py
+23
-0
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+39
-8
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+1
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+36
-15
python/sglang/srt/operations.py
python/sglang/srt/operations.py
+6
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
No files found.
python/sglang/srt/distributed/device_communicators/pynccl.py
View file @
eff4eb3f
...
...
@@ -148,7 +148,11 @@ class PyNcclCommunicator:
)
def
all_gather
(
self
,
output_tensor
:
torch
.
Tensor
,
input_tensor
:
torch
.
Tensor
,
stream
=
None
self
,
output_tensor
:
torch
.
Tensor
,
input_tensor
:
torch
.
Tensor
,
stream
=
None
,
sizes
:
Optional
[
list
[
int
]]
=
None
,
):
if
self
.
disabled
:
return
...
...
@@ -161,14 +165,33 @@ class PyNcclCommunicator:
)
if
stream
is
None
:
stream
=
self
.
stream
self
.
nccl
.
ncclAllGather
(
buffer_type
(
input_tensor
.
data_ptr
()),
buffer_type
(
output_tensor
.
data_ptr
()),
input_tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
input_tensor
.
dtype
),
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
),
)
if
sizes
is
not
None
:
split_offset
=
0
self
.
nccl
.
ncclGroupStart
()
for
root
,
split_size
in
enumerate
(
sizes
):
dst_slice
=
output_tensor
[
split_offset
:
split_offset
+
split_size
]
self
.
nccl
.
ncclBroadcast
(
buffer_type
(
input_tensor
.
data_ptr
()),
buffer_type
(
dst_slice
.
data_ptr
()),
dst_slice
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
input_tensor
.
dtype
),
root
,
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
),
)
split_offset
+=
split_size
self
.
nccl
.
ncclGroupEnd
()
else
:
self
.
nccl
.
ncclAllGather
(
buffer_type
(
input_tensor
.
data_ptr
()),
buffer_type
(
output_tensor
.
data_ptr
()),
input_tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
input_tensor
.
dtype
),
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
),
)
def
reduce_scatter
(
self
,
...
...
@@ -176,6 +199,7 @@ class PyNcclCommunicator:
input_tensor
:
torch
.
Tensor
,
op
:
ReduceOp
=
ReduceOp
.
SUM
,
stream
=
None
,
sizes
:
Optional
[
list
[
int
]]
=
None
,
):
if
self
.
disabled
:
return
...
...
@@ -188,15 +212,35 @@ class PyNcclCommunicator:
)
if
stream
is
None
:
stream
=
self
.
stream
self
.
nccl
.
ncclReduceScatter
(
buffer_type
(
input_tensor
.
data_ptr
()),
buffer_type
(
output_tensor
.
data_ptr
()),
output_tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
input_tensor
.
dtype
),
ncclRedOpTypeEnum
.
from_torch
(
op
),
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
),
)
if
sizes
is
not
None
:
split_offset
=
0
self
.
nccl
.
ncclGroupStart
()
for
root
,
split_size
in
enumerate
(
sizes
):
chunk
=
input_tensor
[
split_offset
:
split_offset
+
split_size
,
...]
self
.
nccl
.
ncclReduce
(
buffer_type
(
chunk
.
data_ptr
()),
buffer_type
(
output_tensor
.
data_ptr
()),
chunk
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
input_tensor
.
dtype
),
ncclRedOpTypeEnum
.
from_torch
(
op
),
root
,
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
),
)
split_offset
+=
split_size
self
.
nccl
.
ncclGroupEnd
()
else
:
self
.
nccl
.
ncclReduceScatter
(
buffer_type
(
input_tensor
.
data_ptr
()),
buffer_type
(
output_tensor
.
data_ptr
()),
output_tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
input_tensor
.
dtype
),
ncclRedOpTypeEnum
.
from_torch
(
op
),
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
),
)
def
send
(
self
,
tensor
:
torch
.
Tensor
,
dst
:
int
,
stream
=
None
):
if
self
.
disabled
:
...
...
@@ -266,6 +310,12 @@ class PyNcclCommunicator:
def
deregister_comm_window
(
self
,
window
):
return
self
.
nccl
.
ncclCommWindowDeregister
(
self
.
comm
,
window
)
def
group_start
(
self
):
self
.
nccl
.
ncclGroupStart
()
def
group_end
(
self
):
self
.
nccl
.
ncclGroupEnd
()
@
contextmanager
def
change_state
(
self
,
enable
:
Optional
[
bool
]
=
None
,
stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
...
...
python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py
View file @
eff4eb3f
...
...
@@ -206,6 +206,26 @@ class NCCLLibrary:
cudaStream_t
,
],
),
# ncclResult_t ncclReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, int root,
# ncclComm_t comm, cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function
(
"ncclReduce"
,
ncclResult_t
,
[
buffer_type
,
buffer_type
,
ctypes
.
c_size_t
,
ncclDataType_t
,
ncclRedOp_t
,
ctypes
.
c_int
,
ncclComm_t
,
cudaStream_t
,
],
),
# ncclResult_t ncclReduceScatter(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
...
...
@@ -278,6 +298,10 @@ class NCCLLibrary:
# it is better not to call it at all.
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
Function
(
"ncclCommDestroy"
,
ncclResult_t
,
[
ncclComm_t
]),
# ncclResult_t ncclGroupStart();
Function
(
"ncclGroupStart"
,
ncclResult_t
,
[]),
# ncclResult_t ncclGroupEnd();
Function
(
"ncclGroupEnd"
,
ncclResult_t
,
[]),
]
exported_functions_symm_mem
=
[
...
...
@@ -400,6 +424,28 @@ class NCCLLibrary:
)
)
def
ncclReduce
(
self
,
sendbuff
:
buffer_type
,
recvbuff
:
buffer_type
,
count
:
int
,
datatype
:
int
,
op
:
int
,
root
:
int
,
comm
:
ncclComm_t
,
stream
:
cudaStream_t
,
)
->
None
:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclReduce"
](
sendbuff
,
recvbuff
,
count
,
datatype
,
op
,
root
,
comm
,
stream
)
)
def
ncclReduceScatter
(
self
,
sendbuff
:
buffer_type
,
...
...
@@ -499,6 +545,12 @@ class NCCLLibrary:
def
ncclCommWindowDeregister
(
self
,
comm
:
ncclComm_t
,
window
:
ncclWindow_t
)
->
None
:
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclCommWindowDeregister"
](
comm
,
window
))
def
ncclGroupStart
(
self
)
->
None
:
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclGroupStart"
]())
def
ncclGroupEnd
(
self
)
->
None
:
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclGroupEnd"
]())
__all__
=
[
"NCCLLibrary"
,
...
...
python/sglang/srt/distributed/parallel_state.py
View file @
eff4eb3f
...
...
@@ -583,6 +583,39 @@ class GroupCoordinator:
torch
.
distributed
.
reduce_scatter
(
output
,
input_list
,
group
=
self
.
device_group
)
return
output
def
reduce_scatterv
(
self
,
input_
:
torch
.
Tensor
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
sizes
:
Optional
[
List
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
world_size
=
self
.
world_size
pynccl_comm
=
self
.
pynccl_comm
with
pynccl_comm
.
change_state
(
enable
=
True
,
stream
=
torch
.
cuda
.
current_stream
()):
assert
(
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
),
"pynccl is required for reduce_scatterv"
if
sizes
is
not
None
:
assert
len
(
sizes
)
==
world_size
assert
input_
.
shape
[
0
]
==
sum
(
sizes
)
chunk_size
=
sizes
[
self
.
rank_in_group
]
else
:
assert
input_
.
shape
[
0
]
%
world_size
==
0
chunk_size
=
input_
.
shape
[
0
]
//
world_size
output_shape
=
(
chunk_size
,)
+
input_
.
shape
[
1
:]
if
output
is
None
:
output
=
torch
.
empty
(
output_shape
,
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
else
:
assert
output
.
shape
==
output_shape
pynccl_comm
.
reduce_scatter
(
output
,
input_
,
sizes
=
sizes
)
return
output
def
_all_gather_into_tensor
(
self
,
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
):
pynccl_comm
=
self
.
pynccl_comm
if
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
:
...
...
@@ -673,6 +706,54 @@ class GroupCoordinator:
)
return
output_tensor
def
all_gatherv
(
self
,
input_
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]],
sizes
:
Optional
[
List
[
int
]]
=
None
,
)
->
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]:
"""
Supports varying sizes per rank and input tensor list.
`sizes`: a list of len(world_size) with the number of items per rank to gather.
"""
world_size
=
self
.
world_size
pynccl_comm
=
self
.
pynccl_comm
with
pynccl_comm
.
change_state
(
enable
=
True
,
stream
=
torch
.
cuda
.
current_stream
()):
assert
(
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
),
"pynccl is required for all_gatherv"
def
_all_gather_single
(
input_
:
torch
.
Tensor
,
sizes
:
Optional
[
List
[
int
]]
=
None
):
input_size
=
input_
.
size
()
if
sizes
is
not
None
:
assert
len
(
sizes
)
==
world_size
assert
input_
.
shape
[
0
]
==
sizes
[
self
.
rank_in_group
]
output_size
=
(
sum
(
sizes
),)
+
input_size
[
1
:]
# 'sizes' is not needed if all inputs in the same group have the same shape
if
all
(
s
==
sizes
[
0
]
for
s
in
sizes
):
sizes
=
None
else
:
output_size
=
(
input_size
[
0
]
*
world_size
,)
+
input_size
[
1
:]
# Allocate output tensor.
output_tensor
=
torch
.
empty
(
output_size
,
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
pynccl_comm
.
all_gather
(
output_tensor
,
input_
,
sizes
=
sizes
)
return
output_tensor
if
isinstance
(
input_
,
torch
.
Tensor
):
return
_all_gather_single
(
input_
,
sizes
)
output_list
=
[]
pynccl_comm
.
group_start
()
for
inp
in
input_
:
output_list
.
append
(
_all_gather_single
(
inp
,
sizes
=
sizes
))
pynccl_comm
.
group_end
()
return
output_list
def
gather
(
self
,
input_
:
torch
.
Tensor
,
dst
:
int
=
0
,
dim
:
int
=
-
1
)
->
Optional
[
torch
.
Tensor
]:
...
...
python/sglang/srt/layers/communicator.py
View file @
eff4eb3f
...
...
@@ -35,7 +35,10 @@ from sglang.srt.layers.dp_attention import (
get_global_dp_buffer
,
get_local_dp_buffer
,
)
from
sglang.srt.layers.moe
import
get_moe_a2a_backend
from
sglang.srt.layers.moe
import
(
get_moe_a2a_backend
,
should_use_flashinfer_cutlass_moe_fp4_allgather
,
)
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
...
...
@@ -112,7 +115,11 @@ class LayerScatterModes:
if
context
.
is_layer_sparse
:
return
(
ScatterMode
.
SCATTERED
if
not
get_moe_a2a_backend
().
is_none
()
if
(
# Token dispatch/combine will be handled outside of LayerCommunicator for these modes.
not
get_moe_a2a_backend
().
is_none
()
or
should_use_flashinfer_cutlass_moe_fp4_allgather
()
)
else
ScatterMode
.
FULL
)
else
:
...
...
python/sglang/srt/layers/dp_attention.py
View file @
eff4eb3f
...
...
@@ -72,6 +72,7 @@ class _DpGatheredBufferWrapper:
_device
:
torch
.
device
_global_dp_buffer_len
:
int
_local_dp_buffer_len
:
int
_global_num_tokens
:
Optional
[
List
[
int
]]
@
classmethod
def
set_metadata
(
cls
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
...
...
@@ -80,9 +81,15 @@ class _DpGatheredBufferWrapper:
cls
.
_device
=
device
@
classmethod
def
set_dp_buffer_len
(
cls
,
global_dp_buffer_len
:
int
,
local_dp_buffer_len
:
int
):
def
set_dp_buffer_len
(
cls
,
global_dp_buffer_len
:
int
,
local_dp_buffer_len
:
int
,
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
,
):
cls
.
_global_dp_buffer_len
=
global_dp_buffer_len
cls
.
_local_dp_buffer_len
=
local_dp_buffer_len
cls
.
_global_num_tokens
=
global_num_tokens
@
classmethod
def
get_global_dp_buffer
(
cls
)
->
torch
.
Tensor
:
...
...
@@ -108,10 +115,18 @@ class _DpGatheredBufferWrapper:
def
get_local_dp_buffer_len
(
cls
)
->
int
:
return
cls
.
_local_dp_buffer_len
@
classmethod
def
get_dp_global_num_tokens
(
cls
)
->
List
[
int
]:
return
cls
.
_global_num_tokens
def
set_dp_buffer_len
(
global_dp_buffer_len
:
int
,
local_dp_buffer_len
:
int
):
def
set_dp_buffer_len
(
global_dp_buffer_len
:
int
,
local_dp_buffer_len
:
int
,
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
,
):
_DpGatheredBufferWrapper
.
set_dp_buffer_len
(
global_dp_buffer_len
,
local_dp_buffer_len
global_dp_buffer_len
,
local_dp_buffer_len
,
global_num_tokens
)
...
...
@@ -131,6 +146,10 @@ def get_local_dp_buffer_len() -> int:
return
_DpGatheredBufferWrapper
.
get_local_dp_buffer_len
()
def
get_dp_global_num_tokens
()
->
List
[
int
]:
return
_DpGatheredBufferWrapper
.
get_dp_global_num_tokens
()
def
compute_dp_attention_world_info
(
enable_dp_attention
,
tp_rank
,
tp_size
,
dp_size
):
if
not
enable_dp_attention
:
return
tp_rank
,
tp_size
,
0
...
...
python/sglang/srt/layers/logits_processor.py
View file @
eff4eb3f
...
...
@@ -191,7 +191,11 @@ class LogitsMetadata:
else
:
self
.
global_dp_buffer_len
=
self
.
global_dp_buffer_len
set_dp_buffer_len
(
self
.
global_dp_buffer_len
,
self
.
dp_local_num_tokens
)
set_dp_buffer_len
(
self
.
global_dp_buffer_len
,
self
.
dp_local_num_tokens
,
self
.
global_num_tokens_for_logprob_cpu
,
)
class
LogitsProcessor
(
nn
.
Module
):
...
...
python/sglang/srt/layers/moe/__init__.py
View file @
eff4eb3f
...
...
@@ -10,6 +10,7 @@ from sglang.srt.layers.moe.utils import (
get_tbo_token_distribution_threshold
,
initialize_moe_config
,
is_tbo_enabled
,
should_use_flashinfer_cutlass_moe_fp4_allgather
,
should_use_flashinfer_trtllm_moe
,
)
...
...
@@ -23,6 +24,7 @@ __all__ = [
"get_moe_runner_backend"
,
"get_deepep_mode"
,
"should_use_flashinfer_trtllm_moe"
,
"should_use_flashinfer_cutlass_moe_fp4_allgather"
,
"is_tbo_enabled"
,
"get_tbo_token_distribution_threshold"
,
"get_deepep_config"
,
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
eff4eb3f
...
...
@@ -28,6 +28,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.modelopt_quant
import
ModelOptNvFp4FusedMoEMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedFusedMoEMethod
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_loader.weight_utils
import
narrow_padded_param_and_loaded_weight
...
...
@@ -621,9 +622,7 @@ class FusedMoE(torch.nn.Module):
if
"ModelOpt"
in
self
.
quant_method
.
__class__
.
__name__
:
# Determine per-tensor weight scale patterns based on variant
is_fp4_variant
=
(
"ModelOptNvFp4FusedMoEMethod"
in
self
.
quant_method
.
__class__
.
__name__
)
is_fp4_variant
=
isinstance
(
self
.
quant_method
,
ModelOptNvFp4FusedMoEMethod
)
# FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
per_tensor_conditions
=
(
...
...
python/sglang/srt/layers/moe/topk.py
View file @
eff4eb3f
...
...
@@ -327,6 +327,13 @@ class TopK(CustomOp):
expert_location_dispatch_info
=
expert_location_dispatch_info
,
)
def
empty_topk_output
(
self
,
device
:
torch
.
device
)
->
TopKOutput
:
topk
=
self
.
topk_config
.
top_k
-
self
.
topk_config
.
num_fused_shared_experts
topk_weights
=
torch
.
empty
((
0
,
topk
),
dtype
=
torch
.
float32
,
device
=
device
)
topk_idx
=
torch
.
full
((
0
,
topk
),
-
1
,
dtype
=
torch
.
int32
,
device
=
device
)
router_logits
=
torch
.
empty
((
0
,
topk
),
dtype
=
torch
.
float32
,
device
=
device
)
return
StandardTopKOutput
(
topk_weights
,
topk_idx
,
router_logits
)
# ------------------------------- TopK implementation -------------------------------------
...
...
python/sglang/srt/layers/moe/utils.py
View file @
eff4eb3f
...
...
@@ -7,6 +7,11 @@ from typing import TYPE_CHECKING, Optional
from
packaging
import
version
as
pkg_version
from
sglang.srt.distributed.parallel_state
import
get_moe_expert_parallel_world_size
from
sglang.srt.layers.dp_attention
import
(
get_attention_dp_size
,
is_dp_attention_enabled
,
)
from
sglang.srt.utils
import
logger
if
TYPE_CHECKING
:
...
...
@@ -99,6 +104,7 @@ DEEPEP_MODE: Optional[DeepEPMode] = None
IS_TBO_ENABLED
:
Optional
[
bool
]
=
None
TBO_TOKEN_DISTRIBUTION_THRESHOLD
:
Optional
[
float
]
=
None
DEEPEP_CONFIG
:
Optional
[
str
]
=
None
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
:
Optional
[
bool
]
=
None
def
initialize_moe_config
(
server_args
:
ServerArgs
):
...
...
@@ -108,6 +114,7 @@ def initialize_moe_config(server_args: ServerArgs):
global
DEEPEP_CONFIG
global
IS_TBO_ENABLED
global
TBO_TOKEN_DISTRIBUTION_THRESHOLD
global
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
MOE_A2A_BACKEND
=
MoeA2ABackend
(
server_args
.
moe_a2a_backend
)
MOE_RUNNER_BACKEND
=
MoeRunnerBackend
(
server_args
.
moe_runner_backend
)
...
...
@@ -115,6 +122,9 @@ def initialize_moe_config(server_args: ServerArgs):
DEEPEP_CONFIG
=
server_args
.
deepep_config
or
""
IS_TBO_ENABLED
=
server_args
.
enable_two_batch_overlap
TBO_TOKEN_DISTRIBUTION_THRESHOLD
=
server_args
.
tbo_token_distribution_threshold
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
=
(
server_args
.
disable_flashinfer_cutlass_moe_fp4_allgather
)
def
get_moe_a2a_backend
()
->
MoeA2ABackend
:
...
...
@@ -175,3 +185,16 @@ def should_use_flashinfer_trtllm_moe():
>=
pkg_version
.
parse
(
"0.2.9rc1"
)
)
return
result
@
lru_cache
(
maxsize
=
1
)
def
should_use_flashinfer_cutlass_moe_fp4_allgather
():
"""
Perform FP4 quantize before all-gather for flashinfer cutlass moe to reduce communication cost for high-throughput serving.
"""
return
(
not
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
and
get_moe_runner_backend
().
is_flashinfer_cutlass
()
and
is_dp_attention_enabled
()
and
get_moe_expert_parallel_world_size
()
==
get_attention_dp_size
()
)
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
eff4eb3f
...
...
@@ -7,7 +7,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
import
torch
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.moe
import
should_use_flashinfer_trtllm_moe
from
sglang.srt.distributed
import
get_tp_group
from
sglang.srt.layers.dp_attention
import
get_dp_global_num_tokens
,
get_local_dp_buffer
from
sglang.srt.layers.moe
import
(
should_use_flashinfer_cutlass_moe_fp4_allgather
,
should_use_flashinfer_trtllm_moe
,
)
from
sglang.srt.layers.moe.cutlass_moe_params
import
CutlassMoEParams
,
CutlassMoEType
from
sglang.srt.layers.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
sglang.srt.layers.quantization.base_config
import
(
...
...
@@ -1176,16 +1181,37 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
),
"apply_router_weight_on_input is not supported for Flashinfer"
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# and fp4 quantized weights loaded from the checkpoint
topk_weights
,
topk_ids
=
topk_output
.
topk_weights
,
topk_output
.
topk_ids
output_dtype
=
x
.
dtype
x_sf
=
None
if
should_use_flashinfer_cutlass_moe_fp4_allgather
():
from
flashinfer
import
fp4_quantize
,
nvfp4_block_scale_interleave
# Quantize before comm, swizzle after.
if
x
.
shape
[
0
]
>
0
:
x
,
x_sf
=
fp4_quantize
(
x
,
layer
.
w13_input_scale_quant
,
is_sf_swizzled_layout
=
False
)
else
:
x_col
=
x
.
shape
[
1
]
x
=
torch
.
zeros
(
0
,
x_col
//
2
,
dtype
=
torch
.
uint8
,
device
=
x
.
device
)
x_sf
=
torch
.
zeros
(
0
,
x_col
//
16
,
dtype
=
torch
.
uint8
,
device
=
x
.
device
)
topk_weights
,
topk_ids
,
x
,
x_sf
=
get_tp_group
().
all_gatherv
(
[
topk_weights
,
topk_ids
,
x
,
x_sf
],
sizes
=
get_dp_global_num_tokens
()
)
x_sf
=
nvfp4_block_scale_interleave
(
x_sf
)
output
=
flashinfer_cutlass_fused_moe
(
x
,
topk_ids
.
to
(
torch
.
int
),
topk_weights
,
layer
.
w13_weight
.
view
(
torch
.
long
),
layer
.
w2_weight
.
view
(
torch
.
long
),
x
.
dtype
,
input
=
x
,
token_selected_experts
=
topk_ids
.
to
(
torch
.
int
),
token_final_scales
=
topk_weights
,
fc1_expert_weights
=
layer
.
w13_weight
.
view
(
torch
.
long
),
fc2_expert_weights
=
layer
.
w2_weight
.
view
(
torch
.
long
),
output_dtype
=
output_dtype
,
input_sf
=
x_sf
,
quant_scales
=
[
layer
.
w13_input_scale_quant
,
layer
.
w13_blockscale_swizzled
.
view
(
torch
.
int32
),
...
...
@@ -1202,6 +1228,11 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
)[
0
]
if
moe_runner_config
.
routed_scaling_factor
is
not
None
:
output
*=
moe_runner_config
.
routed_scaling_factor
if
should_use_flashinfer_cutlass_moe_fp4_allgather
():
output
,
global_output
=
get_local_dp_buffer
(),
output
get_tp_group
().
reduce_scatterv
(
global_output
,
output
=
output
,
sizes
=
get_dp_global_num_tokens
()
)
return
output
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_moe_fp4
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
eff4eb3f
...
...
@@ -84,6 +84,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"chunked_prefill_size"
,
"device"
,
"disable_chunked_prefix_cache"
,
"disable_flashinfer_cutlass_moe_fp4_allgather"
,
"disable_radix_cache"
,
"enable_dp_lm_head"
,
"enable_flashinfer_allreduce_fusion"
,
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
eff4eb3f
...
...
@@ -649,7 +649,7 @@ class ForwardBatch:
num_tokens
=
global_num_tokens
[
0
]
self
.
global_dp_buffer_len
=
buffer_len
set_dp_buffer_len
(
buffer_len
,
num_tokens
)
set_dp_buffer_len
(
buffer_len
,
num_tokens
,
global_
num_tokens
)
bs
=
self
.
batch_size
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
eff4eb3f
...
...
@@ -60,7 +60,11 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe
import
get_deepep_mode
,
get_moe_a2a_backend
from
sglang.srt.layers.moe
import
(
get_deepep_mode
,
get_moe_a2a_backend
,
should_use_flashinfer_cutlass_moe_fp4_allgather
,
)
from
sglang.srt.layers.moe.ep_moe.layer
import
DeepEPMoE
,
get_moe_impl_class
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
...
...
@@ -343,7 +347,7 @@ class DeepseekV2MoE(nn.Module):
self
.
shared_experts_weight_block_size
=
None
if
config
.
n_shared_experts
is
not
None
and
self
.
num_fused_shared_experts
==
0
:
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
# disable tp for shared experts when enable deepep moe
# disable tp for shared experts when enable deepep moe
, or with fp4 allgather
self
.
shared_experts
=
DeepseekV2MLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
...
...
@@ -354,6 +358,7 @@ class DeepseekV2MoE(nn.Module):
**
(
dict
(
tp_rank
=
0
,
tp_size
=
1
)
if
get_moe_a2a_backend
().
is_deepep
()
or
should_use_flashinfer_cutlass_moe_fp4_allgather
()
else
{}
),
)
...
...
@@ -433,14 +438,19 @@ class DeepseekV2MoE(nn.Module):
if
(
self
.
alt_stream
is
not
None
and
self
.
num_fused_shared_experts
==
0
and
hidden_states
.
shape
[
0
]
>
0
and
hidden_states
.
shape
[
0
]
<=
DUAL_STREAM_TOKEN_THRESHOLD
):
return
self
.
forward_normal_dual_stream
(
hidden_states
,
should_allreduce_fusion
,
use_reduce_scatter
hidden_states
,
should_allreduce_fusion
,
use_reduce_scatter
,
)
else
:
return
self
.
forward_normal
(
hidden_states
,
should_allreduce_fusion
,
use_reduce_scatter
hidden_states
,
should_allreduce_fusion
,
use_reduce_scatter
,
)
else
:
return
self
.
forward_deepep
(
hidden_states
,
forward_batch
)
...
...
@@ -471,7 +481,12 @@ class DeepseekV2MoE(nn.Module):
torch
.
add
(
final_hidden_states
,
shared_output
,
out
=
final_hidden_states_out
)
final_hidden_states
=
final_hidden_states_out
sm
.
tag
(
final_hidden_states
)
if
self
.
tp_size
>
1
and
not
should_allreduce_fusion
and
not
use_reduce_scatter
:
if
(
self
.
tp_size
>
1
and
not
should_allreduce_fusion
and
not
use_reduce_scatter
and
not
should_use_flashinfer_cutlass_moe_fp4_allgather
()
):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
...
...
@@ -486,10 +501,14 @@ class DeepseekV2MoE(nn.Module):
):
return
self
.
forward_cpu
(
hidden_states
,
should_allreduce_fusion
)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
if
hidden_states
.
shape
[
0
]
>
0
:
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
else
:
shared_output
=
None
topk_output
=
self
.
topk
.
empty_topk_output
(
hidden_states
.
device
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
if
not
_is_cuda
and
not
_use_aiter
:
...
...
@@ -501,7 +520,12 @@ class DeepseekV2MoE(nn.Module):
torch
.
add
(
final_hidden_states
,
shared_output
,
out
=
final_hidden_states_out
)
final_hidden_states
=
final_hidden_states_out
sm
.
tag
(
final_hidden_states
)
if
self
.
tp_size
>
1
and
not
should_allreduce_fusion
and
not
use_reduce_scatter
:
if
(
self
.
tp_size
>
1
and
not
should_allreduce_fusion
and
not
use_reduce_scatter
and
not
should_use_flashinfer_cutlass_moe_fp4_allgather
()
):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
...
...
@@ -580,11 +604,8 @@ class DeepseekV2MoE(nn.Module):
),
)
else
:
topk_idx
=
torch
.
full
(
(
0
,
self
.
top_k
),
-
1
,
dtype
=
torch
.
int
,
device
=
hidden_states
.
device
)
topk_weights
=
torch
.
empty
(
(
0
,
self
.
top_k
),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
topk_weights
,
topk_idx
,
_
=
self
.
topk
.
empty_topk_output
(
hidden_states
.
device
)
final_hidden_states
=
self
.
experts
(
...
...
python/sglang/srt/operations.py
View file @
eff4eb3f
...
...
@@ -84,6 +84,7 @@ class _StageExecutor:
forward_batch
:
ForwardBatch
=
inputs
[
"forward_batch"
]
self
.
_global_dp_buffer_len
=
forward_batch
.
global_dp_buffer_len
self
.
_local_dp_buffer_len
=
forward_batch
.
input_ids
.
shape
[
0
]
self
.
_global_num_tokens
=
forward_batch
.
global_num_tokens_cpu
def
next
(
self
):
assert
not
self
.
done
...
...
@@ -91,7 +92,11 @@ class _StageExecutor:
stage
=
self
.
_stages
[
self
.
_index
]
if
self
.
_global_dp_buffer_len
is
not
None
:
set_dp_buffer_len
(
self
.
_global_dp_buffer_len
,
self
.
_local_dp_buffer_len
)
set_dp_buffer_len
(
self
.
_global_dp_buffer_len
,
self
.
_local_dp_buffer_len
,
self
.
_global_num_tokens
,
)
with
_annotate_region
(
debug_name
=
f
"
{
self
.
_debug_name
}{
self
.
_index
}
"
):
for
op
in
stage
:
...
...
python/sglang/srt/server_args.py
View file @
eff4eb3f
...
...
@@ -230,6 +230,7 @@ class ServerArgs:
enable_cudagraph_gc
:
bool
=
False
enable_nccl_nvls
:
bool
=
False
enable_symm_mem
:
bool
=
False
disable_flashinfer_cutlass_moe_fp4_allgather
:
bool
=
False
enable_tokenizer_batch_encode
:
bool
=
False
disable_outlines_disk_cache
:
bool
=
False
disable_custom_all_reduce
:
bool
=
False
...
...
@@ -1714,6 +1715,11 @@ class ServerArgs:
action
=
"store_true"
,
help
=
"Enable NCCL symmetric memory for fast collectives."
,
)
parser
.
add_argument
(
"--disable-flashinfer-cutlass-moe-fp4-allgather"
,
action
=
"store_true"
,
help
=
"Disables quantize before all-gather for flashinfer cutlass moe."
,
)
parser
.
add_argument
(
"--enable-tokenizer-batch-encode"
,
action
=
"store_true"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment