Unverified Commit 394ff869 authored by Yan Ma's avatar Yan Ma Committed by GitHub
Browse files

[XPU][CT] support per-channel quantization in xpu fp8 linear method (#38316)


Signed-off-by: default avatarYan Ma <yan.ma@intel.com>
parent df1e30e7
......@@ -204,7 +204,7 @@ _POSSIBLE_WFP8A16_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]
# To be added
],
PlatformEnum.XPU: [
# To be added
XPUFP8ScaledMMLinearKernel,
],
}
......
......@@ -9,6 +9,11 @@ from vllm.model_executor.kernels.linear import ( # noqa: E501
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticChannelSym,
kFp8StaticTensorSym,
)
from vllm.model_executor.utils import replace_parameter
from vllm.platforms import current_platform
......@@ -23,6 +28,11 @@ class XPUFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if c.weight_quant_key not in {kFp8StaticChannelSym, kFp8StaticTensorSym}:
return (
False,
"XPUFP8ScaledMM only support per-channel and per-tensor quantization",
)
if c.weight_quant_key.dtype not in {torch.float8_e5m2, torch.float8_e4m3fn}:
return False, "XPUFP8ScaledMM only support FP8 weight dtype"
return True, None
......@@ -35,6 +45,9 @@ class XPUFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
self.config = c
self.layer_param_names = layer_param_names
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
replace_parameter(layer, "weight", layer.weight.data.t())
def apply_weights(
self,
layer: torch.nn.Module,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment