Unverified Commit 55d037e2 authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[CT][FP8][Marlin] refactor CompressedTensorsW8A16Fp8 to use kernel abstraction (#38244)


Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
Signed-off-by: default avatarKunshang Ji <jikunshang95@gmail.com>
parent ecbfbb8d
...@@ -177,6 +177,21 @@ _POSSIBLE_FP8_BLOCK_KERNELS: dict[ ...@@ -177,6 +177,21 @@ _POSSIBLE_FP8_BLOCK_KERNELS: dict[
], ],
} }
_POSSIBLE_WFP8A16_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = {
PlatformEnum.CUDA: [
MarlinFP8ScaledMMLinearKernel,
],
PlatformEnum.ROCM: [
# To be added
],
PlatformEnum.CPU: [
# To be added
],
PlatformEnum.XPU: [
# To be added
],
}
# in priority/performance order (when available) # in priority/performance order (when available)
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = { _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
PlatformEnum.CUDA: [ PlatformEnum.CUDA: [
...@@ -463,6 +478,41 @@ def choose_mp_linear_kernel( ...@@ -463,6 +478,41 @@ def choose_mp_linear_kernel(
) )
def init_wfp8_a16_linear_kernel(
weight_quant_key: QuantKey,
activation_quant_key: QuantKey,
weight_shape: tuple[int, int],
input_dtype: torch.dtype,
out_dtype: torch.dtype,
force_kernel: type[FP8ScaledMMLinearKernel] | None = None,
module_name: str | None = None,
) -> FP8ScaledMMLinearKernel:
config = FP8ScaledMMLinearLayerConfig(
weight_quant_key=weight_quant_key,
activation_quant_key=activation_quant_key,
weight_shape=weight_shape,
input_dtype=input_dtype,
out_dtype=out_dtype,
)
kernel_type = choose_scaled_mm_linear_kernel(
config, _POSSIBLE_WFP8A16_KERNELS, force_kernel=force_kernel
)
if module_name:
logger.info_once(
"Selected %s for %s",
kernel_type.__name__,
module_name,
scope="global",
)
return kernel_type(
config,
layer_param_names=["weight", "weight_scale", "input_scale", "input_scale_ub"],
)
# Maps VLLM_NVFP4_GEMM_BACKEND env var values to kernel classes. # Maps VLLM_NVFP4_GEMM_BACKEND env var values to kernel classes.
_NVFP4_BACKEND_TO_KERNEL: dict[str, type[NvFp4LinearKernel]] = { _NVFP4_BACKEND_TO_KERNEL: dict[str, type[NvFp4LinearKernel]] = {
"flashinfer-cutlass": FlashInferCutlassNvFp4LinearKernel, "flashinfer-cutlass": FlashInferCutlassNvFp4LinearKernel,
...@@ -588,6 +638,7 @@ __all__ = [ ...@@ -588,6 +638,7 @@ __all__ = [
"init_nvfp4_linear_kernel", "init_nvfp4_linear_kernel",
"choose_mp_linear_kernel", "choose_mp_linear_kernel",
"register_linear_kernel", "register_linear_kernel",
"init_wfp8_a16_linear_kernel",
"FP8ScaledMMLinearKernel", "FP8ScaledMMLinearKernel",
"Int8ScaledMMLinearKernel", "Int8ScaledMMLinearKernel",
"ScaledMMLinearKernel", "ScaledMMLinearKernel",
......
...@@ -6,45 +6,49 @@ from collections.abc import Callable ...@@ -6,45 +6,49 @@ from collections.abc import Callable
import torch import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
from vllm.config import get_current_vllm_config
from vllm.model_executor.kernels.linear import (
init_wfp8_a16_linear_kernel,
)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsScheme,
) )
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
STRATEGY_TO_PARAMETER_TYPE,
STRATEGY_TO_WEIGHT_QUANT_KEY,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
create_fp8_scale_parameter, create_fp8_scale_parameter,
create_fp8_weight_parameter, create_fp8_weight_parameter,
process_fp8_weight_block_strategy,
validate_fp8_block_shape, validate_fp8_block_shape,
) )
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
apply_fp8_marlin_linear, kFp8DynamicTensorSym,
prepare_fp8_layer_for_marlin, kFp8StaticTensorSym,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise, convert_to_channelwise,
) )
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import PerTensorScaleParameter
BlockQuantScaleParameter,
ChannelQuantScaleParameter,
PerTensorScaleParameter,
)
from vllm.model_executor.utils import replace_parameter from vllm.model_executor.utils import replace_parameter
__all__ = ["CompressedTensorsW8A16Fp8"] __all__ = ["CompressedTensorsW8A16Fp8"]
strategy_to_parameter_type = {
QuantizationStrategy.BLOCK: BlockQuantScaleParameter,
QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter,
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
}
class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool): def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool):
self.weight_quant = weight_quant self.weight_quant = weight_quant
self.strategy = weight_quant.strategy self.strategy = weight_quant.strategy
self.out_dtype = torch.get_default_dtype()
self.input_dtype = get_current_vllm_config().model_config.dtype
self.is_static_input_scheme = is_static_input_scheme self.is_static_input_scheme = is_static_input_scheme
self.weight_block_size = self.weight_quant.block_structure self.weight_block_size = self.weight_quant.block_structure
self.weight_quant_key = STRATEGY_TO_WEIGHT_QUANT_KEY[self.strategy]
self.activation_quant_key = (
kFp8StaticTensorSym if is_static_input_scheme else kFp8DynamicTensorSym
)
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
# turing and up # turing and up
...@@ -89,7 +93,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): ...@@ -89,7 +93,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
# WEIGHT SCALE # WEIGHT SCALE
weight_scale = create_fp8_scale_parameter( weight_scale = create_fp8_scale_parameter(
strategy_to_parameter_type[self.strategy], STRATEGY_TO_PARAMETER_TYPE[self.strategy],
output_partition_sizes, output_partition_sizes,
input_size_per_partition, input_size_per_partition,
layer.weight_block_size, layer.weight_block_size,
...@@ -105,32 +109,36 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): ...@@ -105,32 +109,36 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
) )
layer.register_parameter("input_scale", input_scale) layer.register_parameter("input_scale", input_scale)
self.linear_kernel = init_wfp8_a16_linear_kernel(
weight_quant_key=self.weight_quant_key,
activation_quant_key=self.activation_quant_key,
weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = layer.weight
weight_scale = layer.weight_scale
size_k_first = True
# TODO(rob): refactor block quant into separate class.
if self.strategy == QuantizationStrategy.BLOCK: if self.strategy == QuantizationStrategy.BLOCK:
assert self.is_static_input_scheme is False assert self.is_static_input_scheme is False
size_k_first = False # MarlinFP8ScaledMMLinearKernel uses "weight_scale_inv" for block
weight, weight_scale = process_fp8_weight_block_strategy( # quant, while CT registers the scale as "weight_scale".
weight, weight_scale # Rename by deleting the old parameter and adding the new one so
) # that prepare_fp8_layer_for_marlin (which prefers "weight_scale"
# over "weight_scale_inv") picks up "weight_scale_inv" correctly.
weight_scale_data = layer.weight_scale.data
del layer._parameters["weight_scale"]
replace_parameter(layer, "weight_scale_inv", weight_scale_data)
else: else:
# Weights must be transposed for marlin
weight = weight.t()
if self.strategy == QuantizationStrategy.TENSOR: if self.strategy == QuantizationStrategy.TENSOR:
# If we have a fused module (QKV, MLP) with per tensor scales, # For fused modules with per-tensor scales, expand each scale
# we expand each scale to its shard's channels. # to its shard's channels.
weight_scale = convert_to_channelwise( replace_parameter(
weight_scale, layer.logical_widths layer,
"weight_scale",
convert_to_channelwise(layer.weight_scale, layer.logical_widths),
) )
# Update layer with new values self.linear_kernel.process_weights_after_loading(layer)
replace_parameter(layer, "weight", weight.data)
replace_parameter(layer, "weight_scale", weight_scale.data)
prepare_fp8_layer_for_marlin(layer, size_k_first=size_k_first)
def apply_weights( def apply_weights(
self, self,
...@@ -138,12 +146,4 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): ...@@ -138,12 +146,4 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
return apply_fp8_marlin_linear( return self.linear_kernel.apply_weights(layer, x, bias)
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)
...@@ -16,6 +16,9 @@ from vllm.model_executor.kernels.linear import ( ...@@ -16,6 +16,9 @@ from vllm.model_executor.kernels.linear import (
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsScheme,
) )
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
STRATEGY_TO_PARAMETER_TYPE,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
create_fp8_input_scale, create_fp8_input_scale,
create_fp8_scale_parameter, create_fp8_scale_parameter,
...@@ -34,20 +37,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -34,20 +37,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported, cutlass_block_fp8_supported,
) )
from vllm.model_executor.parameter import (
BlockQuantScaleParameter,
ChannelQuantScaleParameter,
PerTensorScaleParameter,
)
__all__ = ["CompressedTensorsW8A8Fp8"] __all__ = ["CompressedTensorsW8A8Fp8"]
strategy_to_parameter_type = {
QuantizationStrategy.BLOCK: BlockQuantScaleParameter,
QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter,
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
}
STATIC_QUANT = True STATIC_QUANT = True
DYNAMIC_QUANT = False DYNAMIC_QUANT = False
activation_quant_key_mapping = { activation_quant_key_mapping = {
...@@ -130,7 +122,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -130,7 +122,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
# WEIGHT SCALE # WEIGHT SCALE
weight_scale = create_fp8_scale_parameter( weight_scale = create_fp8_scale_parameter(
strategy_to_parameter_type[self.strategy], STRATEGY_TO_PARAMETER_TYPE[self.strategy],
output_partition_sizes, output_partition_sizes,
input_size_per_partition, input_size_per_partition,
layer.weight_block_size, layer.weight_block_size,
......
...@@ -6,8 +6,36 @@ from types import MappingProxyType ...@@ -6,8 +6,36 @@ from types import MappingProxyType
import regex as re import regex as re
from compressed_tensors import CompressionFormat from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import QuantizationStrategy
from torch.nn import Module from torch.nn import Module
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
)
from vllm.model_executor.parameter import (
BlockQuantScaleParameter,
ChannelQuantScaleParameter,
PerTensorScaleParameter,
)
# Maps quantization strategy to the corresponding scale parameter type.
# Shared across compressed-tensor scheme classes (w8a16_fp8, w8a8_fp8, …).
STRATEGY_TO_PARAMETER_TYPE = {
QuantizationStrategy.BLOCK: BlockQuantScaleParameter,
QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter,
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
}
# Maps quantization strategy to the vLLM weight-quant key used for
# kernel selection. Shared across compressed-tensor scheme classes.
STRATEGY_TO_WEIGHT_QUANT_KEY = {
QuantizationStrategy.BLOCK: kFp8Static128BlockSym,
QuantizationStrategy.CHANNEL: kFp8StaticChannelSym,
QuantizationStrategy.TENSOR: kFp8StaticTensorSym,
}
def is_activation_quantization_format(format: str) -> bool: def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [ _ACTIVATION_QUANTIZATION_FORMATS = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment