"vscode:/vscode.git/clone" did not exist on "07c693909e3689f1ac045528fa578d43df1f7bd8"
Unverified Commit 55d037e2 authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

[CT][FP8][Marlin] refactor CompressedTensorsW8A16Fp8 to use kernel abstraction (#38244)


Signed-off-by: default avatarKunshang Ji <kunshang.ji@intel.com>
Signed-off-by: default avatarKunshang Ji <jikunshang95@gmail.com>
parent ecbfbb8d
......@@ -177,6 +177,21 @@ _POSSIBLE_FP8_BLOCK_KERNELS: dict[
],
}
_POSSIBLE_WFP8A16_KERNELS: dict[PlatformEnum, list[type[FP8ScaledMMLinearKernel]]] = {
PlatformEnum.CUDA: [
MarlinFP8ScaledMMLinearKernel,
],
PlatformEnum.ROCM: [
# To be added
],
PlatformEnum.CPU: [
# To be added
],
PlatformEnum.XPU: [
# To be added
],
}
# in priority/performance order (when available)
_POSSIBLE_KERNELS: dict[PlatformEnum, list[type[MPLinearKernel]]] = {
PlatformEnum.CUDA: [
......@@ -463,6 +478,41 @@ def choose_mp_linear_kernel(
)
def init_wfp8_a16_linear_kernel(
weight_quant_key: QuantKey,
activation_quant_key: QuantKey,
weight_shape: tuple[int, int],
input_dtype: torch.dtype,
out_dtype: torch.dtype,
force_kernel: type[FP8ScaledMMLinearKernel] | None = None,
module_name: str | None = None,
) -> FP8ScaledMMLinearKernel:
config = FP8ScaledMMLinearLayerConfig(
weight_quant_key=weight_quant_key,
activation_quant_key=activation_quant_key,
weight_shape=weight_shape,
input_dtype=input_dtype,
out_dtype=out_dtype,
)
kernel_type = choose_scaled_mm_linear_kernel(
config, _POSSIBLE_WFP8A16_KERNELS, force_kernel=force_kernel
)
if module_name:
logger.info_once(
"Selected %s for %s",
kernel_type.__name__,
module_name,
scope="global",
)
return kernel_type(
config,
layer_param_names=["weight", "weight_scale", "input_scale", "input_scale_ub"],
)
# Maps VLLM_NVFP4_GEMM_BACKEND env var values to kernel classes.
_NVFP4_BACKEND_TO_KERNEL: dict[str, type[NvFp4LinearKernel]] = {
"flashinfer-cutlass": FlashInferCutlassNvFp4LinearKernel,
......@@ -588,6 +638,7 @@ __all__ = [
"init_nvfp4_linear_kernel",
"choose_mp_linear_kernel",
"register_linear_kernel",
"init_wfp8_a16_linear_kernel",
"FP8ScaledMMLinearKernel",
"Int8ScaledMMLinearKernel",
"ScaledMMLinearKernel",
......
......@@ -6,45 +6,49 @@ from collections.abc import Callable
import torch
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
from vllm.config import get_current_vllm_config
from vllm.model_executor.kernels.linear import (
init_wfp8_a16_linear_kernel,
)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
STRATEGY_TO_PARAMETER_TYPE,
STRATEGY_TO_WEIGHT_QUANT_KEY,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
create_fp8_scale_parameter,
create_fp8_weight_parameter,
process_fp8_weight_block_strategy,
validate_fp8_block_shape,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin,
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8DynamicTensorSym,
kFp8StaticTensorSym,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise,
)
from vllm.model_executor.parameter import (
BlockQuantScaleParameter,
ChannelQuantScaleParameter,
PerTensorScaleParameter,
)
from vllm.model_executor.parameter import PerTensorScaleParameter
from vllm.model_executor.utils import replace_parameter
__all__ = ["CompressedTensorsW8A16Fp8"]
strategy_to_parameter_type = {
QuantizationStrategy.BLOCK: BlockQuantScaleParameter,
QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter,
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
}
class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool):
self.weight_quant = weight_quant
self.strategy = weight_quant.strategy
self.out_dtype = torch.get_default_dtype()
self.input_dtype = get_current_vllm_config().model_config.dtype
self.is_static_input_scheme = is_static_input_scheme
self.weight_block_size = self.weight_quant.block_structure
self.weight_quant_key = STRATEGY_TO_WEIGHT_QUANT_KEY[self.strategy]
self.activation_quant_key = (
kFp8StaticTensorSym if is_static_input_scheme else kFp8DynamicTensorSym
)
@classmethod
def get_min_capability(cls) -> int:
# turing and up
......@@ -89,7 +93,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
# WEIGHT SCALE
weight_scale = create_fp8_scale_parameter(
strategy_to_parameter_type[self.strategy],
STRATEGY_TO_PARAMETER_TYPE[self.strategy],
output_partition_sizes,
input_size_per_partition,
layer.weight_block_size,
......@@ -105,32 +109,36 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
)
layer.register_parameter("input_scale", input_scale)
self.linear_kernel = init_wfp8_a16_linear_kernel(
weight_quant_key=self.weight_quant_key,
activation_quant_key=self.activation_quant_key,
weight_shape=layer.weight.shape,
input_dtype=self.input_dtype,
out_dtype=self.out_dtype,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = layer.weight
weight_scale = layer.weight_scale
size_k_first = True
# TODO(rob): refactor block quant into separate class.
if self.strategy == QuantizationStrategy.BLOCK:
assert self.is_static_input_scheme is False
size_k_first = False
weight, weight_scale = process_fp8_weight_block_strategy(
weight, weight_scale
)
# MarlinFP8ScaledMMLinearKernel uses "weight_scale_inv" for block
# quant, while CT registers the scale as "weight_scale".
# Rename by deleting the old parameter and adding the new one so
# that prepare_fp8_layer_for_marlin (which prefers "weight_scale"
# over "weight_scale_inv") picks up "weight_scale_inv" correctly.
weight_scale_data = layer.weight_scale.data
del layer._parameters["weight_scale"]
replace_parameter(layer, "weight_scale_inv", weight_scale_data)
else:
# Weights must be transposed for marlin
weight = weight.t()
if self.strategy == QuantizationStrategy.TENSOR:
# If we have a fused module (QKV, MLP) with per tensor scales,
# we expand each scale to its shard's channels.
weight_scale = convert_to_channelwise(
weight_scale, layer.logical_widths
# For fused modules with per-tensor scales, expand each scale
# to its shard's channels.
replace_parameter(
layer,
"weight_scale",
convert_to_channelwise(layer.weight_scale, layer.logical_widths),
)
# Update layer with new values
replace_parameter(layer, "weight", weight.data)
replace_parameter(layer, "weight_scale", weight_scale.data)
prepare_fp8_layer_for_marlin(layer, size_k_first=size_k_first)
self.linear_kernel.process_weights_after_loading(layer)
def apply_weights(
self,
......@@ -138,12 +146,4 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias,
)
return self.linear_kernel.apply_weights(layer, x, bias)
......@@ -16,6 +16,9 @@ from vllm.model_executor.kernels.linear import (
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme,
)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
STRATEGY_TO_PARAMETER_TYPE,
)
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
create_fp8_input_scale,
create_fp8_scale_parameter,
......@@ -34,20 +37,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
)
from vllm.model_executor.parameter import (
BlockQuantScaleParameter,
ChannelQuantScaleParameter,
PerTensorScaleParameter,
)
__all__ = ["CompressedTensorsW8A8Fp8"]
strategy_to_parameter_type = {
QuantizationStrategy.BLOCK: BlockQuantScaleParameter,
QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter,
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
}
STATIC_QUANT = True
DYNAMIC_QUANT = False
activation_quant_key_mapping = {
......@@ -130,7 +122,7 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
# WEIGHT SCALE
weight_scale = create_fp8_scale_parameter(
strategy_to_parameter_type[self.strategy],
STRATEGY_TO_PARAMETER_TYPE[self.strategy],
output_partition_sizes,
input_size_per_partition,
layer.weight_block_size,
......
......@@ -6,8 +6,36 @@ from types import MappingProxyType
import regex as re
from compressed_tensors import CompressionFormat
from compressed_tensors.quantization import QuantizationStrategy
from torch.nn import Module
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8Static128BlockSym,
kFp8StaticChannelSym,
kFp8StaticTensorSym,
)
from vllm.model_executor.parameter import (
BlockQuantScaleParameter,
ChannelQuantScaleParameter,
PerTensorScaleParameter,
)
# Maps quantization strategy to the corresponding scale parameter type.
# Shared across compressed-tensor scheme classes (w8a16_fp8, w8a8_fp8, …).
STRATEGY_TO_PARAMETER_TYPE = {
QuantizationStrategy.BLOCK: BlockQuantScaleParameter,
QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter,
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
}
# Maps quantization strategy to the vLLM weight-quant key used for
# kernel selection. Shared across compressed-tensor scheme classes.
STRATEGY_TO_WEIGHT_QUANT_KEY = {
QuantizationStrategy.BLOCK: kFp8Static128BlockSym,
QuantizationStrategy.CHANNEL: kFp8StaticChannelSym,
QuantizationStrategy.TENSOR: kFp8StaticTensorSym,
}
def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment