Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f9a40871
Unverified
Commit
f9a40871
authored
Nov 11, 2025
by
Michael Goin
Committed by
GitHub
Nov 11, 2025
Browse files
Remove weight_scale.T special case for SM90 Block FP8 CUTLASS kernel (#28431)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
287bbbeb
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
36 additions
and
36 deletions
+36
-36
benchmarks/kernels/bench_block_fp8_gemm.py
benchmarks/kernels/bench_block_fp8_gemm.py
+29
-14
csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
...8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
+2
-1
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+1
-1
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+1
-1
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+3
-19
No files found.
benchmarks/kernels/bench_block_fp8_gemm.py
View file @
f9a40871
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
# Disable DeepGEMM for this benchmark to use CUTLASS
os
.
environ
[
"VLLM_USE_DEEP_GEMM"
]
=
"0"
import
torch
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
apply_w8a8_block_fp8_linear
,
W8A8BlockFp8LinearOp
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
CUTLASS_BLOCK_FP8_SUPPORTED
,
...
...
@@ -39,13 +47,14 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
,
fp8_min
=
fp8_info
.
max
,
fp8_info
.
min
# Create random
FP8
tensor
s
# Create random
input
tensor
(bfloat16, will be quantized by W8A8BlockFp8LinearOp)
A_ref
=
(
torch
.
rand
(
M
,
K
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
# Create quantized weight tensor
B_ref
=
(
torch
.
rand
(
N
,
K
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
-
0.5
)
*
2
*
fp8_max
B
=
B_ref
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
).
to
(
torch
.
float8_e4m3fn
)
# Create scales
# Create
weight
scales
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
n_tiles
=
(
N
+
block_n
-
1
)
//
block_n
k_tiles
=
(
K
+
block_k
-
1
)
//
block_k
...
...
@@ -55,18 +64,24 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
*
factor_for_scale
)
#
SM90 CUTLASS requires row-major format for scales
if
use_cutlass
and
current_platform
.
is_device_capability
(
90
):
Bs
=
Bs
.
T
.
contiguous
()
#
Create W8A8BlockFp8LinearOp instance
weight_group_shape
=
GroupShape
(
block_n
,
block_k
)
act_quant_group_shape
=
GroupShape
(
1
,
block_k
)
# Per-token, per-group quantization
def
run
():
if
use_cutlass
:
return
apply_w8a8_block_fp8_linear
(
A_ref
,
B
,
block_size
,
Bs
,
cutlass_block_fp8_supported
=
True
linear_op
=
W8A8BlockFp8LinearOp
(
weight_group_shape
=
weight_group_shape
,
act_quant_group_shape
=
act_quant_group_shape
,
cutlass_block_fp8_supported
=
use_cutlass
,
use_aiter_and_is_supported
=
False
,
)
else
:
return
apply_w8a8_block_fp8_linear
(
A_ref
,
B
,
block_size
,
Bs
,
cutlass_block_fp8_supported
=
False
def
run
():
return
linear_op
.
apply
(
input
=
A_ref
,
weight
=
B
,
weight_scale
=
Bs
,
input_scale
=
None
,
bias
=
None
,
)
return
run
...
...
csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
View file @
f9a40871
...
...
@@ -48,7 +48,8 @@ struct cutlass_3x_gemm_fp8_blockwise {
using
ElementBlockScale
=
float
;
using
ScaleConfig
=
cutlass
::
detail
::
Sm90BlockwiseScaleConfig
<
ScaleGranularityM
,
ScaleGranularityN
,
ScaleGranularityK
>
;
ScaleGranularityM
,
ScaleGranularityN
,
ScaleGranularityK
,
cute
::
GMMA
::
Major
::
MN
,
cute
::
GMMA
::
Major
::
K
>
;
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
f9a40871
...
...
@@ -173,7 +173,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
layer
.
input_scale
=
None
if
self
.
strategy
==
QuantizationStrategy
.
BLOCK
:
maybe_post_process_fp8_weight_block
(
layer
,
self
.
cutlass_block_fp8_supported
)
maybe_post_process_fp8_weight_block
(
layer
)
def
apply_weights
(
self
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
f9a40871
...
...
@@ -540,7 +540,7 @@ class Fp8LinearMethod(LinearMethodBase):
return
if
self
.
block_quant
:
maybe_post_process_fp8_weight_block
(
layer
,
self
.
cutlass_block_fp8_supported
)
maybe_post_process_fp8_weight_block
(
layer
)
def
apply
(
self
,
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
f9a40871
...
...
@@ -55,17 +55,13 @@ def cutlass_scaled_mm(
Bs
:
torch
.
Tensor
,
block_size
:
list
[
int
],
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
is_hopper
:
bool
|
None
=
None
,
)
->
torch
.
Tensor
:
if
is_hopper
is
None
:
is_hopper
=
current_platform
.
is_device_capability
(
90
)
return
ops
.
cutlass_scaled_mm
(
A
,
B
.
T
,
out_dtype
=
output_dtype
,
scale_a
=
As
,
# SM90 block FP8 requires row-major scale_b, which we do ahead of time
scale_b
=
Bs
if
block_size
is
not
None
and
is_hopper
else
Bs
.
T
,
scale_b
=
Bs
.
T
,
)
...
...
@@ -130,7 +126,7 @@ def _padded_cutlass(
padded_x_scale
[
0
:
x_scale
.
shape
[
0
],
...].
copy_
(
x_scale
)
output
=
cutlass_scaled_mm
(
padded_qx
,
weight
,
padded_x_scale
,
weight_scale
,
block_size
,
output_dtype
,
True
padded_qx
,
weight
,
padded_x_scale
,
weight_scale
,
block_size
,
output_dtype
)
return
output
[
0
:
qx
.
shape
[
0
],
...]
...
...
@@ -303,7 +299,6 @@ class W8A8BlockFp8LinearOp:
weight_scale
,
list
(
self
.
weight_group_shape
),
input_2d
.
dtype
,
False
,
)
def
_run_aiter
(
...
...
@@ -1125,9 +1120,7 @@ def process_fp8_weight_block_strategy(
return
weight
,
weight_scale
def
maybe_post_process_fp8_weight_block
(
layer
:
torch
.
nn
.
Module
,
cutlass_block_fp8_supported
:
bool
):
def
maybe_post_process_fp8_weight_block
(
layer
:
torch
.
nn
.
Module
):
assert
layer
.
weight_block_size
is
not
None
from
vllm.utils.deep_gemm
import
(
...
...
@@ -1146,15 +1139,6 @@ def maybe_post_process_fp8_weight_block(
requant_weight_ue8m0_inplace
(
layer
.
weight
.
data
,
layer
.
weight_scale
.
data
,
block_sz
)
# SM90 Block FP8 CUTLASS requires row-major weight scales
elif
(
current_platform
.
is_device_capability
(
90
)
and
cutlass_block_fp8_supported
and
not
should_use_deepgemm
):
layer
.
weight_scale
=
torch
.
nn
.
Parameter
(
layer
.
weight_scale
.
data
.
T
.
contiguous
(),
requires_grad
=
False
)
def
expert_weight_is_col_major
(
x
:
torch
.
Tensor
)
->
bool
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment