Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
46794958
Unverified
Commit
46794958
authored
Apr 21, 2026
by
Jhao-Ting Chen
Committed by
GitHub
Apr 22, 2026
Browse files
test: add nan/inf clamp regression test for fused_topk_bias (#40553)
Signed-off-by:
Jhao-Ting Chen
<
jhaotingc@nvidia.com
>
parent
6ff8dea0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
69 additions
and
0 deletions
+69
-0
tests/kernels/moe/test_fused_topk.py
tests/kernels/moe/test_fused_topk.py
+69
-0
No files found.
tests/kernels/moe/test_fused_topk.py
View file @
46794958
...
@@ -202,3 +202,72 @@ def test_fused_topk_nan_inf_clamp(
...
@@ -202,3 +202,72 @@ def test_fused_topk_nan_inf_clamp(
f
"Row
{
row
}
has non-finite weights
{
topk_weights
[
row
].
tolist
()
}
"
f
"Row
{
row
}
has non-finite weights
{
topk_weights
[
row
].
tolist
()
}
"
f
"(bad_value=
{
bad_value
}
, scoring_func=
{
scoring_func
}
)"
f
"(bad_value=
{
bad_value
}
, scoring_func=
{
scoring_func
}
)"
)
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"This test is skipped on non-CUDA platform."
)
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
6
,
8
,
16
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
3
,
4
])
@
pytest
.
mark
.
parametrize
(
"scoring_func"
,
[
"softmax"
,
"sigmoid"
])
@
pytest
.
mark
.
parametrize
(
"bad_value"
,
[
float
(
"nan"
),
float
(
"inf"
)])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
,
torch
.
half
,
torch
.
float32
])
def
test_fused_topk_bias_nan_inf_clamp
(
num_experts
:
int
,
topk
:
int
,
scoring_func
:
str
,
bad_value
:
float
,
dtype
:
torch
.
dtype
,
):
"""Regression test: NaN/Inf in gating logits must not produce duplicate
expert IDs or non-finite weights when e_score_correction_bias is present.
Same scenario as test_fused_topk_nan_inf_clamp but exercising the bias
path (fused_topk_bias) so the fix in topk_softmax_kernels.cu is covered
for that entry point as well.
"""
torch
.
manual_seed
(
0
)
num_tokens
=
4
hidden_size
=
1024
hidden_states
=
torch
.
randn
((
num_tokens
,
hidden_size
),
dtype
=
dtype
,
device
=
"cuda"
)
e_score_correction_bias
=
torch
.
randn
(
(
num_experts
,),
dtype
=
torch
.
float32
,
device
=
"cuda"
)
gating_output
=
torch
.
randn
((
num_tokens
,
num_experts
),
dtype
=
dtype
,
device
=
"cuda"
)
gating_output
[
1
:,
:]
=
bad_value
topk_weights
,
topk_ids
=
fused_topk_bias
(
hidden_states
=
hidden_states
,
gating_output
=
gating_output
,
e_score_correction_bias
=
e_score_correction_bias
,
topk
=
topk
,
renormalize
=
False
,
scoring_func
=
scoring_func
,
)
# Normal row must still match the torch reference.
ref_weights
,
ref_ids
=
torch_topk
(
gating_output
=
gating_output
[:
1
],
topk
=
topk
,
renormalize
=
False
,
e_score_correction_bias
=
e_score_correction_bias
,
scoring_func
=
scoring_func
,
)
torch
.
testing
.
assert_close
(
ref_weights
.
to
(
torch
.
float32
),
topk_weights
[:
1
],
atol
=
1e-2
,
rtol
=
1e-2
)
torch
.
testing
.
assert_close
(
ref_ids
.
to
(
torch
.
int32
),
topk_ids
[:
1
],
atol
=
0
,
rtol
=
0
)
# Poisoned rows: IDs must be unique (no duplicates) and weights must be
# finite (no NaN/Inf propagation into downstream MoE kernels).
for
row
in
range
(
1
,
num_tokens
):
row_ids
=
topk_ids
[
row
]
assert
row_ids
.
unique
().
numel
()
==
topk
,
(
f
"Row
{
row
}
has duplicate expert IDs
{
row_ids
.
tolist
()
}
"
f
"(bad_value=
{
bad_value
}
, scoring_func=
{
scoring_func
}
)"
)
assert
torch
.
isfinite
(
topk_weights
[
row
]).
all
(),
(
f
"Row
{
row
}
has non-finite weights
{
topk_weights
[
row
].
tolist
()
}
"
f
"(bad_value=
{
bad_value
}
, scoring_func=
{
scoring_func
}
)"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment