Unverified Commit fd0e3772 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Support FP8 block quant for CompressedTensorsW8A16Fp8 (#33280)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent f857a03f
...@@ -651,7 +651,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -651,7 +651,7 @@ class CompressedTensorsConfig(QuantizationConfig):
# note: input_quant will be present for converted models; # note: input_quant will be present for converted models;
# will be ignored during inference post loading # will be ignored during inference post loading
return CompressedTensorsW8A16Fp8( return CompressedTensorsW8A16Fp8(
strategy=weight_quant.strategy, weight_quant=weight_quant,
is_static_input_scheme=not input_quant.dynamic, is_static_input_scheme=not input_quant.dynamic,
) )
...@@ -659,7 +659,7 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -659,7 +659,7 @@ class CompressedTensorsConfig(QuantizationConfig):
if self._is_fp8_w8a16(weight_quant, input_quant): if self._is_fp8_w8a16(weight_quant, input_quant):
is_static_input_scheme = input_quant and not input_quant.dynamic is_static_input_scheme = input_quant and not input_quant.dynamic
return CompressedTensorsW8A16Fp8( return CompressedTensorsW8A16Fp8(
strategy=weight_quant.strategy, weight_quant=weight_quant,
is_static_input_scheme=is_static_input_scheme, is_static_input_scheme=is_static_input_scheme,
) )
......
...@@ -4,11 +4,17 @@ ...@@ -4,11 +4,17 @@
from collections.abc import Callable from collections.abc import Callable
import torch import torch
from compressed_tensors.quantization import QuantizationStrategy from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsScheme,
) )
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
create_fp8_scale_parameter,
create_fp8_weight_parameter,
process_fp8_weight_block_strategy,
validate_fp8_block_shape,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, apply_fp8_marlin_linear,
prepare_fp8_layer_for_marlin, prepare_fp8_layer_for_marlin,
...@@ -17,57 +23,40 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ...@@ -17,57 +23,40 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise, convert_to_channelwise,
) )
from vllm.model_executor.parameter import ( from vllm.model_executor.parameter import (
BlockQuantScaleParameter,
ChannelQuantScaleParameter, ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter, PerTensorScaleParameter,
) )
from vllm.model_executor.utils import replace_parameter
__all__ = ["CompressedTensorsW8A16Fp8"] __all__ = ["CompressedTensorsW8A16Fp8"]
SUPPORTED_STRATEGIES = [QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR] strategy_to_parameter_type = {
QuantizationStrategy.BLOCK: BlockQuantScaleParameter,
QuantizationStrategy.CHANNEL: ChannelQuantScaleParameter,
QuantizationStrategy.TENSOR: PerTensorScaleParameter,
}
class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
def __init__(self, strategy: str, is_static_input_scheme: bool): def __init__(self, weight_quant: QuantizationArgs, is_static_input_scheme: bool):
self.strategy = strategy self.weight_quant = weight_quant
self.strategy = weight_quant.strategy
self.is_static_input_scheme = is_static_input_scheme self.is_static_input_scheme = is_static_input_scheme
self.weight_block_size = self.weight_quant.block_structure
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
# turing and up # turing and up
return 75 return 75
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
# So if we have a fused module (QKV, MLP) with per tensor scales,
# we expand each scale to its shard's channels.
def process_weights_after_loading(self, layer) -> None:
if self.strategy == QuantizationStrategy.TENSOR:
ws_channelwise = convert_to_channelwise(
layer.weight_scale, layer.logical_widths
)
layer.weight_scale = torch.nn.Parameter(ws_channelwise, requires_grad=False)
else:
# required by torch.compile to be torch.nn.Parameter
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False
)
# Weights must be transposed for marlin
layer.weight = torch.nn.Parameter(layer.weight.t(), requires_grad=False)
if self.is_static_input_scheme:
# required by torch.compile to be torch.nn.Parameter
layer.input_scale = torch.nn.Parameter(
layer.input_scale.data, requires_grad=False
)
prepare_fp8_layer_for_marlin(layer)
def create_weights( def create_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
input_size: int,
output_partition_sizes: list[int],
input_size_per_partition: int, input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype, params_dtype: torch.dtype,
weight_loader: Callable, weight_loader: Callable,
**kwargs, **kwargs,
...@@ -79,38 +68,33 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): ...@@ -79,38 +68,33 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
layer.orig_dtype = params_dtype layer.orig_dtype = params_dtype
layer.weight_block_size = None layer.weight_block_size = None
# WEIGHT if self.strategy == QuantizationStrategy.BLOCK:
weight = ModelWeightParameter( assert self.weight_block_size is not None
data=torch.empty( layer.weight_block_size = self.weight_block_size
output_size_per_partition, # Validate block quantization shapes
validate_fp8_block_shape(
layer,
input_size,
output_size,
input_size_per_partition, input_size_per_partition,
dtype=torch.float8_e4m3fn, output_partition_sizes,
), self.weight_block_size,
input_dim=1, )
output_dim=0,
weight_loader=weight_loader, # WEIGHT
weight = create_fp8_weight_parameter(
output_size_per_partition, input_size_per_partition, weight_loader
) )
layer.register_parameter("weight", weight) layer.register_parameter("weight", weight)
# WEIGHT SCALE # WEIGHT SCALE
if self.strategy == QuantizationStrategy.CHANNEL: weight_scale = create_fp8_scale_parameter(
weight_scale = ChannelQuantScaleParameter( strategy_to_parameter_type[self.strategy],
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), output_partition_sizes,
output_dim=0, input_size_per_partition,
weight_loader=weight_loader, layer.weight_block_size,
) weight_loader,
elif self.strategy == QuantizationStrategy.TENSOR:
weight_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
else:
raise ValueError(
f"Unsupported weight strategy={self.strategy}, "
f"supported strategies are {SUPPORTED_STRATEGIES}"
) )
weight_scale[:] = torch.finfo(torch.float32).min
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE (to deal with converted checkpoints) # INPUT SCALE (to deal with converted checkpoints)
...@@ -121,6 +105,33 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme): ...@@ -121,6 +105,33 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
) )
layer.register_parameter("input_scale", input_scale) layer.register_parameter("input_scale", input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = layer.weight
weight_scale = layer.weight_scale
size_k_first = True
# TODO(rob): refactor block quant into separate class.
if self.strategy == QuantizationStrategy.BLOCK:
assert self.is_static_input_scheme is False
size_k_first = False
weight, weight_scale = process_fp8_weight_block_strategy(
weight, weight_scale
)
else:
# Weights must be transposed for marlin
weight = weight.t()
if self.strategy == QuantizationStrategy.TENSOR:
# If we have a fused module (QKV, MLP) with per tensor scales,
# we expand each scale to its shard's channels.
weight_scale = convert_to_channelwise(
weight_scale, layer.logical_widths
)
# Update layer with new values
replace_parameter(layer, "weight", weight.data)
replace_parameter(layer, "weight_scale", weight_scale.data)
prepare_fp8_layer_for_marlin(layer, size_k_first=size_k_first)
def apply_weights( def apply_weights(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
......
...@@ -400,7 +400,6 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -400,7 +400,6 @@ class Fp8LinearMethod(LinearMethodBase):
None, None,
weight_loader, weight_loader,
) )
set_weight_attrs(scale, {"scale_type": "weight_scale"})
layer.register_parameter("weight_scale", scale) layer.register_parameter("weight_scale", scale)
else: else:
assert not self.act_q_static assert not self.act_q_static
...@@ -412,7 +411,6 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -412,7 +411,6 @@ class Fp8LinearMethod(LinearMethodBase):
self.weight_block_size, self.weight_block_size,
weight_loader, weight_loader,
) )
set_weight_attrs(scale, {"scale_type": "weight_scale"})
# The weight_scale_inv name is intentional for deepseekv3 # The weight_scale_inv name is intentional for deepseekv3
layer.register_parameter("weight_scale_inv", scale) layer.register_parameter("weight_scale_inv", scale)
......
...@@ -29,7 +29,7 @@ from vllm.model_executor.parameter import ( ...@@ -29,7 +29,7 @@ from vllm.model_executor.parameter import (
ChannelQuantScaleParameter, ChannelQuantScaleParameter,
PerTensorScaleParameter, PerTensorScaleParameter,
) )
from vllm.model_executor.utils import replace_parameter from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import ( from vllm.utils.deep_gemm import (
...@@ -1520,6 +1520,7 @@ def create_fp8_scale_parameter( ...@@ -1520,6 +1520,7 @@ def create_fp8_scale_parameter(
raise ValueError(f"Unknown parameter type: {parameter_type}") raise ValueError(f"Unknown parameter type: {parameter_type}")
scale[:] = torch.finfo(torch.float32).min scale[:] = torch.finfo(torch.float32).min
set_weight_attrs(scale, {"scale_type": "weight_scale"})
return scale return scale
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment