Unverified Commit 148117ea authored by vllmellm's avatar vllmellm Committed by GitHub
Browse files

[Refactor] Make FP8 Linear Ops use kernel abstraction (#27814)


Signed-off-by: default avatarvllmellm <vllm.ellm@embeddedllm.com>
parent e9c83cdc
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
)
class FlashInferFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_cuda():
return False, "requires CUDA."
if not has_flashinfer():
return False, "requires FlashInfer to be installed."
if compute_capability is not None and compute_capability < 100:
return False, "requires compute capability 100 and above."
return True, None
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = (
c.activation_quant_key.scale.group_shape.is_per_tensor()
)
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if not (per_tensor_activation_scales and per_tensor_weight_scales):
return False, "requires per tensor activation and weight scales."
return True, None
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
return flashinfer_scaled_fp8_mm(
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from packaging import version
from vllm.config import CompilationMode, get_current_vllm_config
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
)
class TorchFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
"""
Base class for FP8 linear kernels using Torch.
Each subclass represents a kernel variant for
specific device capabilities and torch versions.
"""
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not (current_platform.is_cuda_alike() or current_platform.is_cpu()):
return False, "requires ROCm, CUDA or CPU."
if compute_capability is not None and compute_capability < 89:
return False, "requires compute capability 89 and above."
return True, None
def get_output_padding(self) -> int | None:
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
#
# The perf gain is still relevant as of 16/1/2026
# torch version == 2.9.0. More details in the link below:
# https://github.com/vllm-project/vllm/issues/32269
vllm_config = get_current_vllm_config().compilation_config
pad_output = vllm_config.mode < CompilationMode.VLLM_COMPILE
return 17 if pad_output else None
class PerTensorTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = (
c.activation_quant_key.scale.group_shape.is_per_tensor()
)
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if not (per_tensor_activation_scales and per_tensor_weight_scales):
return False, "requires per tensor activation and weight scales."
return True, None
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
output = torch._scaled_mm(
A, B, out_dtype=out_dtype, scale_a=As, scale_b=Bs, bias=bias
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
return torch.narrow(output, 0, 0, output_shape[0]).view(*output_shape)
class RowWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_rocm():
return False, "requires ROCm."
from vllm.platforms.rocm import on_mi3xx
if not on_mi3xx():
return False, "requires MI3xx."
if compute_capability is not None and compute_capability < 94:
return False, "requires compute capability 94 and above."
if not version.parse(torch.__version__) >= version.parse("2.7"):
return False, "requires pytorch version >=2.7."
return True, None
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = (
c.activation_quant_key.scale.group_shape.is_per_tensor()
)
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if c.out_dtype == torch.float16:
# hipblaslt rowwise _scaled_mm only supports BFloat16
return False, "supports BFloat16 output data type only."
if per_tensor_activation_scales or per_tensor_weight_scales:
return False, "cannot be used with per tensor activation and weight scales."
return True, None
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
# Note:
# For now it has only been validated on ROCm platform.
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
#
# For CUDA platform please validate if the torch._scaled_mm supports
# rowwise scaled GEMM before using it
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(
A,
B,
out_dtype=out_dtype,
scale_a=As,
scale_b=Bs.t(),
bias=bias,
)
return torch.narrow(output, 0, 0, output_shape[0]).view(*output_shape)
class ChannelWiseTorchFP8ScaledMMLinearKernel(TorchFP8ScaledMMLinearKernel):
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = (
c.activation_quant_key.scale.group_shape.is_per_tensor()
)
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if per_tensor_activation_scales and per_tensor_weight_scales:
return False, "cannot be used with per tensor activation and weight scales."
return True, None
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
# Use unfused DQ due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as scales
dummy_tensor = torch.ones(1, dtype=torch.float32, device=A.device)
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(
A,
B,
scale_a=dummy_tensor,
scale_b=dummy_tensor,
out_dtype=torch.float32,
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, output_shape[0])
x_scale = torch.narrow(As, 0, 0, output_shape[0])
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * Bs.t()
if bias is not None:
output = output + bias
return output.to(out_dtype).view(*output_shape)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils.platform_utils import get_cu_count
from vllm.utils.torch_utils import direct_register_custom_op
from .ScaledMMLinearKernel import (
FP8ScaledMMLinearKernel,
FP8ScaledMMLinearLayerConfig,
)
def rocm_per_tensor_float_w8a8_scaled_mm_impl(
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
if (
A.shape[0] == 1
and B.shape[1] % 16 == 0
and ((bias is None) or (bias.dtype == out_dtype))
):
output = ops.wvSplitKQ(
B.t(),
A,
out_dtype,
As,
Bs,
get_cu_count(),
bias,
)
# Fallback
else:
output = torch._scaled_mm(
A,
B,
out_dtype=out_dtype,
scale_a=As,
scale_b=Bs,
bias=bias,
)
return output
def rocm_per_tensor_float_w8a8_scaled_mm_fake(
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
return A.new_empty((*A.shape[:-1], B.shape[1]), dtype=out_dtype)
if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_per_tensor_float_w8a8_scaled_mm_impl",
op_func=rocm_per_tensor_float_w8a8_scaled_mm_impl,
fake_impl=rocm_per_tensor_float_w8a8_scaled_mm_fake,
)
class ROCmFP8ScaledMMLinearKernel(FP8ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_rocm():
return False, "requires ROCm."
from vllm.platforms.rocm import on_mi3xx
if not on_mi3xx():
return False, "requires MI3xx."
if not envs.VLLM_ROCM_USE_SKINNY_GEMM:
return False, "requires VLLM_ROCM_USE_SKINNY_GEMM to be enabled."
return True, None
@classmethod
def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
per_tensor_activation_scales = (
c.activation_quant_key.scale.group_shape.is_per_tensor()
)
per_tensor_weight_scales = c.weight_quant_key.scale.group_shape.is_per_tensor()
if not (per_tensor_activation_scales and per_tensor_weight_scales):
return False, "requires per tensor activation and weight scales."
return True, None
def apply_scaled_mm(
self,
*,
A: torch.Tensor,
B: torch.Tensor,
out_dtype: torch.dtype,
As: torch.Tensor,
Bs: torch.Tensor,
bias: torch.Tensor | None,
output_shape: list,
) -> torch.Tensor:
output = torch.ops.vllm.rocm_per_tensor_float_w8a8_scaled_mm_impl(
A, B, out_dtype, As, Bs, bias
)
return torch.narrow(output, 0, 0, A.shape[0]).view(*output_shape)
...@@ -14,30 +14,35 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ...@@ -14,30 +14,35 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig from .cutlass import CutlassInt8ScaledMMLinearKernel
from .ScaledMMLinearKernel import (
Int8ScaledMMLinearLayerConfig,
)
class TritonScaledMMLinearKernel(ScaledMMLinearKernel): class TritonInt8ScaledMMLinearKernel(CutlassInt8ScaledMMLinearKernel):
@classmethod @classmethod
def is_supported( def is_supported(
cls, compute_capability: int | None = None cls, compute_capability: int | None = None
) -> tuple[bool, str | None]: ) -> tuple[bool, str | None]:
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
return True, None return True, None
return False, "Requires ROCm or CUDA." return False, "requires ROCm or CUDA."
@classmethod @classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: def can_implement(cls, c: Int8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not c.input_symmetric: if not c.input_symmetric:
return False, "Only symmetric input is supported." return False, "supports symmetric input only."
return True, None return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = getattr(layer, self.w_q_name) w_q, _, i_s, _, _ = self._get_layer_params(layer)
w_q_name, w_s_name, i_s_name, i_zp_name, azp_adj_name = self.layer_param_names
replace_parameter( replace_parameter(
layer, layer,
self.w_q_name, w_q_name,
torch.nn.Parameter(weight.t().data, requires_grad=False), torch.nn.Parameter(w_q.t().data, requires_grad=False),
) )
# WEIGHT SCALE # WEIGHT SCALE
...@@ -45,29 +50,29 @@ class TritonScaledMMLinearKernel(ScaledMMLinearKernel): ...@@ -45,29 +50,29 @@ class TritonScaledMMLinearKernel(ScaledMMLinearKernel):
# If we have a fused module (QKV, MLP) with per tensor scales (thus N # If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case. # scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1 is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name) weight_scale = getattr(layer, w_s_name)
if is_fused_module and not self.config.is_channelwise: if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths) weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
replace_parameter( replace_parameter(
layer, layer,
self.w_s_name, w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False), torch.nn.Parameter(weight_scale.data, requires_grad=False),
) )
# INPUT SCALE # INPUT SCALE
if self.config.is_static_input_scheme: if self.config.is_static_input_scheme:
input_scale = getattr(layer, self.i_s_name) assert i_s is not None
replace_parameter( replace_parameter(
layer, layer,
self.i_s_name, i_s_name,
torch.nn.Parameter(input_scale.max(), requires_grad=False), torch.nn.Parameter(i_s.max(), requires_grad=False),
) )
setattr(layer, self.i_zp_name, None) setattr(layer, i_zp_name, None)
else: else:
setattr(layer, self.i_s_name, None) setattr(layer, i_s_name, None)
setattr(layer, self.i_zp_name, None) setattr(layer, i_zp_name, None)
setattr(layer, self.azp_adj_name, None) setattr(layer, azp_adj_name, None)
def apply_weights( def apply_weights(
self, self,
...@@ -75,7 +80,7 @@ class TritonScaledMMLinearKernel(ScaledMMLinearKernel): ...@@ -75,7 +80,7 @@ class TritonScaledMMLinearKernel(ScaledMMLinearKernel):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) w_q, w_s, i_s, i_zp, _ = self._get_layer_params(layer)
x_q, x_s, x_zp = ops.scaled_int8_quant( x_q, x_s, x_zp = ops.scaled_int8_quant(
x.contiguous(), i_s, i_zp, symmetric=True x.contiguous(), i_s, i_zp, symmetric=True
......
...@@ -49,6 +49,9 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -49,6 +49,9 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import (
build_flashinfer_fp4_cutlass_moe_prepare_finalize, build_flashinfer_fp4_cutlass_moe_prepare_finalize,
...@@ -78,10 +81,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -78,10 +81,12 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape, GroupShape,
cutlass_fp4_supported, cutlass_fp4_supported,
is_layer_skipped, is_layer_skipped,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kFp8StaticTokenSym,
swizzle_blockscale, swizzle_blockscale,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
cutlass_block_fp8_supported, cutlass_block_fp8_supported,
requantize_with_max_scale, requantize_with_max_scale,
) )
...@@ -438,8 +443,11 @@ class ModelOptFp8LinearMethod(LinearMethodBase): ...@@ -438,8 +443,11 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: ModelOptFp8Config) -> None: def __init__(self, quant_config: ModelOptFp8Config) -> None:
self.quant_config = quant_config self.quant_config = quant_config
self.fp8_linear = Fp8LinearOp( self.fp8_linear = init_fp8_linear_kernel(
act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR activation_quant_key=kFp8StaticTensorSym,
weight_quant_key=kFp8StaticTensorSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
) )
def create_weights( def create_weights(
...@@ -507,13 +515,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase): ...@@ -507,13 +515,7 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.fp8_linear.apply( return self.fp8_linear.apply_weights(layer, x, bias)
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
)
class ModelOptFp8PcPtLinearMethod(LinearMethodBase): class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
...@@ -527,8 +529,11 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase): ...@@ -527,8 +529,11 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
def __init__(self, quant_config: ModelOptFp8Config) -> None: def __init__(self, quant_config: ModelOptFp8Config) -> None:
self.quant_config = quant_config self.quant_config = quant_config
self.fp8_linear = Fp8LinearOp( self.fp8_linear = init_fp8_linear_kernel(
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN activation_quant_key=kFp8DynamicTokenSym,
weight_quant_key=kFp8StaticTokenSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
) )
def create_weights( def create_weights(
...@@ -585,13 +590,7 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase): ...@@ -585,13 +590,7 @@ class ModelOptFp8PcPtLinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.fp8_linear.apply( return self.fp8_linear.apply_weights(layer, x, bias)
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=None,
bias=bias,
)
class ModelOptFp8PbWoLinearMethod(LinearMethodBase): class ModelOptFp8PbWoLinearMethod(LinearMethodBase):
......
...@@ -17,11 +17,13 @@ from vllm.model_executor.layers.quantization.fp8 import ( ...@@ -17,11 +17,13 @@ from vllm.model_executor.layers.quantization.fp8 import (
Fp8KVCacheMethod, Fp8KVCacheMethod,
Fp8LinearMethod, Fp8LinearMethod,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
is_layer_skipped, is_layer_skipped,
kFp8DynamicTokenSym,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
...@@ -97,9 +99,11 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): ...@@ -97,9 +99,11 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
) )
super().__init__(quant_config=quant_config) super().__init__(quant_config=quant_config)
# Force weight quantization # Force weight quantization
self.quant_config.is_checkpoint_fp8_serialized = False self.fp8_linear = init_fp8_linear_kernel(
self.fp8_linear = Fp8LinearOp( activation_quant_key=kFp8DynamicTokenSym,
act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN weight_quant_key=kFp8DynamicTokenSym,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
) )
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
...@@ -130,11 +134,4 @@ class PTPCFp8LinearMethod(Fp8LinearMethod): ...@@ -130,11 +134,4 @@ class PTPCFp8LinearMethod(Fp8LinearMethod):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.fp8_linear.apply( return self.fp8_linear.apply_weights(layer, x, bias)
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=None,
input_scale_ub=None,
bias=bias,
)
...@@ -7,10 +7,18 @@ from typing import Any, cast ...@@ -7,10 +7,18 @@ from typing import Any, cast
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
kFp8DynamicTokenSym,
kFp8StaticTensorSym,
kFp8StaticTokenSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
requantize_with_max_scale, requantize_with_max_scale,
) )
...@@ -23,6 +31,8 @@ from vllm.platforms import current_platform ...@@ -23,6 +31,8 @@ from vllm.platforms import current_platform
__all__ = ["QuarkW8A8Fp8"] __all__ = ["QuarkW8A8Fp8"]
logger = init_logger(__name__)
class QuarkW8A8Fp8(QuarkScheme): class QuarkW8A8Fp8(QuarkScheme):
def __init__( def __init__(
...@@ -35,15 +45,16 @@ class QuarkW8A8Fp8(QuarkScheme): ...@@ -35,15 +45,16 @@ class QuarkW8A8Fp8(QuarkScheme):
self.is_static_input_scheme = not cast(bool, input_config.get("is_dynamic")) self.is_static_input_scheme = not cast(bool, input_config.get("is_dynamic"))
self.input_qscheme = cast(str, input_config.get("qscheme")) self.input_qscheme = cast(str, input_config.get("qscheme"))
per_token = ( per_token_activation = (
not self.is_static_input_scheme and self.input_qscheme == "per_channel" not self.is_static_input_scheme and self.input_qscheme == "per_channel"
) )
self.act_quant_group_shape = ( per_token_weight = self.weight_qscheme == "per_channel"
GroupShape.PER_TOKEN if per_token else GroupShape.PER_TENSOR
self.activation_quant_key = (
kFp8DynamicTokenSym if per_token_activation else kFp8StaticTensorSym
) )
self.fp8_linear = Fp8LinearOp( self.weight_quant_key = (
act_quant_static=self.is_static_input_scheme, kFp8StaticTokenSym if per_token_weight else kFp8StaticTensorSym
act_quant_group_shape=self.act_quant_group_shape,
) )
self.out_dtype = torch.get_default_dtype() self.out_dtype = torch.get_default_dtype()
...@@ -94,7 +105,7 @@ class QuarkW8A8Fp8(QuarkScheme): ...@@ -94,7 +105,7 @@ class QuarkW8A8Fp8(QuarkScheme):
layer.input_scale = Parameter(input_scale, requires_grad=False) layer.input_scale = Parameter(input_scale, requires_grad=False)
else: else:
weight_scale = layer.weight_scale.data weight_scale = layer.weight_scale.data
if self.act_quant_group_shape == GroupShape.PER_TOKEN: if self.activation_quant_key.scale.group_shape == GroupShape.PER_TOKEN:
weight_scale = weight_scale.view(-1, 1) weight_scale = weight_scale.view(-1, 1)
layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight = Parameter(weight.t(), requires_grad=False)
# required by torch.compile to be torch.nn.Parameter # required by torch.compile to be torch.nn.Parameter
...@@ -106,8 +117,6 @@ class QuarkW8A8Fp8(QuarkScheme): ...@@ -106,8 +117,6 @@ class QuarkW8A8Fp8(QuarkScheme):
# INPUT SCALE # INPUT SCALE
if self.is_static_input_scheme: if self.is_static_input_scheme:
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
else:
layer.input_scale = None
def create_weights( def create_weights(
self, self,
...@@ -163,17 +172,17 @@ class QuarkW8A8Fp8(QuarkScheme): ...@@ -163,17 +172,17 @@ class QuarkW8A8Fp8(QuarkScheme):
input_scale[:] = torch.finfo(torch.float32).min input_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("input_scale", input_scale) layer.register_parameter("input_scale", input_scale)
self.fp8_linear = init_fp8_linear_kernel(
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
out_dtype=torch.get_default_dtype(),
module_name=self.__class__.__name__,
)
def apply_weights( def apply_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return self.fp8_linear.apply( return self.fp8_linear.apply_weights(layer, x, bias)
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
out_dtype=self.out_dtype,
input_scale=layer.input_scale,
bias=bias,
)
...@@ -7,8 +7,7 @@ import torch ...@@ -7,8 +7,7 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
ScaledMMLinearLayerConfig, init_int8_linear_kernel,
choose_scaled_mm_linear_kernel,
) )
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
...@@ -22,8 +21,6 @@ logger = init_logger(__name__) ...@@ -22,8 +21,6 @@ logger = init_logger(__name__)
class QuarkW8A8Int8(QuarkScheme): class QuarkW8A8Int8(QuarkScheme):
_kernel_backends_being_used: set[str] = set()
def __init__( def __init__(
self, self,
qscheme: str, qscheme: str,
...@@ -50,18 +47,13 @@ class QuarkW8A8Int8(QuarkScheme): ...@@ -50,18 +47,13 @@ class QuarkW8A8Int8(QuarkScheme):
): ):
layer.logical_widths = output_partition_sizes layer.logical_widths = output_partition_sizes
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( self.kernel = init_int8_linear_kernel(
is_channelwise=(self.qscheme == "per_channel"), is_channelwise=(self.qscheme == "per_channel"),
is_static_input_scheme=(self.is_static_input_scheme is True), is_static_input_scheme=(self.is_static_input_scheme is True),
input_symmetric=(self.input_symmetric is True), input_symmetric=(self.input_symmetric is True),
module_name=self.__class__.__name__,
) )
kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# WEIGHT # WEIGHT
weight = ModelWeightParameter( weight = ModelWeightParameter(
data=torch.empty( data=torch.empty(
...@@ -102,25 +94,21 @@ class QuarkW8A8Int8(QuarkScheme): ...@@ -102,25 +94,21 @@ class QuarkW8A8Int8(QuarkScheme):
layer.register_parameter("weight_zero_point", weight_zero_point) layer.register_parameter("weight_zero_point", weight_zero_point)
# INPUT SCALE # INPUT SCALE
input_zero_point = None
input_scale = None
if self.is_static_input_scheme: if self.is_static_input_scheme:
input_scale = BasevLLMParameter( input_scale = BasevLLMParameter(
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
) )
layer.register_parameter("input_scale", input_scale)
input_zero_point = BasevLLMParameter( input_zero_point = BasevLLMParameter(
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
) )
layer.register_parameter("input_zero_point", input_zero_point)
self.kernel = kernel_type( layer.register_parameter("input_scale", input_scale)
c=scaled_mm_linear_kernel_config, layer.register_parameter("input_zero_point", input_zero_point)
w_q_param_name="weight", if not hasattr(layer, "azp_adj"):
w_s_param_name="weight_scale", layer.register_parameter("azp_adj", None)
i_s_param_name="input_scale",
i_zp_param_name="input_zero_point",
azp_adj_param_name="azp_adj",
)
# Checkpoints are serialized in quark format, which is # Checkpoints are serialized in quark format, which is
# different from the format the kernel may want. Handle repacking here. # different from the format the kernel may want. Handle repacking here.
......
...@@ -123,6 +123,9 @@ kFp8StaticTensorSym = QuantKey(FP8_DTYPE, kStaticTensorScale, symmetric=True) ...@@ -123,6 +123,9 @@ kFp8StaticTensorSym = QuantKey(FP8_DTYPE, kStaticTensorScale, symmetric=True)
kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR) kDynamicTensorScale = ScaleDesc(torch.float32, False, GroupShape.PER_TENSOR)
kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True) kFp8DynamicTensorSym = QuantKey(FP8_DTYPE, kDynamicTensorScale, symmetric=True)
kStaticTokenScale = ScaleDesc(torch.float32, True, GroupShape.PER_TOKEN)
kFp8StaticTokenSym = QuantKey(FP8_DTYPE, kStaticTokenScale, symmetric=True)
kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN) kDynamicTokenScale = ScaleDesc(torch.float32, False, GroupShape.PER_TOKEN)
kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True) kFp8DynamicTokenSym = QuantKey(FP8_DTYPE, kDynamicTokenScale, symmetric=True)
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
import torch import torch
from packaging import version
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs
from vllm.config import CompilationMode, get_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.flashinfer import flashinfer_scaled_fp8_mm, has_flashinfer
from vllm.utils.platform_utils import get_cu_count
from vllm.utils.torch_utils import direct_register_custom_op
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = None
# The condition to determine if it is on a platform that supports
# torch._scaled_mm rowwise feature.
# The condition is determined once as the operations
# are time-consuming.
USE_ROWWISE_TORCH_SCALED_MM = (
current_platform.is_rocm()
and version.parse(torch.__version__) >= version.parse("2.7")
and current_platform.has_device_capability(94)
)
def sparse_cutlass_supported() -> bool: def sparse_cutlass_supported() -> bool:
...@@ -140,361 +117,6 @@ def requantize_with_max_scale( ...@@ -140,361 +117,6 @@ def requantize_with_max_scale(
return max_w_scale, weight return max_w_scale, weight
def maybe_create_device_identity():
# Allocate dummy ones tensor for torch._scaled_mm
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY is None:
TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32)
def cutlass_w8a8_scaled_mm(
*,
qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
**kwargs,
) -> torch.Tensor:
# Fused GEMM_DQ
output = ops.cutlass_scaled_mm(
qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias
)
return output.view(*output_shape)
def flashinfer_w8a8_scaled_mm(
*,
qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
**kwargs,
) -> torch.Tensor:
return flashinfer_scaled_fp8_mm(
qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias
)
def rocm_per_tensor_w8a8_scaled_mm_impl(
qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
from vllm.platforms.rocm import on_mi3xx
if (
envs.VLLM_ROCM_USE_SKINNY_GEMM
and on_mi3xx()
and qinput.shape[0] == 1
and qinput.shape[1] % 16 == 0
and ((bias is None) or (bias.dtype == out_dtype))
):
output = ops.wvSplitKQ(
weight.t(),
qinput,
out_dtype,
scale_a,
scale_b,
get_cu_count(),
bias,
)
else:
output = torch._scaled_mm(
qinput,
weight,
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b,
bias=bias,
)
return output
def rocm_per_tensor_w8a8_scaled_mm_fake(
qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor,
) -> torch.Tensor:
return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]), dtype=out_dtype)
def rocm_per_tensor_w8a8_scaled_mm(
*,
qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
) -> torch.Tensor:
output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl(
qinput, weight, out_dtype, scale_a, scale_b, bias
)
return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape)
direct_register_custom_op(
op_name="rocm_per_tensor_w8a8_scaled_mm_impl",
op_func=rocm_per_tensor_w8a8_scaled_mm_impl,
fake_impl=rocm_per_tensor_w8a8_scaled_mm_fake,
)
def torch_per_tensor_w8a8_scaled_mm(
*,
qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
) -> torch.Tensor:
output = torch._scaled_mm(
qinput, weight, out_dtype=out_dtype, scale_a=scale_a, scale_b=scale_b, bias=bias
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape)
def torch_per_token_w8a8_scaled_mm(
*,
qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
**kwargs,
) -> torch.Tensor:
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
# when using it.
# For now it has only been validated on ROCm platform.
# fp8 rowwise scaling in torch._scaled_mm is introduced in
# https://github.com/pytorch/pytorch/pull/144432 using
# hipBLASLt and ROCm 6.3, which only exists in torch 2.7 and above.
#
# For CUDA platform please validate if the torch._scaled_mm supports
# rowwise scaled GEMM before using it
# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(
qinput,
weight,
out_dtype=out_dtype,
scale_a=scale_a,
scale_b=scale_b.t(),
bias=bias,
)
output = torch.narrow(output, 0, 0, qinput.shape[0])
output = output.view(*output_shape)
return output
def torch_channelwise_w8a8_scaled_mm(
*,
qinput: torch.Tensor,
weight: torch.Tensor,
out_dtype: torch.dtype,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor,
output_shape: list,
**kwargs,
) -> torch.Tensor:
# Use unfused DQ due to limitations with scaled_mm
# Symmetric quantized GEMM by definition computes the following:
# C = (s_x * X) (s_w * W) + bias
# This is equivalent to dequantizing the weights and activations
# before applying a GEMM.
#
# In order to compute quantized operands, a quantized kernel
# will rewrite the above like so:
# C = s_w * s_x * (X * W) + bias
#
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output = torch._scaled_mm(
qinput,
weight,
scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32,
)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, qinput.shape[0])
x_scale = torch.narrow(scale_a, 0, 0, qinput.shape[0])
# DQ
# C = sw * sx * (X * W) + bias
output = output * x_scale * scale_b.t()
if bias is not None:
output = output + bias
return output.to(out_dtype).view(*output_shape)
def dispatch_w8a8_scaled_mm(
preferred_backend: str, per_tensor_weights: bool, per_tensor_activations: bool
) -> Callable[..., torch.Tensor]:
if per_tensor_weights and per_tensor_activations:
if preferred_backend == "rocm":
return rocm_per_tensor_w8a8_scaled_mm
if preferred_backend == "flashinfer":
return flashinfer_w8a8_scaled_mm
if preferred_backend == "cutlass":
return cutlass_w8a8_scaled_mm
return torch_per_tensor_w8a8_scaled_mm
# cutlass_scaled_mm supports per tensor/channel W and per tensor/token A
if preferred_backend == "cutlass" or preferred_backend == "flashinfer":
return cutlass_w8a8_scaled_mm
# If torch.scaled_mm supports per-channel (weights) per-token (inputs)
if (
not per_tensor_weights
and not per_tensor_activations
and USE_ROWWISE_TORCH_SCALED_MM
):
return torch_per_token_w8a8_scaled_mm
# Normally, torch.scaled_mm supports per tensor weights + activations only
# so fallback to naive if per channel or per token
return torch_channelwise_w8a8_scaled_mm
# TODO(luka): follow similar pattern for marlin and block-fp8-linear
# https://github.com/vllm-project/vllm/issues/14397
class Fp8LinearOp:
"""
This class executes a FP8 linear layer using cutlass if supported and
torch.scaled_mm otherwise.
It needs to be a class instead of a method so that config can be read
in the __init__ method, as reading config is not allowed inside forward.
"""
def __init__(
self,
act_quant_static: bool,
act_quant_group_shape: GroupShape = GroupShape.PER_TENSOR,
pad_output: bool | None = None,
):
if current_platform.is_rocm():
self.preferred_backend = "rocm"
elif current_platform.is_cuda() and cutlass_fp8_supported():
if has_flashinfer() and current_platform.has_device_capability(100):
self.preferred_backend = "flashinfer"
else:
self.preferred_backend = "cutlass"
else:
self.preferred_backend = "torch"
# Note: we pad the input because torch._scaled_mm is more performant
# for matrices with batch dimension > 16.
# This could change in the future.
# We also don't pad when using torch.compile,
# as it breaks with dynamic shapes.
if pad_output is None:
config = get_current_vllm_config().compilation_config
pad_output = (
config.mode < CompilationMode.VLLM_COMPILE
and self.preferred_backend == "torch"
)
self.output_padding = 17 if pad_output else None
self.act_quant_static = act_quant_static
self.act_quant_group_shape = act_quant_group_shape
self.quant_fp8 = QuantFP8(
static=act_quant_static,
group_shape=act_quant_group_shape,
num_token_padding=self.output_padding,
)
def apply(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
out_dtype: torch.dtype | None = None,
input_scale: torch.Tensor | None = None,
input_scale_ub: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
# ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.input_scale is None and x_scale computed from x.
# If static, layer.input_scale is scalar and x_scale is input_scale.
# View input as 2D matrix for fp8 methods
input_2d = input.view(-1, input.shape[-1])
output_shape = [*input.shape[:-1], weight.shape[1]]
if out_dtype is None:
out_dtype = input.dtype
# If input not quantized
# TODO(luka) remove this path if not used anymore
if input.dtype != current_platform.fp8_dtype():
qinput, x_scale = self.quant_fp8(
input_2d,
input_scale,
input_scale_ub,
)
else:
qinput, x_scale = input_2d, input_scale
# Must have dim() conditions
# In per-token quant scenario, when the number of token is 1,
# the scale will only have 1 elements.
# Without checking the dim(),
# we cannot distingushes between per-tensor and per-token quant.
# Example:
# When the number of token is 1, per-token scale is [[1]]
# When per-tensor scale is [1] or ().
per_tensor_weights = weight_scale.numel() == 1
per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
# TODO(luka) do this dispatch during init (after ScaledMM refactor)
w8a8_scaled_mm_func = dispatch_w8a8_scaled_mm(
self.preferred_backend, per_tensor_weights, per_tensor_activations
)
return w8a8_scaled_mm_func(
qinput=qinput,
weight=weight,
out_dtype=out_dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias,
output_shape=output_shape,
)
def normalize_e4m3fn_to_e4m3fnuz( def normalize_e4m3fn_to_e4m3fnuz(
weight: torch.Tensor, weight: torch.Tensor,
weight_scale: torch.Tensor, weight_scale: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment