Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a167fd0b
Unverified
Commit
a167fd0b
authored
Jul 24, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Jul 24, 2025
Browse files
[code style] Clean dead triton kernel code in fused_moe and useless vllm_ops import (#8310)
parent
2f86f3ad
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
242 deletions
+27
-242
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+25
-224
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+2
-9
python/sglang/srt/layers/quantization/utils.py
python/sglang/srt/layers/quantization/utils.py
+0
-9
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
a167fd0b
...
...
@@ -53,9 +53,7 @@ elif _is_hip:
from
aiter
import
moe_sum
except
ImportError
:
raise
ImportError
(
"aiter is required when SGLANG_USE_AITER is set to True"
)
else
:
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm._custom_ops
import
scaled_fp8_quant
if
_is_cuda
or
_is_hip
:
from
sgl_kernel
import
moe_align_block_size
as
sgl_moe_align_block_size
...
...
@@ -63,9 +61,6 @@ if _is_cuda or _is_hip:
logger
=
logging
.
getLogger
(
__name__
)
padding_size
=
128
if
bool
(
int
(
os
.
getenv
(
"SGLANG_MOE_PADDING"
,
"0"
)))
else
0
enable_moe_align_block_size_triton
=
bool
(
int
(
os
.
getenv
(
"ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON"
,
"0"
))
)
@
triton
.
jit
...
...
@@ -533,190 +528,6 @@ def fused_moe_kernel(
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
@
triton
.
jit
def
moe_align_block_size_stage1
(
topk_ids_ptr
,
tokens_cnts_ptr
,
num_experts
:
tl
.
constexpr
,
numel
:
tl
.
constexpr
,
tokens_per_thread
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
start_idx
=
pid
*
tokens_per_thread
off_c
=
(
pid
+
1
)
*
num_experts
for
i
in
range
(
tokens_per_thread
):
if
start_idx
+
i
<
numel
:
idx
=
tl
.
load
(
topk_ids_ptr
+
start_idx
+
i
)
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
off_c
+
idx
)
tl
.
store
(
tokens_cnts_ptr
+
off_c
+
idx
,
token_cnt
+
1
)
@
triton
.
jit
def
moe_align_block_size_stage2
(
tokens_cnts_ptr
,
num_experts
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
last_cnt
=
0
for
i
in
range
(
1
,
num_experts
+
1
):
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
i
*
num_experts
+
pid
)
last_cnt
=
last_cnt
+
token_cnt
tl
.
store
(
tokens_cnts_ptr
+
i
*
num_experts
+
pid
,
last_cnt
)
@
triton
.
jit
def
moe_align_block_size_stage3
(
total_tokens_post_pad_ptr
,
tokens_cnts_ptr
,
cumsum_ptr
,
num_experts
:
tl
.
constexpr
,
block_size
:
tl
.
constexpr
,
):
last_cumsum
=
0
off_cnt
=
num_experts
*
num_experts
for
i
in
range
(
1
,
num_experts
+
1
):
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
off_cnt
+
i
-
1
)
last_cumsum
=
last_cumsum
+
tl
.
cdiv
(
token_cnt
,
block_size
)
*
block_size
tl
.
store
(
cumsum_ptr
+
i
,
last_cumsum
)
tl
.
store
(
total_tokens_post_pad_ptr
,
last_cumsum
)
@
triton
.
jit
def
moe_align_block_size_stage4
(
topk_ids_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
tokens_cnts_ptr
,
cumsum_ptr
,
num_experts
:
tl
.
constexpr
,
block_size
:
tl
.
constexpr
,
numel
:
tl
.
constexpr
,
tokens_per_thread
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
start_idx
=
tl
.
load
(
cumsum_ptr
+
pid
)
end_idx
=
tl
.
load
(
cumsum_ptr
+
pid
+
1
)
for
i
in
range
(
start_idx
,
end_idx
,
block_size
):
tl
.
store
(
expert_ids_ptr
+
i
//
block_size
,
pid
)
start_idx
=
pid
*
tokens_per_thread
off_t
=
pid
*
num_experts
for
i
in
range
(
start_idx
,
tl
.
minimum
(
start_idx
+
tokens_per_thread
,
numel
)):
expert_id
=
tl
.
load
(
topk_ids_ptr
+
i
)
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
off_t
+
expert_id
)
rank_post_pad
=
token_cnt
+
tl
.
load
(
cumsum_ptr
+
expert_id
)
tl
.
store
(
sorted_token_ids_ptr
+
rank_post_pad
,
i
)
tl
.
store
(
tokens_cnts_ptr
+
off_t
+
expert_id
,
token_cnt
+
1
)
def
moe_align_block_size_triton
(
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
block_size
:
int
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_pad
:
torch
.
Tensor
,
)
->
None
:
numel
=
topk_ids
.
numel
()
grid
=
(
num_experts
,)
tokens_cnts
=
torch
.
zeros
(
(
num_experts
+
1
,
num_experts
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
cumsum
=
torch
.
zeros
((
num_experts
+
1
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
tokens_per_thread
=
ceil_div
(
numel
,
num_experts
)
moe_align_block_size_stage1
[
grid
](
topk_ids
,
tokens_cnts
,
num_experts
,
numel
,
tokens_per_thread
,
)
moe_align_block_size_stage2
[
grid
](
tokens_cnts
,
num_experts
,
)
moe_align_block_size_stage3
[(
1
,)](
num_tokens_post_pad
,
tokens_cnts
,
cumsum
,
num_experts
,
block_size
,
)
moe_align_block_size_stage4
[
grid
](
topk_ids
,
sorted_token_ids
,
expert_ids
,
tokens_cnts
,
cumsum
,
num_experts
,
block_size
,
numel
,
tokens_per_thread
,
)
@
triton
.
jit
def
init_sorted_ids_and_cumsum_buffer_kernel
(
sorted_ids_ptr
,
cumsum_buffer_ptr
,
max_num_tokens_padded
,
topk_ids_numel
,
num_experts
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
ALIGNED_NUM_EXPERTS_P1
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
offsets
=
pid
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
sorted_ids_blocks
=
tl
.
cdiv
(
max_num_tokens_padded
,
BLOCK_SIZE
)
if
pid
<
sorted_ids_blocks
:
mask
=
offsets
<
max_num_tokens_padded
tl
.
store
(
sorted_ids_ptr
+
offsets
,
tl
.
full
((
BLOCK_SIZE
,),
topk_ids_numel
,
dtype
=
tl
.
int32
),
mask
=
mask
,
)
elif
pid
==
sorted_ids_blocks
:
offset_e
=
tl
.
arange
(
0
,
ALIGNED_NUM_EXPERTS_P1
)
mask_e
=
offset_e
<
num_experts
+
1
tl
.
store
(
cumsum_buffer_ptr
+
offset_e
,
tl
.
zeros
((
ALIGNED_NUM_EXPERTS_P1
,),
dtype
=
tl
.
int32
),
mask
=
mask_e
,
)
def
init_sorted_ids_and_cumsum_buffer
(
max_num_tokens_padded
:
int
,
topk_ids_numel
:
int
,
num_experts
:
int
,
device
=
"cuda"
):
sorted_ids
=
torch
.
empty
((
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
device
)
cumsum_buffer
=
torch
.
empty
((
num_experts
+
1
,),
dtype
=
torch
.
int32
,
device
=
device
)
BLOCK_SIZE
=
1024
sorted_ids_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
BLOCK_SIZE
)
grid
=
(
sorted_ids_blocks
+
1
,)
init_sorted_ids_and_cumsum_buffer_kernel
[
grid
](
sorted_ids
,
cumsum_buffer
,
max_num_tokens_padded
,
topk_ids_numel
,
num_experts
,
BLOCK_SIZE
,
next_power_of_2
(
num_experts
+
1
),
)
return
sorted_ids
,
cumsum_buffer
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
@@ -766,17 +577,7 @@ def moe_align_block_size(
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
if
enable_moe_align_block_size_triton
:
sorted_ids
.
fill_
(
topk_ids
.
numel
())
moe_align_block_size_triton
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
)
else
:
cumsum_buffer
=
torch
.
empty
(
(
num_experts
+
1
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
a167fd0b
...
...
@@ -28,15 +28,6 @@ if TYPE_CHECKING:
CompressedTensorsConfig
,
)
_is_cuda
=
is_cuda
()
_is_npu
=
is_npu
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu
=
is_cpu
()
_is_hip
=
is_hip
()
if
not
(
_is_cuda
or
_is_npu
or
(
_is_cpu
and
_is_cpu_amx_available
)
or
_is_hip
):
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm._custom_ops
import
scaled_fp8_quant
try
:
import
vllm
...
...
@@ -568,6 +559,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
requires_grad
=
False
,
)
from
vllm
import
_custom_ops
as
vllm_ops
marlin_w13_qweight
=
vllm_ops
.
gptq_marlin_moe_repack
(
layer
.
w13_weight_packed
,
layer
.
w13_g_idx_sort_indices
,
...
...
python/sglang/srt/layers/quantization/utils.py
View file @
a167fd0b
...
...
@@ -17,15 +17,6 @@ from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_np
if
TYPE_CHECKING
:
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
_is_cuda
=
is_cuda
()
_is_npu
=
is_npu
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu
=
is_cpu
()
_is_hip
=
is_hip
()
if
not
(
_is_cuda
or
_is_npu
or
(
_is_cpu
and
_is_cpu_amx_available
)
or
_is_hip
):
from
vllm._custom_ops
import
scaled_fp8_quant
def
is_layer_skipped
(
prefix
:
str
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment