Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
245e4f2c
Unverified
Commit
245e4f2c
authored
Oct 18, 2025
by
Wentao Ye
Committed by
GitHub
Oct 18, 2025
Browse files
[Feature] Batch Invariant: Support DeepGEMM and Blackwell (#27127)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
1d165d6d
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
71 additions
and
21 deletions
+71
-21
tests/v1/generation/test_batch_invariance.py
tests/v1/generation/test_batch_invariance.py
+8
-8
tests/v1/generation/test_rms_norm_batch_invariant.py
tests/v1/generation/test_rms_norm_batch_invariant.py
+9
-9
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+54
-4
No files found.
tests/v1/generation/test_batch_invariance.py
View file @
245e4f2c
...
@@ -10,9 +10,9 @@ import torch
...
@@ -10,9 +10,9 @@ import torch
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
hopper_only
=
pytest
.
mark
.
skipif
(
skip_unsupported
=
pytest
.
mark
.
skipif
(
not
(
current_platform
.
is_cuda
()
and
current_platform
.
i
s_device_capability
(
90
)),
not
(
current_platform
.
is_cuda
()
and
current_platform
.
ha
s_device_capability
(
90
)),
reason
=
"Requires CUDA and Hopper (SM90)"
,
reason
=
"Requires CUDA and
>=
Hopper (SM90)"
,
)
)
...
@@ -74,7 +74,7 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
...
@@ -74,7 +74,7 @@ def _random_prompt(min_words: int = 1024, max_words: int = 1024 * 2) -> str:
return
base_prompt
return
base_prompt
@
hopper_only
@
skip_unsupported
@
pytest
.
mark
.
timeout
(
1000
)
@
pytest
.
mark
.
timeout
(
1000
)
def
test_v1_generation_is_deterministic_across_batch_sizes_with_needle
():
def
test_v1_generation_is_deterministic_across_batch_sizes_with_needle
():
"""
"""
...
@@ -219,7 +219,7 @@ def _extract_step_logprobs(request_output):
...
@@ -219,7 +219,7 @@ def _extract_step_logprobs(request_output):
return
None
,
None
return
None
,
None
@
hopper_only
@
skip_unsupported
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
])
@
pytest
.
mark
.
forked
@
pytest
.
mark
.
forked
def
test_logprobs_bitwise_batch_invariance_bs1_vs_bsN
(
backend
):
def
test_logprobs_bitwise_batch_invariance_bs1_vs_bsN
(
backend
):
...
@@ -434,7 +434,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
...
@@ -434,7 +434,7 @@ def test_logprobs_bitwise_batch_invariance_bs1_vs_bsN(backend):
pytest
.
fail
(
msg
)
pytest
.
fail
(
msg
)
@
hopper_only
@
skip_unsupported
def
test_simple_generation
():
def
test_simple_generation
():
"""
"""
Simple test that runs the model with a basic prompt and prints the output.
Simple test that runs the model with a basic prompt and prints the output.
...
@@ -480,7 +480,7 @@ def test_simple_generation():
...
@@ -480,7 +480,7 @@ def test_simple_generation():
llm
.
shutdown
()
llm
.
shutdown
()
@
hopper_only
@
skip_unsupported
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
])
@
pytest
.
mark
.
forked
@
pytest
.
mark
.
forked
def
test_logprobs_WITHOUT_batch_invariance_should_FAIL
(
backend
):
def
test_logprobs_WITHOUT_batch_invariance_should_FAIL
(
backend
):
...
@@ -707,7 +707,7 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
...
@@ -707,7 +707,7 @@ def test_logprobs_WITHOUT_batch_invariance_should_FAIL(backend):
os
.
environ
[
"VLLM_BATCH_INVARIANT"
]
=
old_value
os
.
environ
[
"VLLM_BATCH_INVARIANT"
]
=
old_value
@
hopper_only
@
skip_unsupported
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
])
@
pytest
.
mark
.
forked
@
pytest
.
mark
.
forked
def
test_decode_logprobs_match_prefill_logprobs
(
backend
):
def
test_decode_logprobs_match_prefill_logprobs
(
backend
):
...
...
tests/v1/generation/test_rms_norm_batch_invariant.py
View file @
245e4f2c
...
@@ -14,13 +14,13 @@ from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_no
...
@@ -14,13 +14,13 @@ from vllm.model_executor.layers.batch_invariant import rms_norm as triton_rms_no
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
hopper_only
=
pytest
.
mark
.
skipif
(
skip_unsupported
=
pytest
.
mark
.
skipif
(
not
(
current_platform
.
is_cuda
()
and
current_platform
.
i
s_device_capability
(
90
)),
not
(
current_platform
.
is_cuda
()
and
current_platform
.
ha
s_device_capability
(
90
)),
reason
=
"Requires CUDA and Hopper (SM90)"
,
reason
=
"Requires CUDA and
>=
Hopper (SM90)"
,
)
)
@
hopper_only
@
skip_unsupported
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
,
16
,
64
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
4
,
16
,
64
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
512
,
2048
,
4096
,
8192
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
512
,
2048
,
4096
,
8192
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
...
@@ -69,7 +69,7 @@ def test_rms_norm_batch_invariant_vs_standard(
...
@@ -69,7 +69,7 @@ def test_rms_norm_batch_invariant_vs_standard(
)
)
@
hopper_only
@
skip_unsupported
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
16
,
128
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
16
,
128
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
1
,
32
,
512
])
@
pytest
.
mark
.
parametrize
(
"seq_len"
,
[
1
,
32
,
512
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
2048
,
4096
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
2048
,
4096
])
...
@@ -111,7 +111,7 @@ def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int):
...
@@ -111,7 +111,7 @@ def test_rms_norm_3d_input(batch_size: int, seq_len: int, hidden_size: int):
)
)
@
hopper_only
@
skip_unsupported
def
test_rms_norm_numerical_stability
():
def
test_rms_norm_numerical_stability
():
"""
"""
Test RMS norm numerical stability with extreme values.
Test RMS norm numerical stability with extreme values.
...
@@ -171,7 +171,7 @@ def test_rms_norm_numerical_stability():
...
@@ -171,7 +171,7 @@ def test_rms_norm_numerical_stability():
)
)
@
hopper_only
@
skip_unsupported
def
test_rms_norm_formula
():
def
test_rms_norm_formula
():
"""
"""
Test that RMS norm follows the correct mathematical formula.
Test that RMS norm follows the correct mathematical formula.
...
@@ -204,7 +204,7 @@ def test_rms_norm_formula():
...
@@ -204,7 +204,7 @@ def test_rms_norm_formula():
)
)
@
hopper_only
@
skip_unsupported
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
128
,
1024
,
4096
,
16384
])
@
pytest
.
mark
.
parametrize
(
"hidden_size"
,
[
128
,
1024
,
4096
,
16384
])
def
test_rms_norm_different_hidden_sizes
(
hidden_size
:
int
):
def
test_rms_norm_different_hidden_sizes
(
hidden_size
:
int
):
"""
"""
...
@@ -242,7 +242,7 @@ def test_rms_norm_different_hidden_sizes(hidden_size: int):
...
@@ -242,7 +242,7 @@ def test_rms_norm_different_hidden_sizes(hidden_size: int):
)
)
@
hopper_only
@
skip_unsupported
def
test_rms_norm_determinism
():
def
test_rms_norm_determinism
():
"""
"""
Test that batch-invariant RMS norm produces deterministic results.
Test that batch-invariant RMS norm produces deterministic results.
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
245e4f2c
...
@@ -41,6 +41,7 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -41,6 +41,7 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
,
QuantizationConfig
,
QuantizeMethodBase
,
QuantizeMethodBase
,
)
)
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
FlashinferMoeBackend
,
FlashinferMoeBackend
,
...
@@ -94,9 +95,11 @@ from vllm.platforms import current_platform
...
@@ -94,9 +95,11 @@ from vllm.platforms import current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
from
vllm.utils
import
has_deep_gemm
from
vllm.utils
import
has_deep_gemm
from
vllm.utils.deep_gemm
import
(
from
vllm.utils.deep_gemm
import
(
fp8_gemm_nt
,
get_col_major_tma_aligned_tensor
,
get_col_major_tma_aligned_tensor
,
is_deep_gemm_e8m0_used
,
is_deep_gemm_e8m0_used
,
is_deep_gemm_supported
,
is_deep_gemm_supported
,
should_use_deepgemm_for_fp8_linear
,
)
)
from
vllm.utils.flashinfer
import
has_flashinfer_moe
from
vllm.utils.flashinfer
import
has_flashinfer_moe
...
@@ -539,8 +542,34 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -539,8 +542,34 @@ class Fp8LinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# If batch invariant mode is enabled, dequantize and use BF16 compute
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
# we will use BF16 dequant when DeepGEMM is not supported.
if
vllm_is_batch_invariant
():
if
vllm_is_batch_invariant
():
if
self
.
block_quant
and
should_use_deepgemm_for_fp8_linear
(
torch
.
bfloat16
,
layer
.
weight
,
None
):
# use group quant consistent with block size across K
assert
self
.
act_q_group_shape
is
not
None
q_input
,
input_scale
=
QuantFP8
(
False
,
self
.
act_q_group_shape
,
column_major_scales
=
True
,
)(
x
)
output_2d
=
torch
.
empty
(
(
q_input
.
shape
[
0
],
layer
.
weight
.
shape
[
0
]),
dtype
=
torch
.
bfloat16
,
device
=
q_input
.
device
,
)
fp8_gemm_nt
(
(
q_input
,
input_scale
),
(
layer
.
weight
,
layer
.
weight_scale
),
output_2d
,
)
if
bias
is
not
None
:
output_2d
=
output_2d
+
bias
return
output_2d
# Dequantize FP8 weights to BF16
# Dequantize FP8 weights to BF16
weight_fp8
=
layer
.
weight
.
to
(
torch
.
bfloat16
)
weight_fp8
=
layer
.
weight
.
to
(
torch
.
bfloat16
)
weight_scale
=
layer
.
weight_scale
.
to
(
torch
.
bfloat16
)
weight_scale
=
layer
.
weight_scale
.
to
(
torch
.
bfloat16
)
...
@@ -555,8 +584,29 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -555,8 +584,29 @@ class Fp8LinearMethod(LinearMethodBase):
N
,
K
=
weight_fp8
.
shape
N
,
K
=
weight_fp8
.
shape
# Scale is stored transposed: [num_blocks_k, num_blocks_n]
# determine expected number of blocks along N and K
# We need to transpose it to [num_blocks_n, num_blocks_k] first
num_blocks_n
=
(
N
+
block_n
-
1
)
//
block_n
num_blocks_k
=
(
K
+
block_k
-
1
)
//
block_k
# scale layout may be [num_blocks_n, num_blocks_k]
# or [num_blocks_k, num_blocks_n] depending on backend
if
weight_scale
.
dim
()
!=
2
:
raise
RuntimeError
(
f
"FP8 block scale must be 2D, got
{
tuple
(
weight_scale
.
shape
)
}
"
)
scale_rows
,
scale_cols
=
weight_scale
.
shape
if
(
scale_rows
,
scale_cols
)
==
(
num_blocks_k
,
num_blocks_n
):
if
num_blocks_n
==
num_blocks_k
:
# ambiguous square case, warn and skip transpose
logger
.
warning
(
"Batch-invariant FP8: square block-scale %dx%d; "
"skipping transpose to avoid misorientation."
,
scale_rows
,
scale_cols
,
)
else
:
# clear KN -> transpose to NK
weight_scale
=
weight_scale
.
t
()
weight_scale
=
weight_scale
.
t
()
# Expand scale to match weight dimensions
# Expand scale to match weight dimensions
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment