Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
27d5ee3e
Unverified
Commit
27d5ee3e
authored
Mar 23, 2026
by
Kunshang Ji
Committed by
GitHub
Mar 23, 2026
Browse files
[FP8]add FP8 WoQ kernel abstraction. (#32929)
Signed-off-by:
Zhu, Zufang
<
zufang.zhu@intel.com
>
parent
35141a7e
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
177 additions
and
83 deletions
+177
-83
vllm/model_executor/kernels/linear/__init__.py
vllm/model_executor/kernels/linear/__init__.py
+4
-0
vllm/model_executor/kernels/linear/scaled_mm/__init__.py
vllm/model_executor/kernels/linear/scaled_mm/__init__.py
+4
-0
vllm/model_executor/kernels/linear/scaled_mm/marlin.py
vllm/model_executor/kernels/linear/scaled_mm/marlin.py
+120
-0
vllm/model_executor/layers/quantization/fbgemm_fp8.py
vllm/model_executor/layers/quantization/fbgemm_fp8.py
+0
-12
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+49
-71
No files found.
vllm/model_executor/kernels/linear/__init__.py
View file @
27d5ee3e
...
@@ -72,6 +72,9 @@ from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
...
@@ -72,6 +72,9 @@ from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
from
vllm.model_executor.kernels.linear.scaled_mm.flashinfer
import
(
from
vllm.model_executor.kernels.linear.scaled_mm.flashinfer
import
(
FlashInferFP8ScaledMMLinearKernel
,
FlashInferFP8ScaledMMLinearKernel
,
)
)
from
vllm.model_executor.kernels.linear.scaled_mm.marlin
import
(
MarlinFP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.kernels.linear.scaled_mm.pytorch
import
(
from
vllm.model_executor.kernels.linear.scaled_mm.pytorch
import
(
ChannelWiseTorchFP8ScaledMMLinearKernel
,
ChannelWiseTorchFP8ScaledMMLinearKernel
,
PerTensorTorchFP8ScaledMMLinearKernel
,
PerTensorTorchFP8ScaledMMLinearKernel
,
...
@@ -104,6 +107,7 @@ _POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[Int8ScaledMMLinearKernel]]]
...
@@ -104,6 +107,7 @@ _POSSIBLE_INT8_KERNELS: dict[PlatformEnum, list[type[Int8ScaledMMLinearKernel]]]
# in priority/performance order (when available)
# in priority/performance order (when available)
_POSSIBLE_FP8_KERNELS
:
dict
[
PlatformEnum
,
list
[
type
[
FP8ScaledMMLinearKernel
]]]
=
{
_POSSIBLE_FP8_KERNELS
:
dict
[
PlatformEnum
,
list
[
type
[
FP8ScaledMMLinearKernel
]]]
=
{
PlatformEnum
.
CUDA
:
[
PlatformEnum
.
CUDA
:
[
MarlinFP8ScaledMMLinearKernel
,
FlashInferFP8ScaledMMLinearKernel
,
FlashInferFP8ScaledMMLinearKernel
,
CutlassFP8ScaledMMLinearKernel
,
CutlassFP8ScaledMMLinearKernel
,
PerTensorTorchFP8ScaledMMLinearKernel
,
PerTensorTorchFP8ScaledMMLinearKernel
,
...
...
vllm/model_executor/kernels/linear/scaled_mm/__init__.py
View file @
27d5ee3e
...
@@ -14,6 +14,9 @@ from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
...
@@ -14,6 +14,9 @@ from vllm.model_executor.kernels.linear.scaled_mm.cutlass import (
from
vllm.model_executor.kernels.linear.scaled_mm.flashinfer
import
(
from
vllm.model_executor.kernels.linear.scaled_mm.flashinfer
import
(
FlashInferFP8ScaledMMLinearKernel
,
FlashInferFP8ScaledMMLinearKernel
,
)
)
from
vllm.model_executor.kernels.linear.scaled_mm.marlin
import
(
MarlinFP8ScaledMMLinearKernel
,
)
from
vllm.model_executor.kernels.linear.scaled_mm.pytorch
import
(
from
vllm.model_executor.kernels.linear.scaled_mm.pytorch
import
(
ChannelWiseTorchFP8ScaledMMLinearKernel
,
ChannelWiseTorchFP8ScaledMMLinearKernel
,
PerTensorTorchFP8ScaledMMLinearKernel
,
PerTensorTorchFP8ScaledMMLinearKernel
,
...
@@ -46,6 +49,7 @@ __all__ = [
...
@@ -46,6 +49,7 @@ __all__ = [
"CutlassFP8ScaledMMLinearKernel"
,
"CutlassFP8ScaledMMLinearKernel"
,
"CutlassInt8ScaledMMLinearKernel"
,
"CutlassInt8ScaledMMLinearKernel"
,
"FlashInferFP8ScaledMMLinearKernel"
,
"FlashInferFP8ScaledMMLinearKernel"
,
"MarlinFP8ScaledMMLinearKernel"
,
"ChannelWiseTorchFP8ScaledMMLinearKernel"
,
"ChannelWiseTorchFP8ScaledMMLinearKernel"
,
"PerTensorTorchFP8ScaledMMLinearKernel"
,
"PerTensorTorchFP8ScaledMMLinearKernel"
,
"RowWiseTorchFP8ScaledMMLinearKernel"
,
"RowWiseTorchFP8ScaledMMLinearKernel"
,
...
...
vllm/model_executor/kernels/linear/scaled_mm/marlin.py
0 → 100644
View file @
27d5ee3e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Sequence
import
torch
import
vllm.envs
as
envs
from
vllm.model_executor.layers.batch_invariant
import
vllm_is_batch_invariant
from
vllm.model_executor.layers.quantization.utils.fp8_utils
import
(
process_fp8_weight_block_strategy
,
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
is_fp8_marlin_supported
,
prepare_fp8_layer_for_marlin
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
kFp8Static128BlockSym
,
)
from
vllm.model_executor.utils
import
replace_parameter
from
vllm.platforms
import
current_platform
from
.ScaledMMLinearKernel
import
(
FP8ScaledMMLinearKernel
,
FP8ScaledMMLinearLayerConfig
,
)
class
MarlinFP8ScaledMMLinearKernel
(
FP8ScaledMMLinearKernel
):
"""
FP8 Marlin kernel for GPUs that lack FP8 hardware support.
Leverages the Marlin kernel for fast weight-only FP8 quantization.
"""
@
classmethod
def
is_supported
(
cls
,
compute_capability
:
int
|
None
=
None
)
->
tuple
[
bool
,
str
|
None
]:
if
not
current_platform
.
is_cuda
():
return
False
,
"requires CUDA."
# Check if platform supports FP8 Marlin
if
not
is_fp8_marlin_supported
():
return
False
,
"FP8 Marlin requires compute capability 7.5 or higher"
if
vllm_is_batch_invariant
():
return
False
,
"FP8 Marlin not supported for batch invariant execution."
if
(
compute_capability
is
not
None
and
compute_capability
>=
89
and
not
envs
.
VLLM_TEST_FORCE_FP8_MARLIN
):
return
(
False
,
"To apply FP8 Marlin on high-capability GPUs, please set "
"VLLM_TEST_FORCE_FP8_MARLIN=1"
,
)
return
True
,
None
@
classmethod
def
can_implement
(
cls
,
c
:
FP8ScaledMMLinearLayerConfig
)
->
tuple
[
bool
,
str
|
None
]:
return
True
,
None
def
__init__
(
self
,
c
:
FP8ScaledMMLinearLayerConfig
,
layer_param_names
:
Sequence
[
str
]
)
->
None
:
super
().
__init__
(
c
,
layer_param_names
)
self
.
marlin_input_dtype
=
None
self
.
block_quant
=
self
.
config
.
weight_quant_key
in
{
kFp8Static128BlockSym
}
self
.
size_k_first
=
not
self
.
block_quant
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
self
.
block_quant
:
weight
,
weight_scale_inv
=
process_fp8_weight_block_strategy
(
layer
.
weight
,
layer
.
weight_scale_inv
)
# Update layer with new values
replace_parameter
(
layer
,
"weight"
,
weight
.
data
)
replace_parameter
(
layer
,
"weight_scale_inv"
,
weight_scale_inv
.
data
)
else
:
weight
=
layer
.
weight
.
t
()
replace_parameter
(
layer
,
"weight"
,
weight
.
data
)
layer
.
input_scale
=
None
prepare_fp8_layer_for_marlin
(
layer
,
self
.
size_k_first
,
input_dtype
=
self
.
marlin_input_dtype
)
del
layer
.
input_scale
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
if
self
.
block_quant
:
weight_scale
=
layer
.
weight_scale_inv
else
:
weight_scale
=
layer
.
weight_scale
return
apply_fp8_marlin_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
weight_scale
,
workspace
=
layer
.
workspace
,
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
input_dtype
=
self
.
marlin_input_dtype
,
bias
=
bias
,
)
def
apply_scaled_mm
(
self
,
*
,
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
,
output_shape
:
list
,
)
->
torch
.
Tensor
:
pass
vllm/model_executor/layers/quantization/fbgemm_fp8.py
View file @
27d5ee3e
...
@@ -22,7 +22,6 @@ from vllm.model_executor.layers.quantization.base_config import (
...
@@ -22,7 +22,6 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase
,
QuantizeMethodBase
,
)
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
,
prepare_fp8_layer_for_marlin
,
)
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
...
@@ -177,15 +176,4 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
...
@@ -177,15 +176,4 @@ class FBGEMMFp8LinearMethod(LinearMethodBase):
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
|
None
=
None
,
bias
:
torch
.
Tensor
|
None
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
self
.
quant_config
.
use_marlin
:
return
apply_fp8_marlin_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
workspace
=
layer
.
workspace
,
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
bias
=
bias
,
)
return
self
.
fp8_linear
.
apply_weights
(
layer
,
x
,
bias
)
return
self
.
fp8_linear
.
apply_weights
(
layer
,
x
,
bias
)
vllm/model_executor/layers/quantization/fp8.py
View file @
27d5ee3e
...
@@ -7,7 +7,6 @@ import torch
...
@@ -7,7 +7,6 @@ import torch
from
torch.nn
import
Module
from
torch.nn
import
Module
from
torch.utils._python_dispatch
import
TorchDispatchMode
from
torch.utils._python_dispatch
import
TorchDispatchMode
import
vllm.envs
as
envs
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm._aiter_ops
import
rocm_aiter_ops
from
vllm._aiter_ops
import
rocm_aiter_ops
...
@@ -16,6 +15,7 @@ from vllm.logger import init_logger
...
@@ -16,6 +15,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.kernels.linear
import
(
from
vllm.model_executor.kernels.linear
import
(
init_fp8_linear_kernel
,
init_fp8_linear_kernel
,
)
)
from
vllm.model_executor.kernels.linear.scaled_mm
import
MarlinFP8ScaledMMLinearKernel
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.attention
import
Attention
from
vllm.model_executor.layers.batch_invariant
import
(
from
vllm.model_executor.layers.batch_invariant
import
(
vllm_is_batch_invariant
,
vllm_is_batch_invariant
,
...
@@ -61,10 +61,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
...
@@ -61,10 +61,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
get_marlin_input_dtype
,
get_marlin_input_dtype
,
)
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils_fp8
import
(
apply_fp8_marlin_linear
,
prepare_fp8_layer_for_marlin
,
)
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
from
vllm.model_executor.layers.quantization.utils.quant_utils
import
(
GroupShape
,
GroupShape
,
is_layer_skipped
,
is_layer_skipped
,
...
@@ -280,15 +276,6 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -280,15 +276,6 @@ class Fp8LinearMethod(LinearMethodBase):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
# kernel for fast weight-only FP8 quantization
self
.
marlin_input_dtype
=
None
self
.
marlin_input_dtype
=
None
self
.
use_marlin
=
(
not
current_platform
.
has_device_capability
(
89
)
or
envs
.
VLLM_TEST_FORCE_FP8_MARLIN
)
# Disable marlin for rocm
if
current_platform
.
is_rocm
()
or
current_platform
.
is_xpu
():
self
.
use_marlin
=
False
if
vllm_is_batch_invariant
():
self
.
use_marlin
=
False
self
.
use_aiter_and_is_supported
=
rocm_aiter_ops
.
is_linear_fp8_enabled
()
self
.
use_aiter_and_is_supported
=
rocm_aiter_ops
.
is_linear_fp8_enabled
()
self
.
use_deep_gemm
=
is_deep_gemm_supported
()
self
.
use_deep_gemm
=
is_deep_gemm_supported
()
...
@@ -297,16 +284,6 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -297,16 +284,6 @@ class Fp8LinearMethod(LinearMethodBase):
self
.
block_quant
=
self
.
weight_block_size
is
not
None
self
.
block_quant
=
self
.
weight_block_size
is
not
None
self
.
act_q_static
=
self
.
quant_config
.
activation_scheme
==
"static"
self
.
act_q_static
=
self
.
quant_config
.
activation_scheme
==
"static"
if
self
.
block_quant
:
assert
not
self
.
act_q_static
assert
self
.
weight_block_size
is
not
None
self
.
w8a8_block_fp8_linear
=
W8A8BlockFp8LinearOp
(
weight_group_shape
=
GroupShape
(
*
self
.
weight_block_size
),
act_quant_group_shape
=
GroupShape
(
1
,
self
.
weight_block_size
[
0
]),
cutlass_block_fp8_supported
=
self
.
cutlass_block_fp8_supported
,
use_aiter_and_is_supported
=
self
.
use_aiter_and_is_supported
,
)
else
:
# Use per-token quantization for better perf if dynamic and cutlass
# Use per-token quantization for better perf if dynamic and cutlass
if
self
.
act_q_static
:
if
self
.
act_q_static
:
activation_quant_key
=
kFp8StaticTensorSym
activation_quant_key
=
kFp8StaticTensorSym
...
@@ -315,12 +292,28 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -315,12 +292,28 @@ class Fp8LinearMethod(LinearMethodBase):
else
:
else
:
activation_quant_key
=
kFp8DynamicTensorSym
activation_quant_key
=
kFp8DynamicTensorSym
if
self
.
block_quant
:
weight_quant_key
=
kFp8Static128BlockSym
else
:
weight_quant_key
=
kFp8StaticTensorSym
self
.
fp8_linear
=
init_fp8_linear_kernel
(
self
.
fp8_linear
=
init_fp8_linear_kernel
(
activation_quant_key
=
activation_quant_key
,
activation_quant_key
=
activation_quant_key
,
weight_quant_key
=
kFp8StaticTensorSym
,
weight_quant_key
=
weight_quant_key
,
out_dtype
=
torch
.
get_default_dtype
(),
out_dtype
=
torch
.
get_default_dtype
(),
module_name
=
self
.
__class__
.
__name__
,
module_name
=
self
.
__class__
.
__name__
,
)
)
self
.
use_marlin
=
isinstance
(
self
.
fp8_linear
,
MarlinFP8ScaledMMLinearKernel
)
if
self
.
block_quant
and
not
self
.
use_marlin
:
assert
not
self
.
act_q_static
assert
self
.
weight_block_size
is
not
None
self
.
w8a8_block_fp8_linear
=
W8A8BlockFp8LinearOp
(
weight_group_shape
=
GroupShape
(
*
self
.
weight_block_size
),
act_quant_group_shape
=
GroupShape
(
1
,
self
.
weight_block_size
[
0
]),
cutlass_block_fp8_supported
=
self
.
cutlass_block_fp8_supported
,
use_aiter_and_is_supported
=
self
.
use_aiter_and_is_supported
,
)
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -387,12 +380,18 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -387,12 +380,18 @@ class Fp8LinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"input_scale"
,
scale
)
layer
.
register_parameter
(
"input_scale"
,
scale
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
size_k_first
=
True
if
self
.
use_marlin
:
# Only Marlin kernels support `marlin_input_dtype`; guard to avoid
# AttributeError if backend selection changes.
if
hasattr
(
self
.
fp8_linear
,
"marlin_input_dtype"
):
self
.
fp8_linear
.
marlin_input_dtype
=
self
.
marlin_input_dtype
self
.
fp8_linear
.
process_weights_after_loading
(
layer
)
return
input_scale
=
None
input_scale
=
None
# TODO(rob): refactor block quant into separate class.
# TODO(rob): refactor block quant into separate class.
if
self
.
block_quant
:
if
self
.
block_quant
:
assert
not
self
.
act_q_static
assert
not
self
.
act_q_static
size_k_first
=
False
weight
,
weight_scale_inv
=
process_fp8_weight_block_strategy
(
weight
,
weight_scale_inv
=
process_fp8_weight_block_strategy
(
layer
.
weight
,
layer
.
weight_scale_inv
layer
.
weight
,
layer
.
weight_scale_inv
...
@@ -411,7 +410,6 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -411,7 +410,6 @@ class Fp8LinearMethod(LinearMethodBase):
# If using w8a8, torch._scaled_mm needs per tensor, so
# If using w8a8, torch._scaled_mm needs per tensor, so
# requantize the logical shards as a single weight.
# requantize the logical shards as a single weight.
if
not
self
.
use_marlin
:
weight
,
weight_scale
,
input_scale
=
process_fp8_weight_tensor_strategy
(
weight
,
weight_scale
,
input_scale
=
process_fp8_weight_tensor_strategy
(
weight
,
weight
,
weight_scale
,
weight_scale
,
...
@@ -432,14 +430,6 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -432,14 +430,6 @@ class Fp8LinearMethod(LinearMethodBase):
else
:
else
:
layer
.
input_scale
=
None
layer
.
input_scale
=
None
if
self
.
use_marlin
:
prepare_fp8_layer_for_marlin
(
layer
,
size_k_first
,
input_dtype
=
self
.
marlin_input_dtype
)
# Activations not quantized for marlin.
del
layer
.
input_scale
return
if
self
.
block_quant
:
if
self
.
block_quant
:
maybe_post_process_fp8_weight_block
(
layer
)
maybe_post_process_fp8_weight_block
(
layer
)
...
@@ -486,21 +476,7 @@ class Fp8LinearMethod(LinearMethodBase):
...
@@ -486,21 +476,7 @@ class Fp8LinearMethod(LinearMethodBase):
return
torch
.
nn
.
functional
.
linear
(
x
,
weight_bf16
.
t
(),
bias
)
return
torch
.
nn
.
functional
.
linear
(
x
,
weight_bf16
.
t
(),
bias
)
if
self
.
use_marlin
:
if
self
.
use_marlin
:
if
self
.
block_quant
:
return
self
.
fp8_linear
.
apply_weights
(
layer
,
x
,
bias
)
weight_scale
=
layer
.
weight_scale_inv
else
:
weight_scale
=
layer
.
weight_scale
return
apply_fp8_marlin_linear
(
input
=
x
,
weight
=
layer
.
weight
,
weight_scale
=
weight_scale
,
workspace
=
layer
.
workspace
,
size_n
=
layer
.
output_size_per_partition
,
size_k
=
layer
.
input_size_per_partition
,
input_dtype
=
self
.
marlin_input_dtype
,
bias
=
bias
,
)
if
self
.
block_quant
:
if
self
.
block_quant
:
assert
self
.
weight_block_size
is
not
None
assert
self
.
weight_block_size
is
not
None
...
@@ -623,18 +599,20 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
...
@@ -623,18 +599,20 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
layer
.
input_scale
=
None
layer
.
input_scale
=
None
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
scale
=
None
)
qweight
,
weight_scale
=
ops
.
scaled_fp8_quant
(
layer
.
weight
,
scale
=
None
)
weight
=
qweight
.
t
()
# Update layer with new values.
# Update layer with new values.
replace_parameter
(
layer
,
"weight"
,
weight
.
data
)
replace_parameter
(
layer
,
"weight"
,
q
weight
.
data
)
replace_parameter
(
layer
,
"weight_scale"
,
weight_scale
.
data
)
replace_parameter
(
layer
,
"weight_scale"
,
weight_scale
.
data
)
if
self
.
use_marlin
:
if
self
.
use_marlin
:
size_k_first
=
True
# Only Marlin kernels support `marlin_input_dtype`; guard to avoid
prepare_fp8_layer_for_marlin
(
# AttributeError if backend selection changes.
layer
,
size_k_first
,
input_dtype
=
self
.
marlin_input_dtype
if
hasattr
(
self
.
fp8_linear
,
"marlin_input_dtype"
):
)
self
.
fp8_linear
.
marlin_input_dtype
=
self
.
marlin_input_dtype
# Activations not quantized for marlin.
self
.
fp8_linear
.
process_weights_after_loading
(
layer
)
else
:
weight
=
qweight
.
t
()
replace_parameter
(
layer
,
"weight"
,
weight
.
data
)
# Prevent duplicate processing (e.g., during weight reload)
# Prevent duplicate processing (e.g., during weight reload)
layer
.
_already_called_process_weights_after_loading
=
True
layer
.
_already_called_process_weights_after_loading
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment