Unverified Commit 11e2375f authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Refactor] Move MXFP8 GEMM management into MxFp8LinearKernel (#39205)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent fc645f1a
...@@ -58,6 +58,19 @@ from vllm.model_executor.kernels.linear.mixed_precision.xpu import ( ...@@ -58,6 +58,19 @@ from vllm.model_executor.kernels.linear.mixed_precision.xpu import (
XPUW4A8IntLinearKernel, XPUW4A8IntLinearKernel,
XPUwNa16LinearKernel, XPUwNa16LinearKernel,
) )
from vllm.model_executor.kernels.linear.mxfp8 import (
Mxfp8LinearKernel,
Mxfp8LinearLayerConfig,
)
from vllm.model_executor.kernels.linear.mxfp8.emulation import (
EmulationMxfp8LinearKernel,
)
from vllm.model_executor.kernels.linear.mxfp8.flashinfer import (
FlashInferCutlassMxfp8LinearKernel,
)
from vllm.model_executor.kernels.linear.mxfp8.marlin import (
MarlinMxfp8LinearKernel,
)
from vllm.model_executor.kernels.linear.nvfp4 import ( from vllm.model_executor.kernels.linear.nvfp4 import (
NvFp4LinearKernel, NvFp4LinearKernel,
NvFp4LinearLayerConfig, NvFp4LinearLayerConfig,
...@@ -221,6 +234,17 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = { ...@@ -221,6 +234,17 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
} }
# in priority/performance order (when available) # in priority/performance order (when available)
_POSSIBLE_MXFP8_KERNELS: dict[PlatformEnum, list[type[Mxfp8LinearKernel]]] = {
PlatformEnum.CUDA: [
FlashInferCutlassMxfp8LinearKernel,
MarlinMxfp8LinearKernel,
EmulationMxfp8LinearKernel,
],
PlatformEnum.ROCM: [
EmulationMxfp8LinearKernel,
],
}
_POSSIBLE_NVFP4_KERNELS: dict[PlatformEnum, list[type[NvFp4LinearKernel]]] = { _POSSIBLE_NVFP4_KERNELS: dict[PlatformEnum, list[type[NvFp4LinearKernel]]] = {
PlatformEnum.CUDA: [ PlatformEnum.CUDA: [
FlashInferCutlassNvFp4LinearKernel, FlashInferCutlassNvFp4LinearKernel,
...@@ -482,6 +506,41 @@ def choose_mp_linear_kernel( ...@@ -482,6 +506,41 @@ def choose_mp_linear_kernel(
) )
def init_mxfp8_linear_kernel() -> Mxfp8LinearKernel:
"""Select and instantiate the best MXFP8 linear kernel for the
current platform."""
config = Mxfp8LinearLayerConfig()
platform = current_platform._enum
possible = _POSSIBLE_MXFP8_KERNELS.get(platform, [])
failure_reasons = []
for kernel_cls in possible:
if kernel_cls.__name__ in envs.VLLM_DISABLED_KERNELS:
failure_reasons.append(
f" {kernel_cls.__name__} disabled by environment variable"
)
continue
is_supported, reason = kernel_cls.is_supported()
if not is_supported:
failure_reasons.append(f"{kernel_cls.__name__}: {reason}")
continue
can_implement, reason = kernel_cls.can_implement(config)
if not can_implement:
failure_reasons.append(f"{kernel_cls.__name__}: {reason}")
continue
logger.info_once("Using %s for MXFP8 GEMM", kernel_cls.__name__)
return kernel_cls(config)
raise ValueError(
"Failed to find a kernel that can implement the "
"MXFP8 linear layer. Reasons: \n" + "\n".join(failure_reasons)
)
def init_wfp8_a16_linear_kernel( def init_wfp8_a16_linear_kernel(
weight_quant_key: QuantKey, weight_quant_key: QuantKey,
activation_quant_key: QuantKey, activation_quant_key: QuantKey,
...@@ -628,6 +687,10 @@ def register_linear_kernel( ...@@ -628,6 +687,10 @@ def register_linear_kernel(
if platform not in _POSSIBLE_FP8_KERNELS: if platform not in _POSSIBLE_FP8_KERNELS:
_POSSIBLE_FP8_KERNELS[platform] = [] _POSSIBLE_FP8_KERNELS[platform] = []
_POSSIBLE_FP8_KERNELS[platform].append(kernel_class) _POSSIBLE_FP8_KERNELS[platform].append(kernel_class)
elif kernel_type == "mxfp8":
if platform not in _POSSIBLE_MXFP8_KERNELS:
_POSSIBLE_MXFP8_KERNELS[platform] = []
_POSSIBLE_MXFP8_KERNELS[platform].append(kernel_class)
elif kernel_type == "nvfp4": elif kernel_type == "nvfp4":
if platform not in _POSSIBLE_NVFP4_KERNELS: if platform not in _POSSIBLE_NVFP4_KERNELS:
_POSSIBLE_NVFP4_KERNELS[platform] = [] _POSSIBLE_NVFP4_KERNELS[platform] = []
...@@ -674,6 +737,12 @@ __all__ = [ ...@@ -674,6 +737,12 @@ __all__ = [
"TritonW4A16LinearKernel", "TritonW4A16LinearKernel",
"XPUW4A8IntLinearKernel", "XPUW4A8IntLinearKernel",
"XPUwNa16LinearKernel", "XPUwNa16LinearKernel",
"init_mxfp8_linear_kernel",
"Mxfp8LinearKernel",
"Mxfp8LinearLayerConfig",
"FlashInferCutlassMxfp8LinearKernel",
"MarlinMxfp8LinearKernel",
"EmulationMxfp8LinearKernel",
"CutlassNvFp4LinearKernel", "CutlassNvFp4LinearKernel",
"EmulationNvFp4LinearKernel", "EmulationNvFp4LinearKernel",
"FbgemmNvFp4LinearKernel", "FbgemmNvFp4LinearKernel",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass
import torch
@dataclass
class Mxfp8LinearLayerConfig:
"""Configuration for an MXFP8 linear layer.
All MXFP8 layers share the same structure: FP8-E4M3 weights with
uint8 (E8M0) per-block scales at block size 32.
"""
pass
class Mxfp8LinearKernel(ABC):
"""Base class for MXFP8 quantized linear kernels.
Each subclass implements a specific GEMM backend (FlashInfer CUTLASS,
Marlin, emulation).
"""
def __init__(self, c: Mxfp8LinearLayerConfig) -> None:
assert self.can_implement(c)[0]
assert self.is_supported()[0]
self.config = c
@classmethod
@abstractmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
raise NotImplementedError
@classmethod
@abstractmethod
def can_implement(cls, c: Mxfp8LinearLayerConfig) -> tuple[bool, str | None]:
raise NotImplementedError
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
raise NotImplementedError
@abstractmethod
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.model_executor.kernels.linear.mxfp8.Mxfp8LinearKernel import (
Mxfp8LinearKernel,
Mxfp8LinearLayerConfig,
)
__all__ = [
"Mxfp8LinearKernel",
"Mxfp8LinearLayerConfig",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
MXFP8_BLOCK_SIZE,
MXFP8_SCALE_DTYPE,
dequant_mxfp8_to_bf16,
)
from .Mxfp8LinearKernel import Mxfp8LinearKernel, Mxfp8LinearLayerConfig
class EmulationMxfp8LinearKernel(Mxfp8LinearKernel):
"""Software emulation fallback for MXFP8 (dequant to BF16)."""
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
return True, None
@classmethod
def can_implement(cls, c: Mxfp8LinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = layer.weight.data # [N, K]
N, K = weight.shape
scale_k = K // MXFP8_BLOCK_SIZE
weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous()
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
weight_scale = layer.weight_scale
if weight_scale.dtype != MXFP8_SCALE_DTYPE:
raise ValueError(
f"Emulation backend requires {MXFP8_SCALE_DTYPE} "
f"weight_scale dtype, got {weight_scale.dtype}."
)
if weight_scale.ndim != 2:
raise ValueError(
f"Emulation backend requires 2D weight_scale, "
f"got {weight_scale.ndim}D. "
f"Ensure process_weights_after_loading was called."
)
weight_bf16 = dequant_mxfp8_to_bf16(layer.weight, weight_scale)
output = torch.nn.functional.linear(x, weight_bf16, bias)
return output.to(x.dtype)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
MXFP8_BLOCK_SIZE,
mxfp8_e4m3_quantize,
swizzle_mxfp8_scale,
)
from vllm.platforms import current_platform
from vllm.utils import flashinfer as vllm_flashinfer
from .Mxfp8LinearKernel import Mxfp8LinearKernel, Mxfp8LinearLayerConfig
class FlashInferCutlassMxfp8LinearKernel(Mxfp8LinearKernel):
"""MXFP8 W8A8 GEMM via FlashInfer CUTLASS (SM100+)."""
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if current_platform.has_device_capability(100):
return True, None
return False, "requires >=sm_100 (Blackwell)"
@classmethod
def can_implement(cls, c: Mxfp8LinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = layer.weight.data # [N, K]
N, K = weight.shape
scale_k = K // MXFP8_BLOCK_SIZE
weight_scale_2d = layer.weight_scale.data[:N, :scale_k].contiguous()
weight_scale_swizzled = swizzle_mxfp8_scale(weight_scale_2d, M=N, K=K)
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
layer.weight_scale = Parameter(
weight_scale_swizzled.contiguous(), requires_grad=False
)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
weight = layer.weight
weight_scale = layer.weight_scale
out_dtype = x.dtype
N, K = weight.shape
input_shape = x.shape
input_2d = x.view(-1, K)
M_orig = input_2d.shape[0]
min_dim = 128
assert min_dim <= K, (
f"mm_mxfp8 requires K >= {min_dim}, got K={K}. "
f"in_features is too small for mm_mxfp8."
)
assert K % MXFP8_BLOCK_SIZE == 0, (
f"mm_mxfp8 requires K to be divisible by {MXFP8_BLOCK_SIZE}, got K={K}."
)
assert min_dim <= N, (
f"mm_mxfp8 requires N >= {min_dim}, got N={N}. "
f"out_features is too small for mm_mxfp8."
)
M_padded = ((M_orig + min_dim - 1) // min_dim) * min_dim
if M_padded != M_orig:
pad_rows = M_padded - M_orig
input_2d = torch.nn.functional.pad(input_2d, (0, 0, 0, pad_rows))
input_mxfp8, input_scale = mxfp8_e4m3_quantize(
input_2d, is_sf_swizzled_layout=True
)
if not weight.is_contiguous():
weight = weight.contiguous()
output = vllm_flashinfer.mm_mxfp8(
input_mxfp8,
weight.t(),
input_scale,
weight_scale,
out_dtype=out_dtype,
backend="cutlass",
)
if M_padded != M_orig:
output = output[:M_orig, :]
if bias is not None:
output = output + bias
output_shape = (*input_shape[:-1], N)
return output.view(output_shape)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from .Mxfp8LinearKernel import Mxfp8LinearKernel, Mxfp8LinearLayerConfig
class MarlinMxfp8LinearKernel(Mxfp8LinearKernel):
"""MXFP8 W8A16 GEMM via Marlin (SM80+)."""
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
is_fp8_marlin_supported,
)
if is_fp8_marlin_supported():
return True, None
return False, "Marlin FP8 not available"
@classmethod
def can_implement(cls, c: Mxfp8LinearLayerConfig) -> tuple[bool, str | None]:
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_mxfp8_layer_for_marlin,
)
prepare_mxfp8_layer_for_marlin(layer)
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_mxfp8_marlin_linear,
)
return apply_mxfp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)
...@@ -517,10 +517,10 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod): ...@@ -517,10 +517,10 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
# TODO: remove this check once the following RFC is resolved. # TODO: remove this check once the following RFC is resolved.
# https://github.com/vllm-project/vllm/issues/33314 # https://github.com/vllm-project/vllm/issues/33314
# This check is required because Mxfp8OnlineLinearMethod inherits from # Subclasses (e.g. Mxfp8OnlineLinearMethod) only need the weight
# Fp8OnlineLinearMethod but only calls super().create_weights(), so we must # registration above and manage their own kernel, so skip fp8_linear
# skip the fp8_linear kernel creation. # kernel creation for them.
if hasattr(self, "mxfp8_linear"): if type(self) is not Fp8OnlineLinearMethod:
return return
self.fp8_linear = init_fp8_linear_kernel( self.fp8_linear = init_fp8_linear_kernel(
......
...@@ -12,6 +12,7 @@ from vllm.config import get_current_vllm_config ...@@ -12,6 +12,7 @@ from vllm.config import get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import ( from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel, init_fp8_linear_kernel,
init_mxfp8_linear_kernel,
init_nvfp4_linear_kernel, init_nvfp4_linear_kernel,
) )
from vllm.model_executor.layers.attention import Attention, MLAAttention from vllm.model_executor.layers.attention import Attention, MLAAttention
...@@ -70,7 +71,6 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( ...@@ -70,7 +71,6 @@ from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
MXFP8_BLOCK_SIZE, MXFP8_BLOCK_SIZE,
MXFP8_SCALE_DTYPE, MXFP8_SCALE_DTYPE,
MXFP8_VALUE_DTYPE, MXFP8_VALUE_DTYPE,
Mxfp8LinearOp,
mxfp8_e4m3_quantize, mxfp8_e4m3_quantize,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
...@@ -1576,7 +1576,7 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase): ...@@ -1576,7 +1576,7 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
"Dynamic quantization is not supported." "Dynamic quantization is not supported."
) )
self.mxfp8_linear_op = Mxfp8LinearOp() self.kernel = init_mxfp8_linear_kernel()
def create_weights( def create_weights(
self, self,
...@@ -1658,7 +1658,7 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase): ...@@ -1658,7 +1658,7 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
f" got {layer.weight_scale.dtype}" f" got {layer.weight_scale.dtype}"
) )
self.mxfp8_linear_op.process_weights(layer) self.kernel.process_weights_after_loading(layer)
def apply( def apply(
self, self,
...@@ -1666,16 +1666,7 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase): ...@@ -1666,16 +1666,7 @@ class ModelOptMxFp8LinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.mxfp8_linear_op.apply( return self.kernel.apply_weights(layer, x, bias)
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=x.dtype,
bias=bias,
workspace=getattr(layer, "workspace", None),
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
)
class ModelOptMxFp8FusedMoE(FusedMoEMethodBase): class ModelOptMxFp8FusedMoE(FusedMoEMethodBase):
......
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
from torch.nn import Module from torch.nn import Module
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import init_mxfp8_linear_kernel
from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.fused_moe import ( from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoE,
...@@ -34,7 +35,6 @@ from vllm.model_executor.layers.quantization.fp8 import ( ...@@ -34,7 +35,6 @@ from vllm.model_executor.layers.quantization.fp8 import (
) )
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
MXFP8_BLOCK_SIZE, MXFP8_BLOCK_SIZE,
Mxfp8LinearOp,
mxfp8_e4m3_quantize, mxfp8_e4m3_quantize,
) )
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
...@@ -126,8 +126,7 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod): ...@@ -126,8 +126,7 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
def __init__(self, quant_config: "Mxfp8Config"): def __init__(self, quant_config: "Mxfp8Config"):
self.quant_config = quant_config self.quant_config = quant_config
self.out_dtype = torch.get_default_dtype() self.kernel = init_mxfp8_linear_kernel()
self.mxfp8_linear = Mxfp8LinearOp()
def create_weights( def create_weights(
self, self,
...@@ -166,7 +165,7 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod): ...@@ -166,7 +165,7 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
replace_parameter(layer, "weight", weight_fp8.data) replace_parameter(layer, "weight", weight_fp8.data)
replace_parameter(layer, "weight_scale", weight_scale.data) replace_parameter(layer, "weight_scale", weight_scale.data)
self.mxfp8_linear.process_weights(layer) self.kernel.process_weights_after_loading(layer)
layer._already_called_process_weights_after_loading = True layer._already_called_process_weights_after_loading = True
...@@ -176,16 +175,7 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod): ...@@ -176,16 +175,7 @@ class Mxfp8OnlineLinearMethod(Fp8OnlineLinearMethod):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.mxfp8_linear.apply( return self.kernel.apply_weights(layer, x, bias)
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
bias=bias,
workspace=getattr(layer, "workspace", None),
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
)
class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod): class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from enum import Enum
import torch import torch
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.utils import flashinfer as vllm_flashinfer
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
logger = init_logger(__name__)
class Mxfp8LinearBackend(Enum):
EMULATION = "emulation"
FLASHINFER_CUTLASS = "flashinfer-cutlass"
MARLIN = "marlin"
# MXFP8 constants # MXFP8 constants
MXFP8_VALUE_DTYPE = torch.float8_e4m3fn MXFP8_VALUE_DTYPE = torch.float8_e4m3fn
MXFP8_SCALE_DTYPE = torch.uint8 MXFP8_SCALE_DTYPE = torch.uint8
MXFP8_BLOCK_SIZE = 32 MXFP8_BLOCK_SIZE = 32
def select_mxfp8_linear_backend() -> Mxfp8LinearBackend:
"""Select the best MXFP8 linear backend for the current device.
- SM100+ (Blackwell): FLASHINFER_CUTLASS (native MXFP8 W8A8 GEMM)
- SM80+ (Ampere/Ada): MARLIN (MXFP8 W8A16 GEMM)
- Otherwise: EMULATION (dequant to BF16 fallback)
"""
from vllm.platforms import current_platform
if current_platform.has_device_capability(100):
return Mxfp8LinearBackend.FLASHINFER_CUTLASS
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
is_fp8_marlin_supported,
)
if is_fp8_marlin_supported():
return Mxfp8LinearBackend.MARLIN
return Mxfp8LinearBackend.EMULATION
def swizzle_mxfp8_scale(sf: torch.Tensor, M: int, K: int) -> torch.Tensor: def swizzle_mxfp8_scale(sf: torch.Tensor, M: int, K: int) -> torch.Tensor:
"""Swizzle MXFP8 scales from row-major 2D to F8_128x4 layout.""" """Swizzle MXFP8 scales from row-major 2D to F8_128x4 layout."""
scaling_vector_size = MXFP8_BLOCK_SIZE # 32 for MXFP8 scaling_vector_size = MXFP8_BLOCK_SIZE # 32 for MXFP8
...@@ -209,194 +173,3 @@ def xpu_mxfp8_quantize( ...@@ -209,194 +173,3 @@ def xpu_mxfp8_quantize(
x: torch.Tensor, dtype: torch.dtype | None = None x: torch.Tensor, dtype: torch.dtype | None = None
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
return torch.ops.vllm.xpu_mxfp8_quantize(x, dtype) return torch.ops.vllm.xpu_mxfp8_quantize(x, dtype)
class Mxfp8LinearOp:
def __init__(self):
self.backend = select_mxfp8_linear_backend()
logger.info_once("Using %s backend for MXFP8 GEMM", self.backend)
def process_weights(self, layer: torch.nn.Module) -> None:
"""Process MXFP8 weights after loading into backend-specific format."""
if self.backend == Mxfp8LinearBackend.MARLIN:
self._process_weights_marlin(layer)
elif self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS:
self._process_weights_flashinfer_cutlass(layer)
else:
self._process_weights_emulation(layer)
def _process_weights_emulation(self, layer: torch.nn.Module) -> None:
"""Keep scales as 2D uint8 for dequant-to-BF16 emulation."""
weight = layer.weight.data # [N, K]
N, K = weight.shape
scale_k = K // MXFP8_BLOCK_SIZE
weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous()
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
def _process_weights_flashinfer_cutlass(self, layer: torch.nn.Module) -> None:
"""Swizzle scales to F8_128x4 layout for flashinfer CUTLASS."""
weight = layer.weight.data # [N, K]
N, K = weight.shape
scale_k = K // MXFP8_BLOCK_SIZE
weight_scale_2d = layer.weight_scale.data[:N, :scale_k].contiguous()
weight_scale_swizzled = swizzle_mxfp8_scale(weight_scale_2d, M=N, K=K)
layer.weight = Parameter(weight.contiguous(), requires_grad=False)
layer.weight_scale = Parameter(
weight_scale_swizzled.contiguous(), requires_grad=False
)
def _process_weights_marlin(self, layer: torch.nn.Module) -> None:
"""Repack MXFP8 weights and scales into Marlin kernel format."""
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_mxfp8_layer_for_marlin,
)
prepare_mxfp8_layer_for_marlin(layer)
def _apply_emulation(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out_dtype: torch.dtype,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
if weight_scale.dtype != MXFP8_SCALE_DTYPE:
raise ValueError(
f"TORCH backend requires {MXFP8_SCALE_DTYPE} weight_scale dtype, "
f"got {weight_scale.dtype}."
)
if weight_scale.ndim != 2:
raise ValueError(
f"TORCH backend requires 2D weight_scale, got {weight_scale.ndim}D. "
f"Ensure process_weights_after_loading was called."
)
weight_bf16 = dequant_mxfp8_to_bf16(weight, weight_scale)
output = torch.nn.functional.linear(input, weight_bf16, bias)
return output.to(out_dtype)
def _apply_flashinfer_cutlass(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out_dtype: torch.dtype,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
N, K = weight.shape
input_shape = input.shape
input_2d = input.view(-1, K)
M_orig = input_2d.shape[0]
# Minimum dimension size for F8_128x4 block scaling layout
min_dim = 128
assert min_dim <= K, (
f"mm_mxfp8 requires K >= {min_dim}, got K={K}. "
f"in_features is too small for mm_mxfp8."
)
assert K % MXFP8_BLOCK_SIZE == 0, (
f"mm_mxfp8 requires K to be divisible by {MXFP8_BLOCK_SIZE}, got K={K}."
)
assert min_dim <= N, (
f"mm_mxfp8 requires N >= {min_dim}, got N={N}. "
f"out_features is too small for mm_mxfp8."
)
M_padded = ((M_orig + min_dim - 1) // min_dim) * min_dim
if M_padded != M_orig:
pad_rows = M_padded - M_orig
input_2d = torch.nn.functional.pad(input_2d, (0, 0, 0, pad_rows))
input_mxfp8, input_scale = mxfp8_e4m3_quantize(
input_2d,
is_sf_swizzled_layout=True, # Swizzled for best accuracy
)
if not weight.is_contiguous():
weight = weight.contiguous()
output = vllm_flashinfer.mm_mxfp8(
input_mxfp8,
weight.t(),
input_scale,
weight_scale,
out_dtype=out_dtype,
backend="cutlass",
)
if M_padded != M_orig:
output = output[:M_orig, :]
if bias is not None:
output = output + bias
output_shape = (*input_shape[:-1], N)
return output.view(output_shape)
def _apply_marlin(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out_dtype: torch.dtype,
bias: torch.Tensor | None = None,
*,
workspace: torch.Tensor,
size_n: int,
size_k: int,
) -> torch.Tensor:
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_mxfp8_marlin_linear,
)
return apply_mxfp8_marlin_linear(
input=input,
weight=weight,
weight_scale=weight_scale,
workspace=workspace,
size_n=size_n,
size_k=size_k,
bias=bias,
)
def apply(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out_dtype: torch.dtype,
bias: torch.Tensor | None = None,
*,
workspace: torch.Tensor | None = None,
size_n: int = 0,
size_k: int = 0,
) -> torch.Tensor:
if self.backend == Mxfp8LinearBackend.EMULATION:
return self._apply_emulation(input, weight, weight_scale, out_dtype, bias)
if self.backend == Mxfp8LinearBackend.MARLIN:
assert workspace is not None
return self._apply_marlin(
input,
weight,
weight_scale,
out_dtype,
bias,
workspace=workspace,
size_n=size_n,
size_k=size_k,
)
assert self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS
return self._apply_flashinfer_cutlass(
input, weight, weight_scale, out_dtype, bias
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment