Unverified Commit 1b19bd75 authored by zofia's avatar zofia Committed by GitHub
Browse files

[MXFP8] [XPU] add a new compressed tensor schema and add a xpu mxfp8 gemm kernel (#38707)


Signed-off-by: default avatarZhu, Zufang <zufang.zhu@intel.com>
parent 200a727e
...@@ -71,6 +71,9 @@ from vllm.model_executor.kernels.linear.mxfp8.flashinfer import ( ...@@ -71,6 +71,9 @@ from vllm.model_executor.kernels.linear.mxfp8.flashinfer import (
from vllm.model_executor.kernels.linear.mxfp8.marlin import ( from vllm.model_executor.kernels.linear.mxfp8.marlin import (
MarlinMxfp8LinearKernel, MarlinMxfp8LinearKernel,
) )
from vllm.model_executor.kernels.linear.mxfp8.xpu import (
XPUMxFp8LinearKernel,
)
from vllm.model_executor.kernels.linear.nvfp4 import ( from vllm.model_executor.kernels.linear.nvfp4 import (
NvFp4LinearKernel, NvFp4LinearKernel,
NvFp4LinearLayerConfig, NvFp4LinearLayerConfig,
...@@ -243,6 +246,10 @@ _POSSIBLE_MXFP8_KERNELS: dict[PlatformEnum, list[type[Mxfp8LinearKernel]]] = { ...@@ -243,6 +246,10 @@ _POSSIBLE_MXFP8_KERNELS: dict[PlatformEnum, list[type[Mxfp8LinearKernel]]] = {
PlatformEnum.ROCM: [ PlatformEnum.ROCM: [
EmulationMxfp8LinearKernel, EmulationMxfp8LinearKernel,
], ],
PlatformEnum.XPU: [
XPUMxFp8LinearKernel,
EmulationMxfp8LinearKernel,
],
} }
_POSSIBLE_NVFP4_KERNELS: dict[PlatformEnum, list[type[NvFp4LinearKernel]]] = { _POSSIBLE_NVFP4_KERNELS: dict[PlatformEnum, list[type[NvFp4LinearKernel]]] = {
...@@ -742,6 +749,7 @@ __all__ = [ ...@@ -742,6 +749,7 @@ __all__ = [
"Mxfp8LinearLayerConfig", "Mxfp8LinearLayerConfig",
"FlashInferCutlassMxfp8LinearKernel", "FlashInferCutlassMxfp8LinearKernel",
"MarlinMxfp8LinearKernel", "MarlinMxfp8LinearKernel",
"XPUMxFp8LinearKernel",
"EmulationMxfp8LinearKernel", "EmulationMxfp8LinearKernel",
"CutlassNvFp4LinearKernel", "CutlassNvFp4LinearKernel",
"EmulationNvFp4LinearKernel", "EmulationNvFp4LinearKernel",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
xpu_mxfp8_quantize as quant_mxfp8,
)
from vllm.model_executor.utils import replace_parameter
from vllm.platforms import current_platform
from .Mxfp8LinearKernel import Mxfp8LinearKernel, Mxfp8LinearLayerConfig
class XPUMxFp8LinearKernel(Mxfp8LinearKernel):
"""MXFP8 W8A8 GEMM on XPU."""
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_xpu():
return False, "XPUMxFp8 only support on XPU"
return True, None
@classmethod
def can_implement(cls, c: Mxfp8LinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight_scale = layer.weight_scale.view(torch.float8_e8m0fnu)
weight_scale = weight_scale.t().contiguous()
replace_parameter(layer, "weight", layer.weight.t())
replace_parameter(layer, "weight_scale", weight_scale.data)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
out_dtype = x.dtype
x_fp8, x_scale = quant_mxfp8(x)
return torch.ops._xpu_C.fp8_gemm(
x_fp8,
layer.weight,
out_dtype,
x_scale,
layer.weight_scale,
bias,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment