Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5f7fab88
Unverified
Commit
5f7fab88
authored
Apr 16, 2026
by
vllmellm
Committed by
GitHub
Apr 16, 2026
Browse files
[ROCm][FEAT] Integrate aiter gemm w8a8 ptpc (#33773)
Signed-off-by:
vllmellm
<
vllm.ellm@embeddedllm.com
>
parent
343f6523
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
279 additions
and
23 deletions
+279
-23
tests/utils.py
tests/utils.py
+1
-0
vllm/_aiter_ops.py
vllm/_aiter_ops.py
+97
-9
vllm/model_executor/kernels/linear/__init__.py
vllm/model_executor/kernels/linear/__init__.py
+9
-3
vllm/model_executor/kernels/linear/scaled_mm/aiter.py
vllm/model_executor/kernels/linear/scaled_mm/aiter.py
+160
-4
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+4
-5
vllm/model_executor/layers/quantization/fbgemm_fp8.py
vllm/model_executor/layers/quantization/fbgemm_fp8.py
+2
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+2
-2
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+2
-0
vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
...cutor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
+2
-0
No files found.
tests/utils.py
View file @
5f7fab88
...
@@ -1856,6 +1856,7 @@ class TestFP8Layer(torch.nn.Module):
...
@@ -1856,6 +1856,7 @@ class TestFP8Layer(torch.nn.Module):
out_dtype
=
out_dtype
,
out_dtype
=
out_dtype
,
force_kernel
=
force_kernel
,
force_kernel
=
force_kernel
,
)
)
self
.
kernel
.
process_weights_after_loading
(
self
)
def
is_quant_fp8_enabled
(
self
)
->
bool
:
def
is_quant_fp8_enabled
(
self
)
->
bool
:
return
self
.
kernel
.
quant_fp8
.
enabled
()
return
self
.
kernel
.
quant_fp8
.
enabled
()
...
...
vllm/_aiter_ops.py
View file @
5f7fab88
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
import
functools
import
functools
from
collections.abc
import
Callable
from
collections.abc
import
Callable
import
pandas
as
pd
import
torch
import
torch
from
torch._ops
import
OpOverload
from
torch._ops
import
OpOverload
...
@@ -56,6 +57,29 @@ def is_aiter_found_and_supported() -> bool:
...
@@ -56,6 +57,29 @@ def is_aiter_found_and_supported() -> bool:
return
False
return
False
@
functools
.
cache
def
_load_gemm_tuned_configs
(
q_dtype_w
:
torch
.
dtype
,
csv_path
:
str
)
->
set
[
tuple
[
int
,
int
,
int
]]:
try
:
df
=
pd
.
read_csv
(
csv_path
).
drop_duplicates
()
df
=
df
[
df
[
"q_dtype_w"
]
==
str
(
q_dtype_w
)]
return
set
(
zip
(
df
[
"N"
].
astype
(
int
),
df
[
"K"
].
astype
(
int
),
df
[
"M"
].
astype
(
int
)))
except
Exception
:
return
set
()
def
_check_kernel_tuned
(
N
:
int
,
K
:
int
,
q_dtype_w
:
torch
.
dtype
,
csv_path
:
str
)
->
bool
:
configs
=
_load_gemm_tuned_configs
(
q_dtype_w
,
csv_path
)
l_m
=
(
[
1
,
2
,
4
]
+
list
(
range
(
8
,
513
,
8
))
+
[
1024
,
1536
]
+
[
2
**
i
for
i
in
range
(
11
,
19
)]
)
return
any
((
N
,
K
,
M
)
in
configs
for
M
in
l_m
)
def
if_aiter_supported
(
func
:
Callable
)
->
Callable
:
def
if_aiter_supported
(
func
:
Callable
)
->
Callable
:
"""Decorator that only executes the function if
"""Decorator that only executes the function if
ROCm AITER package is supported and enabled on gfx9 archs.
ROCm AITER package is supported and enabled on gfx9 archs.
...
@@ -468,7 +492,7 @@ def _rocm_aiter_mla_decode_fwd_fake(
...
@@ -468,7 +492,7 @@ def _rocm_aiter_mla_decode_fwd_fake(
pass
pass
def
_rocm_aiter_
gemm_a8w8
_impl
(
def
_rocm_aiter_
w8a8_gemm
_impl
(
A
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
...
@@ -485,7 +509,7 @@ def _rocm_aiter_gemm_a8w8_impl(
...
@@ -485,7 +509,7 @@ def _rocm_aiter_gemm_a8w8_impl(
return
gemm_a8w8_CK
(
A
,
B
,
As
,
Bs
,
bias
,
output_dtype
)
return
gemm_a8w8_CK
(
A
,
B
,
As
,
Bs
,
bias
,
output_dtype
)
def
_rocm_aiter_
gemm_a8w8
_fake
(
def
_rocm_aiter_
w8a8_gemm
_fake
(
A
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
...
@@ -499,6 +523,35 @@ def _rocm_aiter_gemm_a8w8_fake(
...
@@ -499,6 +523,35 @@ def _rocm_aiter_gemm_a8w8_fake(
return
Y
return
Y
def
_rocm_aiter_preshuffled_per_token_w8a8_gemm_impl
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
from
aiter
import
gemm_a8w8_bpreshuffle
output
=
gemm_a8w8_bpreshuffle
(
A
,
B
,
As
,
Bs
,
None
,
output_dtype
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
return
output
def
_rocm_aiter_preshuffled_per_token_w8a8_gemm_fake
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
m
=
A
.
shape
[
0
]
n
=
B
.
shape
[
0
]
return
torch
.
empty
(
m
,
n
,
dtype
=
output_dtype
,
device
=
A
.
device
)
def
_rocm_aiter_triton_gemm_a8w8_blockscale_impl
(
def
_rocm_aiter_triton_gemm_a8w8_blockscale_impl
(
A
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
...
@@ -1313,11 +1366,15 @@ class rocm_aiter_ops:
...
@@ -1313,11 +1366,15 @@ class rocm_aiter_ops:
)
)
direct_register_custom_op
(
direct_register_custom_op
(
op_name
=
"rocm_aiter_gemm_a8w8"
,
op_name
=
"rocm_aiter_w8a8_gemm"
,
op_func
=
_rocm_aiter_gemm_a8w8_impl
,
op_func
=
_rocm_aiter_w8a8_gemm_impl
,
mutates_args
=
[],
fake_impl
=
_rocm_aiter_w8a8_gemm_fake
,
fake_impl
=
_rocm_aiter_gemm_a8w8_fake
,
)
dispatch_key
=
current_platform
.
dispatch_key
,
direct_register_custom_op
(
op_name
=
"_rocm_aiter_preshuffled_per_token_w8a8_gemm"
,
op_func
=
_rocm_aiter_preshuffled_per_token_w8a8_gemm_impl
,
fake_impl
=
_rocm_aiter_preshuffled_per_token_w8a8_gemm_fake
,
)
)
direct_register_custom_op
(
direct_register_custom_op
(
...
@@ -1493,7 +1550,18 @@ class rocm_aiter_ops:
...
@@ -1493,7 +1550,18 @@ class rocm_aiter_ops:
)
)
@
staticmethod
@
staticmethod
def
gemm_a8w8
(
def
w8a8_gemm
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
return
torch
.
ops
.
vllm
.
rocm_aiter_w8a8_gemm
(
A
,
B
,
As
,
Bs
,
bias
,
output_dtype
)
@
staticmethod
def
preshuffled_per_token_w8a8_gemm
(
A
:
torch
.
Tensor
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
...
@@ -1501,7 +1569,9 @@ class rocm_aiter_ops:
...
@@ -1501,7 +1569,9 @@ class rocm_aiter_ops:
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
return
torch
.
ops
.
vllm
.
rocm_aiter_gemm_a8w8
(
A
,
B
,
As
,
Bs
,
bias
,
output_dtype
)
return
torch
.
ops
.
vllm
.
_rocm_aiter_preshuffled_per_token_w8a8_gemm
(
A
,
B
,
As
,
Bs
,
bias
,
output_dtype
)
@
staticmethod
@
staticmethod
def
triton_gemm_a8w8_blockscale
(
def
triton_gemm_a8w8_blockscale
(
...
@@ -1920,6 +1990,24 @@ class rocm_aiter_ops:
...
@@ -1920,6 +1990,24 @@ class rocm_aiter_ops:
(
8192
,
3584
),
(
8192
,
3584
),
]
]
@
staticmethod
def
is_shuffled_per_token_w8a8_gemm_tuned
(
N
:
int
,
K
:
int
,
q_dtype_w
:
torch
.
dtype
)
->
bool
:
import
aiter.ops.gemm_op_a8w8
as
aiter_gemm_a8w8_ops
csv_path
=
(
aiter_gemm_a8w8_ops
.
AITER_CONFIGS
.
AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_FILE
)
return
_check_kernel_tuned
(
N
,
K
,
q_dtype_w
,
csv_path
)
@
staticmethod
def
is_per_token_w8a8_gemm_tuned
(
N
:
int
,
K
:
int
,
q_dtype_w
:
torch
.
dtype
)
->
bool
:
import
aiter.ops.gemm_op_a8w8
as
aiter_gemm_a8w8_ops
csv_path
=
aiter_gemm_a8w8_ops
.
AITER_CONFIGS
.
AITER_CONFIG_GEMM_A8W8_FILE
return
_check_kernel_tuned
(
N
,
K
,
q_dtype_w
,
csv_path
)
@
staticmethod
@
staticmethod
def
shuffle_weight
(
def
shuffle_weight
(
tensor
:
torch
.
Tensor
,
layout
:
tuple
[
int
,
int
]
=
(
16
,
16
)
tensor
:
torch
.
Tensor
,
layout
:
tuple
[
int
,
int
]
=
(
16
,
16
)
...
...
vllm/model_executor/kernels/linear/__init__.py
View file @
5f7fab88
...
@@ -106,6 +106,8 @@ from vllm.model_executor.kernels.linear.scaled_mm import (
...
@@ -106,6 +106,8 @@ from vllm.model_executor.kernels.linear.scaled_mm import (
from
vllm.model_executor.kernels.linear.scaled_mm.aiter
import
(
from
vllm.model_executor.kernels.linear.scaled_mm.aiter
import
(
AiterFp8BlockScaledMMKernel
,
AiterFp8BlockScaledMMKernel
,
AiterInt8ScaledMMLinearKernel
,
AiterInt8ScaledMMLinearKernel
,
AiterPerTokenFp8ScaledMMLinearKernel
,
AiterPreshuffledPerTokenFp8ScaledMMLinearKernel
,
)
)
from
vllm.model_executor.kernels.linear.scaled_mm.cpu
import
(
from
vllm.model_executor.kernels.linear.scaled_mm.cpu
import
(
CPUInt8ScaledMMLinearKernel
,
CPUInt8ScaledMMLinearKernel
,
...
@@ -165,6 +167,8 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
...
@@ -165,6 +167,8 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
ChannelWiseTorchFP8ScaledMMLinearKernel
,
ChannelWiseTorchFP8ScaledMMLinearKernel
,
],
],
PlatformEnum
.
ROCM
:
[
PlatformEnum
.
ROCM
:
[
AiterPreshuffledPerTokenFp8ScaledMMLinearKernel
,
AiterPerTokenFp8ScaledMMLinearKernel
,
ROCmFP8ScaledMMLinearKernel
,
ROCmFP8ScaledMMLinearKernel
,
PerTensorTorchFP8ScaledMMLinearKernel
,
PerTensorTorchFP8ScaledMMLinearKernel
,
RowWiseTorchFP8ScaledMMLinearKernel
,
RowWiseTorchFP8ScaledMMLinearKernel
,
...
@@ -360,18 +364,18 @@ def choose_scaled_mm_linear_kernel(
...
@@ -360,18 +364,18 @@ def choose_scaled_mm_linear_kernel(
def
init_fp8_linear_kernel
(
def
init_fp8_linear_kernel
(
activation_quant_key
:
QuantKey
,
activation_quant_key
:
QuantKey
,
weight_quant_key
:
QuantKey
,
weight_quant_key
:
QuantKey
,
weight_shape
:
tuple
[
int
,
int
],
input_dtype
:
torch
.
dtype
,
input_dtype
:
torch
.
dtype
,
out_dtype
:
torch
.
dtype
,
out_dtype
:
torch
.
dtype
,
force_kernel
:
type
[
_KernelT
]
|
None
=
None
,
weight_shape
:
tuple
[
int
,
int
],
force_kernel
:
type
[
FP8ScaledMMLinearKernel
]
|
None
=
None
,
module_name
:
str
|
None
=
None
,
module_name
:
str
|
None
=
None
,
)
->
FP8ScaledMMLinearKernel
|
Fp8BlockScaledMMLinearKernel
:
)
->
FP8ScaledMMLinearKernel
|
Fp8BlockScaledMMLinearKernel
:
scaled_mm_linear_kernel_config
=
FP8ScaledMMLinearLayerConfig
(
scaled_mm_linear_kernel_config
=
FP8ScaledMMLinearLayerConfig
(
weight_quant_key
=
weight_quant_key
,
weight_quant_key
=
weight_quant_key
,
activation_quant_key
=
activation_quant_key
,
activation_quant_key
=
activation_quant_key
,
weight_shape
=
weight_shape
,
input_dtype
=
input_dtype
,
input_dtype
=
input_dtype
,
out_dtype
=
out_dtype
,
out_dtype
=
out_dtype
,
weight_shape
=
weight_shape
,
)
)
if
activation_quant_key
.
scale
.
group_shape
.
is_per_group
():
if
activation_quant_key
.
scale
.
group_shape
.
is_per_group
():
...
@@ -725,6 +729,8 @@ __all__ = [
...
@@ -725,6 +729,8 @@ __all__ = [
"FP8ScaledMMLinearLayerConfig"
,
"FP8ScaledMMLinearLayerConfig"
,
"Int8ScaledMMLinearLayerConfig"
,
"Int8ScaledMMLinearLayerConfig"
,
"ScaledMMLinearLayerConfig"
,
"ScaledMMLinearLayerConfig"
,
"AiterPreshuffledPerTokenFp8ScaledMMLinearKernel"
,
"AiterPerTokenFp8ScaledMMLinearKernel"
,
"NvFp4LinearKernel"
,
"NvFp4LinearKernel"
,
"NvFp4LinearLayerConfig"
,
"NvFp4LinearLayerConfig"
,
"AiterInt8ScaledMMLinearKernel"
,
"AiterInt8ScaledMMLinearKernel"
,
...
...
vllm/model_executor/kernels/linear/scaled_mm/aiter.py
View file @
5f7fab88
...
@@ -5,18 +5,27 @@
...
@@ -5,18 +5,27 @@
import
torch
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm._aiter_ops
import
(
rocm_aiter_ops
,
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
GroupShape
,
)
)
from
vllm.model_executor.utils
import
replace_parameter
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
.BlockScaledMMLinearKernel
import
(
from
.BlockScaledMMLinearKernel
import
(
Fp8BlockScaledMMLinearKernel
,
Fp8BlockScaledMMLinearKernel
,
FP8ScaledMMLinearLayerConfig
,
)
)
from
.cutlass
import
CutlassInt8ScaledMMLinearKernel
from
.cutlass
import
CutlassInt8ScaledMMLinearKernel
from
.ScaledMMLinearKernel
import
Int8ScaledMMLinearLayerConfig
from
.ScaledMMLinearKernel
import
(
FP8ScaledMMLinearKernel
,
FP8ScaledMMLinearLayerConfig
,
Int8ScaledMMLinearLayerConfig
,
)
logger
=
init_logger
(
__name__
)
class
AiterInt8ScaledMMLinearKernel
(
CutlassInt8ScaledMMLinearKernel
):
class
AiterInt8ScaledMMLinearKernel
(
CutlassInt8ScaledMMLinearKernel
):
...
@@ -113,7 +122,154 @@ class AiterInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
...
@@ -113,7 +122,154 @@ class AiterInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
# a to be [M, K]
# a to be [M, K]
# b to be [N, K]
# b to be [N, K]
# CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
# CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return
rocm_aiter_ops
.
gemm_a8w8
(
x_q
,
w_q
.
t
(),
x_s
,
w_s
,
bias
,
out_dtype
)
return
rocm_aiter_ops
.
w8a8_gemm
(
x_q
,
w_q
.
t
(),
x_s
,
w_s
,
bias
,
out_dtype
)
class
AiterPreshuffledPerTokenFp8ScaledMMLinearKernel
(
FP8ScaledMMLinearKernel
):
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
if
not
current_platform
.
is_rocm
():
return
False
,
"requires ROCm."
if
not
rocm_aiter_ops
.
is_linear_fp8_enabled
():
return
(
False
,
"requires setting `VLLM_ROCM_USE_AITER=1` "
"and `VLLM_ROCM_USE_AITER_LINEAR=1`. "
"`VLLM_ROCM_USE_AITER_LINEAR` default is True."
,
)
try
:
import
aiter
# noqa: F401
except
Exception
:
return
False
,
"requires aiter library to be installed."
return
True
,
None
@
classmethod
def
can_implement
(
cls
,
c
:
FP8ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
is_ptpc
=
(
c
.
activation_quant_key
.
scale
.
group_shape
.
is_per_token
()
and
c
.
weight_quant_key
.
scale
.
group_shape
.
is_per_channel
()
)
if
c
.
weight_shape
is
None
:
return
False
,
"weight_shape is required for Aiter kernels"
N
,
K
=
c
.
weight_shape
fp8_dtype
=
current_platform
.
fp8_dtype
()
if
c
.
out_dtype
is
not
torch
.
bfloat16
:
return
False
,
"requires bfloat16 output dtype."
if
not
is_ptpc
:
return
(
False
,
"requires per token activation scales and per channel weight scales."
,
)
if
not
(
N
%
16
==
0
and
K
%
16
==
0
):
return
(
False
,
f
"requires N and K dimensions divisible by 16, received "
f
"N=
{
N
}
and K=
{
K
}
."
,
)
# Aiter's shuffled per-token Gemm performs better than torch only when its
# tuned.
if
not
rocm_aiter_ops
.
is_shuffled_per_token_w8a8_gemm_tuned
(
N
,
K
,
fp8_dtype
):
return
(
False
,
f
"requires a tuned configuration for N:
{
N
}
and K:
{
K
}
"
f
"and fp8 dtype
{
fp8_dtype
}
."
,
)
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
w_name
,
*
_
=
self
.
layer_param_names
w
,
*
_
=
self
.
_get_layer_params
(
layer
)
replace_parameter
(
layer
,
w_name
,
torch
.
nn
.
Parameter
(
rocm_aiter_ops
.
shuffle_weight
(
w
.
t
().
contiguous
()).
data
,
requires_grad
=
False
,
),
)
def
apply_scaled_mm
(
self
,
*
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
,
output_shape
:
list
,
)
->
torch
.
Tensor
:
return
rocm_aiter_ops
.
preshuffled_per_token_w8a8_gemm
(
A
,
B
,
As
,
Bs
,
bias
,
out_dtype
)
class
AiterPerTokenFp8ScaledMMLinearKernel
(
FP8ScaledMMLinearKernel
):
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
return
AiterPreshuffledPerTokenFp8ScaledMMLinearKernel
.
is_supported
(
compute_capability
)
@
classmethod
def
can_implement
(
cls
,
c
:
FP8ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
is_ptpc
=
(
c
.
activation_quant_key
.
scale
.
group_shape
.
is_per_token
()
and
c
.
weight_quant_key
.
scale
.
group_shape
.
is_per_channel
()
)
if
c
.
weight_shape
is
None
:
return
False
,
"weight_shape is required for Aiter kernels"
N
,
K
=
c
.
weight_shape
fp8_dtype
=
current_platform
.
fp8_dtype
()
if
not
is_ptpc
:
return
(
False
,
"requires per token activation scales and per channel weight scales."
,
)
# Aiter's per-token Gemm performs better than torch oonly when its
# tuned.
if
not
rocm_aiter_ops
.
is_per_token_w8a8_gemm_tuned
(
N
,
K
,
fp8_dtype
):
return
(
False
,
f
"requires a tuned configuration for N:
{
N
}
and K:
{
K
}
"
f
"and fp8 dtype
{
fp8_dtype
}
."
,
)
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
w_name
,
*
_
=
self
.
layer_param_names
w
,
*
_
=
self
.
_get_layer_params
(
layer
)
replace_parameter
(
layer
,
w_name
,
torch
.
nn
.
Parameter
(
w
.
t
(),
requires_grad
=
False
),
)
def
apply_scaled_mm
(
self
,
*
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
,
output_shape
:
list
,
)
->
torch
.
Tensor
:
return
rocm_aiter_ops
.
w8a8_gemm
(
A
,
B
,
As
,
Bs
,
bias
,
out_dtype
)
class
AiterFp8BlockScaledMMKernel
(
Fp8BlockScaledMMLinearKernel
):
class
AiterFp8BlockScaledMMKernel
(
Fp8BlockScaledMMLinearKernel
):
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
5f7fab88
...
@@ -31,8 +31,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
...
@@ -31,8 +31,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape
,
GroupShape
,
create_fp8_quant_key
,
create_fp8_quant_key
,
kFp8DynamicTokenSym
,
kFp8DynamicTokenSym
,
kFp8StaticChannelSym
,
kFp8StaticTensorSym
,
kFp8StaticTensorSym
,
kFp8StaticTokenSym
,
)
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
cutlass_block_fp8_supported
,
cutlass_block_fp8_supported
,
...
@@ -47,7 +47,7 @@ activation_quant_key_mapping = {
...
@@ -47,7 +47,7 @@ activation_quant_key_mapping = {
DYNAMIC_QUANT
:
kFp8DynamicTokenSym
,
DYNAMIC_QUANT
:
kFp8DynamicTokenSym
,
}
}
weight_quant_key_mapping
=
{
weight_quant_key_mapping
=
{
QuantizationStrategy
.
CHANNEL
:
kFp8Static
Token
Sym
,
QuantizationStrategy
.
CHANNEL
:
kFp8Static
Channel
Sym
,
QuantizationStrategy
.
TENSOR
:
kFp8StaticTensorSym
,
QuantizationStrategy
.
TENSOR
:
kFp8StaticTensorSym
,
}
}
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -67,7 +67,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -67,7 +67,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self
.
use_aiter_and_is_supported
=
rocm_aiter_ops
.
is_linear_fp8_enabled
()
self
.
use_aiter_and_is_supported
=
rocm_aiter_ops
.
is_linear_fp8_enabled
()
assert
not
self
.
is_static_input_scheme
assert
not
self
.
is_static_input_scheme
self
.
act_q_group_shape
=
GroupShape
(
1
,
self
.
weight_block_size
[
0
])
self
.
act_q_group_shape
=
GroupShape
(
1
,
self
.
weight_block_size
[
0
])
self
.
weight_quant_key
=
create_fp8_quant_key
(
self
.
weight_quant_key
=
create_fp8_quant_key
(
static
=
True
,
group_shape
=
GroupShape
(
*
self
.
weight_block_size
)
static
=
True
,
group_shape
=
GroupShape
(
*
self
.
weight_block_size
)
)
)
...
@@ -76,7 +75,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -76,7 +75,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
)
)
else
:
else
:
self
.
activation_quant_key
=
activation_quant_key_mapping
[
self
.
activation_quant_key
=
activation_quant_key_mapping
[
is_static_input_scheme
self
.
is_static_input_scheme
]
]
self
.
weight_quant_key
=
weight_quant_key_mapping
[
self
.
strategy
]
self
.
weight_quant_key
=
weight_quant_key_mapping
[
self
.
strategy
]
...
@@ -138,9 +137,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
...
@@ -138,9 +137,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self
.
fp8_linear
=
init_fp8_linear_kernel
(
self
.
fp8_linear
=
init_fp8_linear_kernel
(
activation_quant_key
=
self
.
activation_quant_key
,
activation_quant_key
=
self
.
activation_quant_key
,
weight_quant_key
=
self
.
weight_quant_key
,
weight_quant_key
=
self
.
weight_quant_key
,
weight_shape
=
layer
.
weight
.
shape
,
input_dtype
=
self
.
input_dtype
,
input_dtype
=
self
.
input_dtype
,
out_dtype
=
self
.
out_dtype
,
out_dtype
=
self
.
out_dtype
,
weight_shape
=
(
output_size_per_partition
,
input_size_per_partition
),
module_name
=
self
.
__class__
.
__name__
,
module_name
=
self
.
__class__
.
__name__
,
)
)
...
...
vllm/model_executor/layers/quantization/fbgemm_fp8.py
View file @
5f7fab88
...
@@ -175,6 +175,8 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
...
@@ -175,6 +175,8 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
# Activations not quantized for marlin.
# Activations not quantized for marlin.
del
layer
.
input_scale_ub
del
layer
.
input_scale_ub
self
.
fp8_linear
.
process_weights_after_loading
(
layer
)
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
5f7fab88
...
@@ -397,8 +397,6 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -397,8 +397,6 @@ class Fp8LinearMethod(LinearMethodBase):
if
self
.
block_quant
:
if
self
.
block_quant
:
assert
not
self
.
act_q_static
assert
not
self
.
act_q_static
self
.
fp8_linear
.
process_weights_after_loading
(
layer
)
# If checkpoint not serialized fp8, quantize the weights.
# If checkpoint not serialized fp8, quantize the weights.
else
:
else
:
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
...
@@ -428,6 +426,8 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -428,6 +426,8 @@ class Fp8LinearMethod(LinearMethodBase):
else
:
else
:
layer
.
input_scale
=
None
layer
.
input_scale
=
None
self
.
fp8_linear
.
process_weights_after_loading
(
layer
)
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
5f7fab88
...
@@ -517,6 +517,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
...
@@ -517,6 +517,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
max_w_scale
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
max_w_scale
,
requires_grad
=
False
)
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
self
.
fp8_linear
.
process_weights_after_loading
(
layer
)
def
apply
(
def
apply
(
self
,
self
,
...
@@ -597,6 +598,7 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
...
@@ -597,6 +598,7 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
weight
=
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
self
.
fp8_linear
.
process_weights_after_loading
(
layer
)
def
apply
(
def
apply
(
self
,
self
,
...
...
vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
View file @
5f7fab88
...
@@ -120,6 +120,8 @@ class QuarkW8A8Fp8(QuarkScheme):
...
@@ -120,6 +120,8 @@ class QuarkW8A8Fp8(QuarkScheme):
if
self
.
is_static_input_scheme
:
if
self
.
is_static_input_scheme
:
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
self
.
fp8_linear
.
process_weights_after_loading
(
layer
)
def
create_weights
(
def
create_weights
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment