Commit 500b93c8 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.3.post1' into v0.5.3.post1-dtk24.04.1

parents 99426767 38c4b7e8
...@@ -60,9 +60,8 @@ class BitsAndBytesConfig(QuantizationConfig): ...@@ -60,9 +60,8 @@ class BitsAndBytesConfig(QuantizationConfig):
target_modules = cls.get_from_keys(config, ["target_modules"]) target_modules = cls.get_from_keys(config, ["target_modules"])
return cls(adapter_name, target_modules) return cls(adapter_name, target_modules)
def get_quant_method( def get_quant_method(self, layer: torch.nn.Module,
self, prefix: str) -> Optional["BitsAndBytesLinearMethod"]:
layer: torch.nn.Module) -> Optional["BitsAndBytesLinearMethod"]:
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return BitsAndBytesLinearMethod(self) return BitsAndBytesLinearMethod(self)
return None return None
......
...@@ -5,25 +5,33 @@ from pydantic import BaseModel ...@@ -5,25 +5,33 @@ from pydantic import BaseModel
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24, CompressedTensorsScheme, CompressedTensorsUnquantized,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsWNA16) CompressedTensorsW8A8Int8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat, QuantizationArgs, QuantizationStrategy, CompressionFormat, QuantizationArgs, QuantizationStrategy,
QuantizationType, find_first_name_or_class_match) QuantizationType, find_matched_target, is_activation_quantization_format,
should_ignore_layer)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform from vllm.platforms import current_platform
class CompressedTensorsConfig(QuantizationConfig): class CompressedTensorsConfig(QuantizationConfig):
def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str], def __init__(self,
quant_format: str): target_scheme_map: Dict[str, Any],
ignore: List[str],
quant_format: str,
kv_cache_scheme: Optional[Dict[str, Any]] = None):
self.ignore = ignore self.ignore = ignore
self.layer_quant_details = layer_quant_details
self.quant_format = quant_format self.quant_format = quant_format
# Map from [target -> scheme]
self.target_scheme_map = target_scheme_map
self.kv_cache_scheme = kv_cache_scheme
def get_linear_method(self) -> "CompressedTensorsLinearMethod": def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self) return CompressedTensorsLinearMethod(self)
...@@ -36,21 +44,28 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -36,21 +44,28 @@ class CompressedTensorsConfig(QuantizationConfig):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
return 75 return 70
def get_name(self) -> str: def get_name(self) -> str:
return "compressed_tensors" return "compressed_tensors"
# TODO (@robertgshaw2-neuralmagic): do layer skipping though here
# rather than though create_weights to match other methods
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module self,
) -> Optional["CompressedTensorsLinearMethod"]: layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return CompressedTensorsLinearMethod(self) return CompressedTensorsLinearMethod(self)
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
return None return None
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig": def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
layer_quant_details: Dict[str, Any] = dict() target_scheme_map: Dict[str, Any] = dict()
ignore: List[str] = config.get("ignore", None) ignore: List[str] = config.get("ignore", None)
quant_format: str = config.get("format", None) quant_format: str = config.get("format", None)
...@@ -62,35 +77,37 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -62,35 +77,37 @@ class CompressedTensorsConfig(QuantizationConfig):
# details follow the structure defined by the QuantizationArgs # details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the # pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use. # quant_config and also store the details for later use.
for key, quant_config in config["config_groups"].items(): for _, quant_config in config["config_groups"].items():
targets = quant_config.get("targets") targets = quant_config.get("targets")
for target in targets: for target in targets:
layer_quant_details[target] = {} target_scheme_map[target] = {}
layer_quant_details[target][ target_scheme_map[target][
"weights"] = QuantizationArgs.parse_obj( "weights"] = QuantizationArgs.parse_obj(
quant_config.get("weights")) quant_config.get("weights"))
try: try:
layer_quant_details[target][ target_scheme_map[target][
"input_activations"] = QuantizationArgs.parse_obj( "input_activations"] = QuantizationArgs.parse_obj(
quant_config.get("input_activations")) quant_config.get("input_activations"))
except Exception: except Exception:
layer_quant_details[target]["input_activations"] = None target_scheme_map[target]["input_activations"] = None
return cls(layer_quant_details=layer_quant_details, return cls(target_scheme_map=target_scheme_map,
ignore=ignore, ignore=ignore,
quant_format=quant_format) quant_format=quant_format,
kv_cache_scheme=config.get("kv_cache_scheme"))
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> List[str]:
return [] return []
def _check_gptq_and_marlin_can_run(self): def _check_scheme_supported(self, min_capability: int):
capability = current_platform.get_device_capability() capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1] capability = capability[0] * 10 + capability[1]
if capability < 80: if capability < min_capability:
raise RuntimeError("The quantization config is not supported for ", raise RuntimeError(
"the current GPU. Minimum capability: 80. ", "Quantization scheme is not supported for ",
f"Current capability: {capability}.") f"the current GPU. Min capability: {min_capability}. ",
f"Current capability: {capability}.")
def _is_static_tensor_w8a8(self, weight_quant: BaseModel, def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool: input_quant: BaseModel) -> bool:
...@@ -132,10 +149,11 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -132,10 +149,11 @@ class CompressedTensorsConfig(QuantizationConfig):
# Confirm weight scheme is supported. # Confirm weight scheme is supported.
is_symmetric_weight = weight_quant.symmetric is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic is_static_weight = not weight_quant.dynamic
is_per_tensor_weight = ( is_per_tensor_or_channel_weight = (weight_quant.strategy in [
weight_quant.strategy == QuantizationStrategy.TENSOR) QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
])
if not (is_symmetric_weight and is_static_weight if not (is_symmetric_weight and is_static_weight
and is_per_tensor_weight): and is_per_tensor_or_channel_weight):
return False return False
# Dynamic quantization is always supported if weights supported. # Dynamic quantization is always supported if weights supported.
...@@ -164,11 +182,12 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -164,11 +182,12 @@ class CompressedTensorsConfig(QuantizationConfig):
return (is_channel_group and input_quant_none and is_symmetric return (is_channel_group and input_quant_none and is_symmetric
and is_static) and is_static)
def _get_schema(self, weight_quant: BaseModel, def _get_scheme_from_parts(
input_quant: BaseModel) -> "CompressedTensorsScheme": self, weight_quant: BaseModel,
input_quant: BaseModel) -> "CompressedTensorsScheme":
# Detect If Mixed Precision
if self._is_wNa16_group_channel(weight_quant, input_quant): if self._is_wNa16_group_channel(weight_quant, input_quant):
self._check_gptq_and_marlin_can_run()
if (self.quant_format == CompressionFormat.marlin_24.value if (self.quant_format == CompressionFormat.marlin_24.value
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS): and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
return CompressedTensorsW4A16Sparse24( return CompressedTensorsW4A16Sparse24(
...@@ -182,11 +201,12 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -182,11 +201,12 @@ class CompressedTensorsConfig(QuantizationConfig):
strategy=weight_quant.strategy, strategy=weight_quant.strategy,
group_size=weight_quant.group_size) group_size=weight_quant.group_size)
if (self.quant_format == CompressionFormat.int_quantized.value or # Detect If Activation Quantization.
self.quant_format == CompressionFormat.float_quantized.value): if is_activation_quantization_format(self.quant_format):
if self._is_fp8_w8a8(weight_quant, input_quant): if self._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8( return CompressedTensorsW8A8Fp8(
input_dynamic=input_quant.dynamic) strategy=weight_quant.strategy,
is_static_input_scheme=(not input_quant.dynamic))
if self._is_static_tensor_w8a8(weight_quant, input_quant): if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8( return CompressedTensorsW8A8Int8(
...@@ -201,26 +221,53 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -201,26 +221,53 @@ class CompressedTensorsConfig(QuantizationConfig):
raise NotImplementedError( raise NotImplementedError(
"No compressed-tensors compatible scheme was found.") "No compressed-tensors compatible scheme was found.")
def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme": def get_scheme(
self,
layer: torch.nn.Module,
layer_name: Optional[str] = None) -> "CompressedTensorsScheme":
"""
compressed-tensors supports non uniform in the following way:
ignore: List of layer_names or nn.Module names to be ignored.
targets of config_groups: There can be N config_groups which each
have a quantization scheme. Each config_group has a list of targets
which can be a full layer_name, a regex for a layer_name, or
an nn.Module name.
layer_type_name = find_first_name_or_class_match( We first check whether a layer is in the ignore group and use
name="", CompressedTensorsUnquantized (i.e. fp16/bf16) scheme for the layer
We then detect whether a layer_name is found in any target and
use the quantization scheme corresponding to the matched target
to select the CompressedTensorsScheme used for infernece.
"""
# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if should_ignore_layer(layer_name, ignore=self.ignore):
return CompressedTensorsUnquantized()
# Find the "target" in the compressed-tensors config
# that our layer conforms to.
# TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions
matched_target = find_matched_target(
layer_name=layer_name,
module=layer, module=layer,
targets=self.layer_quant_details.keys(), targets=self.target_scheme_map.keys())
check_contains=True)
# Find the quant_scheme
scheme = self.target_scheme_map[matched_target]
if layer_type_name is None: return self._get_scheme_from_parts(
raise ValueError(f"Could not matching target for layer {layer}") weight_quant=scheme["weights"],
input_quant=scheme["input_activations"])
layer_quant_details: Dict[str, Any] = self.layer_quant_details.get( # Raise error if device does not support the scheme
layer_type_name, None) # (e.g. fp8 needs ada lovelace)
if layer_quant_details is None: self._check_scheme_supported(scheme.get_min_capability())
raise ValueError(
f"Could not find quantization details for {layer}.")
return self._get_schema( return scheme
weight_quant=layer_quant_details["weights"],
input_quant=layer_quant_details["input_activations"])
class CompressedTensorsLinearMethod(LinearMethodBase): class CompressedTensorsLinearMethod(LinearMethodBase):
...@@ -240,11 +287,11 @@ class CompressedTensorsLinearMethod(LinearMethodBase): ...@@ -240,11 +287,11 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
Use the CompressedTensorsScheme associated with each layer to create Use the CompressedTensorsScheme associated with each layer to create
the necessary parameters for the layer. See LinearMethodBase for param the necessary parameters for the layer. See LinearMethodBase for param
details details
""" """
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
layer_name = extra_weight_attrs.get("prefix")
scheme = self.quantization_config.get_scheme(layer=layer) scheme = self.quantization_config.get_scheme(layer, layer_name)
scheme.create_weights( scheme.create_weights(
layer=layer, layer=layer,
input_size=input_size, input_size=input_size,
...@@ -271,3 +318,47 @@ class CompressedTensorsLinearMethod(LinearMethodBase): ...@@ -271,3 +318,47 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
if scheme is None: if scheme is None:
raise ValueError("A scheme must be defined for each layer") raise ValueError("A scheme must be defined for each layer")
return scheme.apply_weights(layer, x, bias=bias) return scheme.apply_weights(layer, x, bias=bias)
class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from compressed-tensors
checkpoints.
"""
def __init__(self, quant_config: CompressedTensorsConfig):
self.validate_kv_cache_scheme(quant_config.kv_cache_scheme)
super().__init__(quant_config)
@staticmethod
def validate_kv_cache_scheme(kv_cache_scheme: Optional[Dict[str, Any]]):
"""
Validator for the kv cache scheme. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM
:param kv_cache_scheme: the compressed-tensors kv cache scheme
"""
if kv_cache_scheme is None:
return
type_ = kv_cache_scheme.get("type")
num_bits = kv_cache_scheme.get("num_bits")
if type_ != "float" and num_bits != 8:
raise NotImplementedError(
"Currently supported kv cache quantization is "
"num_bits=8, type=float, however "
f"received num_bits={num_bits}, type={type_}")
strategy = kv_cache_scheme.get("strategy")
if strategy != "tensor":
raise NotImplementedError(
"Only support per-tensor scaling factor "
"for compressed-tensors KV cache. "
f"Expected strategy: tensor, found strategy: {strategy}")
is_symmetric = kv_cache_scheme.get("symmetric")
if not is_symmetric:
raise NotImplementedError(
"Only support symmetric scaling factor "
"for compressed-tensors KV cache. "
f"However found symmetric: {is_symmetric}")
...@@ -12,6 +12,13 @@ class CompressedTensorsScheme(ABC): ...@@ -12,6 +12,13 @@ class CompressedTensorsScheme(ABC):
of different quantization schemes supported by CompressedTensors. of different quantization schemes supported by CompressedTensors.
""" """
@abstractmethod
def get_min_capability(self) -> int:
"""
Get minimum device capability.
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def create_weights(self, *args, **kwargs): def create_weights(self, *args, **kwargs):
""" """
......
...@@ -18,6 +18,10 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme): ...@@ -18,6 +18,10 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
in a linear transformation. in a linear transformation.
""" """
def get_min_capability(self) -> int:
# volta and up
return 70
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass pass
...@@ -29,7 +33,6 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme): ...@@ -29,7 +33,6 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
weight = Parameter(torch.empty(sum(output_partition_sizes), weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition, input_size_per_partition,
device="cuda",
dtype=params_dtype), dtype=params_dtype),
requires_grad=False) requires_grad=False)
......
...@@ -29,6 +29,10 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme): ...@@ -29,6 +29,10 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
raise ValueError( raise ValueError(
"group_size must be given when using strategy group") "group_size must be given when using strategy group")
def get_min_capability(self) -> int:
# ampere + up
return 80
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass pass
......
from typing import Callable, List, Optional from typing import Callable, List, Optional
import torch import torch
from torch.nn import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme) CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, create_per_tensor_scale_param, cutlass_fp8_supported, apply_fp8_linear, create_per_channel_scale_param,
create_per_tensor_scale_param, cutlass_fp8_supported,
requantize_with_max_scale) requantize_with_max_scale)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
...@@ -14,39 +18,49 @@ __all__ = ["CompressedTensorsW8A8Fp8"] ...@@ -14,39 +18,49 @@ __all__ = ["CompressedTensorsW8A8Fp8"]
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def __init__(self, input_dynamic: bool): def __init__(self, strategy: str, is_static_input_scheme: bool):
self.input_dynamic = input_dynamic self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.cutlass_fp8_supported = cutlass_fp8_supported() self.cutlass_fp8_supported = cutlass_fp8_supported()
# W8A8-Fp8 kernels support only per-tensor and per-channel cases. def get_min_capability(self) -> int:
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N # lovelace and up
# scales being passed to the kernel), we requantize with a single scale. return 89
def process_weights_after_loading(self, layer) -> None: def process_weights_after_loading(self, layer) -> None:
# Dequant -> Quant with max scale. # If per tensor, when we have a fused module (e.g. QKV) with per
max_w_scale, weight = requantize_with_max_scale( # tensor scales (thus N scales being passed to the kernel),
weight=layer.weight, # requantize so we can always run per tensor
weight_scale=layer.weight_scale, if self.strategy == QuantizationStrategy.TENSOR:
logical_widths=layer.logical_widths, max_w_scale, weight = requantize_with_max_scale(
) weight=layer.weight,
weight_scale=layer.weight_scale,
# Update layer with new values. logical_widths=layer.logical_widths,
layer.weight = torch.nn.Parameter(weight.t(), requires_grad=False) )
layer.weight_scale = torch.nn.Parameter(max_w_scale,
requires_grad=False) layer.weight = Parameter(weight.t(), requires_grad=False)
if self.input_dynamic: layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
layer.input_scale = None
# If channelwise, scales are already lined up, so just transpose.
elif self.strategy == QuantizationStrategy.CHANNEL:
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)
else:
raise ValueError(f"Unknown quantization strategy {self.strategy}")
# INPUT SCALE
if self.is_static_input_scheme:
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
else: else:
layer.input_scale = torch.nn.Parameter(layer.input_scale.max(), layer.input_scale = None
requires_grad=False)
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int], output_partition_sizes: List[int],
input_size_per_partition: int, input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable, params_dtype: torch.dtype, weight_loader: Callable,
**kwargs): **kwargs):
del params_dtype
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes layer.logical_widths = output_partition_sizes
...@@ -63,12 +77,17 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -63,12 +77,17 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
}) })
# WEIGHT SCALE # WEIGHT SCALE
weight_scale = create_per_tensor_scale_param( if self.strategy == QuantizationStrategy.CHANNEL:
output_partition_sizes, weight_loader=weight_loader) weight_scale = create_per_channel_scale_param(
output_partition_sizes, weight_loader=weight_loader)
else:
assert self.strategy == QuantizationStrategy.TENSOR
weight_scale = create_per_tensor_scale_param(
output_partition_sizes, weight_loader=weight_loader)
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE # INPUT SCALE
if not self.input_dynamic: if self.is_static_input_scheme:
input_scale = create_per_tensor_scale_param( input_scale = create_per_tensor_scale_param(
output_partition_sizes, weight_loader=weight_loader) output_partition_sizes, weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale) layer.register_parameter("input_scale", input_scale)
...@@ -84,4 +103,5 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme): ...@@ -84,4 +103,5 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
weight_scale=layer.weight_scale, weight_scale=layer.weight_scale,
input_scale=layer.input_scale, input_scale=layer.input_scale,
bias=bias, bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported) cutlass_fp8_supported=self.cutlass_fp8_supported,
use_per_token_if_dynamic=True)
...@@ -19,6 +19,10 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -19,6 +19,10 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self.strategy = strategy self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme self.is_static_input_scheme = is_static_input_scheme
def get_min_capability(self) -> int:
# turing and up
return 75
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# WEIGHT # WEIGHT
# Cutlass kernels need transposed weight. # Cutlass kernels need transposed weight.
......
...@@ -7,8 +7,8 @@ from vllm import _custom_ops as ops ...@@ -7,8 +7,8 @@ from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme) CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, replace_tensor, verify_marlin_supported, marlin_permute_scales, replace_tensor, verify_gptq_marlin_supported,
verify_marlin_supports_shape) verify_marlin_supports_shape)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
...@@ -38,9 +38,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -38,9 +38,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
self.group_size = group_size self.group_size = group_size
# Verify supported on platform. # Verify supported on platform.
verify_marlin_supported(num_bits=self.num_bits, verify_gptq_marlin_supported(num_bits=self.num_bits,
group_size=self.group_size, group_size=self.group_size,
is_sym=True) is_sym=True)
def get_min_capability(self) -> int:
# ampere and up
return 80
def create_weights(self, layer: torch.nn.Module, input_size: int, def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int], output_partition_sizes: List[int],
...@@ -131,6 +135,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -131,6 +135,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
layer.g_idx = marlin_make_empty_g_idx(device) layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
# No zero-point
layer.weight_zp = marlin_make_empty_g_idx(device)
# Repack weights from compressed-tensors format to marlin format. # Repack weights from compressed-tensors format to marlin format.
marlin_qweight = ops.gptq_marlin_repack( marlin_qweight = ops.gptq_marlin_repack(
layer.weight_packed.t().contiguous(), layer.weight_packed.t().contiguous(),
...@@ -151,10 +158,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -151,10 +158,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
return apply_marlin_linear( return apply_gptq_marlin_linear(
input=x, input=x,
weight=layer.weight_packed, weight=layer.weight_packed,
weight_scale=layer.weight_scale, weight_scale=layer.weight_scale,
weight_zp=layer.weight_zp,
g_idx=layer.g_idx, g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices, g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace, workspace=layer.workspace,
......
...@@ -9,6 +9,7 @@ from torch.nn import Module ...@@ -9,6 +9,7 @@ from torch.nn import Module
class CompressionFormat(Enum): class CompressionFormat(Enum):
dense = "dense" dense = "dense"
sparse_bitmask = "sparse-bitmask" sparse_bitmask = "sparse-bitmask"
naive_quantized = "naive-quantized"
float_quantized = "float-quantized" float_quantized = "float-quantized"
int_quantized = "int-quantized" int_quantized = "int-quantized"
pack_quantized = "pack-quantized" pack_quantized = "pack-quantized"
...@@ -76,25 +77,115 @@ class QuantizationArgs(BaseModel): ...@@ -76,25 +77,115 @@ class QuantizationArgs(BaseModel):
) )
def find_first_name_or_class_match( def is_activation_quantization_format(format: str) -> bool:
name: str, _ACTIVATION_QUANTIZATION_FORMATS = [
module: Module, CompressionFormat.naive_quantized.value,
targets: Iterable[str], CompressionFormat.int_quantized.value,
check_contains: bool = False) -> Optional[str]: CompressionFormat.float_quantized.value
]
return format in _ACTIVATION_QUANTIZATION_FORMATS
# fused_name: List[shard_name]
_FUSED_LAYER_NAME_MAPPING = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
def should_ignore_layer(layer_name: Optional[str],
ignore: Iterable[str]) -> bool:
if layer_name is None:
return False
# layer_name = model.layers.0.self_attn.qkv_proj
# proj_name = qkv_proj
proj_name = layer_name.split(".")[-1]
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in _FUSED_LAYER_NAME_MAPPING:
shard_proj_names = _FUSED_LAYER_NAME_MAPPING[proj_name]
# Convert fused_name --> [shard_names]
shard_names = [
layer_name.replace(proj_name, shard_proj_name)
for shard_proj_name in shard_proj_names
]
# Layer should be ignored if shards are ignored.
should_ignore_layer = None
for shard_name in shard_names:
should_ignore_shard = check_equal_or_regex_match(
layer_name=shard_name, targets=ignore)
# If shard_idx=0, set layer ignore to match shard.
if should_ignore_layer is None:
should_ignore_layer = should_ignore_shard
# If shard_idx=1+ confirm scheme matches prior shards.
elif should_ignore_shard != should_ignore_layer:
raise ValueError(f"Found a different quantization schemes for "
f"{shard_proj_names} in {layer_name}. vLLM "
"requires all to use the same scheme.")
# Unfused layers like down_proj and o_proj will match
# the safetensors checkpoint already.
else:
should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name,
targets=ignore)
assert should_ignore_layer is not None
return should_ignore_layer
def check_equal_or_regex_match(layer_name: str,
targets: Iterable[str]) -> bool:
""" """
Helper function to map the quantization details listed in the config Checks whether a layer_name is exactly equal or a regex match for
for a given list of targets against each model layer. First uses the if target starts with 're:' to any target in list.
layer name to try and find a match. If no name match is found, uses """
the layer class name. Returns None otherwise. for target in targets:
if _is_equal_or_regex_match(layer_name, target):
return True
return False
def find_matched_target(layer_name: Optional[str], module: Module,
targets: Iterable[str]) -> str:
"""
Helper function to look up which "target" in the compressed-tensors
config that a layer corresponds to.
Recall that a compressed-tensors configs has a concept of
config_groups, where each layer can be quantized with with a different
scheme.
:param name: layer name targets in each config_group will be a list of either layer names
(or regexes corresponding to layer names) or names of torch Modules.
First, we try to match the layer_name with a target
Second, we try to match the module's name with a target
:param layer_name: layer name
:param module: torch.nn.Module :param module: torch.nn.Module
:param targets: list of targets to match the layer against :param targets: list of targets to match the layer against
:param check_contains: whether or not to do a substring match
""" """
return _find_first_match(name, targets) or _find_first_match( if layer_name is None:
module.__class__.__name__, targets, check_contains) layer_name = ""
matched_target = (_find_first_match(layer_name, targets)
or _find_first_match(module.__class__.__name__, targets,
True))
if matched_target is None:
raise ValueError(f"Unable to find matching target for {module} in the "
"compressed-tensors config.")
return matched_target
def _find_first_match(value: str, def _find_first_match(value: str,
...@@ -111,13 +202,46 @@ def _find_first_match(value: str, ...@@ -111,13 +202,46 @@ def _find_first_match(value: str,
""" """
for target in targets: for target in targets:
if target.startswith("re:"): if _is_equal_or_regex_match(value,
pattern = target[3:] target,
if re.match(pattern, value): check_contains=check_contains):
return target
elif check_contains:
if target.lower() in value.lower():
return target
elif target == value:
return target return target
return None return None
def get_compressed_tensors_cache_scale(name: str) -> Optional[str]:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
param name expected by vLLM
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
# If no matches, return None
return None
def _is_equal_or_regex_match(value: str,
target: str,
check_contains: bool = False) -> bool:
"""
Checks whether a value is exactly equal or a regex match for target
if target starts with 're:'. If check_contains is set to True,
additionally checks if the target string is contained within the value.
"""
if target.startswith("re:"):
pattern = target[3:]
if re.match(pattern, value):
return True
elif check_contains:
if target.lower() in value.lower():
return True
elif target == value:
return True
return False
...@@ -69,9 +69,8 @@ class DeepSpeedFPConfig(QuantizationConfig): ...@@ -69,9 +69,8 @@ class DeepSpeedFPConfig(QuantizationConfig):
"quantize_config.json", "quantize_config.json",
] ]
def get_quant_method( def get_quant_method(self, layer: torch.nn.Module,
self, prefix: str) -> Optional["DeepSpeedFPLinearMethod"]:
layer: torch.nn.Module) -> Optional["DeepSpeedFPLinearMethod"]:
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return DeepSpeedFPLinearMethod(self) return DeepSpeedFPLinearMethod(self)
return None return None
......
from typing import Any, Dict, List, Optional
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, create_per_channel_scale_param)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
logger = init_logger(__name__)
# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
_FUSED_LAYER_NAME_MAPPING = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
class FBGEMMFp8Config(QuantizationConfig):
"""Config class for FBGEMM Fp8."""
def __init__(self, ignore_list: List[str], input_scale_ub: float):
self.ignore_list = ignore_list if ignore_list else []
self.input_scale_ub = input_scale_ub
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
self.use_marlin = capability < 89
@classmethod
def get_name(cls) -> str:
return "fbgemm_fp8"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.float16]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config":
ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"])
input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
def _is_layer_skipped(self, prefix: str) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name = prefix.split(".")[-1]
if proj_name in _FUSED_LAYER_NAME_MAPPING:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in _FUSED_LAYER_NAME_MAPPING[proj_name]
]
is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = shard_prefix in self.ignore_list
if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision.")
else:
is_skipped = prefix in self.ignore_list
assert is_skipped is not None
return is_skipped
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if self._is_layer_skipped(prefix):
return UnquantizedLinearMethod()
return FBGEMMFp8LinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class FBGEMMFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: FBGEMMFp8Config):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
# WEIGHT
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
requires_grad=False)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
"input_dim": 1,
"output_dim": 0,
**extra_weight_attrs,
})
# WEIGHT SCALE
weight_scale = create_per_channel_scale_param(output_partition_sizes,
**extra_weight_attrs)
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE UPPER BOUND
input_scale_ub = torch.nn.Parameter(torch.tensor(
(self.quant_config.input_scale_ub), dtype=torch.float32),
requires_grad=False)
layer.input_scale_ub = input_scale_ub
def process_weights_after_loading(self, layer: Module) -> None:
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)
if self.quant_config.use_marlin:
prepare_fp8_layer_for_marlin(layer)
# Activations not quantized for marlin.
del layer.input_scale_ub
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.quant_config.use_marlin:
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias)
return apply_fp8_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=None,
input_scale_ub=layer.input_scale_ub,
bias=bias,
cutlass_fp8_supported=True,
use_per_token_if_dynamic=True)
...@@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, ...@@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
...@@ -66,8 +67,8 @@ class Fp8Config(QuantizationConfig): ...@@ -66,8 +67,8 @@ class Fp8Config(QuantizationConfig):
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme) activation_scheme=activation_scheme)
def get_quant_method( def get_quant_method(self, layer: torch.nn.Module,
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]: prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
...@@ -214,7 +215,8 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -214,7 +215,8 @@ class Fp8LinearMethod(LinearMethodBase):
weight_scale=layer.weight_scale, weight_scale=layer.weight_scale,
input_scale=layer.input_scale, input_scale=layer.input_scale,
bias=bias, bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported) cutlass_fp8_supported=self.cutlass_fp8_supported,
use_per_token_if_dynamic=False)
class Fp8MoEMethod(FusedMoEMethodBase): class Fp8MoEMethod(FusedMoEMethodBase):
...@@ -399,39 +401,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -399,39 +401,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_group=topk_group) topk_group=topk_group)
class Fp8KVCacheMethod(QuantizeMethodBase): class Fp8KVCacheMethod(BaseKVCacheMethod):
"""Supports loading kv-cache scaling factors from FP8 checkpoints. """
Supports loading kv-cache scaling factors from FP8 checkpoints.
""" """
def __init__(self, quant_config: Fp8Config): def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config super().__init__(quant_config)
def create_weights(self, layer: torch.nn.Module):
"""Create "weight" (aka kv_scale) for an attention layer.
Args:
layer: The layer that is using the QuantizeMethodBase factory.
"""
# Initialize the KV cache scale to 1.0 as the default value.
# If the kv_scale appears in the checkpoint, it will be
# overwritten when loading weights.
layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False)
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")
def process_weights_after_loading(self, layer: Module) -> None:
# If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
if layer.kv_cache_dtype != "auto":
kv_scale = layer.kv_scale.to("cpu").tolist()
if not isinstance(kv_scale, float):
raise ValueError("Only support per-tensor scaling factor "
"for fp8 KV cache")
layer._kv_scale = kv_scale
if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
print_warning_once(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
"cause accuracy issues. Please make sure kv-cache scaling "
"factor is available in the fp8 checkpoint.")
del layer.kv_scale
...@@ -69,8 +69,8 @@ class GPTQConfig(QuantizationConfig): ...@@ -69,8 +69,8 @@ class GPTQConfig(QuantizationConfig):
default=False) default=False)
return cls(weight_bits, group_size, desc_act, lm_head_quantized) return cls(weight_bits, group_size, desc_act, lm_head_quantized)
def get_quant_method( def get_quant_method(self, layer: torch.nn.Module,
self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]: prefix: str) -> Optional["GPTQLinearMethod"]:
if (isinstance(layer, LinearBase) or if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return GPTQLinearMethod(self) return GPTQLinearMethod(self)
......
...@@ -10,10 +10,10 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, ...@@ -10,10 +10,10 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_marlin_linear, check_marlin_supported, marlin_is_k_full, apply_gptq_marlin_linear, check_gptq_marlin_supported, marlin_is_k_full,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
verify_marlin_supported, verify_marlin_supports_shape) verify_gptq_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -37,9 +37,9 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -37,9 +37,9 @@ class GPTQMarlinConfig(QuantizationConfig):
self.lm_head_quantized = lm_head_quantized self.lm_head_quantized = lm_head_quantized
# Verify supported on platform. # Verify supported on platform.
verify_marlin_supported(num_bits=self.weight_bits, verify_gptq_marlin_supported(num_bits=self.weight_bits,
group_size=self.group_size, group_size=self.group_size,
is_sym=self.is_sym) is_sym=self.is_sym)
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, " return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
...@@ -77,7 +77,7 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -77,7 +77,7 @@ class GPTQMarlinConfig(QuantizationConfig):
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]: user_quant) -> Optional[str]:
can_convert = cls.is_marlin_compatible(hf_quant_cfg) can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "marlin") is_valid_user_quant = (user_quant is None or user_quant == "marlin")
...@@ -94,9 +94,8 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -94,9 +94,8 @@ class GPTQMarlinConfig(QuantizationConfig):
" faster inference") " faster inference")
return None return None
def get_quant_method( def get_quant_method(self, layer: torch.nn.Module,
self, prefix: str) -> Optional["GPTQMarlinLinearMethod"]:
layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
if (isinstance(layer, LinearBase) or if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return GPTQMarlinLinearMethod(self) return GPTQMarlinLinearMethod(self)
...@@ -106,22 +105,27 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -106,22 +105,27 @@ class GPTQMarlinConfig(QuantizationConfig):
return [] return []
@classmethod @classmethod
def is_marlin_compatible(cls, quant_config: Dict[str, Any]): def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config. # Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits", None) num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None) group_size = quant_config.get("group_size", None)
sym = quant_config.get("sym", None) sym = quant_config.get("sym", None)
desc_act = quant_config.get("desc_act", None) desc_act = quant_config.get("desc_act", None)
if quant_method != "gptq":
return False
# If we cannot find the info needed in the config, cannot convert. # If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or sym is None if (num_bits is None or group_size is None or sym is None
or desc_act is None): or desc_act is None):
return False return False
return check_marlin_supported(num_bits=num_bits, return check_gptq_marlin_supported(
group_size=group_size, num_bits=num_bits,
is_sym=sym, group_size=group_size,
min_capability=cls.get_min_capability()) is_sym=sym,
min_capability=cls.get_min_capability())
class GPTQMarlinLinearMethod(LinearMethodBase): class GPTQMarlinLinearMethod(LinearMethodBase):
...@@ -279,6 +283,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -279,6 +283,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
layer.g_idx = marlin_make_empty_g_idx(device) layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
# No zero-point
layer.zp = marlin_make_empty_g_idx(device)
# Repack weights from autogptq format to marlin format. # Repack weights from autogptq format to marlin format.
marlin_qweight = ops.gptq_marlin_repack( marlin_qweight = ops.gptq_marlin_repack(
layer.qweight, layer.qweight,
...@@ -303,10 +310,11 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -303,10 +310,11 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return apply_marlin_linear( return apply_gptq_marlin_linear(
input=x, input=x,
weight=layer.qweight, weight=layer.qweight,
weight_scale=layer.scales, weight_scale=layer.scales,
weight_zp=layer.zp,
g_idx=layer.g_idx, g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices, g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace, workspace=layer.workspace,
......
...@@ -109,9 +109,8 @@ class GPTQMarlin24Config(QuantizationConfig): ...@@ -109,9 +109,8 @@ class GPTQMarlin24Config(QuantizationConfig):
return None return None
def get_quant_method( def get_quant_method(self, layer: torch.nn.Module,
self, prefix: str) -> Optional["GPTQMarlin24LinearMethod"]:
layer: torch.nn.Module) -> Optional["GPTQMarlin24LinearMethod"]:
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return GPTQMarlin24LinearMethod(self) return GPTQMarlin24LinearMethod(self)
return None return None
......
import torch
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.utils import print_warning_once
class BaseKVCacheMethod(QuantizeMethodBase):
"""
Quant method that adds `_k_scale` and `_v_scale` attributes to the
Attention layer to support loading those scaling factors from checkpoints.
The k/v_scale will be used to:
- quantize k/v_cache entries before saving them to the cache
- dequantize k/v_cache entries before fetching them from the cache
:param quant_config: the appropriate QuantizationConfig
"""
def __init__(self, quant_config: QuantizationConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module):
"""
Create "weight" (aka k_scale and v_scale) for an attention layer.
"""
# Initialize the KV cache scales to -1.0, which is an invalid value.
# If the k/v_scale appears in the checkpoint, it will be
# overwritten when loading weights.
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0),
requires_grad=False)
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0),
requires_grad=False)
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError(
f"{self.__class__.__name__}.apply should not be called.")
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
if layer.kv_cache_dtype != "auto":
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
# We prefer to use separate k_scale and v_scale if present
k_scale = layer.k_scale.to("cpu").tolist()
v_scale = layer.v_scale.to("cpu").tolist()
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
# If no scales were loaded (both scales are invalid negative
# values), use the default value of 1.0
k_scale = torch.nn.Parameter(torch.tensor(1.0),
requires_grad=False)
v_scale = torch.nn.Parameter(torch.tensor(1.0),
requires_grad=False)
else:
# If we find a single kv_scale in the checkpoint, we remap
# kv_scale to k_scale during weight loading, and duplicate
# k_scale to v_scale here
assert layer.k_scale > 0.0
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
k_scale = scale_to_duplicate.to("cpu").tolist()
v_scale = scale_to_duplicate.to("cpu").tolist()
if not isinstance(k_scale, float) or not isinstance(
v_scale, float):
raise ValueError("Only support per-tensor scaling factor "
"for fp8 KV cache")
# These are used in the final Attention.forward()
layer._k_scale = k_scale
layer._v_scale = v_scale
if (layer._k_scale == 1.0 and layer._v_scale == 1.0
and "e5m2" not in layer.kv_cache_dtype):
print_warning_once(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
"may cause accuracy issues. Please make sure k/v_scale "
"scaling factors are available in the fp8 checkpoint.")
del layer.k_scale
del layer.v_scale
...@@ -100,8 +100,8 @@ class MarlinConfig(QuantizationConfig): ...@@ -100,8 +100,8 @@ class MarlinConfig(QuantizationConfig):
return None return None
def get_quant_method( def get_quant_method(self, layer: torch.nn.Module,
self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]: prefix: str) -> Optional["MarlinLinearMethod"]:
if (isinstance(layer, LinearBase) or if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return MarlinLinearMethod(self) return MarlinLinearMethod(self)
......
...@@ -52,8 +52,8 @@ class SqueezeLLMConfig(QuantizationConfig): ...@@ -52,8 +52,8 @@ class SqueezeLLMConfig(QuantizationConfig):
weight_bits = cls.get_from_keys(config, ["wbits"]) weight_bits = cls.get_from_keys(config, ["wbits"])
return cls(weight_bits) return cls(weight_bits)
def get_quant_method( def get_quant_method(self, layer: torch.nn.Module,
self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]: prefix: str) -> Optional[QuantizeMethodBase]:
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return SqueezeLLMLinearMethod(self) return SqueezeLLMLinearMethod(self)
return None return None
......
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import numpy
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
from .quant_utils import pack_cols, unpack_cols
GPTQ_MARLIN_TILE = 16 GPTQ_MARLIN_TILE = 16
GPTQ_MARLIN_MIN_THREAD_N = 64 GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128 GPTQ_MARLIN_MIN_THREAD_K = 128
GPTQ_MARLIN_MAX_PARALLEL = 16 GPTQ_MARLIN_MAX_PARALLEL = 16
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8] MARLIN_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
GPTQ_MARLIN_SUPPORTED_SYM = [True]
GTPQ_MARLIN_UNSUPPORTED_GROUP_SIZE_ACT_ORDER = [-1]
def _check_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
min_capability: Optional[int],
def check_marlin_supported(num_bits: int, group_size: int, is_sym: bool, has_zp: bool) -> Tuple[bool, Optional[str]]:
min_capability: int) -> bool: if min_capability is not None:
major, minor = current_platform.get_device_capability()
# If the capability of the device is too low, cannot convert. device_capability = major * 10 + minor
major, minor = current_platform.get_device_capability() if device_capability < min_capability:
device_capability = major * 10 + minor return (False, "Marlin does not support device_capability = {}"
if device_capability < min_capability: ", the min_capability required is {}".format(
return False device_capability, min_capability))
return (device_capability >= min_capability if num_bits not in MARLIN_SUPPORTED_NUM_BITS:
and num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS return (False, "Marlin does not support weight_bits = {}. "
and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES "Only weight_bits = {} are supported.".format(
and is_sym in GPTQ_MARLIN_SUPPORTED_SYM) num_bits, MARLIN_SUPPORTED_NUM_BITS))
if group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
def verify_marlin_supported(num_bits: int, group_size: Optional[int], return (False, "Marlin does not support group_size = {}. Only "
is_sym: bool) -> None: "group_sizes = {} are supported.".format(
group_size, MARLIN_SUPPORTED_GROUP_SIZES))
if num_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
raise ValueError( if not has_zp and not is_sym:
f"Marlin does not support weight_bits = {num_bits}. " return (False,
f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} " "Marlin without zero_points must have symmetric quantization")
"are supported.")
if (group_size is None return True, None
or group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES):
raise ValueError(
f"Marlin does not support group_size = {group_size}. " def check_gptq_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} " min_capability: int) -> bool:
"are supported.") cond, _ = _check_marlin_supported(num_bits,
if is_sym not in GPTQ_MARLIN_SUPPORTED_SYM: group_size,
raise ValueError( is_sym,
f"Marlin does not support is_sym = is_sym. " min_capability,
f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.") has_zp=False)
return cond
def check_awq_marlin_supported(num_bits: int, group_size: int, has_zp: bool,
min_capability: int) -> bool:
cond, _ = _check_marlin_supported(num_bits,
group_size,
False,
min_capability,
has_zp=has_zp)
return cond
def verify_gptq_marlin_supported(num_bits: int, group_size: int,
is_sym: bool) -> None:
cond, err_msg = _check_marlin_supported(num_bits,
group_size,
is_sym,
min_capability=None,
has_zp=False)
if not cond:
assert err_msg is not None
raise ValueError("GPTQ" + err_msg)
def verify_awq_marlin_supported(num_bits: int, group_size: int,
has_zp: bool) -> None:
cond, err_msg = _check_marlin_supported(num_bits,
group_size,
False,
min_capability=None,
has_zp=has_zp)
if not cond:
assert err_msg is not None
raise ValueError("AWQ" + err_msg)
def verify_marlin_supports_shape(output_size_per_partition: int, def verify_marlin_supports_shape(output_size_per_partition: int,
...@@ -138,6 +176,51 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, ...@@ -138,6 +176,51 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
return s return s
def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int,
num_bits: int) -> torch.Tensor:
# Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA
scale_perm, _ = get_scale_perms()
zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
# Interleave column dim (for the dequantize code) and pack it to int32
if num_bits == 4:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
zp = zp.reshape((-1, size_n)).contiguous()
zp = pack_cols(zp, num_bits, size_k, size_n)
return zp
def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
size_n: int, num_bits: int) -> torch.Tensor:
# AWQ zero-points are quantized and packed on the column dim.
# In addition, the values are permuted based on dequantizer.
# Here we undo both of these, and then apply marlin permutation
# and pack it back.
q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
# Undo interleaving (use argsort(..) to get inverse perm)
if num_bits == 4:
undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
elif num_bits == 8:
undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
q_zp = q_zp.reshape((-1, size_n)).contiguous()
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
return marlin_zp
# Newly generated tensors need to replace existing tensors that are # Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed) # already registered as parameters by vLLM (and won't be freed)
def replace_tensor(layer: torch.nn.Module, name: str, def replace_tensor(layer: torch.nn.Module, name: str,
...@@ -149,23 +232,61 @@ def replace_tensor(layer: torch.nn.Module, name: str, ...@@ -149,23 +232,61 @@ def replace_tensor(layer: torch.nn.Module, name: str,
del new_t del new_t
def apply_marlin_linear(input: torch.Tensor, def apply_gptq_marlin_linear(
weight: torch.Tensor, input: torch.Tensor,
weight_scale: torch.Tensor, weight: torch.Tensor,
g_idx: torch.Tensor, weight_scale: torch.Tensor,
g_idx_sort_indices: torch.Tensor, weight_zp: torch.Tensor,
workspace: torch.Tensor, g_idx: torch.Tensor,
num_bits: int, g_idx_sort_indices: torch.Tensor,
output_size_per_partition: int, workspace: torch.Tensor,
input_size_per_partition: int, num_bits: int,
is_k_full: bool, output_size_per_partition: int,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: input_size_per_partition: int,
is_k_full: bool,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition, )
output = ops.gptq_marlin_gemm(reshaped_x,
weight,
weight_scale,
weight_zp,
g_idx,
g_idx_sort_indices,
workspace,
num_bits,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
is_k_full=is_k_full,
has_zp=False)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
def apply_awq_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_zp: torch.Tensor,
g_idx: torch.Tensor,
g_idx_sort_indices: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
output_size_per_partition: int,
input_size_per_partition: int,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1]) reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition, ) out_shape = input.shape[:-1] + (output_size_per_partition, )
output = ops.gptq_marlin_gemm(reshaped_x, output = ops.gptq_marlin_gemm(reshaped_x,
weight, weight,
weight_scale, weight_scale,
weight_zp,
g_idx, g_idx,
g_idx_sort_indices, g_idx_sort_indices,
workspace, workspace,
...@@ -173,7 +294,8 @@ def apply_marlin_linear(input: torch.Tensor, ...@@ -173,7 +294,8 @@ def apply_marlin_linear(input: torch.Tensor,
size_m=reshaped_x.shape[0], size_m=reshaped_x.shape[0],
size_n=output_size_per_partition, size_n=output_size_per_partition,
size_k=input_size_per_partition, size_k=input_size_per_partition,
is_k_full=is_k_full) is_k_full=True,
has_zp=True)
if bias is not None: if bias is not None:
output.add_(bias) # In-place add output.add_(bias) # In-place add
......
...@@ -76,8 +76,14 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None: ...@@ -76,8 +76,14 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
# WEIGHT SCALES # WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we # Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise # expand it to channelwise
scales = layer.weight_scale.repeat(1, part_size_n).to( is_channelwise = (len(layer.weight_scale.shape) > 0
layer.orig_dtype).to(device) and layer.weight_scale.shape[0] == part_size_n)
if is_channelwise:
scales = layer.weight_scale
else:
scales = layer.weight_scale.repeat(1, part_size_n)
scales = scales.to(layer.orig_dtype).to(device)
# Permute scales # Permute scales
marlin_scales = marlin_permute_scales(s=scales, marlin_scales = marlin_permute_scales(s=scales,
size_k=part_size_k, size_k=part_size_k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment