Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
a167fd0b
You need to sign in or sign up before continuing.
Unverified
Commit
a167fd0b
authored
Jul 24, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Jul 24, 2025
Browse files
[code style] Clean dead triton kernel code in fused_moe and useless vllm_ops import (#8310)
parent
2f86f3ad
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
242 deletions
+27
-242
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+25
-224
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+2
-9
python/sglang/srt/layers/quantization/utils.py
python/sglang/srt/layers/quantization/utils.py
+0
-9
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
a167fd0b
...
@@ -53,9 +53,7 @@ elif _is_hip:
...
@@ -53,9 +53,7 @@ elif _is_hip:
from
aiter
import
moe_sum
from
aiter
import
moe_sum
except
ImportError
:
except
ImportError
:
raise
ImportError
(
"aiter is required when SGLANG_USE_AITER is set to True"
)
raise
ImportError
(
"aiter is required when SGLANG_USE_AITER is set to True"
)
else
:
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm._custom_ops
import
scaled_fp8_quant
if
_is_cuda
or
_is_hip
:
if
_is_cuda
or
_is_hip
:
from
sgl_kernel
import
moe_align_block_size
as
sgl_moe_align_block_size
from
sgl_kernel
import
moe_align_block_size
as
sgl_moe_align_block_size
...
@@ -63,9 +61,6 @@ if _is_cuda or _is_hip:
...
@@ -63,9 +61,6 @@ if _is_cuda or _is_hip:
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
padding_size
=
128
if
bool
(
int
(
os
.
getenv
(
"SGLANG_MOE_PADDING"
,
"0"
)))
else
0
padding_size
=
128
if
bool
(
int
(
os
.
getenv
(
"SGLANG_MOE_PADDING"
,
"0"
)))
else
0
enable_moe_align_block_size_triton
=
bool
(
int
(
os
.
getenv
(
"ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON"
,
"0"
))
)
@
triton
.
jit
@
triton
.
jit
...
@@ -533,190 +528,6 @@ def fused_moe_kernel(
...
@@ -533,190 +528,6 @@ def fused_moe_kernel(
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
@
triton
.
jit
def
moe_align_block_size_stage1
(
topk_ids_ptr
,
tokens_cnts_ptr
,
num_experts
:
tl
.
constexpr
,
numel
:
tl
.
constexpr
,
tokens_per_thread
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
start_idx
=
pid
*
tokens_per_thread
off_c
=
(
pid
+
1
)
*
num_experts
for
i
in
range
(
tokens_per_thread
):
if
start_idx
+
i
<
numel
:
idx
=
tl
.
load
(
topk_ids_ptr
+
start_idx
+
i
)
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
off_c
+
idx
)
tl
.
store
(
tokens_cnts_ptr
+
off_c
+
idx
,
token_cnt
+
1
)
@
triton
.
jit
def
moe_align_block_size_stage2
(
tokens_cnts_ptr
,
num_experts
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
last_cnt
=
0
for
i
in
range
(
1
,
num_experts
+
1
):
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
i
*
num_experts
+
pid
)
last_cnt
=
last_cnt
+
token_cnt
tl
.
store
(
tokens_cnts_ptr
+
i
*
num_experts
+
pid
,
last_cnt
)
@
triton
.
jit
def
moe_align_block_size_stage3
(
total_tokens_post_pad_ptr
,
tokens_cnts_ptr
,
cumsum_ptr
,
num_experts
:
tl
.
constexpr
,
block_size
:
tl
.
constexpr
,
):
last_cumsum
=
0
off_cnt
=
num_experts
*
num_experts
for
i
in
range
(
1
,
num_experts
+
1
):
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
off_cnt
+
i
-
1
)
last_cumsum
=
last_cumsum
+
tl
.
cdiv
(
token_cnt
,
block_size
)
*
block_size
tl
.
store
(
cumsum_ptr
+
i
,
last_cumsum
)
tl
.
store
(
total_tokens_post_pad_ptr
,
last_cumsum
)
@
triton
.
jit
def
moe_align_block_size_stage4
(
topk_ids_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
tokens_cnts_ptr
,
cumsum_ptr
,
num_experts
:
tl
.
constexpr
,
block_size
:
tl
.
constexpr
,
numel
:
tl
.
constexpr
,
tokens_per_thread
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
start_idx
=
tl
.
load
(
cumsum_ptr
+
pid
)
end_idx
=
tl
.
load
(
cumsum_ptr
+
pid
+
1
)
for
i
in
range
(
start_idx
,
end_idx
,
block_size
):
tl
.
store
(
expert_ids_ptr
+
i
//
block_size
,
pid
)
start_idx
=
pid
*
tokens_per_thread
off_t
=
pid
*
num_experts
for
i
in
range
(
start_idx
,
tl
.
minimum
(
start_idx
+
tokens_per_thread
,
numel
)):
expert_id
=
tl
.
load
(
topk_ids_ptr
+
i
)
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
off_t
+
expert_id
)
rank_post_pad
=
token_cnt
+
tl
.
load
(
cumsum_ptr
+
expert_id
)
tl
.
store
(
sorted_token_ids_ptr
+
rank_post_pad
,
i
)
tl
.
store
(
tokens_cnts_ptr
+
off_t
+
expert_id
,
token_cnt
+
1
)
def
moe_align_block_size_triton
(
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
block_size
:
int
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_pad
:
torch
.
Tensor
,
)
->
None
:
numel
=
topk_ids
.
numel
()
grid
=
(
num_experts
,)
tokens_cnts
=
torch
.
zeros
(
(
num_experts
+
1
,
num_experts
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
cumsum
=
torch
.
zeros
((
num_experts
+
1
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
tokens_per_thread
=
ceil_div
(
numel
,
num_experts
)
moe_align_block_size_stage1
[
grid
](
topk_ids
,
tokens_cnts
,
num_experts
,
numel
,
tokens_per_thread
,
)
moe_align_block_size_stage2
[
grid
](
tokens_cnts
,
num_experts
,
)
moe_align_block_size_stage3
[(
1
,)](
num_tokens_post_pad
,
tokens_cnts
,
cumsum
,
num_experts
,
block_size
,
)
moe_align_block_size_stage4
[
grid
](
topk_ids
,
sorted_token_ids
,
expert_ids
,
tokens_cnts
,
cumsum
,
num_experts
,
block_size
,
numel
,
tokens_per_thread
,
)
@
triton
.
jit
def
init_sorted_ids_and_cumsum_buffer_kernel
(
sorted_ids_ptr
,
cumsum_buffer_ptr
,
max_num_tokens_padded
,
topk_ids_numel
,
num_experts
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
ALIGNED_NUM_EXPERTS_P1
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
offsets
=
pid
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
sorted_ids_blocks
=
tl
.
cdiv
(
max_num_tokens_padded
,
BLOCK_SIZE
)
if
pid
<
sorted_ids_blocks
:
mask
=
offsets
<
max_num_tokens_padded
tl
.
store
(
sorted_ids_ptr
+
offsets
,
tl
.
full
((
BLOCK_SIZE
,),
topk_ids_numel
,
dtype
=
tl
.
int32
),
mask
=
mask
,
)
elif
pid
==
sorted_ids_blocks
:
offset_e
=
tl
.
arange
(
0
,
ALIGNED_NUM_EXPERTS_P1
)
mask_e
=
offset_e
<
num_experts
+
1
tl
.
store
(
cumsum_buffer_ptr
+
offset_e
,
tl
.
zeros
((
ALIGNED_NUM_EXPERTS_P1
,),
dtype
=
tl
.
int32
),
mask
=
mask_e
,
)
def
init_sorted_ids_and_cumsum_buffer
(
max_num_tokens_padded
:
int
,
topk_ids_numel
:
int
,
num_experts
:
int
,
device
=
"cuda"
):
sorted_ids
=
torch
.
empty
((
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
device
)
cumsum_buffer
=
torch
.
empty
((
num_experts
+
1
,),
dtype
=
torch
.
int32
,
device
=
device
)
BLOCK_SIZE
=
1024
sorted_ids_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
BLOCK_SIZE
)
grid
=
(
sorted_ids_blocks
+
1
,)
init_sorted_ids_and_cumsum_buffer_kernel
[
grid
](
sorted_ids
,
cumsum_buffer
,
max_num_tokens_padded
,
topk_ids_numel
,
num_experts
,
BLOCK_SIZE
,
next_power_of_2
(
num_experts
+
1
),
)
return
sorted_ids
,
cumsum_buffer
def
moe_align_block_size
(
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
...
@@ -766,42 +577,32 @@ def moe_align_block_size(
...
@@ -766,42 +577,32 @@ def moe_align_block_size(
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
if
enable_moe_align_block_size_triton
:
sorted_ids
.
fill_
(
topk_ids
.
numel
())
moe_align_block_size_triton
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
)
else
:
cumsum_buffer
=
torch
.
empty
(
(
num_experts
+
1
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
token_cnts_buffer
=
torch
.
empty
(
(
num_experts
+
1
)
*
num_experts
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
,
)
# Threshold based on benchmark results
cumsum_buffer
=
torch
.
empty
(
fuse_sorted_ids_padding
=
sorted_ids
.
shape
[
0
]
<=
4096
(
num_experts
+
1
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
if
not
fuse_sorted_ids_padding
:
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
token_cnts_buffer
=
torch
.
empty
(
(
num_experts
+
1
)
*
num_experts
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
,
)
sgl_moe_align_block_size
(
# Threshold based on benchmark results
topk_ids
,
fuse_sorted_ids_padding
=
sorted_ids
.
shape
[
0
]
<=
4096
num_experts
,
if
not
fuse_sorted_ids_padding
:
block_size
,
sorted_ids
.
fill_
(
topk_ids
.
numel
())
sorted_ids
,
expert_ids
,
sgl_moe_align_block_size
(
num_tokens_post_pad
,
topk_ids
,
token_cnts_buffer
,
num_experts
,
cumsum_buffer
,
block_size
,
fuse_sorted_ids_padding
,
sorted_ids
,
)
expert_ids
,
num_tokens_post_pad
,
token_cnts_buffer
,
cumsum_buffer
,
fuse_sorted_ids_padding
,
)
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
return
sorted_ids
,
expert_ids
,
num_tokens_post_pad
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
a167fd0b
...
@@ -28,15 +28,6 @@ if TYPE_CHECKING:
...
@@ -28,15 +28,6 @@ if TYPE_CHECKING:
CompressedTensorsConfig
,
CompressedTensorsConfig
,
)
)
_is_cuda
=
is_cuda
()
_is_npu
=
is_npu
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu
=
is_cpu
()
_is_hip
=
is_hip
()
if
not
(
_is_cuda
or
_is_npu
or
(
_is_cpu
and
_is_cpu_amx_available
)
or
_is_hip
):
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm._custom_ops
import
scaled_fp8_quant
try
:
try
:
import
vllm
import
vllm
...
@@ -568,6 +559,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
...
@@ -568,6 +559,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
requires_grad
=
False
,
requires_grad
=
False
,
)
)
from
vllm
import
_custom_ops
as
vllm_ops
marlin_w13_qweight
=
vllm_ops
.
gptq_marlin_moe_repack
(
marlin_w13_qweight
=
vllm_ops
.
gptq_marlin_moe_repack
(
layer
.
w13_weight_packed
,
layer
.
w13_weight_packed
,
layer
.
w13_g_idx_sort_indices
,
layer
.
w13_g_idx_sort_indices
,
...
...
python/sglang/srt/layers/quantization/utils.py
View file @
a167fd0b
...
@@ -17,15 +17,6 @@ from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_np
...
@@ -17,15 +17,6 @@ from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_np
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
_is_cuda
=
is_cuda
()
_is_npu
=
is_npu
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu
=
is_cpu
()
_is_hip
=
is_hip
()
if
not
(
_is_cuda
or
_is_npu
or
(
_is_cpu
and
_is_cpu_amx_available
)
or
_is_hip
):
from
vllm._custom_ops
import
scaled_fp8_quant
def
is_layer_skipped
(
def
is_layer_skipped
(
prefix
:
str
,
prefix
:
str
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment