Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5f7fab88
Unverified
Commit
5f7fab88
authored
Apr 16, 2026
by
vllmellm
Committed by
GitHub
Apr 16, 2026
Browse files
[ROCm][FEAT] Integrate aiter gemm w8a8 ptpc (#33773)
Signed-off-by:
vllmellm
<
vllm.ellm@embeddedllm.com
>
parent
343f6523
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
279 additions
and
23 deletions
+279
-23
tests/utils.py
tests/utils.py
+1
-0
vllm/_aiter_ops.py
vllm/_aiter_ops.py
+97
-9
vllm/model_executor/kernels/linear/__init__.py
vllm/model_executor/kernels/linear/__init__.py
+9
-3
vllm/model_executor/kernels/linear/scaled_mm/aiter.py
vllm/model_executor/kernels/linear/scaled_mm/aiter.py
+160
-4
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+4
-5
vllm/model_executor/layers/quantization/fbgemm_fp8.py
vllm/model_executor/layers/quantization/fbgemm_fp8.py
+2
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+2
-2
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+2
-0
vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
...cutor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
+2
-0
No files found.
tests/utils.py
View file @
5f7fab88
...
...
@@ -1856,6 +1856,7 @@ class TestFP8Layer(torch.nn.Module):
out_dtype
=
out_dtype
,
force_kernel
=
force_kernel
,
)
self
.
kernel
.
process_weights_after_loading
(
self
)
def
is_quant_fp8_enabled
(
self
)
->
bool
:
return
self
.
kernel
.
quant_fp8
.
enabled
()
...
...
vllm/_aiter_ops.py
View file @
5f7fab88
...
...
@@ -3,6 +3,7 @@
import
functools
from
collections.abc
import
Callable
import
pandas
as
pd
import
torch
from
torch._ops
import
OpOverload
...
...
@@ -56,6 +57,29 @@ def is_aiter_found_and_supported() -> bool:
return
False
@
functools
.
cache
def
_load_gemm_tuned_configs
(
q_dtype_w
:
torch
.
dtype
,
csv_path
:
str
)
->
set
[
tuple
[
int
,
int
,
int
]]:
try
:
df
=
pd
.
read_csv
(
csv_path
).
drop_duplicates
()
df
=
df
[
df
[
"q_dtype_w"
]
==
str
(
q_dtype_w
)]
return
set
(
zip
(
df
[
"N"
].
astype
(
int
),
df
[
"K"
].
astype
(
int
),
df
[
"M"
].
astype
(
int
)))
except
Exception
:
return
set
()
def
_check_kernel_tuned
(
N
:
int
,
K
:
int
,
q_dtype_w
:
torch
.
dtype
,
csv_path
:
str
)
->
bool
:
configs
=
_load_gemm_tuned_configs
(
q_dtype_w
,
csv_path
)
l_m
=
(
[
1
,
2
,
4
]
+
list
(
range
(
8
,
513
,
8
))
+
[
1024
,
1536
]
+
[
2
**
i
for
i
in
range
(
11
,
19
)]
)
return
any
((
N
,
K
,
M
)
in
configs
for
M
in
l_m
)
def
if_aiter_supported
(
func
:
Callable
)
->
Callable
:
"""Decorator that only executes the function if
ROCm AITER package is supported and enabled on gfx9 archs.
...
...
@@ -468,7 +492,7 @@ def _rocm_aiter_mla_decode_fwd_fake(
pass
def
_rocm_aiter_
gemm_a8w8
_impl
(
def
_rocm_aiter_
w8a8_gemm
_impl
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
...
...
@@ -485,7 +509,7 @@ def _rocm_aiter_gemm_a8w8_impl(
return
gemm_a8w8_CK
(
A
,
B
,
As
,
Bs
,
bias
,
output_dtype
)
def
_rocm_aiter_
gemm_a8w8
_fake
(
def
_rocm_aiter_
w8a8_gemm
_fake
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
...
...
@@ -499,6 +523,35 @@ def _rocm_aiter_gemm_a8w8_fake(
return
Y
def
_rocm_aiter_preshuffled_per_token_w8a8_gemm_impl
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
from
aiter
import
gemm_a8w8_bpreshuffle
output
=
gemm_a8w8_bpreshuffle
(
A
,
B
,
As
,
Bs
,
None
,
output_dtype
)
if
bias
is
not
None
:
output
.
add_
(
bias
)
return
output
def
_rocm_aiter_preshuffled_per_token_w8a8_gemm_fake
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
m
=
A
.
shape
[
0
]
n
=
B
.
shape
[
0
]
return
torch
.
empty
(
m
,
n
,
dtype
=
output_dtype
,
device
=
A
.
device
)
def
_rocm_aiter_triton_gemm_a8w8_blockscale_impl
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
...
...
@@ -1313,11 +1366,15 @@ class rocm_aiter_ops:
)
direct_register_custom_op
(
op_name
=
"rocm_aiter_gemm_a8w8"
,
op_func
=
_rocm_aiter_gemm_a8w8_impl
,
mutates_args
=
[],
fake_impl
=
_rocm_aiter_gemm_a8w8_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
op_name
=
"rocm_aiter_w8a8_gemm"
,
op_func
=
_rocm_aiter_w8a8_gemm_impl
,
fake_impl
=
_rocm_aiter_w8a8_gemm_fake
,
)
direct_register_custom_op
(
op_name
=
"_rocm_aiter_preshuffled_per_token_w8a8_gemm"
,
op_func
=
_rocm_aiter_preshuffled_per_token_w8a8_gemm_impl
,
fake_impl
=
_rocm_aiter_preshuffled_per_token_w8a8_gemm_fake
,
)
direct_register_custom_op
(
...
...
@@ -1493,7 +1550,18 @@ class rocm_aiter_ops:
)
@
staticmethod
def
gemm_a8w8
(
def
w8a8_gemm
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
return
torch
.
ops
.
vllm
.
rocm_aiter_w8a8_gemm
(
A
,
B
,
As
,
Bs
,
bias
,
output_dtype
)
@
staticmethod
def
preshuffled_per_token_w8a8_gemm
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
...
...
@@ -1501,7 +1569,9 @@ class rocm_aiter_ops:
bias
:
torch
.
Tensor
|
None
=
None
,
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
return
torch
.
ops
.
vllm
.
rocm_aiter_gemm_a8w8
(
A
,
B
,
As
,
Bs
,
bias
,
output_dtype
)
return
torch
.
ops
.
vllm
.
_rocm_aiter_preshuffled_per_token_w8a8_gemm
(
A
,
B
,
As
,
Bs
,
bias
,
output_dtype
)
@
staticmethod
def
triton_gemm_a8w8_blockscale
(
...
...
@@ -1920,6 +1990,24 @@ class rocm_aiter_ops:
(
8192
,
3584
),
]
@
staticmethod
def
is_shuffled_per_token_w8a8_gemm_tuned
(
N
:
int
,
K
:
int
,
q_dtype_w
:
torch
.
dtype
)
->
bool
:
import
aiter.ops.gemm_op_a8w8
as
aiter_gemm_a8w8_ops
csv_path
=
(
aiter_gemm_a8w8_ops
.
AITER_CONFIGS
.
AITER_CONFIG_GEMM_A8W8_BPRESHUFFLE_FILE
)
return
_check_kernel_tuned
(
N
,
K
,
q_dtype_w
,
csv_path
)
@
staticmethod
def
is_per_token_w8a8_gemm_tuned
(
N
:
int
,
K
:
int
,
q_dtype_w
:
torch
.
dtype
)
->
bool
:
import
aiter.ops.gemm_op_a8w8
as
aiter_gemm_a8w8_ops
csv_path
=
aiter_gemm_a8w8_ops
.
AITER_CONFIGS
.
AITER_CONFIG_GEMM_A8W8_FILE
return
_check_kernel_tuned
(
N
,
K
,
q_dtype_w
,
csv_path
)
@
staticmethod
def
shuffle_weight
(
tensor
:
torch
.
Tensor
,
layout
:
tuple
[
int
,
int
]
=
(
16
,
16
)
...
...
vllm/model_executor/kernels/linear/__init__.py
View file @
5f7fab88
...
...
@@ -106,6 +106,8 @@ from vllm.model_executor.kernels.linear.scaled_mm import (
from
vllm.model_executor.kernels.linear.scaled_mm.aiter
import
(
AiterFp8BlockScaledMMKernel
,
AiterInt8ScaledMMLinearKernel
,
AiterPerTokenFp8ScaledMMLinearKernel
,
AiterPreshuffledPerTokenFp8ScaledMMLinearKernel
,
)
from
vllm.model_executor.kernels.linear.scaled_mm.cpu
import
(
CPUInt8ScaledMMLinearKernel
,
...
...
@@ -165,6 +167,8 @@ _POSSIBLE_FP8_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] =
ChannelWiseTorchFP8ScaledMMLinearKernel
,
],
PlatformEnum
.
ROCM
:
[
AiterPreshuffledPerTokenFp8ScaledMMLinearKernel
,
AiterPerTokenFp8ScaledMMLinearKernel
,
ROCmFP8ScaledMMLinearKernel
,
PerTensorTorchFP8ScaledMMLinearKernel
,
RowWiseTorchFP8ScaledMMLinearKernel
,
...
...
@@ -360,18 +364,18 @@ def choose_scaled_mm_linear_kernel(
def
init_fp8_linear_kernel
(
activation_quant_key
:
QuantKey
,
weight_quant_key
:
QuantKey
,
weight_shape
:
tuple
[
int
,
int
],
input_dtype
:
torch
.
dtype
,
out_dtype
:
torch
.
dtype
,
force_kernel
:
type
[
_KernelT
]
|
None
=
None
,
weight_shape
:
tuple
[
int
,
int
],
force_kernel
:
type
[
FP8ScaledMMLinearKernel
]
|
None
=
None
,
module_name
:
str
|
None
=
None
,
)
->
FP8ScaledMMLinearKernel
|
Fp8BlockScaledMMLinearKernel
:
scaled_mm_linear_kernel_config
=
FP8ScaledMMLinearLayerConfig
(
weight_quant_key
=
weight_quant_key
,
activation_quant_key
=
activation_quant_key
,
weight_shape
=
weight_shape
,
input_dtype
=
input_dtype
,
out_dtype
=
out_dtype
,
weight_shape
=
weight_shape
,
)
if
activation_quant_key
.
scale
.
group_shape
.
is_per_group
():
...
...
@@ -725,6 +729,8 @@ __all__ = [
"FP8ScaledMMLinearLayerConfig"
,
"Int8ScaledMMLinearLayerConfig"
,
"ScaledMMLinearLayerConfig"
,
"AiterPreshuffledPerTokenFp8ScaledMMLinearKernel"
,
"AiterPerTokenFp8ScaledMMLinearKernel"
,
"NvFp4LinearKernel"
,
"NvFp4LinearLayerConfig"
,
"AiterInt8ScaledMMLinearKernel"
,
...
...
vllm/model_executor/kernels/linear/scaled_mm/aiter.py
View file @
5f7fab88
...
...
@@ -5,18 +5,27 @@
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm._aiter_ops
import
(
rocm_aiter_ops
,
)
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
)
from
vllm.model_executor.utils
import
replace_parameter
from
vllm.platforms
import
current_platform
from
.BlockScaledMMLinearKernel
import
(
Fp8BlockScaledMMLinearKernel
,
FP8ScaledMMLinearLayerConfig
,
)
from
.cutlass
import
CutlassInt8ScaledMMLinearKernel
from
.ScaledMMLinearKernel
import
Int8ScaledMMLinearLayerConfig
from
.ScaledMMLinearKernel
import
(
FP8ScaledMMLinearKernel
,
FP8ScaledMMLinearLayerConfig
,
Int8ScaledMMLinearLayerConfig
,
)
logger
=
init_logger
(
__name__
)
class
AiterInt8ScaledMMLinearKernel
(
CutlassInt8ScaledMMLinearKernel
):
...
...
@@ -113,7 +122,154 @@ class AiterInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
# a to be [M, K]
# b to be [N, K]
# CutlassInt8ScaledMMLinearKernel prepare weight `w_q` in [K, N] format
return
rocm_aiter_ops
.
gemm_a8w8
(
x_q
,
w_q
.
t
(),
x_s
,
w_s
,
bias
,
out_dtype
)
return
rocm_aiter_ops
.
w8a8_gemm
(
x_q
,
w_q
.
t
(),
x_s
,
w_s
,
bias
,
out_dtype
)
class
AiterPreshuffledPerTokenFp8ScaledMMLinearKernel
(
FP8ScaledMMLinearKernel
):
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
if
not
current_platform
.
is_rocm
():
return
False
,
"requires ROCm."
if
not
rocm_aiter_ops
.
is_linear_fp8_enabled
():
return
(
False
,
"requires setting `VLLM_ROCM_USE_AITER=1` "
"and `VLLM_ROCM_USE_AITER_LINEAR=1`. "
"`VLLM_ROCM_USE_AITER_LINEAR` default is True."
,
)
try
:
import
aiter
# noqa: F401
except
Exception
:
return
False
,
"requires aiter library to be installed."
return
True
,
None
@
classmethod
def
can_implement
(
cls
,
c
:
FP8ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
is_ptpc
=
(
c
.
activation_quant_key
.
scale
.
group_shape
.
is_per_token
()
and
c
.
weight_quant_key
.
scale
.
group_shape
.
is_per_channel
()
)
if
c
.
weight_shape
is
None
:
return
False
,
"weight_shape is required for Aiter kernels"
N
,
K
=
c
.
weight_shape
fp8_dtype
=
current_platform
.
fp8_dtype
()
if
c
.
out_dtype
is
not
torch
.
bfloat16
:
return
False
,
"requires bfloat16 output dtype."
if
not
is_ptpc
:
return
(
False
,
"requires per token activation scales and per channel weight scales."
,
)
if
not
(
N
%
16
==
0
and
K
%
16
==
0
):
return
(
False
,
f
"requires N and K dimensions divisible by 16, received "
f
"N=
{
N
}
and K=
{
K
}
."
,
)
# Aiter's shuffled per-token Gemm performs better than torch only when its
# tuned.
if
not
rocm_aiter_ops
.
is_shuffled_per_token_w8a8_gemm_tuned
(
N
,
K
,
fp8_dtype
):
return
(
False
,
f
"requires a tuned configuration for N:
{
N
}
and K:
{
K
}
"
f
"and fp8 dtype
{
fp8_dtype
}
."
,
)
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
w_name
,
*
_
=
self
.
layer_param_names
w
,
*
_
=
self
.
_get_layer_params
(
layer
)
replace_parameter
(
layer
,
w_name
,
torch
.
nn
.
Parameter
(
rocm_aiter_ops
.
shuffle_weight
(
w
.
t
().
contiguous
()).
data
,
requires_grad
=
False
,
),
)
def
apply_scaled_mm
(
self
,
*
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
,
output_shape
:
list
,
)
->
torch
.
Tensor
:
return
rocm_aiter_ops
.
preshuffled_per_token_w8a8_gemm
(
A
,
B
,
As
,
Bs
,
bias
,
out_dtype
)
class
AiterPerTokenFp8ScaledMMLinearKernel
(
FP8ScaledMMLinearKernel
):
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
return
AiterPreshuffledPerTokenFp8ScaledMMLinearKernel
.
is_supported
(
compute_capability
)
@
classmethod
def
can_implement
(
cls
,
c
:
FP8ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
is_ptpc
=
(
c
.
activation_quant_key
.
scale
.
group_shape
.
is_per_token
()
and
c
.
weight_quant_key
.
scale
.
group_shape
.
is_per_channel
()
)
if
c
.
weight_shape
is
None
:
return
False
,
"weight_shape is required for Aiter kernels"
N
,
K
=
c
.
weight_shape
fp8_dtype
=
current_platform
.
fp8_dtype
()
if
not
is_ptpc
:
return
(
False
,
"requires per token activation scales and per channel weight scales."
,
)
# Aiter's per-token Gemm performs better than torch oonly when its
# tuned.
if
not
rocm_aiter_ops
.
is_per_token_w8a8_gemm_tuned
(
N
,
K
,
fp8_dtype
):
return
(
False
,
f
"requires a tuned configuration for N:
{
N
}
and K:
{
K
}
"
f
"and fp8 dtype
{
fp8_dtype
}
."
,
)
return
True
,
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
w_name
,
*
_
=
self
.
layer_param_names
w
,
*
_
=
self
.
_get_layer_params
(
layer
)
replace_parameter
(
layer
,
w_name
,
torch
.
nn
.
Parameter
(
w
.
t
(),
requires_grad
=
False
),
)
def
apply_scaled_mm
(
self
,
*
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
,
output_shape
:
list
,
)
->
torch
.
Tensor
:
return
rocm_aiter_ops
.
w8a8_gemm
(
A
,
B
,
As
,
Bs
,
bias
,
out_dtype
)
class
AiterFp8BlockScaledMMKernel
(
Fp8BlockScaledMMLinearKernel
):
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
5f7fab88
...
...
@@ -31,8 +31,8 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape
,
create_fp8_quant_key
,
kFp8DynamicTokenSym
,
kFp8StaticChannelSym
,
kFp8StaticTensorSym
,
kFp8StaticTokenSym
,
)
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
cutlass_block_fp8_supported
,
...
...
@@ -47,7 +47,7 @@ activation_quant_key_mapping = {
DYNAMIC_QUANT
:
kFp8DynamicTokenSym
,
}
weight_quant_key_mapping
=
{
QuantizationStrategy
.
CHANNEL
:
kFp8Static
Token
Sym
,
QuantizationStrategy
.
CHANNEL
:
kFp8Static
Channel
Sym
,
QuantizationStrategy
.
TENSOR
:
kFp8StaticTensorSym
,
}
logger
=
init_logger
(
__name__
)
...
...
@@ -67,7 +67,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self
.
use_aiter_and_is_supported
=
rocm_aiter_ops
.
is_linear_fp8_enabled
()
assert
not
self
.
is_static_input_scheme
self
.
act_q_group_shape
=
GroupShape
(
1
,
self
.
weight_block_size
[
0
])
self
.
weight_quant_key
=
create_fp8_quant_key
(
static
=
True
,
group_shape
=
GroupShape
(
*
self
.
weight_block_size
)
)
...
...
@@ -76,7 +75,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
)
else
:
self
.
activation_quant_key
=
activation_quant_key_mapping
[
is_static_input_scheme
self
.
is_static_input_scheme
]
self
.
weight_quant_key
=
weight_quant_key_mapping
[
self
.
strategy
]
...
...
@@ -138,9 +137,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
self
.
fp8_linear
=
init_fp8_linear_kernel
(
activation_quant_key
=
self
.
activation_quant_key
,
weight_quant_key
=
self
.
weight_quant_key
,
weight_shape
=
layer
.
weight
.
shape
,
input_dtype
=
self
.
input_dtype
,
out_dtype
=
self
.
out_dtype
,
weight_shape
=
(
output_size_per_partition
,
input_size_per_partition
),
module_name
=
self
.
__class__
.
__name__
,
)
...
...
vllm/model_executor/layers/quantization/fbgemm_fp8.py
View file @
5f7fab88
...
...
@@ -175,6 +175,8 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
# Activations not quantized for marlin.
del
layer
.
input_scale_ub
self
.
fp8_linear
.
process_weights_after_loading
(
layer
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
5f7fab88
...
...
@@ -397,8 +397,6 @@ class Fp8LinearMethod(LinearMethodBase):
if
self
.
block_quant
:
assert
not
self
.
act_q_static
self
.
fp8_linear
.
process_weights_after_loading
(
layer
)
# If checkpoint not serialized fp8, quantize the weights.
else
:
# If checkpoint is fp8 per-tensor, handle that there are N scales for N
...
...
@@ -428,6 +426,8 @@ class Fp8LinearMethod(LinearMethodBase):
else
:
layer
.
input_scale
=
None
self
.
fp8_linear
.
process_weights_after_loading
(
layer
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
5f7fab88
...
...
@@ -517,6 +517,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
max_w_scale
,
requires_grad
=
False
)
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
self
.
fp8_linear
.
process_weights_after_loading
(
layer
)
def
apply
(
self
,
...
...
@@ -597,6 +598,7 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
layer
.
weight
=
Parameter
(
layer
.
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight_scale
=
Parameter
(
layer
.
weight_scale
.
data
,
requires_grad
=
False
)
self
.
fp8_linear
.
process_weights_after_loading
(
layer
)
def
apply
(
self
,
...
...
vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py
View file @
5f7fab88
...
...
@@ -120,6 +120,8 @@ class QuarkW8A8Fp8(QuarkScheme):
if
self
.
is_static_input_scheme
:
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
self
.
fp8_linear
.
process_weights_after_loading
(
layer
)
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment