Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
96590097
Commit
96590097
authored
Jan 20, 2026
by
lixh6
Committed by
wanglong3
Jan 21, 2026
Browse files
feat: support fp8-blockwise matmul impl.
parent
43155293
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
41 additions
and
7 deletions
+41
-7
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+41
-7
No files found.
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
96590097
...
...
@@ -26,6 +26,8 @@ from vllm.triton_utils import tl, triton
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils.deep_gemm
import
(
fp8_gemm_nt
,
is_deep_gemm_e8m0_used
,
should_use_deepgemm_for_fp8_linear
)
from
lmslim
import
quant_ops
from
lmslim.quantize.quant_ops
import
BlockSize
logger
=
init_logger
(
__name__
)
...
...
@@ -98,7 +100,7 @@ if current_platform.is_rocm():
def
dispatch_w8a8_blockscale_func
(
use_cutlass
:
bool
,
use_aiter_and_is_supported
:
bool
use_cutlass
:
bool
,
use_aiter_and_is_supported
:
bool
,
use_blaslt
:
bool
)
->
Callable
[[
torch
.
Tensor
,
torch
.
Tensor
,
...
...
@@ -111,6 +113,8 @@ def dispatch_w8a8_blockscale_func(
return
cutlass_scaled_mm
if
(
use_aiter_and_is_supported
):
return
torch
.
ops
.
vllm
.
rocm_aiter_gemm_w8a8_blockscale
if
use_blaslt
:
return
hipblaslt_w8a8_block_fp8_matmul
return
w8a8_block_fp8_matmul
...
...
@@ -129,6 +133,10 @@ def apply_w8a8_block_fp8_linear(
assert
input_scale
is
None
# View input as 2D matrix for fp8 methods
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[]
if
envs
.
VLLM_W8A8_BACKEND
==
3
:
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
-
1
]]
else
:
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
output_dtype
=
input
.
dtype
...
...
@@ -136,7 +144,6 @@ def apply_w8a8_block_fp8_linear(
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
...
...
@@ -149,9 +156,8 @@ def apply_w8a8_block_fp8_linear(
if
bias
is
not
None
:
output
+=
bias
return
output
.
to
(
dtype
=
output_dtype
).
view
(
*
output_shape
)
w8a8_blockscale_func
=
dispatch_w8a8_blockscale_func
(
cutlass_block_fp8_supported
,
use_aiter_and_is_supported
)
cutlass_block_fp8_supported
,
use_aiter_and_is_supported
,
envs
.
VLLM_W8A8_BACKEND
==
3
)
if
cutlass_block_fp8_supported
:
num_pad
=
0
if
current_platform
.
is_device_capability
(
90
):
...
...
@@ -195,6 +201,10 @@ def apply_w8a8_block_fp8_linear_fake(
cutlass_block_fp8_supported
:
bool
=
CUTLASS_BLOCK_FP8_SUPPORTED
,
use_aiter_and_is_supported
:
bool
=
False
,
)
->
torch
.
Tensor
:
output_shape
=
[]
if
envs
.
VLLM_W8A8_BACKEND
==
3
:
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
-
1
]]
else
:
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
return
torch
.
empty
(
output_shape
,
dtype
=
input
.
dtype
,
device
=
input
.
device
)
...
...
@@ -581,6 +591,26 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
)
return
None
def
hipblaslt_w8a8_block_fp8_matmul
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
block_size
:
list
[
int
],
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
assert
A
.
shape
[
1
]
==
B
.
shape
[
0
]
m
,
k
=
A
.
shape
_
,
n
=
B
.
shape
enum_block_size
=
BlockSize
.
block_128x128
if
block_size
[
0
]
==
64
:
enum_block_size
=
BlockSize
.
block_64x64
elif
block_size
[
0
]
==
128
:
enum_block_size
=
BlockSize
.
block_128x128
else
:
print
(
f
"[WARN] Unsupported block_size:
{
block_size
}
. Falling back to BlockSize.block_128x128"
)
_
,
d
=
quant_ops
.
hipblaslt_w8a8_blockwise_gemm
(
A
,
B
,
As
,
Bs
,
m
,
n
,
k
,
'NN'
,
output_dtype
,
enum_block_size
,
None
)
return
d
def
w8a8_block_fp8_matmul
(
A
:
torch
.
Tensor
,
...
...
@@ -898,6 +928,10 @@ def process_fp8_weight_block_strategy(
weight
,
weight_scale
,
_
=
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
weight_scale
)
if
envs
.
VLLM_W8A8_BACKEND
==
3
:
weight
=
weight
.
T
.
contiguous
()
weight_scale
=
weight_scale
.
T
.
contiguous
()
else
:
weight
=
_maybe_pad_fp8_weight
(
weight
)
return
weight
,
weight_scale
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment