Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
eff4eb3f
Unverified
Commit
eff4eb3f
authored
Aug 15, 2025
by
Trevor Morris
Committed by
GitHub
Aug 15, 2025
Browse files
Add fp4 quantize before all-gather for Flashinfer cutlass MoE DP (max throughput) (#7667)
parent
87dab548
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
360 additions
and
52 deletions
+360
-52
python/sglang/srt/distributed/device_communicators/pynccl.py
python/sglang/srt/distributed/device_communicators/pynccl.py
+68
-18
python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py
...ng/srt/distributed/device_communicators/pynccl_wrapper.py
+52
-0
python/sglang/srt/distributed/parallel_state.py
python/sglang/srt/distributed/parallel_state.py
+81
-0
python/sglang/srt/layers/communicator.py
python/sglang/srt/layers/communicator.py
+9
-2
python/sglang/srt/layers/dp_attention.py
python/sglang/srt/layers/dp_attention.py
+22
-3
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+5
-1
python/sglang/srt/layers/moe/__init__.py
python/sglang/srt/layers/moe/__init__.py
+2
-0
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+2
-3
python/sglang/srt/layers/moe/topk.py
python/sglang/srt/layers/moe/topk.py
+7
-0
python/sglang/srt/layers/moe/utils.py
python/sglang/srt/layers/moe/utils.py
+23
-0
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+39
-8
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+1
-1
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+36
-15
python/sglang/srt/operations.py
python/sglang/srt/operations.py
+6
-1
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+6
-0
No files found.
python/sglang/srt/distributed/device_communicators/pynccl.py
View file @
eff4eb3f
...
@@ -148,7 +148,11 @@ class PyNcclCommunicator:
...
@@ -148,7 +148,11 @@ class PyNcclCommunicator:
)
)
def
all_gather
(
def
all_gather
(
self
,
output_tensor
:
torch
.
Tensor
,
input_tensor
:
torch
.
Tensor
,
stream
=
None
self
,
output_tensor
:
torch
.
Tensor
,
input_tensor
:
torch
.
Tensor
,
stream
=
None
,
sizes
:
Optional
[
list
[
int
]]
=
None
,
):
):
if
self
.
disabled
:
if
self
.
disabled
:
return
return
...
@@ -161,14 +165,33 @@ class PyNcclCommunicator:
...
@@ -161,14 +165,33 @@ class PyNcclCommunicator:
)
)
if
stream
is
None
:
if
stream
is
None
:
stream
=
self
.
stream
stream
=
self
.
stream
self
.
nccl
.
ncclAllGather
(
buffer_type
(
input_tensor
.
data_ptr
()),
if
sizes
is
not
None
:
buffer_type
(
output_tensor
.
data_ptr
()),
split_offset
=
0
input_tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
input_tensor
.
dtype
),
self
.
nccl
.
ncclGroupStart
()
self
.
comm
,
for
root
,
split_size
in
enumerate
(
sizes
):
cudaStream_t
(
stream
.
cuda_stream
),
dst_slice
=
output_tensor
[
split_offset
:
split_offset
+
split_size
]
)
self
.
nccl
.
ncclBroadcast
(
buffer_type
(
input_tensor
.
data_ptr
()),
buffer_type
(
dst_slice
.
data_ptr
()),
dst_slice
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
input_tensor
.
dtype
),
root
,
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
),
)
split_offset
+=
split_size
self
.
nccl
.
ncclGroupEnd
()
else
:
self
.
nccl
.
ncclAllGather
(
buffer_type
(
input_tensor
.
data_ptr
()),
buffer_type
(
output_tensor
.
data_ptr
()),
input_tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
input_tensor
.
dtype
),
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
),
)
def
reduce_scatter
(
def
reduce_scatter
(
self
,
self
,
...
@@ -176,6 +199,7 @@ class PyNcclCommunicator:
...
@@ -176,6 +199,7 @@ class PyNcclCommunicator:
input_tensor
:
torch
.
Tensor
,
input_tensor
:
torch
.
Tensor
,
op
:
ReduceOp
=
ReduceOp
.
SUM
,
op
:
ReduceOp
=
ReduceOp
.
SUM
,
stream
=
None
,
stream
=
None
,
sizes
:
Optional
[
list
[
int
]]
=
None
,
):
):
if
self
.
disabled
:
if
self
.
disabled
:
return
return
...
@@ -188,15 +212,35 @@ class PyNcclCommunicator:
...
@@ -188,15 +212,35 @@ class PyNcclCommunicator:
)
)
if
stream
is
None
:
if
stream
is
None
:
stream
=
self
.
stream
stream
=
self
.
stream
self
.
nccl
.
ncclReduceScatter
(
buffer_type
(
input_tensor
.
data_ptr
()),
if
sizes
is
not
None
:
buffer_type
(
output_tensor
.
data_ptr
()),
split_offset
=
0
output_tensor
.
numel
(),
self
.
nccl
.
ncclGroupStart
()
ncclDataTypeEnum
.
from_torch
(
input_tensor
.
dtype
),
for
root
,
split_size
in
enumerate
(
sizes
):
ncclRedOpTypeEnum
.
from_torch
(
op
),
chunk
=
input_tensor
[
split_offset
:
split_offset
+
split_size
,
...]
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
),
self
.
nccl
.
ncclReduce
(
)
buffer_type
(
chunk
.
data_ptr
()),
buffer_type
(
output_tensor
.
data_ptr
()),
chunk
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
input_tensor
.
dtype
),
ncclRedOpTypeEnum
.
from_torch
(
op
),
root
,
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
),
)
split_offset
+=
split_size
self
.
nccl
.
ncclGroupEnd
()
else
:
self
.
nccl
.
ncclReduceScatter
(
buffer_type
(
input_tensor
.
data_ptr
()),
buffer_type
(
output_tensor
.
data_ptr
()),
output_tensor
.
numel
(),
ncclDataTypeEnum
.
from_torch
(
input_tensor
.
dtype
),
ncclRedOpTypeEnum
.
from_torch
(
op
),
self
.
comm
,
cudaStream_t
(
stream
.
cuda_stream
),
)
def
send
(
self
,
tensor
:
torch
.
Tensor
,
dst
:
int
,
stream
=
None
):
def
send
(
self
,
tensor
:
torch
.
Tensor
,
dst
:
int
,
stream
=
None
):
if
self
.
disabled
:
if
self
.
disabled
:
...
@@ -266,6 +310,12 @@ class PyNcclCommunicator:
...
@@ -266,6 +310,12 @@ class PyNcclCommunicator:
def
deregister_comm_window
(
self
,
window
):
def
deregister_comm_window
(
self
,
window
):
return
self
.
nccl
.
ncclCommWindowDeregister
(
self
.
comm
,
window
)
return
self
.
nccl
.
ncclCommWindowDeregister
(
self
.
comm
,
window
)
def
group_start
(
self
):
self
.
nccl
.
ncclGroupStart
()
def
group_end
(
self
):
self
.
nccl
.
ncclGroupEnd
()
@
contextmanager
@
contextmanager
def
change_state
(
def
change_state
(
self
,
enable
:
Optional
[
bool
]
=
None
,
stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
self
,
enable
:
Optional
[
bool
]
=
None
,
stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
...
...
python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py
View file @
eff4eb3f
...
@@ -206,6 +206,26 @@ class NCCLLibrary:
...
@@ -206,6 +206,26 @@ class NCCLLibrary:
cudaStream_t
,
cudaStream_t
,
],
],
),
),
# ncclResult_t ncclReduce(
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, int root,
# ncclComm_t comm, cudaStream_t stream);
# note that cudaStream_t is a pointer type, so the last argument
# is a pointer
Function
(
"ncclReduce"
,
ncclResult_t
,
[
buffer_type
,
buffer_type
,
ctypes
.
c_size_t
,
ncclDataType_t
,
ncclRedOp_t
,
ctypes
.
c_int
,
ncclComm_t
,
cudaStream_t
,
],
),
# ncclResult_t ncclReduceScatter(
# ncclResult_t ncclReduceScatter(
# const void* sendbuff, void* recvbuff, size_t count,
# const void* sendbuff, void* recvbuff, size_t count,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
# ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
...
@@ -278,6 +298,10 @@ class NCCLLibrary:
...
@@ -278,6 +298,10 @@ class NCCLLibrary:
# it is better not to call it at all.
# it is better not to call it at all.
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
# ncclResult_t ncclCommDestroy(ncclComm_t comm);
Function
(
"ncclCommDestroy"
,
ncclResult_t
,
[
ncclComm_t
]),
Function
(
"ncclCommDestroy"
,
ncclResult_t
,
[
ncclComm_t
]),
# ncclResult_t ncclGroupStart();
Function
(
"ncclGroupStart"
,
ncclResult_t
,
[]),
# ncclResult_t ncclGroupEnd();
Function
(
"ncclGroupEnd"
,
ncclResult_t
,
[]),
]
]
exported_functions_symm_mem
=
[
exported_functions_symm_mem
=
[
...
@@ -400,6 +424,28 @@ class NCCLLibrary:
...
@@ -400,6 +424,28 @@ class NCCLLibrary:
)
)
)
)
def
ncclReduce
(
self
,
sendbuff
:
buffer_type
,
recvbuff
:
buffer_type
,
count
:
int
,
datatype
:
int
,
op
:
int
,
root
:
int
,
comm
:
ncclComm_t
,
stream
:
cudaStream_t
,
)
->
None
:
# `datatype` actually should be `ncclDataType_t`
# and `op` should be `ncclRedOp_t`
# both are aliases of `ctypes.c_int`
# when we pass int to a function, it will be converted to `ctypes.c_int`
# by ctypes automatically
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclReduce"
](
sendbuff
,
recvbuff
,
count
,
datatype
,
op
,
root
,
comm
,
stream
)
)
def
ncclReduceScatter
(
def
ncclReduceScatter
(
self
,
self
,
sendbuff
:
buffer_type
,
sendbuff
:
buffer_type
,
...
@@ -499,6 +545,12 @@ class NCCLLibrary:
...
@@ -499,6 +545,12 @@ class NCCLLibrary:
def
ncclCommWindowDeregister
(
self
,
comm
:
ncclComm_t
,
window
:
ncclWindow_t
)
->
None
:
def
ncclCommWindowDeregister
(
self
,
comm
:
ncclComm_t
,
window
:
ncclWindow_t
)
->
None
:
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclCommWindowDeregister"
](
comm
,
window
))
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclCommWindowDeregister"
](
comm
,
window
))
def
ncclGroupStart
(
self
)
->
None
:
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclGroupStart"
]())
def
ncclGroupEnd
(
self
)
->
None
:
self
.
NCCL_CHECK
(
self
.
_funcs
[
"ncclGroupEnd"
]())
__all__
=
[
__all__
=
[
"NCCLLibrary"
,
"NCCLLibrary"
,
...
...
python/sglang/srt/distributed/parallel_state.py
View file @
eff4eb3f
...
@@ -583,6 +583,39 @@ class GroupCoordinator:
...
@@ -583,6 +583,39 @@ class GroupCoordinator:
torch
.
distributed
.
reduce_scatter
(
output
,
input_list
,
group
=
self
.
device_group
)
torch
.
distributed
.
reduce_scatter
(
output
,
input_list
,
group
=
self
.
device_group
)
return
output
return
output
def
reduce_scatterv
(
self
,
input_
:
torch
.
Tensor
,
output
:
Optional
[
torch
.
Tensor
]
=
None
,
sizes
:
Optional
[
List
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
world_size
=
self
.
world_size
pynccl_comm
=
self
.
pynccl_comm
with
pynccl_comm
.
change_state
(
enable
=
True
,
stream
=
torch
.
cuda
.
current_stream
()):
assert
(
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
),
"pynccl is required for reduce_scatterv"
if
sizes
is
not
None
:
assert
len
(
sizes
)
==
world_size
assert
input_
.
shape
[
0
]
==
sum
(
sizes
)
chunk_size
=
sizes
[
self
.
rank_in_group
]
else
:
assert
input_
.
shape
[
0
]
%
world_size
==
0
chunk_size
=
input_
.
shape
[
0
]
//
world_size
output_shape
=
(
chunk_size
,)
+
input_
.
shape
[
1
:]
if
output
is
None
:
output
=
torch
.
empty
(
output_shape
,
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
else
:
assert
output
.
shape
==
output_shape
pynccl_comm
.
reduce_scatter
(
output
,
input_
,
sizes
=
sizes
)
return
output
def
_all_gather_into_tensor
(
self
,
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
):
def
_all_gather_into_tensor
(
self
,
output
:
torch
.
Tensor
,
input
:
torch
.
Tensor
):
pynccl_comm
=
self
.
pynccl_comm
pynccl_comm
=
self
.
pynccl_comm
if
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
:
if
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
:
...
@@ -673,6 +706,54 @@ class GroupCoordinator:
...
@@ -673,6 +706,54 @@ class GroupCoordinator:
)
)
return
output_tensor
return
output_tensor
def
all_gatherv
(
self
,
input_
:
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]],
sizes
:
Optional
[
List
[
int
]]
=
None
,
)
->
Union
[
torch
.
Tensor
,
List
[
torch
.
Tensor
]]:
"""
Supports varying sizes per rank and input tensor list.
`sizes`: a list of len(world_size) with the number of items per rank to gather.
"""
world_size
=
self
.
world_size
pynccl_comm
=
self
.
pynccl_comm
with
pynccl_comm
.
change_state
(
enable
=
True
,
stream
=
torch
.
cuda
.
current_stream
()):
assert
(
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
),
"pynccl is required for all_gatherv"
def
_all_gather_single
(
input_
:
torch
.
Tensor
,
sizes
:
Optional
[
List
[
int
]]
=
None
):
input_size
=
input_
.
size
()
if
sizes
is
not
None
:
assert
len
(
sizes
)
==
world_size
assert
input_
.
shape
[
0
]
==
sizes
[
self
.
rank_in_group
]
output_size
=
(
sum
(
sizes
),)
+
input_size
[
1
:]
# 'sizes' is not needed if all inputs in the same group have the same shape
if
all
(
s
==
sizes
[
0
]
for
s
in
sizes
):
sizes
=
None
else
:
output_size
=
(
input_size
[
0
]
*
world_size
,)
+
input_size
[
1
:]
# Allocate output tensor.
output_tensor
=
torch
.
empty
(
output_size
,
dtype
=
input_
.
dtype
,
device
=
input_
.
device
)
pynccl_comm
.
all_gather
(
output_tensor
,
input_
,
sizes
=
sizes
)
return
output_tensor
if
isinstance
(
input_
,
torch
.
Tensor
):
return
_all_gather_single
(
input_
,
sizes
)
output_list
=
[]
pynccl_comm
.
group_start
()
for
inp
in
input_
:
output_list
.
append
(
_all_gather_single
(
inp
,
sizes
=
sizes
))
pynccl_comm
.
group_end
()
return
output_list
def
gather
(
def
gather
(
self
,
input_
:
torch
.
Tensor
,
dst
:
int
=
0
,
dim
:
int
=
-
1
self
,
input_
:
torch
.
Tensor
,
dst
:
int
=
0
,
dim
:
int
=
-
1
)
->
Optional
[
torch
.
Tensor
]:
)
->
Optional
[
torch
.
Tensor
]:
...
...
python/sglang/srt/layers/communicator.py
View file @
eff4eb3f
...
@@ -35,7 +35,10 @@ from sglang.srt.layers.dp_attention import (
...
@@ -35,7 +35,10 @@ from sglang.srt.layers.dp_attention import (
get_global_dp_buffer
,
get_global_dp_buffer
,
get_local_dp_buffer
,
get_local_dp_buffer
,
)
)
from
sglang.srt.layers.moe
import
get_moe_a2a_backend
from
sglang.srt.layers.moe
import
(
get_moe_a2a_backend
,
should_use_flashinfer_cutlass_moe_fp4_allgather
,
)
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
...
@@ -112,7 +115,11 @@ class LayerScatterModes:
...
@@ -112,7 +115,11 @@ class LayerScatterModes:
if
context
.
is_layer_sparse
:
if
context
.
is_layer_sparse
:
return
(
return
(
ScatterMode
.
SCATTERED
ScatterMode
.
SCATTERED
if
not
get_moe_a2a_backend
().
is_none
()
if
(
# Token dispatch/combine will be handled outside of LayerCommunicator for these modes.
not
get_moe_a2a_backend
().
is_none
()
or
should_use_flashinfer_cutlass_moe_fp4_allgather
()
)
else
ScatterMode
.
FULL
else
ScatterMode
.
FULL
)
)
else
:
else
:
...
...
python/sglang/srt/layers/dp_attention.py
View file @
eff4eb3f
...
@@ -72,6 +72,7 @@ class _DpGatheredBufferWrapper:
...
@@ -72,6 +72,7 @@ class _DpGatheredBufferWrapper:
_device
:
torch
.
device
_device
:
torch
.
device
_global_dp_buffer_len
:
int
_global_dp_buffer_len
:
int
_local_dp_buffer_len
:
int
_local_dp_buffer_len
:
int
_global_num_tokens
:
Optional
[
List
[
int
]]
@
classmethod
@
classmethod
def
set_metadata
(
cls
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
def
set_metadata
(
cls
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
):
...
@@ -80,9 +81,15 @@ class _DpGatheredBufferWrapper:
...
@@ -80,9 +81,15 @@ class _DpGatheredBufferWrapper:
cls
.
_device
=
device
cls
.
_device
=
device
@
classmethod
@
classmethod
def
set_dp_buffer_len
(
cls
,
global_dp_buffer_len
:
int
,
local_dp_buffer_len
:
int
):
def
set_dp_buffer_len
(
cls
,
global_dp_buffer_len
:
int
,
local_dp_buffer_len
:
int
,
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
,
):
cls
.
_global_dp_buffer_len
=
global_dp_buffer_len
cls
.
_global_dp_buffer_len
=
global_dp_buffer_len
cls
.
_local_dp_buffer_len
=
local_dp_buffer_len
cls
.
_local_dp_buffer_len
=
local_dp_buffer_len
cls
.
_global_num_tokens
=
global_num_tokens
@
classmethod
@
classmethod
def
get_global_dp_buffer
(
cls
)
->
torch
.
Tensor
:
def
get_global_dp_buffer
(
cls
)
->
torch
.
Tensor
:
...
@@ -108,10 +115,18 @@ class _DpGatheredBufferWrapper:
...
@@ -108,10 +115,18 @@ class _DpGatheredBufferWrapper:
def
get_local_dp_buffer_len
(
cls
)
->
int
:
def
get_local_dp_buffer_len
(
cls
)
->
int
:
return
cls
.
_local_dp_buffer_len
return
cls
.
_local_dp_buffer_len
@
classmethod
def
get_dp_global_num_tokens
(
cls
)
->
List
[
int
]:
return
cls
.
_global_num_tokens
def
set_dp_buffer_len
(
global_dp_buffer_len
:
int
,
local_dp_buffer_len
:
int
):
def
set_dp_buffer_len
(
global_dp_buffer_len
:
int
,
local_dp_buffer_len
:
int
,
global_num_tokens
:
Optional
[
List
[
int
]]
=
None
,
):
_DpGatheredBufferWrapper
.
set_dp_buffer_len
(
_DpGatheredBufferWrapper
.
set_dp_buffer_len
(
global_dp_buffer_len
,
local_dp_buffer_len
global_dp_buffer_len
,
local_dp_buffer_len
,
global_num_tokens
)
)
...
@@ -131,6 +146,10 @@ def get_local_dp_buffer_len() -> int:
...
@@ -131,6 +146,10 @@ def get_local_dp_buffer_len() -> int:
return
_DpGatheredBufferWrapper
.
get_local_dp_buffer_len
()
return
_DpGatheredBufferWrapper
.
get_local_dp_buffer_len
()
def
get_dp_global_num_tokens
()
->
List
[
int
]:
return
_DpGatheredBufferWrapper
.
get_dp_global_num_tokens
()
def
compute_dp_attention_world_info
(
enable_dp_attention
,
tp_rank
,
tp_size
,
dp_size
):
def
compute_dp_attention_world_info
(
enable_dp_attention
,
tp_rank
,
tp_size
,
dp_size
):
if
not
enable_dp_attention
:
if
not
enable_dp_attention
:
return
tp_rank
,
tp_size
,
0
return
tp_rank
,
tp_size
,
0
...
...
python/sglang/srt/layers/logits_processor.py
View file @
eff4eb3f
...
@@ -191,7 +191,11 @@ class LogitsMetadata:
...
@@ -191,7 +191,11 @@ class LogitsMetadata:
else
:
else
:
self
.
global_dp_buffer_len
=
self
.
global_dp_buffer_len
self
.
global_dp_buffer_len
=
self
.
global_dp_buffer_len
set_dp_buffer_len
(
self
.
global_dp_buffer_len
,
self
.
dp_local_num_tokens
)
set_dp_buffer_len
(
self
.
global_dp_buffer_len
,
self
.
dp_local_num_tokens
,
self
.
global_num_tokens_for_logprob_cpu
,
)
class
LogitsProcessor
(
nn
.
Module
):
class
LogitsProcessor
(
nn
.
Module
):
...
...
python/sglang/srt/layers/moe/__init__.py
View file @
eff4eb3f
...
@@ -10,6 +10,7 @@ from sglang.srt.layers.moe.utils import (
...
@@ -10,6 +10,7 @@ from sglang.srt.layers.moe.utils import (
get_tbo_token_distribution_threshold
,
get_tbo_token_distribution_threshold
,
initialize_moe_config
,
initialize_moe_config
,
is_tbo_enabled
,
is_tbo_enabled
,
should_use_flashinfer_cutlass_moe_fp4_allgather
,
should_use_flashinfer_trtllm_moe
,
should_use_flashinfer_trtllm_moe
,
)
)
...
@@ -23,6 +24,7 @@ __all__ = [
...
@@ -23,6 +24,7 @@ __all__ = [
"get_moe_runner_backend"
,
"get_moe_runner_backend"
,
"get_deepep_mode"
,
"get_deepep_mode"
,
"should_use_flashinfer_trtllm_moe"
,
"should_use_flashinfer_trtllm_moe"
,
"should_use_flashinfer_cutlass_moe_fp4_allgather"
,
"is_tbo_enabled"
,
"is_tbo_enabled"
,
"get_tbo_token_distribution_threshold"
,
"get_tbo_token_distribution_threshold"
,
"get_deepep_config"
,
"get_deepep_config"
,
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
eff4eb3f
...
@@ -28,6 +28,7 @@ from sglang.srt.layers.quantization.base_config import (
...
@@ -28,6 +28,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig
,
QuantizationConfig
,
QuantizeMethodBase
,
QuantizeMethodBase
,
)
)
from
sglang.srt.layers.quantization.modelopt_quant
import
ModelOptNvFp4FusedMoEMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedFusedMoEMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedFusedMoEMethod
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.model_loader.weight_utils
import
narrow_padded_param_and_loaded_weight
from
sglang.srt.model_loader.weight_utils
import
narrow_padded_param_and_loaded_weight
...
@@ -621,9 +622,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -621,9 +622,7 @@ class FusedMoE(torch.nn.Module):
if
"ModelOpt"
in
self
.
quant_method
.
__class__
.
__name__
:
if
"ModelOpt"
in
self
.
quant_method
.
__class__
.
__name__
:
# Determine per-tensor weight scale patterns based on variant
# Determine per-tensor weight scale patterns based on variant
is_fp4_variant
=
(
is_fp4_variant
=
isinstance
(
self
.
quant_method
,
ModelOptNvFp4FusedMoEMethod
)
"ModelOptNvFp4FusedMoEMethod"
in
self
.
quant_method
.
__class__
.
__name__
)
# FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
# FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
per_tensor_conditions
=
(
per_tensor_conditions
=
(
...
...
python/sglang/srt/layers/moe/topk.py
View file @
eff4eb3f
...
@@ -327,6 +327,13 @@ class TopK(CustomOp):
...
@@ -327,6 +327,13 @@ class TopK(CustomOp):
expert_location_dispatch_info
=
expert_location_dispatch_info
,
expert_location_dispatch_info
=
expert_location_dispatch_info
,
)
)
def
empty_topk_output
(
self
,
device
:
torch
.
device
)
->
TopKOutput
:
topk
=
self
.
topk_config
.
top_k
-
self
.
topk_config
.
num_fused_shared_experts
topk_weights
=
torch
.
empty
((
0
,
topk
),
dtype
=
torch
.
float32
,
device
=
device
)
topk_idx
=
torch
.
full
((
0
,
topk
),
-
1
,
dtype
=
torch
.
int32
,
device
=
device
)
router_logits
=
torch
.
empty
((
0
,
topk
),
dtype
=
torch
.
float32
,
device
=
device
)
return
StandardTopKOutput
(
topk_weights
,
topk_idx
,
router_logits
)
# ------------------------------- TopK implementation -------------------------------------
# ------------------------------- TopK implementation -------------------------------------
...
...
python/sglang/srt/layers/moe/utils.py
View file @
eff4eb3f
...
@@ -7,6 +7,11 @@ from typing import TYPE_CHECKING, Optional
...
@@ -7,6 +7,11 @@ from typing import TYPE_CHECKING, Optional
from
packaging
import
version
as
pkg_version
from
packaging
import
version
as
pkg_version
from
sglang.srt.distributed.parallel_state
import
get_moe_expert_parallel_world_size
from
sglang.srt.layers.dp_attention
import
(
get_attention_dp_size
,
is_dp_attention_enabled
,
)
from
sglang.srt.utils
import
logger
from
sglang.srt.utils
import
logger
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -99,6 +104,7 @@ DEEPEP_MODE: Optional[DeepEPMode] = None
...
@@ -99,6 +104,7 @@ DEEPEP_MODE: Optional[DeepEPMode] = None
IS_TBO_ENABLED
:
Optional
[
bool
]
=
None
IS_TBO_ENABLED
:
Optional
[
bool
]
=
None
TBO_TOKEN_DISTRIBUTION_THRESHOLD
:
Optional
[
float
]
=
None
TBO_TOKEN_DISTRIBUTION_THRESHOLD
:
Optional
[
float
]
=
None
DEEPEP_CONFIG
:
Optional
[
str
]
=
None
DEEPEP_CONFIG
:
Optional
[
str
]
=
None
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
:
Optional
[
bool
]
=
None
def
initialize_moe_config
(
server_args
:
ServerArgs
):
def
initialize_moe_config
(
server_args
:
ServerArgs
):
...
@@ -108,6 +114,7 @@ def initialize_moe_config(server_args: ServerArgs):
...
@@ -108,6 +114,7 @@ def initialize_moe_config(server_args: ServerArgs):
global
DEEPEP_CONFIG
global
DEEPEP_CONFIG
global
IS_TBO_ENABLED
global
IS_TBO_ENABLED
global
TBO_TOKEN_DISTRIBUTION_THRESHOLD
global
TBO_TOKEN_DISTRIBUTION_THRESHOLD
global
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
MOE_A2A_BACKEND
=
MoeA2ABackend
(
server_args
.
moe_a2a_backend
)
MOE_A2A_BACKEND
=
MoeA2ABackend
(
server_args
.
moe_a2a_backend
)
MOE_RUNNER_BACKEND
=
MoeRunnerBackend
(
server_args
.
moe_runner_backend
)
MOE_RUNNER_BACKEND
=
MoeRunnerBackend
(
server_args
.
moe_runner_backend
)
...
@@ -115,6 +122,9 @@ def initialize_moe_config(server_args: ServerArgs):
...
@@ -115,6 +122,9 @@ def initialize_moe_config(server_args: ServerArgs):
DEEPEP_CONFIG
=
server_args
.
deepep_config
or
""
DEEPEP_CONFIG
=
server_args
.
deepep_config
or
""
IS_TBO_ENABLED
=
server_args
.
enable_two_batch_overlap
IS_TBO_ENABLED
=
server_args
.
enable_two_batch_overlap
TBO_TOKEN_DISTRIBUTION_THRESHOLD
=
server_args
.
tbo_token_distribution_threshold
TBO_TOKEN_DISTRIBUTION_THRESHOLD
=
server_args
.
tbo_token_distribution_threshold
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
=
(
server_args
.
disable_flashinfer_cutlass_moe_fp4_allgather
)
def
get_moe_a2a_backend
()
->
MoeA2ABackend
:
def
get_moe_a2a_backend
()
->
MoeA2ABackend
:
...
@@ -175,3 +185,16 @@ def should_use_flashinfer_trtllm_moe():
...
@@ -175,3 +185,16 @@ def should_use_flashinfer_trtllm_moe():
>=
pkg_version
.
parse
(
"0.2.9rc1"
)
>=
pkg_version
.
parse
(
"0.2.9rc1"
)
)
)
return
result
return
result
@
lru_cache
(
maxsize
=
1
)
def
should_use_flashinfer_cutlass_moe_fp4_allgather
():
"""
Perform FP4 quantize before all-gather for flashinfer cutlass moe to reduce communication cost for high-throughput serving.
"""
return
(
not
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
and
get_moe_runner_backend
().
is_flashinfer_cutlass
()
and
is_dp_attention_enabled
()
and
get_moe_expert_parallel_world_size
()
==
get_attention_dp_size
()
)
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
eff4eb3f
...
@@ -7,7 +7,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
...
@@ -7,7 +7,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional
import
torch
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
sglang.srt.layers.moe
import
should_use_flashinfer_trtllm_moe
from
sglang.srt.distributed
import
get_tp_group
from
sglang.srt.layers.dp_attention
import
get_dp_global_num_tokens
,
get_local_dp_buffer
from
sglang.srt.layers.moe
import
(
should_use_flashinfer_cutlass_moe_fp4_allgather
,
should_use_flashinfer_trtllm_moe
,
)
from
sglang.srt.layers.moe.cutlass_moe_params
import
CutlassMoEParams
,
CutlassMoEType
from
sglang.srt.layers.moe.cutlass_moe_params
import
CutlassMoEParams
,
CutlassMoEType
from
sglang.srt.layers.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
sglang.srt.layers.parameter
import
ModelWeightParameter
,
PerTensorScaleParameter
from
sglang.srt.layers.quantization.base_config
import
(
from
sglang.srt.layers.quantization.base_config
import
(
...
@@ -1176,16 +1181,37 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -1176,16 +1181,37 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
),
"apply_router_weight_on_input is not supported for Flashinfer"
),
"apply_router_weight_on_input is not supported for Flashinfer"
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
# and fp4 quantized weights loaded from the checkpoint
# and fp4 quantized weights loaded from the checkpoint
topk_weights
,
topk_ids
=
topk_output
.
topk_weights
,
topk_output
.
topk_ids
topk_weights
,
topk_ids
=
topk_output
.
topk_weights
,
topk_output
.
topk_ids
output_dtype
=
x
.
dtype
x_sf
=
None
if
should_use_flashinfer_cutlass_moe_fp4_allgather
():
from
flashinfer
import
fp4_quantize
,
nvfp4_block_scale_interleave
# Quantize before comm, swizzle after.
if
x
.
shape
[
0
]
>
0
:
x
,
x_sf
=
fp4_quantize
(
x
,
layer
.
w13_input_scale_quant
,
is_sf_swizzled_layout
=
False
)
else
:
x_col
=
x
.
shape
[
1
]
x
=
torch
.
zeros
(
0
,
x_col
//
2
,
dtype
=
torch
.
uint8
,
device
=
x
.
device
)
x_sf
=
torch
.
zeros
(
0
,
x_col
//
16
,
dtype
=
torch
.
uint8
,
device
=
x
.
device
)
topk_weights
,
topk_ids
,
x
,
x_sf
=
get_tp_group
().
all_gatherv
(
[
topk_weights
,
topk_ids
,
x
,
x_sf
],
sizes
=
get_dp_global_num_tokens
()
)
x_sf
=
nvfp4_block_scale_interleave
(
x_sf
)
output
=
flashinfer_cutlass_fused_moe
(
output
=
flashinfer_cutlass_fused_moe
(
x
,
input
=
x
,
topk_ids
.
to
(
torch
.
int
),
token_selected_experts
=
topk_ids
.
to
(
torch
.
int
),
topk_weights
,
token_final_scales
=
topk_weights
,
layer
.
w13_weight
.
view
(
torch
.
long
),
fc1_expert_weights
=
layer
.
w13_weight
.
view
(
torch
.
long
),
layer
.
w2_weight
.
view
(
torch
.
long
),
fc2_expert_weights
=
layer
.
w2_weight
.
view
(
torch
.
long
),
x
.
dtype
,
output_dtype
=
output_dtype
,
input_sf
=
x_sf
,
quant_scales
=
[
quant_scales
=
[
layer
.
w13_input_scale_quant
,
layer
.
w13_input_scale_quant
,
layer
.
w13_blockscale_swizzled
.
view
(
torch
.
int32
),
layer
.
w13_blockscale_swizzled
.
view
(
torch
.
int32
),
...
@@ -1202,6 +1228,11 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
...
@@ -1202,6 +1228,11 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
)[
0
]
)[
0
]
if
moe_runner_config
.
routed_scaling_factor
is
not
None
:
if
moe_runner_config
.
routed_scaling_factor
is
not
None
:
output
*=
moe_runner_config
.
routed_scaling_factor
output
*=
moe_runner_config
.
routed_scaling_factor
if
should_use_flashinfer_cutlass_moe_fp4_allgather
():
output
,
global_output
=
get_local_dp_buffer
(),
output
get_tp_group
().
reduce_scatterv
(
global_output
,
output
=
output
,
sizes
=
get_dp_global_num_tokens
()
)
return
output
return
output
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_moe_fp4
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_moe_fp4
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
eff4eb3f
...
@@ -84,6 +84,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
...
@@ -84,6 +84,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"chunked_prefill_size"
,
"chunked_prefill_size"
,
"device"
,
"device"
,
"disable_chunked_prefix_cache"
,
"disable_chunked_prefix_cache"
,
"disable_flashinfer_cutlass_moe_fp4_allgather"
,
"disable_radix_cache"
,
"disable_radix_cache"
,
"enable_dp_lm_head"
,
"enable_dp_lm_head"
,
"enable_flashinfer_allreduce_fusion"
,
"enable_flashinfer_allreduce_fusion"
,
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
eff4eb3f
...
@@ -649,7 +649,7 @@ class ForwardBatch:
...
@@ -649,7 +649,7 @@ class ForwardBatch:
num_tokens
=
global_num_tokens
[
0
]
num_tokens
=
global_num_tokens
[
0
]
self
.
global_dp_buffer_len
=
buffer_len
self
.
global_dp_buffer_len
=
buffer_len
set_dp_buffer_len
(
buffer_len
,
num_tokens
)
set_dp_buffer_len
(
buffer_len
,
num_tokens
,
global_
num_tokens
)
bs
=
self
.
batch_size
bs
=
self
.
batch_size
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
eff4eb3f
...
@@ -60,7 +60,11 @@ from sglang.srt.layers.linear import (
...
@@ -60,7 +60,11 @@ from sglang.srt.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe
import
get_deepep_mode
,
get_moe_a2a_backend
from
sglang.srt.layers.moe
import
(
get_deepep_mode
,
get_moe_a2a_backend
,
should_use_flashinfer_cutlass_moe_fp4_allgather
,
)
from
sglang.srt.layers.moe.ep_moe.layer
import
DeepEPMoE
,
get_moe_impl_class
from
sglang.srt.layers.moe.ep_moe.layer
import
DeepEPMoE
,
get_moe_impl_class
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton.layer
import
FusedMoE
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.moe.topk
import
TopK
...
@@ -343,7 +347,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -343,7 +347,7 @@ class DeepseekV2MoE(nn.Module):
self
.
shared_experts_weight_block_size
=
None
self
.
shared_experts_weight_block_size
=
None
if
config
.
n_shared_experts
is
not
None
and
self
.
num_fused_shared_experts
==
0
:
if
config
.
n_shared_experts
is
not
None
and
self
.
num_fused_shared_experts
==
0
:
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
# disable tp for shared experts when enable deepep moe
# disable tp for shared experts when enable deepep moe
, or with fp4 allgather
self
.
shared_experts
=
DeepseekV2MLP
(
self
.
shared_experts
=
DeepseekV2MLP
(
hidden_size
=
config
.
hidden_size
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
intermediate_size
=
intermediate_size
,
...
@@ -354,6 +358,7 @@ class DeepseekV2MoE(nn.Module):
...
@@ -354,6 +358,7 @@ class DeepseekV2MoE(nn.Module):
**
(
**
(
dict
(
tp_rank
=
0
,
tp_size
=
1
)
dict
(
tp_rank
=
0
,
tp_size
=
1
)
if
get_moe_a2a_backend
().
is_deepep
()
if
get_moe_a2a_backend
().
is_deepep
()
or
should_use_flashinfer_cutlass_moe_fp4_allgather
()
else
{}
else
{}
),
),
)
)
...
@@ -433,14 +438,19 @@ class DeepseekV2MoE(nn.Module):
...
@@ -433,14 +438,19 @@ class DeepseekV2MoE(nn.Module):
if
(
if
(
self
.
alt_stream
is
not
None
self
.
alt_stream
is
not
None
and
self
.
num_fused_shared_experts
==
0
and
self
.
num_fused_shared_experts
==
0
and
hidden_states
.
shape
[
0
]
>
0
and
hidden_states
.
shape
[
0
]
<=
DUAL_STREAM_TOKEN_THRESHOLD
and
hidden_states
.
shape
[
0
]
<=
DUAL_STREAM_TOKEN_THRESHOLD
):
):
return
self
.
forward_normal_dual_stream
(
return
self
.
forward_normal_dual_stream
(
hidden_states
,
should_allreduce_fusion
,
use_reduce_scatter
hidden_states
,
should_allreduce_fusion
,
use_reduce_scatter
,
)
)
else
:
else
:
return
self
.
forward_normal
(
return
self
.
forward_normal
(
hidden_states
,
should_allreduce_fusion
,
use_reduce_scatter
hidden_states
,
should_allreduce_fusion
,
use_reduce_scatter
,
)
)
else
:
else
:
return
self
.
forward_deepep
(
hidden_states
,
forward_batch
)
return
self
.
forward_deepep
(
hidden_states
,
forward_batch
)
...
@@ -471,7 +481,12 @@ class DeepseekV2MoE(nn.Module):
...
@@ -471,7 +481,12 @@ class DeepseekV2MoE(nn.Module):
torch
.
add
(
final_hidden_states
,
shared_output
,
out
=
final_hidden_states_out
)
torch
.
add
(
final_hidden_states
,
shared_output
,
out
=
final_hidden_states_out
)
final_hidden_states
=
final_hidden_states_out
final_hidden_states
=
final_hidden_states_out
sm
.
tag
(
final_hidden_states
)
sm
.
tag
(
final_hidden_states
)
if
self
.
tp_size
>
1
and
not
should_allreduce_fusion
and
not
use_reduce_scatter
:
if
(
self
.
tp_size
>
1
and
not
should_allreduce_fusion
and
not
use_reduce_scatter
and
not
should_use_flashinfer_cutlass_moe_fp4_allgather
()
):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
return
final_hidden_states
...
@@ -486,10 +501,14 @@ class DeepseekV2MoE(nn.Module):
...
@@ -486,10 +501,14 @@ class DeepseekV2MoE(nn.Module):
):
):
return
self
.
forward_cpu
(
hidden_states
,
should_allreduce_fusion
)
return
self
.
forward_cpu
(
hidden_states
,
should_allreduce_fusion
)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
if
hidden_states
.
shape
[
0
]
>
0
:
# router_logits: (num_tokens, n_experts)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
router_logits
=
self
.
gate
(
hidden_states
)
# router_logits: (num_tokens, n_experts)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
router_logits
=
self
.
gate
(
hidden_states
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
else
:
shared_output
=
None
topk_output
=
self
.
topk
.
empty_topk_output
(
hidden_states
.
device
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
if
not
_is_cuda
and
not
_use_aiter
:
if
not
_is_cuda
and
not
_use_aiter
:
...
@@ -501,7 +520,12 @@ class DeepseekV2MoE(nn.Module):
...
@@ -501,7 +520,12 @@ class DeepseekV2MoE(nn.Module):
torch
.
add
(
final_hidden_states
,
shared_output
,
out
=
final_hidden_states_out
)
torch
.
add
(
final_hidden_states
,
shared_output
,
out
=
final_hidden_states_out
)
final_hidden_states
=
final_hidden_states_out
final_hidden_states
=
final_hidden_states_out
sm
.
tag
(
final_hidden_states
)
sm
.
tag
(
final_hidden_states
)
if
self
.
tp_size
>
1
and
not
should_allreduce_fusion
and
not
use_reduce_scatter
:
if
(
self
.
tp_size
>
1
and
not
should_allreduce_fusion
and
not
use_reduce_scatter
and
not
should_use_flashinfer_cutlass_moe_fp4_allgather
()
):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
return
final_hidden_states
...
@@ -580,11 +604,8 @@ class DeepseekV2MoE(nn.Module):
...
@@ -580,11 +604,8 @@ class DeepseekV2MoE(nn.Module):
),
),
)
)
else
:
else
:
topk_idx
=
torch
.
full
(
topk_weights
,
topk_idx
,
_
=
self
.
topk
.
empty_topk_output
(
(
0
,
self
.
top_k
),
-
1
,
dtype
=
torch
.
int
,
device
=
hidden_states
.
device
hidden_states
.
device
)
topk_weights
=
torch
.
empty
(
(
0
,
self
.
top_k
),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
)
final_hidden_states
=
self
.
experts
(
final_hidden_states
=
self
.
experts
(
...
...
python/sglang/srt/operations.py
View file @
eff4eb3f
...
@@ -84,6 +84,7 @@ class _StageExecutor:
...
@@ -84,6 +84,7 @@ class _StageExecutor:
forward_batch
:
ForwardBatch
=
inputs
[
"forward_batch"
]
forward_batch
:
ForwardBatch
=
inputs
[
"forward_batch"
]
self
.
_global_dp_buffer_len
=
forward_batch
.
global_dp_buffer_len
self
.
_global_dp_buffer_len
=
forward_batch
.
global_dp_buffer_len
self
.
_local_dp_buffer_len
=
forward_batch
.
input_ids
.
shape
[
0
]
self
.
_local_dp_buffer_len
=
forward_batch
.
input_ids
.
shape
[
0
]
self
.
_global_num_tokens
=
forward_batch
.
global_num_tokens_cpu
def
next
(
self
):
def
next
(
self
):
assert
not
self
.
done
assert
not
self
.
done
...
@@ -91,7 +92,11 @@ class _StageExecutor:
...
@@ -91,7 +92,11 @@ class _StageExecutor:
stage
=
self
.
_stages
[
self
.
_index
]
stage
=
self
.
_stages
[
self
.
_index
]
if
self
.
_global_dp_buffer_len
is
not
None
:
if
self
.
_global_dp_buffer_len
is
not
None
:
set_dp_buffer_len
(
self
.
_global_dp_buffer_len
,
self
.
_local_dp_buffer_len
)
set_dp_buffer_len
(
self
.
_global_dp_buffer_len
,
self
.
_local_dp_buffer_len
,
self
.
_global_num_tokens
,
)
with
_annotate_region
(
debug_name
=
f
"
{
self
.
_debug_name
}{
self
.
_index
}
"
):
with
_annotate_region
(
debug_name
=
f
"
{
self
.
_debug_name
}{
self
.
_index
}
"
):
for
op
in
stage
:
for
op
in
stage
:
...
...
python/sglang/srt/server_args.py
View file @
eff4eb3f
...
@@ -230,6 +230,7 @@ class ServerArgs:
...
@@ -230,6 +230,7 @@ class ServerArgs:
enable_cudagraph_gc
:
bool
=
False
enable_cudagraph_gc
:
bool
=
False
enable_nccl_nvls
:
bool
=
False
enable_nccl_nvls
:
bool
=
False
enable_symm_mem
:
bool
=
False
enable_symm_mem
:
bool
=
False
disable_flashinfer_cutlass_moe_fp4_allgather
:
bool
=
False
enable_tokenizer_batch_encode
:
bool
=
False
enable_tokenizer_batch_encode
:
bool
=
False
disable_outlines_disk_cache
:
bool
=
False
disable_outlines_disk_cache
:
bool
=
False
disable_custom_all_reduce
:
bool
=
False
disable_custom_all_reduce
:
bool
=
False
...
@@ -1714,6 +1715,11 @@ class ServerArgs:
...
@@ -1714,6 +1715,11 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Enable NCCL symmetric memory for fast collectives."
,
help
=
"Enable NCCL symmetric memory for fast collectives."
,
)
)
parser
.
add_argument
(
"--disable-flashinfer-cutlass-moe-fp4-allgather"
,
action
=
"store_true"
,
help
=
"Disables quantize before all-gather for flashinfer cutlass moe."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-tokenizer-batch-encode"
,
"--enable-tokenizer-batch-encode"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment