Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2ed8b6b3
Unverified
Commit
2ed8b6b3
authored
Oct 16, 2025
by
Wentao Ye
Committed by
GitHub
Oct 16, 2025
Browse files
[Bug] Fix batch invariant test `has` to `is` (#27032)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
013abde6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
74 deletions
+21
-74
tests/v1/generation/test_batch_invariance.py
tests/v1/generation/test_batch_invariance.py
+10
-32
tests/v1/generation/test_rms_norm_batch_invariant.py
tests/v1/generation/test_rms_norm_batch_invariant.py
+11
-42
No files found.
tests/v1/generation/test_batch_invariance.py
View file @
2ed8b6b3
...
@@ -10,6 +10,11 @@ import torch
...
@@ -10,6 +10,11 @@ import torch
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
hopper_only
=
pytest
.
mark
.
skipif
(
not
(
current_platform
.
is_cuda
()
and
current_platform
.
is_device_capability
(
90
)),
reason
=
"Requires CUDA and Hopper (SM90)"
,
)
@
pytest
.
fixture
(
autouse
=
True
)
@
pytest
.
fixture
(
autouse
=
True
)
def
enable_batch_invariant_mode
():
def
enable_batch_invariant_mode
():
...
@@ -66,10 +71,7 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
...
@@ -66,10 +71,7 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
return
base_prompt
return
base_prompt
@
pytest
.
mark
.
skipif
(
@
hopper_only
not
current_platform
.
has_device_capability
(
90
),
reason
=
"Batch invariance tests only supported on Hopper (SM90)"
,
)
@
pytest
.
mark
.
timeout
(
1000
)
@
pytest
.
mark
.
timeout
(
1000
)
def
test_v1_generation_is_deterministic_across_batch_sizes_with_needle
():
def
test_v1_generation_is_deterministic_across_batch_sizes_with_needle
():
"""
"""
...
@@ -214,14 +216,7 @@ def _extract_step_logprobs(request_output):
...
@@ -214,14 +216,7 @@ def _extract_step_logprobs(request_output):
return
None
,
None
return
None
,
None
@
pytest
.
mark
.
skipif
(
@
hopper_only
not
current_platform
.
has_device_capability
(
90
),
reason
=
"Batch invariance tests only supported on Hopper (SM90)"
,
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"Requires CUDA to match production inference path."
,
)
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
])
@
pytest
.
mark
.
forked
@
pytest
.
mark
.
forked
def
test_logprobs_bitwise_batch_invariance_bs1_vs_bsN
(
backend
):
def
test_logprobs_bitwise_batch_invariance_bs1_vs_bsN
(
backend
):
...
@@ -436,10 +431,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
...
@@ -436,10 +431,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
pytest
.
fail
(
msg
)
pytest
.
fail
(
msg
)
@
pytest
.
mark
.
skipif
(
@
hopper_only
not
current_platform
.
has_device_capability
(
90
),
reason
=
"Batch invariance tests only supported on Hopper (SM90)"
,
)
def
test_simple_generation
():
def
test_simple_generation
():
"""
"""
Simple test that runs the model with a basic prompt and prints the output.
Simple test that runs the model with a basic prompt and prints the output.
...
@@ -485,14 +477,7 @@ def test_simple_generation():
...
@@ -485,14 +477,7 @@ def test_simple_generation():
llm
.
shutdown
()
llm
.
shutdown
()
@
pytest
.
mark
.
skipif
(
@
hopper_only
not
current_platform
.
has_device_capability
(
90
),
reason
=
"Batch invariance tests only supported on Hopper (SM90)"
,
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"Requires CUDA to match production inference path."
,
)
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
])
@
pytest
.
mark
.
forked
@
pytest
.
mark
.
forked
def
test_logprobs_WITHOUT_batch_invariance_should_FAIL
(
backend
):
def
test_logprobs_WITHOUT_batch_invariance_should_FAIL
(
backend
):
...
@@ -707,14 +692,7 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
...
@@ -707,14 +692,7 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
os
.
environ
[
"VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"
]
=
old_value
os
.
environ
[
"VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT"
]
=
old_value
@
pytest
.
mark
.
skipif
(
@
hopper_only
not
current_platform
.
has_device_capability
(
90
),
reason
=
"Batch invariance tests only supported on Hopper (SM90)"
,
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"Requires CUDA to match production inference path."
,
)
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
])
@
pytest
.
mark
.
forked
@
pytest
.
mark
.
forked
def
test_decode_logprobs_match_prefill_logprobs
(
backend
):
def
test_decode_logprobs_match_prefill_logprobs
(
backend
):
...
...
tests/v1/generation/test_rms_norm_batch_invariant.py
View file @
2ed8b6b3
...
@@ -14,14 +14,13 @@ from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_no
...
@@ -14,14 +14,13 @@ from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_no
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
hopper_only
=
pytest
.
mark
.
skipif
(
@
pytest
.
mark
.
skipif
(
not
(
current_platform
.
is_cuda
()
and
current_platform
.
is_device_capability
(
90
)),
not
current_platform
.
has_device_capability
(
90
),
reason
=
"Requires CUDA and Hopper (SM90)"
,
reason
=
"Batch invariance tests only supported on Hopper (SM90)"
,
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"Requires CUDA for RMS norm kernels"
)
)
@
hopper_only
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
,
16
,
64
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
,
16
,
64
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
512
,
2048
,
4096
,
8192
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
512
,
2048
,
4096
,
8192
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
...
@@ -70,13 +69,7 @@ def test_rms_norm_batch_invariant_vs_standard(
...
@@ -70,13 +69,7 @@ def test_rms_norm_batch_invariant_vs_standard(
)
)
@
pytest
.
mark
.
skipif
(
@
hopper_only
not
current_platform
.
has_device_capability
(
90
),
reason
=
"Batch invariance tests only supported on Hopper (SM90)"
,
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"Requires CUDA for RMS norm kernels"
)
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
16
,
128
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
16
,
128
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
1
,
32
,
512
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
1
,
32
,
512
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
2048
,
4096
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
2048
,
4096
])
...
@@ -118,13 +111,7 @@ def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int):
...
@@ -118,13 +111,7 @@ def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int):
)
)
@
pytest
.
mark
.
skipif
(
@
hopper_only
not
current_platform
.
has_device_capability
(
90
),
reason
=
"Batch invariance tests only supported on Hopper (SM90)"
,
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"Requires CUDA for RMS norm kernels"
)
def
test_rms_norm_numerical_stability
():
def
test_rms_norm_numerical_stability
():
"""
"""
Test RMS norm numerical stability with extreme values.
Test RMS norm numerical stability with extreme values.
...
@@ -184,13 +171,7 @@ def test_rms_norm_numerical_stability():
...
@@ -184,13 +171,7 @@ def test_rms_norm_numerical_stability():
)
)
@
pytest
.
mark
.
skipif
(
@
hopper_only
not
current_platform
.
has_device_capability
(
90
),
reason
=
"Batch invariance tests only supported on Hopper (SM90)"
,
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"Requires CUDA for RMS norm kernels"
)
def
test_rms_norm_formula
():
def
test_rms_norm_formula
():
"""
"""
Test that RMS norm follows the correct mathematical formula.
Test that RMS norm follows the correct mathematical formula.
...
@@ -223,13 +204,7 @@ def test_rms_norm_formula():
...
@@ -223,13 +204,7 @@ def test_rms_norm_formula():
)
)
@
pytest
.
mark
.
skipif
(
@
hopper_only
not
current_platform
.
has_device_capability
(
90
),
reason
=
"Batch invariance tests only supported on Hopper (SM90)"
,
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"Requires CUDA for RMS norm kernels"
)
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
128
,
1024
,
4096
,
16384
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
128
,
1024
,
4096
,
16384
])
def
test_rms_norm_different_hidden_sizes
(
hidden_size
:
int
):
def
test_rms_norm_different_hidden_sizes
(
hidden_size
:
int
):
"""
"""
...
@@ -267,13 +242,7 @@ def test_rms_norm_different_hidden_sizes(hidden_size: int):
...
@@ -267,13 +242,7 @@ def test_rms_norm_different_hidden_sizes(hidden_size: int):
)
)
@
pytest
.
mark
.
skipif
(
@
hopper_only
not
current_platform
.
has_device_capability
(
90
),
reason
=
"Batch invariance tests only supported on Hopper (SM90)"
,
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"Requires CUDA for RMS norm kernels"
)
def
test_rms_norm_determinism
():
def
test_rms_norm_determinism
():
"""
"""
Test that batch-invariant RMS norm produces deterministic results.
Test that batch-invariant RMS norm produces deterministic results.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment