Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
53d7f1f6
Unverified
Commit
53d7f1f6
authored
Nov 25, 2025
by
Xin Yang
Committed by
GitHub
Nov 26, 2025
Browse files
[Kernel] Use pre-allocated output buffer for triton kernel fused_experts (#29219)
Signed-off-by:
Xin Yang
<
xyangx@amazon.com
>
parent
c5ee4303
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
73 additions
and
11 deletions
+73
-11
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
...l_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
+73
-11
No files found.
vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py
View file @
53d7f1f6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
torch
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
...
...
@@ -12,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.config import (
from
vllm.model_executor.layers.fused_moe.topk_weight_and_reduce
import
(
TopKWeightAndReduceNoOP
,
)
from
vllm.model_executor.layers.fused_moe.utils
import
_resize_cache
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.import_utils
import
has_triton_kernels
...
...
@@ -88,14 +90,17 @@ def triton_kernel_moe_forward(
gating_output
,
topk
,
sm_first
=
not
renormalize
)
output
=
torch
.
empty_like
(
hidden_states
)
return
triton_kernel_fused_experts
(
None
,
output
,
hidden_states
,
w1
,
w2
,
routing_data
,
gather_idx
,
scatter_idx
,
topk
=
topk
,
activation
=
activation
,
quant_config
=
quant_config
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
...
...
@@ -113,6 +118,7 @@ def triton_kernel_fused_experts(
routing_data
,
# RoutingData
gather_indx
,
# GatherIndx
scatter_indx
,
# ScatterIndx
topk
:
int
,
activation
:
str
=
"silu"
,
quant_config
:
FusedMoEQuantConfig
|
None
=
None
,
swiglu_alpha
:
float
=
1.702
,
...
...
@@ -120,6 +126,7 @@ def triton_kernel_fused_experts(
apply_router_weight_on_input
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
intermediate_cache
:
torch
.
Tensor
|
None
=
None
,
a1q_scale
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
if
quant_config
is
None
:
...
...
@@ -131,14 +138,30 @@ def triton_kernel_fused_experts(
assert
quant_config
.
w2_bias
is
None
or
quant_config
.
w2_bias
.
dtype
==
torch
.
float32
# Shape check, only check non-mxfp4
assert
hidden_states
.
ndim
==
2
assert
hidden_states
.
shape
[
-
1
]
==
w1
.
shape
[
-
2
]
assert
w2
.
shape
[
-
1
]
==
w1
.
shape
[
1
]
batch_dim
=
1
M
,
K
=
hidden_states
.
shape
[
-
2
:]
E
,
_
,
N
=
w1
.
shape
if
global_num_experts
==
-
1
:
global_num_experts
=
E
if
intermediate_cache
is
None
:
intermediate_cache
=
torch
.
empty
(
(
batch_dim
,
M
*
topk
,
N
//
2
),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
,
)
# Add batch_dim to output buffer because matmul_ogs expects 3D output
intermediate_cache
=
_resize_cache
(
intermediate_cache
,
(
batch_dim
,
M
*
topk
,
N
//
2
)
)
output_tensor
=
_resize_cache
(
output_tensor
,
(
batch_dim
,
M
,
K
))
act
=
FusedActivation
(
FnSpecs
(
"swiglu"
,
triton_kernels
.
swiglu
.
swiglu_fn
,
(
"alpha"
,
"limit"
)),
(
swiglu_alpha
,
swiglu_limit
),
...
...
@@ -146,7 +169,7 @@ def triton_kernel_fused_experts(
)
gammas
=
routing_data
.
gate_scal
if
routing_data
else
None
intermediate_cache1
=
matmul_ogs
(
matmul_ogs
(
hidden_states
,
w1
,
quant_config
.
w1_bias
,
...
...
@@ -155,10 +178,11 @@ def triton_kernel_fused_experts(
precision_config
=
quant_config
.
w1_precision
,
gammas
=
gammas
if
apply_router_weight_on_input
else
None
,
fused_activation
=
act
,
y
=
intermediate_cache
,
)
intermediate_cache3
=
matmul_ogs
(
intermediate_cache
1
,
matmul_ogs
(
intermediate_cache
.
view
(
M
*
topk
,
N
//
2
)
,
w2
,
quant_config
.
w2_bias
,
routing_data
,
...
...
@@ -167,7 +191,8 @@ def triton_kernel_fused_experts(
gammas
=
None
if
apply_router_weight_on_input
else
gammas
,
y
=
output_tensor
,
)
return
intermediate_cache3
output_tensor
=
output_tensor
.
view
(
M
,
K
)
return
output_tensor
def
make_routing_data
(
...
...
@@ -221,6 +246,42 @@ class BaseOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
def
supports_expert_map
(
self
)
->
bool
:
return
True
def
moe_problem_size
(
self
,
a1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
)
->
tuple
[
int
,
int
,
int
,
int
,
int
]:
"""
Extract the MoE problem size from the given tensor arguments:
- a: The hidden states, input to the MoE layer.
- w1: The first set of expert weights.
- w2: The second set of expert weights.
- topk_ids: The topk ids.
Note: extracting the problem shape from the weight and activation
tensors is not obvious. It needs to be done this way specifically
due to subtle issues with particular kernels, e.g. the int4 kernels
divide the trailing dimension by two, so it's not "correct" to
extract N or K from the trailing dimension of w1 or w2. Similarly,
some kernels transpose the weights, so this needs to be kept in mind.
Note: This implementation covers most cases. However, if experts
require a specialized implementation, like MarlinExperts, they are free
to override this function.
"""
assert
w1
.
dim
()
==
3
and
w2
.
dim
()
==
3
E
,
_
,
N
=
w1
.
size
()
K
=
a1
.
size
(
-
1
)
assert
a1
.
dim
()
==
2
assert
topk_ids
.
size
(
0
)
==
a1
.
size
(
0
),
f
"
{
topk_ids
.
size
(
0
)
}
!=
{
a1
.
size
(
0
)
}
"
M
=
a1
.
size
(
0
)
assert
topk_ids
.
dim
()
==
2
topk
=
topk_ids
.
size
(
1
)
return
E
,
M
,
N
,
K
,
topk
def
finalize_weight_and_reduce_impl
(
self
)
->
mk
.
TopKWeightAndReduce
:
# Weight application and reduction happens in the fused_experts kernel.
return
TopKWeightAndReduceNoOP
()
...
...
@@ -263,8 +324,8 @@ class OAITritonExperts(BaseOAITritonExperts):
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
tuple
[
int
,
...]]:
# workspace are allocated inside the kernel
workspace1
=
(
M
,
K
)
workspace2
=
(
0
,
0
)
workspace1
=
(
0
,
0
)
workspace2
=
(
M
*
topk
,
N
//
2
)
output
=
(
M
,
K
)
return
(
workspace1
,
workspace2
,
output
)
...
...
@@ -297,20 +358,21 @@ class OAITritonExperts(BaseOAITritonExperts):
topk_ids
,
topk_weights
,
local_num_experts
)
experts_output
=
triton_kernel_fused_experts
(
None
,
topk
=
topk_ids
.
size
(
1
)
triton_kernel_fused_experts
(
output
,
hidden_states
,
w1
,
w2
,
routing_data
,
gather_indx
,
scatter_indx
,
topk
=
topk
,
activation
=
activation
,
quant_config
=
self
.
quant_config
,
apply_router_weight_on_input
=
False
,
global_num_experts
=
local_num_experts
,
expert_map
=
None
,
# applied already
intermediate_cache
=
workspace2
,
a1q_scale
=
a1q_scale
,
)
output
.
copy_
(
experts_output
,
non_blocking
=
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment