Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
35d801f1
Unverified
Commit
35d801f1
authored
Nov 10, 2025
by
Wentao Ye
Committed by
GitHub
Nov 11, 2025
Browse files
[Feature] Refactor batch invariant fp8 DeepGEMM (#27606)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
0bf29fad
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
87 deletions
+11
-87
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+11
-87
No files found.
vllm/model_executor/layers/quantization/fp8.py
View file @
35d801f1
...
@@ -43,7 +43,6 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -43,7 +43,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig
,
QuantizationConfig
,
QuantizeMethodBase
,
QuantizeMethodBase
,
)
)
from
vllm.model_executor.layers.quantization.input_quant_fp8
import
QuantFP8
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
from
vllm.model_executor.layers.quantization.utils.flashinfer_utils
import
(
FlashinferMoeBackend
,
FlashinferMoeBackend
,
...
@@ -95,11 +94,9 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -95,11 +94,9 @@ from vllm.model_executor.utils import set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
from
vllm.utils.deep_gemm
import
(
from
vllm.utils.deep_gemm
import
(
fp8_gemm_nt
,
get_col_major_tma_aligned_tensor
,
get_col_major_tma_aligned_tensor
,
is_deep_gemm_e8m0_used
,
is_deep_gemm_e8m0_used
,
is_deep_gemm_supported
,
is_deep_gemm_supported
,
should_use_deepgemm_for_fp8_linear
,
)
)
from
vllm.utils.flashinfer
import
has_flashinfer_moe
from
vllm.utils.flashinfer
import
has_flashinfer_moe
from
vllm.utils.import_utils
import
has_deep_gemm
from
vllm.utils.import_utils
import
has_deep_gemm
...
@@ -554,83 +551,19 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -554,83 +551,19 @@ class Fp8LinearMethod(LinearMethodBase):
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
# if batch invariant mode is enabled, prefer DeepGEMM FP8 path
# we will use BF16 dequant when DeepGEMM is not supported.
# we will use BF16 dequant when DeepGEMM is not supported.
if
vllm_is_batch_invariant
():
if
vllm_is_batch_invariant
():
# Call is_deep_gemm_supported() ahead of time for torch.compile
# dynamo has trouble tracing through
if
self
.
block_quant
and
should_use_deepgemm_for_fp8_linear
(
torch
.
bfloat16
,
layer
.
weight
,
self
.
use_deep_gemm
):
# use group quant consistent with block size across K
assert
self
.
act_q_group_shape
is
not
None
q_input
,
input_scale
=
QuantFP8
(
False
,
self
.
act_q_group_shape
,
column_major_scales
=
True
,
)(
x
)
output_2d
=
torch
.
empty
(
(
q_input
.
shape
[
0
],
layer
.
weight
.
shape
[
0
]),
dtype
=
torch
.
bfloat16
,
device
=
q_input
.
device
,
)
fp8_gemm_nt
(
(
q_input
,
input_scale
),
(
layer
.
weight
,
layer
.
weight_scale
),
output_2d
,
)
if
bias
is
not
None
:
output_2d
=
output_2d
+
bias
return
output_2d
# Dequantize FP8 weights to BF16
weight_fp8
=
layer
.
weight
.
to
(
torch
.
bfloat16
)
weight_scale
=
layer
.
weight_scale
.
to
(
torch
.
bfloat16
)
# Handle different quantization granularities
if
self
.
block_quant
:
if
self
.
block_quant
:
# Block-wise quantization:
# - Weight is NOT transposed, shape is [N, K] (output_size, input_size)
# - Scale has shape [num_blocks_k, num_blocks_n] (TRANSPOSED!)
assert
self
.
weight_block_size
is
not
None
assert
self
.
weight_block_size
is
not
None
block_n
,
block_k
=
self
.
weight_block_size
# Note: order is [N, K]
return
self
.
w8a8_block_fp8_linear
.
apply
(
input
=
x
,
N
,
K
=
weight_fp8
.
shape
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
# determine expected number of blocks along N and K
input_scale
=
layer
.
input_scale
,
num_blocks_n
=
(
N
+
block_n
-
1
)
//
block_n
bias
=
bias
,
num_blocks_k
=
(
K
+
block_k
-
1
)
//
block_k
# scale layout may be [num_blocks_n, num_blocks_k]
# or [num_blocks_k, num_blocks_n] depending on backend
if
weight_scale
.
dim
()
!=
2
:
raise
RuntimeError
(
f
"FP8 block scale must be 2D, got
{
tuple
(
weight_scale
.
shape
)
}
"
)
scale_rows
,
scale_cols
=
weight_scale
.
shape
if
(
scale_rows
,
scale_cols
)
==
(
num_blocks_k
,
num_blocks_n
):
if
num_blocks_n
==
num_blocks_k
:
# ambiguous square case, warn and skip transpose
logger
.
warning
(
"Batch-invariant FP8: square block-scale %dx%d; "
"skipping transpose to avoid misorientation."
,
scale_rows
,
scale_cols
,
)
)
else
:
else
:
# clear KN -> transpose to NK
# per-tensor/channel: dequant to BF16 and run GEMM
weight_scale
=
weight_scale
.
t
()
weight_fp8
=
layer
.
weight
.
to
(
torch
.
bfloat16
)
weight_scale
=
layer
.
weight_scale
.
to
(
torch
.
bfloat16
)
# Expand scale to match weight dimensions
# scale_expanded should have shape [N, K]
scale_expanded
=
weight_scale
.
repeat_interleave
(
block_n
,
dim
=
0
).
repeat_interleave
(
block_k
,
dim
=
1
)
# Trim to exact weight size (in case of padding)
scale_expanded
=
scale_expanded
[:
N
,
:
K
]
weight_bf16
=
weight_fp8
*
scale_expanded
else
:
# Per-tensor quantization: weight IS transposed to [K, N]
# scale should be scalar or [1] or per-output-channel [N]
if
weight_scale
.
numel
()
==
1
:
if
weight_scale
.
numel
()
==
1
:
# Per-tensor: simple scalar multiplication
# Per-tensor: simple scalar multiplication
weight_bf16
=
weight_fp8
*
weight_scale
weight_bf16
=
weight_fp8
*
weight_scale
...
@@ -649,16 +582,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -649,16 +582,7 @@ class Fp8LinearMethod(LinearMethodBase):
else
:
else
:
# Fallback
# Fallback
weight_bf16
=
weight_fp8
*
weight_scale
weight_bf16
=
weight_fp8
*
weight_scale
return
torch
.
nn
.
functional
.
linear
(
x
,
weight_bf16
.
t
(),
bias
)
# For block quant, weight is [N, K], for per-tensor it's [K, N]
# F.linear expects weight to be [N, K], so:
if
self
.
block_quant
:
# Already in correct shape [N, K]
output
=
torch
.
nn
.
functional
.
linear
(
x
,
weight_bf16
,
bias
)
else
:
# Need to transpose back: [K, N] -> [N, K]
output
=
torch
.
nn
.
functional
.
linear
(
x
,
weight_bf16
.
t
(),
bias
)
return
output
if
self
.
use_marlin
:
if
self
.
use_marlin
:
return
apply_fp8_marlin_linear
(
return
apply_fp8_marlin_linear
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment