Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
45232a45
Unverified
Commit
45232a45
authored
Apr 19, 2026
by
TJian
Committed by
GitHub
Apr 19, 2026
Browse files
[FEAT] [Perf] [Gemma4] Fused Gemma4 Routing Function Triton (#39083)
Signed-off-by:
tjtanaa
<
tunjian.tan@embeddedllm.com
>
parent
03ce1c6e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
180 additions
and
16 deletions
+180
-16
pyproject.toml
pyproject.toml
+1
-0
tests/kernels/moe/test_gemma4router.py
tests/kernels/moe/test_gemma4router.py
+57
-0
vllm/model_executor/models/gemma4.py
vllm/model_executor/models/gemma4.py
+122
-16
No files found.
pyproject.toml
View file @
45232a45
...
...
@@ -170,6 +170,7 @@ eles = "eles"
datas
=
"datas"
ser
=
"ser"
ure
=
"ure"
VALU
=
"VALU"
# Walsh-Hadamard Transform
wht
=
"wht"
WHT
=
"WHT"
...
...
tests/kernels/moe/test_gemma4router.py
0 → 100644
View file @
45232a45
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
from
vllm.model_executor.models.gemma4
import
(
gemma4_fused_routing_kernel_triton
,
gemma4_routing_function_torch
,
)
def
sort_by_id
(
w
,
ids
):
order
=
ids
.
argsort
(
dim
=-
1
)
return
w
.
gather
(
1
,
order
),
ids
.
gather
(
1
,
order
)
# Gemma4 MoE Model has context length of 250K
# the minus 1 is to ensure that edge cases are tested
@
pytest
.
mark
.
parametrize
(
"num_tokens"
,
[
1
,
2
,
2048
,
250000
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
128
])
# gemma4 moe experts
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
8
])
# gemma4 topk
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
half
,
torch
.
float32
])
def
test_gemma4_routing_kernel_triton
(
num_tokens
:
int
,
num_experts
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
):
torch
.
manual_seed
(
0
)
gating
=
torch
.
randn
(
num_tokens
,
num_experts
,
dtype
=
dtype
,
device
=
"cuda"
)
scales
=
torch
.
rand
(
num_experts
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
ref_w
,
ref_ids
=
gemma4_routing_function_torch
(
gating
,
topk
,
scales
)
tri_w
,
tri_ids
=
gemma4_fused_routing_kernel_triton
(
gating
,
topk
,
scales
)
# Sort by expert id — to remove tie-breaking differences
ref_ws
,
ref_is
=
sort_by_id
(
ref_w
,
ref_ids
)
tri_ws
,
tri_is
=
sort_by_id
(
tri_w
,
tri_ids
)
ids_match
=
(
ref_is
==
tri_is
).
all
().
item
()
weights_match
=
torch
.
allclose
(
ref_ws
,
tri_ws
,
atol
=
1e-2
,
rtol
=
1e-2
)
all_match
=
ids_match
and
weights_match
max_err
=
(
ref_ws
-
tri_ws
).
abs
().
max
().
item
()
print
(
f
"T=
{
num_tokens
:
5
d
}
E=
{
num_experts
:
4
d
}
K=
{
topk
}
"
f
"
{
str
(
dtype
).
split
(
'.'
)[
-
1
]:
7
s
}
ids=
{
ids_match
}
max_Δweight=
{
max_err
:.
2
e
}
"
)
if
not
all_match
:
bad
=
(
ref_is
!=
tri_is
).
any
(
dim
=-
1
).
nonzero
(
as_tuple
=
True
)[
0
]
if
len
(
bad
):
r
=
bad
[
0
].
item
()
print
(
f
" first bad row
{
r
}
: ref_ids=
{
ref_ids
[
r
].
tolist
()
}
"
f
"tri_ids=
{
tri_ids
[
r
].
tolist
()
}
"
)
assert
all_match
vllm/model_executor/models/gemma4.py
View file @
45232a45
...
...
@@ -57,7 +57,9 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader
,
maybe_remap_kv_scale_name
,
)
from
vllm.platforms
import
current_platform
from
vllm.sequence
import
IntermediateTensors
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.attention.backends.utils
import
KVSharingFastPrefillMetadata
from
.interfaces
import
(
...
...
@@ -79,6 +81,120 @@ from .utils import (
logger
=
init_logger
(
__name__
)
@
triton
.
jit
def
_gemma4_routing_kernel
(
gating_ptr
,
per_expert_scale_ptr
,
topk_weights_ptr
,
topk_ids_ptr
,
E
:
tl
.
constexpr
,
K
:
tl
.
constexpr
,
BLOCK_E
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
offs_e
=
tl
.
arange
(
0
,
BLOCK_E
)
valid
=
offs_e
<
E
logits
=
tl
.
load
(
gating_ptr
+
pid
*
E
+
offs_e
,
mask
=
valid
,
other
=-
float
(
"inf"
),
).
to
(
tl
.
float32
)
max_l
=
tl
.
max
(
logits
,
axis
=
0
)
# Float32 → ascending-sortable bijection
MIN32
=
-
2147483648
logit_bits
=
logits
.
to
(
tl
.
int32
,
bitcast
=
True
)
sign_b
=
logit_bits
>>
31
key
=
tl
.
where
(
sign_b
==
0
,
logit_bits
^
-
1
,
logit_bits
^
MIN32
)
key
=
tl
.
where
(
valid
,
key
,
0x7FFFFFFF
)
sk64
=
key
.
to
(
tl
.
int64
)
&
0x00000000FFFFFFFF
packed
=
(
sk64
<<
32
)
|
offs_e
.
to
(
tl
.
int64
)
sorted_p
=
tl
.
sort
(
packed
,
descending
=
False
)
# Vectorized extraction of ALL sorted elements — no K-loop, no cross-lane reductions
all_keys
=
((
sorted_p
>>
32
)
&
0x00000000FFFFFFFF
).
to
(
tl
.
int32
)
all_ids
=
(
sorted_p
&
0x00000000FFFFFFFF
).
to
(
tl
.
int32
)
# Inverse bijection: recover original logit bits
sign_k
=
all_keys
>>
31
all_bits
=
tl
.
where
(
sign_k
<
0
,
all_keys
^
-
1
,
all_keys
^
MIN32
)
all_logits
=
all_bits
.
to
(
tl
.
float32
,
bitcast
=
True
)
# Compute raw_exp for ALL BLOCK_E elements — vectorized, ~2 VALU clocks
all_raw_exp
=
tl
.
math
.
exp2
((
all_logits
-
max_l
)
*
1.4426950408889634
)
# Sum only top-K for renorm — ONE masked reduction
top_mask
=
offs_e
<
K
renorm_raw
=
tl
.
sum
(
tl
.
where
(
top_mask
,
all_raw_exp
,
0.0
),
axis
=
0
)
renorm_raw
=
tl
.
where
(
renorm_raw
>
0.0
,
renorm_raw
,
1.0
)
inv_renorm
=
1.0
/
renorm_raw
# Load scales for top-K only (masked gather; scale array is tiny → L1 cached)
all_scales
=
tl
.
load
(
per_expert_scale_ptr
+
all_ids
.
to
(
tl
.
int64
),
mask
=
top_mask
,
other
=
1.0
,
).
to
(
tl
.
float32
)
# Final weights: vectorized multiply (only top-K will be stored)
all_weights
=
(
all_raw_exp
*
inv_renorm
*
all_scales
).
to
(
tl
.
float32
)
# Write results with TWO masked stores — replaces K × 2 serial scalar stores
base_off
=
pid
*
K
+
offs_e
tl
.
store
(
topk_ids_ptr
+
base_off
,
all_ids
,
mask
=
top_mask
)
tl
.
store
(
topk_weights_ptr
+
base_off
,
all_weights
,
mask
=
top_mask
)
def
gemma4_fused_routing_kernel_triton
(
gating_output
:
torch
.
Tensor
,
topk
:
int
,
per_expert_scale
:
torch
.
Tensor
,
num_warps
:
int
=
1
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
gating_output
=
gating_output
.
contiguous
()
per_expert_scale
=
per_expert_scale
.
contiguous
()
T
,
E
=
gating_output
.
shape
weights
=
torch
.
empty
(
T
,
topk
,
dtype
=
torch
.
float32
,
device
=
gating_output
.
device
)
ids
=
torch
.
empty
(
T
,
topk
,
dtype
=
torch
.
int32
,
device
=
gating_output
.
device
)
BLOCK_E
=
triton
.
next_power_of_2
(
E
)
_gemma4_routing_kernel
[(
T
,)](
gating_output
,
per_expert_scale
,
weights
,
ids
,
E
,
topk
,
BLOCK_E
,
num_warps
=
num_warps
,
)
return
weights
,
ids
def
gemma4_routing_function_torch
(
gating_output
:
torch
.
Tensor
,
topk
:
int
,
per_expert_scale
:
torch
.
Tensor
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
_
,
topk_ids
=
torch
.
topk
(
gating_output
,
k
=
topk
,
dim
=-
1
)
router_probabilities
=
torch
.
nn
.
functional
.
softmax
(
gating_output
,
dim
=-
1
)
indicator
=
torch
.
nn
.
functional
.
one_hot
(
topk_ids
,
num_classes
=
gating_output
.
size
(
-
1
)
).
sum
(
dim
=-
2
)
gate_weights
=
indicator
*
router_probabilities
renorm_factor
=
torch
.
sum
(
gate_weights
,
dim
=-
1
,
keepdim
=
True
)
renorm_factor
=
torch
.
where
(
renorm_factor
>
0.0
,
renorm_factor
,
1.0
)
dispatch_weights
=
gate_weights
/
renorm_factor
topk_weights
=
dispatch_weights
.
gather
(
1
,
topk_ids
)
# Fold per_expert_scale into routing weights
expert_scales
=
per_expert_scale
[
topk_ids
].
to
(
topk_weights
.
dtype
)
topk_weights
=
topk_weights
*
expert_scales
return
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
def
_get_text_config
(
config
):
"""Dereference text_config if config is a nested Gemma4Config.
...
...
@@ -216,22 +332,12 @@ class Gemma4MoE(nn.Module):
topk
:
int
,
renormalize
:
bool
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
_
,
topk_ids
=
torch
.
topk
(
gating_output
,
k
=
topk
,
dim
=-
1
)
router_probabilities
=
torch
.
nn
.
functional
.
softmax
(
gating_output
,
dim
=-
1
)
indicator
=
torch
.
nn
.
functional
.
one_hot
(
topk_ids
,
num_classes
=
gating_output
.
size
(
-
1
)
).
sum
(
dim
=-
2
)
gate_weights
=
indicator
*
router_probabilities
renorm_factor
=
torch
.
sum
(
gate_weights
,
dim
=-
1
,
keepdim
=
True
)
renorm_factor
=
torch
.
where
(
renorm_factor
>
0.0
,
renorm_factor
,
1.0
)
dispatch_weights
=
gate_weights
/
renorm_factor
topk_weights
=
dispatch_weights
.
gather
(
1
,
topk_ids
)
if
current_platform
.
is_cuda_alike
()
or
current_platform
.
is_xpu
():
return
gemma4_fused_routing_kernel_triton
(
gating_output
,
topk
,
per_expert_scale
)
# Fold per_expert_scale into routing weights
expert_scales
=
per_expert_scale
[
topk_ids
].
to
(
topk_weights
.
dtype
)
topk_weights
=
topk_weights
*
expert_scales
return
topk_weights
.
to
(
torch
.
float32
),
topk_ids
.
to
(
torch
.
int32
)
return
gemma4_routing_function_torch
(
gating_output
,
topk
,
per_expert_scale
)
# FusedMoE experts with custom Gemma4 routing
self
.
experts
=
FusedMoE
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment