Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
900f4720
Commit
900f4720
authored
Jan 17, 2026
by
wanglong3
Browse files
feat: Support w8a8-fp8 GEMM backend.
parent
9d16d5aa
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
49 additions
and
6 deletions
+49
-6
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+6
-2
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+43
-4
No files found.
vllm/model_executor/layers/quantization/fp8.py
View file @
900f4720
...
@@ -331,8 +331,12 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -331,8 +331,12 @@ class Fp8LinearMethod(LinearMethodBase):
weight
=
layer
.
weight
.
data
weight
=
layer
.
weight
.
data
weight_scale_inv
=
layer
.
weight_scale_inv
.
data
weight_scale_inv
=
layer
.
weight_scale_inv
.
data
weight
=
self
.
_maybe_pad_weight
(
weight
)
if
envs
.
VLLM_W8A8_BACKEND
==
3
:
weight
=
weight
.
T
.
contiguous
()
weight_scale_inv
=
weight_scale_inv
.
T
.
contiguous
()
else
:
weight
=
self
.
_maybe_pad_weight
(
weight
)
# Torch.compile cannot use Parameter subclasses.
# Torch.compile cannot use Parameter subclasses.
layer
.
weight
=
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
,
requires_grad
=
False
)
layer
.
weight_scale_inv
=
Parameter
(
weight_scale_inv
,
layer
.
weight_scale_inv
=
Parameter
(
weight_scale_inv
,
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
900f4720
...
@@ -6,6 +6,8 @@ import functools
...
@@ -6,6 +6,8 @@ import functools
import
json
import
json
import
os
import
os
from
typing
import
Any
,
Callable
,
Optional
,
Union
,
List
from
typing
import
Any
,
Callable
,
Optional
,
Union
,
List
from
lmslim
import
quant_ops
from
lmslim.quantize.quant_ops
import
BlockSize
import
torch
import
torch
...
@@ -83,7 +85,7 @@ if current_platform.is_rocm():
...
@@ -83,7 +85,7 @@ if current_platform.is_rocm():
def
dispatch_w8a8_blockscale_func
(
def
dispatch_w8a8_blockscale_func
(
use_cutlass
:
bool
,
use_aiter_and_is_supported
:
bool
use_cutlass
:
bool
,
use_aiter_and_is_supported
:
bool
,
use_blaslt
:
bool
)
->
Callable
[[
)
->
Callable
[[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
...
@@ -96,6 +98,9 @@ def dispatch_w8a8_blockscale_func(
...
@@ -96,6 +98,9 @@ def dispatch_w8a8_blockscale_func(
return
cutlass_scaled_mm
return
cutlass_scaled_mm
if
(
use_aiter_and_is_supported
):
if
(
use_aiter_and_is_supported
):
return
torch
.
ops
.
vllm
.
rocm_aiter_gemm_w8a8_blockscale
return
torch
.
ops
.
vllm
.
rocm_aiter_gemm_w8a8_blockscale
if
use_blaslt
:
return
hipblaslt_w8a8_block_fp8_matmul
return
w8a8_block_fp8_matmul
return
w8a8_block_fp8_matmul
...
@@ -127,7 +132,11 @@ def apply_w8a8_block_fp8_linear(
...
@@ -127,7 +132,11 @@ def apply_w8a8_block_fp8_linear(
assert
input_scale
is
None
assert
input_scale
is
None
# View input as 2D matrix for fp8 methods
# View input as 2D matrix for fp8 methods
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
input_2d
=
input
.
view
(
-
1
,
input
.
shape
[
-
1
])
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
output_shape
=
[]
if
envs
.
VLLM_W8A8_BACKEND
==
3
:
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
-
1
]]
else
:
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
output_dtype
=
input
.
dtype
output_dtype
=
input
.
dtype
if
should_use_deepgemm
(
output_dtype
,
weight
):
if
should_use_deepgemm
(
output_dtype
,
weight
):
...
@@ -166,9 +175,12 @@ def apply_w8a8_block_fp8_linear(
...
@@ -166,9 +175,12 @@ def apply_w8a8_block_fp8_linear(
weight
.
shape
[
0
]
%
128
==
0
and
weight
.
shape
[
1
]
%
128
==
0
)
weight
.
shape
[
0
]
%
128
==
0
and
weight
.
shape
[
1
]
%
128
==
0
)
else
:
else
:
use_cutlass
=
False
use_cutlass
=
False
use_blaslt
=
False
if
envs
.
VLLM_W8A8_BACKEND
==
3
:
use_blaslt
=
True
w8a8_blockscale_func
=
dispatch_w8a8_blockscale_func
(
w8a8_blockscale_func
=
dispatch_w8a8_blockscale_func
(
use_cutlass
,
use_aiter_and_is_supported
)
use_cutlass
,
use_aiter_and_is_supported
,
use_blaslt
)
if
use_cutlass
:
if
use_cutlass
:
q_input
,
x_scale
=
per_token_group_quant_fp8
(
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
column_major_scales
=
use_cutlass
)
input_2d
,
block_size
[
1
],
column_major_scales
=
use_cutlass
)
...
@@ -197,7 +209,11 @@ def apply_w8a8_block_fp8_linear_fake(
...
@@ -197,7 +209,11 @@ def apply_w8a8_block_fp8_linear_fake(
cutlass_block_fp8_supported
:
bool
=
CUTLASS_BLOCK_FP8_SUPPORTED
,
cutlass_block_fp8_supported
:
bool
=
CUTLASS_BLOCK_FP8_SUPPORTED
,
use_aiter_and_is_supported
:
bool
=
False
,
use_aiter_and_is_supported
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
output_shape
=
[]
if
envs
.
VLLM_W8A8_BACKEND
==
3
:
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
-
1
]]
else
:
output_shape
=
[
*
input
.
shape
[:
-
1
],
weight
.
shape
[
0
]]
return
torch
.
empty
(
output_shape
,
dtype
=
input
.
dtype
,
device
=
input
.
device
)
return
torch
.
empty
(
output_shape
,
dtype
=
input
.
dtype
,
device
=
input
.
device
)
...
@@ -566,6 +582,29 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
...
@@ -566,6 +582,29 @@ def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int,
return
None
return
None
def
hipblaslt_w8a8_block_fp8_matmul
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
block_size
:
list
[
int
],
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
m
,
k
=
A
.
shape
_
,
n
=
B
.
shape
enum_block_size
=
BlockSize
.
block_128x128
if
block_size
[
0
]
==
64
:
enum_block_size
=
BlockSize
.
block_64x64
elif
block_size
[
0
]
==
128
:
enum_block_size
=
BlockSize
.
block_128x128
else
:
print
(
f
"[WARN] Unsupported block_size:
{
block_size
}
. Falling back to BlockSize.block_128x128"
)
_
,
d
=
quant_ops
.
hipblaslt_w8a8_blockwise_gemm
(
A
,
B
,
As
,
Bs
,
m
,
n
,
k
,
'NN'
,
output_dtype
,
enum_block_size
,
None
)
return
d
def
w8a8_block_fp8_matmul
(
def
w8a8_block_fp8_matmul
(
A
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment