Commit 500b93c8 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.3.post1' into v0.5.3.post1-dtk24.04.1

parents 99426767 38c4b7e8
......@@ -60,9 +60,8 @@ class BitsAndBytesConfig(QuantizationConfig):
target_modules = cls.get_from_keys(config, ["target_modules"])
return cls(adapter_name, target_modules)
def get_quant_method(
self,
layer: torch.nn.Module) -> Optional["BitsAndBytesLinearMethod"]:
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["BitsAndBytesLinearMethod"]:
if isinstance(layer, LinearBase):
return BitsAndBytesLinearMethod(self)
return None
......
......@@ -5,25 +5,33 @@ from pydantic import BaseModel
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig)
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS,
CompressedTensorsScheme, CompressedTensorsW4A16Sparse24,
CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8,
CompressedTensorsWNA16)
CompressedTensorsScheme, CompressedTensorsUnquantized,
CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
CompressedTensorsW8A8Int8, CompressedTensorsWNA16)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
CompressionFormat, QuantizationArgs, QuantizationStrategy,
QuantizationType, find_first_name_or_class_match)
QuantizationType, find_matched_target, is_activation_quantization_format,
should_ignore_layer)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform
class CompressedTensorsConfig(QuantizationConfig):
def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str],
quant_format: str):
def __init__(self,
target_scheme_map: Dict[str, Any],
ignore: List[str],
quant_format: str,
kv_cache_scheme: Optional[Dict[str, Any]] = None):
self.ignore = ignore
self.layer_quant_details = layer_quant_details
self.quant_format = quant_format
# Map from [target -> scheme]
self.target_scheme_map = target_scheme_map
self.kv_cache_scheme = kv_cache_scheme
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)
......@@ -36,21 +44,28 @@ class CompressedTensorsConfig(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 75
return 70
def get_name(self) -> str:
return "compressed_tensors"
# TODO (@robertgshaw2-neuralmagic): do layer skipping though here
# rather than though create_weights to match other methods
def get_quant_method(
self, layer: torch.nn.Module
) -> Optional["CompressedTensorsLinearMethod"]:
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
return CompressedTensorsLinearMethod(self)
if isinstance(layer, Attention):
return CompressedTensorsKVCacheMethod(self)
return None
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
layer_quant_details: Dict[str, Any] = dict()
target_scheme_map: Dict[str, Any] = dict()
ignore: List[str] = config.get("ignore", None)
quant_format: str = config.get("format", None)
......@@ -62,35 +77,37 @@ class CompressedTensorsConfig(QuantizationConfig):
# details follow the structure defined by the QuantizationArgs
# pydantic model, which is used to verify the structure of the
# quant_config and also store the details for later use.
for key, quant_config in config["config_groups"].items():
for _, quant_config in config["config_groups"].items():
targets = quant_config.get("targets")
for target in targets:
layer_quant_details[target] = {}
layer_quant_details[target][
target_scheme_map[target] = {}
target_scheme_map[target][
"weights"] = QuantizationArgs.parse_obj(
quant_config.get("weights"))
try:
layer_quant_details[target][
target_scheme_map[target][
"input_activations"] = QuantizationArgs.parse_obj(
quant_config.get("input_activations"))
except Exception:
layer_quant_details[target]["input_activations"] = None
target_scheme_map[target]["input_activations"] = None
return cls(layer_quant_details=layer_quant_details,
return cls(target_scheme_map=target_scheme_map,
ignore=ignore,
quant_format=quant_format)
quant_format=quant_format,
kv_cache_scheme=config.get("kv_cache_scheme"))
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
def _check_gptq_and_marlin_can_run(self):
def _check_scheme_supported(self, min_capability: int):
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < 80:
raise RuntimeError("The quantization config is not supported for ",
"the current GPU. Minimum capability: 80. ",
f"Current capability: {capability}.")
if capability < min_capability:
raise RuntimeError(
"Quantization scheme is not supported for ",
f"the current GPU. Min capability: {min_capability}. ",
f"Current capability: {capability}.")
def _is_static_tensor_w8a8(self, weight_quant: BaseModel,
input_quant: BaseModel) -> bool:
......@@ -132,10 +149,11 @@ class CompressedTensorsConfig(QuantizationConfig):
# Confirm weight scheme is supported.
is_symmetric_weight = weight_quant.symmetric
is_static_weight = not weight_quant.dynamic
is_per_tensor_weight = (
weight_quant.strategy == QuantizationStrategy.TENSOR)
is_per_tensor_or_channel_weight = (weight_quant.strategy in [
QuantizationStrategy.TENSOR, QuantizationStrategy.CHANNEL
])
if not (is_symmetric_weight and is_static_weight
and is_per_tensor_weight):
and is_per_tensor_or_channel_weight):
return False
# Dynamic quantization is always supported if weights supported.
......@@ -164,11 +182,12 @@ class CompressedTensorsConfig(QuantizationConfig):
return (is_channel_group and input_quant_none and is_symmetric
and is_static)
def _get_schema(self, weight_quant: BaseModel,
input_quant: BaseModel) -> "CompressedTensorsScheme":
def _get_scheme_from_parts(
self, weight_quant: BaseModel,
input_quant: BaseModel) -> "CompressedTensorsScheme":
# Detect If Mixed Precision
if self._is_wNa16_group_channel(weight_quant, input_quant):
self._check_gptq_and_marlin_can_run()
if (self.quant_format == CompressionFormat.marlin_24.value
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS):
return CompressedTensorsW4A16Sparse24(
......@@ -182,11 +201,12 @@ class CompressedTensorsConfig(QuantizationConfig):
strategy=weight_quant.strategy,
group_size=weight_quant.group_size)
if (self.quant_format == CompressionFormat.int_quantized.value or
self.quant_format == CompressionFormat.float_quantized.value):
# Detect If Activation Quantization.
if is_activation_quantization_format(self.quant_format):
if self._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8(
input_dynamic=input_quant.dynamic)
strategy=weight_quant.strategy,
is_static_input_scheme=(not input_quant.dynamic))
if self._is_static_tensor_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8(
......@@ -201,26 +221,53 @@ class CompressedTensorsConfig(QuantizationConfig):
raise NotImplementedError(
"No compressed-tensors compatible scheme was found.")
def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
def get_scheme(
self,
layer: torch.nn.Module,
layer_name: Optional[str] = None) -> "CompressedTensorsScheme":
"""
compressed-tensors supports non uniform in the following way:
ignore: List of layer_names or nn.Module names to be ignored.
targets of config_groups: There can be N config_groups which each
have a quantization scheme. Each config_group has a list of targets
which can be a full layer_name, a regex for a layer_name, or
an nn.Module name.
layer_type_name = find_first_name_or_class_match(
name="",
We first check whether a layer is in the ignore group and use
CompressedTensorsUnquantized (i.e. fp16/bf16) scheme for the layer
We then detect whether a layer_name is found in any target and
use the quantization scheme corresponding to the matched target
to select the CompressedTensorsScheme used for infernece.
"""
# Check if the layer is skipped for quantization.
# TODO (@robertgshaw2): support module names
if should_ignore_layer(layer_name, ignore=self.ignore):
return CompressedTensorsUnquantized()
# Find the "target" in the compressed-tensors config
# that our layer conforms to.
# TODO (@robertgshaw): add compressed-tensors as dep
# so we do not have to re-write these functions
matched_target = find_matched_target(
layer_name=layer_name,
module=layer,
targets=self.layer_quant_details.keys(),
check_contains=True)
targets=self.target_scheme_map.keys())
# Find the quant_scheme
scheme = self.target_scheme_map[matched_target]
if layer_type_name is None:
raise ValueError(f"Could not matching target for layer {layer}")
return self._get_scheme_from_parts(
weight_quant=scheme["weights"],
input_quant=scheme["input_activations"])
layer_quant_details: Dict[str, Any] = self.layer_quant_details.get(
layer_type_name, None)
if layer_quant_details is None:
raise ValueError(
f"Could not find quantization details for {layer}.")
# Raise error if device does not support the scheme
# (e.g. fp8 needs ada lovelace)
self._check_scheme_supported(scheme.get_min_capability())
return self._get_schema(
weight_quant=layer_quant_details["weights"],
input_quant=layer_quant_details["input_activations"])
return scheme
class CompressedTensorsLinearMethod(LinearMethodBase):
......@@ -240,11 +287,11 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
Use the CompressedTensorsScheme associated with each layer to create
the necessary parameters for the layer. See LinearMethodBase for param
details
"""
weight_loader = extra_weight_attrs.get("weight_loader")
layer_name = extra_weight_attrs.get("prefix")
scheme = self.quantization_config.get_scheme(layer=layer)
scheme = self.quantization_config.get_scheme(layer, layer_name)
scheme.create_weights(
layer=layer,
input_size=input_size,
......@@ -271,3 +318,47 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
if scheme is None:
raise ValueError("A scheme must be defined for each layer")
return scheme.apply_weights(layer, x, bias=bias)
class CompressedTensorsKVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from compressed-tensors
checkpoints.
"""
def __init__(self, quant_config: CompressedTensorsConfig):
self.validate_kv_cache_scheme(quant_config.kv_cache_scheme)
super().__init__(quant_config)
@staticmethod
def validate_kv_cache_scheme(kv_cache_scheme: Optional[Dict[str, Any]]):
"""
Validator for the kv cache scheme. Useful for controlling the
kv cache quantization schemes, that are being supported in vLLM
:param kv_cache_scheme: the compressed-tensors kv cache scheme
"""
if kv_cache_scheme is None:
return
type_ = kv_cache_scheme.get("type")
num_bits = kv_cache_scheme.get("num_bits")
if type_ != "float" and num_bits != 8:
raise NotImplementedError(
"Currently supported kv cache quantization is "
"num_bits=8, type=float, however "
f"received num_bits={num_bits}, type={type_}")
strategy = kv_cache_scheme.get("strategy")
if strategy != "tensor":
raise NotImplementedError(
"Only support per-tensor scaling factor "
"for compressed-tensors KV cache. "
f"Expected strategy: tensor, found strategy: {strategy}")
is_symmetric = kv_cache_scheme.get("symmetric")
if not is_symmetric:
raise NotImplementedError(
"Only support symmetric scaling factor "
"for compressed-tensors KV cache. "
f"However found symmetric: {is_symmetric}")
......@@ -12,6 +12,13 @@ class CompressedTensorsScheme(ABC):
of different quantization schemes supported by CompressedTensors.
"""
@abstractmethod
def get_min_capability(self) -> int:
"""
Get minimum device capability.
"""
raise NotImplementedError
@abstractmethod
def create_weights(self, *args, **kwargs):
"""
......
......@@ -18,6 +18,10 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
in a linear transformation.
"""
def get_min_capability(self) -> int:
# volta and up
return 70
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
......@@ -29,7 +33,6 @@ class CompressedTensorsUnquantized(CompressedTensorsScheme):
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
device="cuda",
dtype=params_dtype),
requires_grad=False)
......
......@@ -29,6 +29,10 @@ class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
raise ValueError(
"group_size must be given when using strategy group")
def get_min_capability(self) -> int:
# ampere + up
return 80
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
......
from typing import Callable, List, Optional
import torch
from torch.nn import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationStrategy)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, create_per_tensor_scale_param, cutlass_fp8_supported,
apply_fp8_linear, create_per_channel_scale_param,
create_per_tensor_scale_param, cutlass_fp8_supported,
requantize_with_max_scale)
from vllm.model_executor.utils import set_weight_attrs
......@@ -14,39 +18,49 @@ __all__ = ["CompressedTensorsW8A8Fp8"]
class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
def __init__(self, input_dynamic: bool):
self.input_dynamic = input_dynamic
def __init__(self, strategy: str, is_static_input_scheme: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.cutlass_fp8_supported = cutlass_fp8_supported()
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
# So if we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), we requantize with a single scale.
def get_min_capability(self) -> int:
# lovelace and up
return 89
def process_weights_after_loading(self, layer) -> None:
# Dequant -> Quant with max scale.
max_w_scale, weight = requantize_with_max_scale(
weight=layer.weight,
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
)
# Update layer with new values.
layer.weight = torch.nn.Parameter(weight.t(), requires_grad=False)
layer.weight_scale = torch.nn.Parameter(max_w_scale,
requires_grad=False)
if self.input_dynamic:
layer.input_scale = None
# If per tensor, when we have a fused module (e.g. QKV) with per
# tensor scales (thus N scales being passed to the kernel),
# requantize so we can always run per tensor
if self.strategy == QuantizationStrategy.TENSOR:
max_w_scale, weight = requantize_with_max_scale(
weight=layer.weight,
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
)
layer.weight = Parameter(weight.t(), requires_grad=False)
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
# If channelwise, scales are already lined up, so just transpose.
elif self.strategy == QuantizationStrategy.CHANNEL:
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)
else:
raise ValueError(f"Unknown quantization strategy {self.strategy}")
# INPUT SCALE
if self.is_static_input_scheme:
layer.input_scale = Parameter(layer.input_scale.max(),
requires_grad=False)
else:
layer.input_scale = torch.nn.Parameter(layer.input_scale.max(),
requires_grad=False)
layer.input_scale = None
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
del params_dtype
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
......@@ -63,12 +77,17 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
})
# WEIGHT SCALE
weight_scale = create_per_tensor_scale_param(
output_partition_sizes, weight_loader=weight_loader)
if self.strategy == QuantizationStrategy.CHANNEL:
weight_scale = create_per_channel_scale_param(
output_partition_sizes, weight_loader=weight_loader)
else:
assert self.strategy == QuantizationStrategy.TENSOR
weight_scale = create_per_tensor_scale_param(
output_partition_sizes, weight_loader=weight_loader)
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE
if not self.input_dynamic:
if self.is_static_input_scheme:
input_scale = create_per_tensor_scale_param(
output_partition_sizes, weight_loader=weight_loader)
layer.register_parameter("input_scale", input_scale)
......@@ -84,4 +103,5 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported)
cutlass_fp8_supported=self.cutlass_fp8_supported,
use_per_token_if_dynamic=True)
......@@ -19,6 +19,10 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
def get_min_capability(self) -> int:
# turing and up
return 75
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# WEIGHT
# Cutlass kernels need transposed weight.
......
......@@ -7,8 +7,8 @@ from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, replace_tensor, verify_marlin_supported,
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, replace_tensor, verify_gptq_marlin_supported,
verify_marlin_supports_shape)
from vllm.model_executor.utils import set_weight_attrs
......@@ -38,9 +38,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
self.group_size = group_size
# Verify supported on platform.
verify_marlin_supported(num_bits=self.num_bits,
group_size=self.group_size,
is_sym=True)
verify_gptq_marlin_supported(num_bits=self.num_bits,
group_size=self.group_size,
is_sym=True)
def get_min_capability(self) -> int:
# ampere and up
return 80
def create_weights(self, layer: torch.nn.Module, input_size: int,
output_partition_sizes: List[int],
......@@ -131,6 +135,9 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
# No zero-point
layer.weight_zp = marlin_make_empty_g_idx(device)
# Repack weights from compressed-tensors format to marlin format.
marlin_qweight = ops.gptq_marlin_repack(
layer.weight_packed.t().contiguous(),
......@@ -151,10 +158,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
return apply_marlin_linear(
return apply_gptq_marlin_linear(
input=x,
weight=layer.weight_packed,
weight_scale=layer.weight_scale,
weight_zp=layer.weight_zp,
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
......
......@@ -9,6 +9,7 @@ from torch.nn import Module
class CompressionFormat(Enum):
dense = "dense"
sparse_bitmask = "sparse-bitmask"
naive_quantized = "naive-quantized"
float_quantized = "float-quantized"
int_quantized = "int-quantized"
pack_quantized = "pack-quantized"
......@@ -76,25 +77,115 @@ class QuantizationArgs(BaseModel):
)
def find_first_name_or_class_match(
name: str,
module: Module,
targets: Iterable[str],
check_contains: bool = False) -> Optional[str]:
def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [
CompressionFormat.naive_quantized.value,
CompressionFormat.int_quantized.value,
CompressionFormat.float_quantized.value
]
return format in _ACTIVATION_QUANTIZATION_FORMATS
# fused_name: List[shard_name]
_FUSED_LAYER_NAME_MAPPING = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
def should_ignore_layer(layer_name: Optional[str],
ignore: Iterable[str]) -> bool:
if layer_name is None:
return False
# layer_name = model.layers.0.self_attn.qkv_proj
# proj_name = qkv_proj
proj_name = layer_name.split(".")[-1]
# Fused layers like gate_up_proj or qkv_proj will not be fused
# in the safetensors checkpoint. So, we convert the name
# from the fused version to unfused + check to make sure that
# each shard of the fused layer has the same scheme.
if proj_name in _FUSED_LAYER_NAME_MAPPING:
shard_proj_names = _FUSED_LAYER_NAME_MAPPING[proj_name]
# Convert fused_name --> [shard_names]
shard_names = [
layer_name.replace(proj_name, shard_proj_name)
for shard_proj_name in shard_proj_names
]
# Layer should be ignored if shards are ignored.
should_ignore_layer = None
for shard_name in shard_names:
should_ignore_shard = check_equal_or_regex_match(
layer_name=shard_name, targets=ignore)
# If shard_idx=0, set layer ignore to match shard.
if should_ignore_layer is None:
should_ignore_layer = should_ignore_shard
# If shard_idx=1+ confirm scheme matches prior shards.
elif should_ignore_shard != should_ignore_layer:
raise ValueError(f"Found a different quantization schemes for "
f"{shard_proj_names} in {layer_name}. vLLM "
"requires all to use the same scheme.")
# Unfused layers like down_proj and o_proj will match
# the safetensors checkpoint already.
else:
should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name,
targets=ignore)
assert should_ignore_layer is not None
return should_ignore_layer
def check_equal_or_regex_match(layer_name: str,
targets: Iterable[str]) -> bool:
"""
Helper function to map the quantization details listed in the config
for a given list of targets against each model layer. First uses the
layer name to try and find a match. If no name match is found, uses
the layer class name. Returns None otherwise.
Checks whether a layer_name is exactly equal or a regex match for
if target starts with 're:' to any target in list.
"""
for target in targets:
if _is_equal_or_regex_match(layer_name, target):
return True
return False
def find_matched_target(layer_name: Optional[str], module: Module,
targets: Iterable[str]) -> str:
"""
Helper function to look up which "target" in the compressed-tensors
config that a layer corresponds to.
Recall that a compressed-tensors configs has a concept of
config_groups, where each layer can be quantized with with a different
scheme.
:param name: layer name
targets in each config_group will be a list of either layer names
(or regexes corresponding to layer names) or names of torch Modules.
First, we try to match the layer_name with a target
Second, we try to match the module's name with a target
:param layer_name: layer name
:param module: torch.nn.Module
:param targets: list of targets to match the layer against
:param check_contains: whether or not to do a substring match
"""
return _find_first_match(name, targets) or _find_first_match(
module.__class__.__name__, targets, check_contains)
if layer_name is None:
layer_name = ""
matched_target = (_find_first_match(layer_name, targets)
or _find_first_match(module.__class__.__name__, targets,
True))
if matched_target is None:
raise ValueError(f"Unable to find matching target for {module} in the "
"compressed-tensors config.")
return matched_target
def _find_first_match(value: str,
......@@ -111,13 +202,46 @@ def _find_first_match(value: str,
"""
for target in targets:
if target.startswith("re:"):
pattern = target[3:]
if re.match(pattern, value):
return target
elif check_contains:
if target.lower() in value.lower():
return target
elif target == value:
if _is_equal_or_regex_match(value,
target,
check_contains=check_contains):
return target
return None
def get_compressed_tensors_cache_scale(name: str) -> Optional[str]:
"""
Check whether the param name matches the format for k/v cache scales
in compressed-tensors. If this is the case, return its equivalent
param name expected by vLLM
:param name: param name
:return: matching param name for KV cache scale in vLLM
"""
if name.endswith(".output_scale") and ".k_proj" in name:
return name.replace(".k_proj.output_scale", ".attn.k_scale")
if name.endswith(".output_scale") and ".v_proj" in name:
return name.replace(".v_proj.output_scale", ".attn.v_scale")
# If no matches, return None
return None
def _is_equal_or_regex_match(value: str,
target: str,
check_contains: bool = False) -> bool:
"""
Checks whether a value is exactly equal or a regex match for target
if target starts with 're:'. If check_contains is set to True,
additionally checks if the target string is contained within the value.
"""
if target.startswith("re:"):
pattern = target[3:]
if re.match(pattern, value):
return True
elif check_contains:
if target.lower() in value.lower():
return True
elif target == value:
return True
return False
......@@ -69,9 +69,8 @@ class DeepSpeedFPConfig(QuantizationConfig):
"quantize_config.json",
]
def get_quant_method(
self,
layer: torch.nn.Module) -> Optional["DeepSpeedFPLinearMethod"]:
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["DeepSpeedFPLinearMethod"]:
if isinstance(layer, LinearBase):
return DeepSpeedFPLinearMethod(self)
return None
......
from typing import Any, Dict, List, Optional
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear, create_per_channel_scale_param)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
logger = init_logger(__name__)
# Note: this is a hack. We should update each model to register the
# stacked params and get it from there instead in a future PR.
# fused_name: List[shard_name]
_FUSED_LAYER_NAME_MAPPING = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
}
class FBGEMMFp8Config(QuantizationConfig):
"""Config class for FBGEMM Fp8."""
def __init__(self, ignore_list: List[str], input_scale_ub: float):
self.ignore_list = ignore_list if ignore_list else []
self.input_scale_ub = input_scale_ub
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
self.use_marlin = capability < 89
@classmethod
def get_name(cls) -> str:
return "fbgemm_fp8"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.bfloat16, torch.float16]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "FBGEMMFp8Config":
ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"])
input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"])
return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub)
def _is_layer_skipped(self, prefix: str) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name = prefix.split(".")[-1]
if proj_name in _FUSED_LAYER_NAME_MAPPING:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in _FUSED_LAYER_NAME_MAPPING[proj_name]
]
is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = shard_prefix in self.ignore_list
if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision.")
else:
is_skipped = prefix in self.ignore_list
assert is_skipped is not None
return is_skipped
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if self._is_layer_skipped(prefix):
return UnquantizedLinearMethod()
return FBGEMMFp8LinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class FBGEMMFp8LinearMethod(LinearMethodBase):
def __init__(self, quant_config: FBGEMMFp8Config):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del input_size, output_size
output_size_per_partition = sum(output_partition_sizes)
layer.logical_widths = output_partition_sizes
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.orig_dtype = params_dtype
# WEIGHT
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=torch.float8_e4m3fn),
requires_grad=False)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
"input_dim": 1,
"output_dim": 0,
**extra_weight_attrs,
})
# WEIGHT SCALE
weight_scale = create_per_channel_scale_param(output_partition_sizes,
**extra_weight_attrs)
layer.register_parameter("weight_scale", weight_scale)
# INPUT SCALE UPPER BOUND
input_scale_ub = torch.nn.Parameter(torch.tensor(
(self.quant_config.input_scale_ub), dtype=torch.float32),
requires_grad=False)
layer.input_scale_ub = input_scale_ub
def process_weights_after_loading(self, layer: Module) -> None:
weight = layer.weight
layer.weight = Parameter(weight.t(), requires_grad=False)
if self.quant_config.use_marlin:
prepare_fp8_layer_for_marlin(layer)
# Activations not quantized for marlin.
del layer.input_scale_ub
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.quant_config.use_marlin:
return apply_fp8_marlin_linear(
input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
workspace=layer.workspace,
size_n=layer.output_size_per_partition,
size_k=layer.input_size_per_partition,
bias=bias)
return apply_fp8_linear(input=x,
weight=layer.weight,
weight_scale=layer.weight_scale,
input_scale=None,
input_scale_ub=layer.input_scale_ub,
bias=bias,
cutlass_fp8_supported=True,
use_per_token_if_dynamic=True)
......@@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
......@@ -66,8 +67,8 @@ class Fp8Config(QuantizationConfig):
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme)
def get_quant_method(
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
......@@ -214,7 +215,8 @@ class Fp8LinearMethod(LinearMethodBase):
weight_scale=layer.weight_scale,
input_scale=layer.input_scale,
bias=bias,
cutlass_fp8_supported=self.cutlass_fp8_supported)
cutlass_fp8_supported=self.cutlass_fp8_supported,
use_per_token_if_dynamic=False)
class Fp8MoEMethod(FusedMoEMethodBase):
......@@ -399,39 +401,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
topk_group=topk_group)
class Fp8KVCacheMethod(QuantizeMethodBase):
"""Supports loading kv-cache scaling factors from FP8 checkpoints.
class Fp8KVCacheMethod(BaseKVCacheMethod):
"""
Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module):
"""Create "weight" (aka kv_scale) for an attention layer.
Args:
layer: The layer that is using the QuantizeMethodBase factory.
"""
# Initialize the KV cache scale to 1.0 as the default value.
# If the kv_scale appears in the checkpoint, it will be
# overwritten when loading weights.
layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False)
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")
def process_weights_after_loading(self, layer: Module) -> None:
# If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
if layer.kv_cache_dtype != "auto":
kv_scale = layer.kv_scale.to("cpu").tolist()
if not isinstance(kv_scale, float):
raise ValueError("Only support per-tensor scaling factor "
"for fp8 KV cache")
layer._kv_scale = kv_scale
if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
print_warning_once(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
"cause accuracy issues. Please make sure kv-cache scaling "
"factor is available in the fp8 checkpoint.")
del layer.kv_scale
super().__init__(quant_config)
......@@ -69,8 +69,8 @@ class GPTQConfig(QuantizationConfig):
default=False)
return cls(weight_bits, group_size, desc_act, lm_head_quantized)
def get_quant_method(
self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]:
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["GPTQLinearMethod"]:
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return GPTQLinearMethod(self)
......
......@@ -10,10 +10,10 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_marlin_linear, check_marlin_supported, marlin_is_k_full,
apply_gptq_marlin_linear, check_gptq_marlin_supported, marlin_is_k_full,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor,
verify_marlin_supported, verify_marlin_supports_shape)
verify_gptq_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
logger = init_logger(__name__)
......@@ -37,9 +37,9 @@ class GPTQMarlinConfig(QuantizationConfig):
self.lm_head_quantized = lm_head_quantized
# Verify supported on platform.
verify_marlin_supported(num_bits=self.weight_bits,
group_size=self.group_size,
is_sym=self.is_sym)
verify_gptq_marlin_supported(num_bits=self.weight_bits,
group_size=self.group_size,
is_sym=self.is_sym)
def __repr__(self) -> str:
return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
......@@ -77,7 +77,7 @@ class GPTQMarlinConfig(QuantizationConfig):
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
can_convert = cls.is_marlin_compatible(hf_quant_cfg)
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "marlin")
......@@ -94,9 +94,8 @@ class GPTQMarlinConfig(QuantizationConfig):
" faster inference")
return None
def get_quant_method(
self,
layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["GPTQMarlinLinearMethod"]:
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return GPTQMarlinLinearMethod(self)
......@@ -106,22 +105,27 @@ class GPTQMarlinConfig(QuantizationConfig):
return []
@classmethod
def is_marlin_compatible(cls, quant_config: Dict[str, Any]):
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None)
sym = quant_config.get("sym", None)
desc_act = quant_config.get("desc_act", None)
if quant_method != "gptq":
return False
# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or sym is None
or desc_act is None):
return False
return check_marlin_supported(num_bits=num_bits,
group_size=group_size,
is_sym=sym,
min_capability=cls.get_min_capability())
return check_gptq_marlin_supported(
num_bits=num_bits,
group_size=group_size,
is_sym=sym,
min_capability=cls.get_min_capability())
class GPTQMarlinLinearMethod(LinearMethodBase):
......@@ -279,6 +283,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
# No zero-point
layer.zp = marlin_make_empty_g_idx(device)
# Repack weights from autogptq format to marlin format.
marlin_qweight = ops.gptq_marlin_repack(
layer.qweight,
......@@ -303,10 +310,11 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return apply_marlin_linear(
return apply_gptq_marlin_linear(
input=x,
weight=layer.qweight,
weight_scale=layer.scales,
weight_zp=layer.zp,
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
......
......@@ -109,9 +109,8 @@ class GPTQMarlin24Config(QuantizationConfig):
return None
def get_quant_method(
self,
layer: torch.nn.Module) -> Optional["GPTQMarlin24LinearMethod"]:
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["GPTQMarlin24LinearMethod"]:
if isinstance(layer, LinearBase):
return GPTQMarlin24LinearMethod(self)
return None
......
import torch
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.utils import print_warning_once
class BaseKVCacheMethod(QuantizeMethodBase):
"""
Quant method that adds `_k_scale` and `_v_scale` attributes to the
Attention layer to support loading those scaling factors from checkpoints.
The k/v_scale will be used to:
- quantize k/v_cache entries before saving them to the cache
- dequantize k/v_cache entries before fetching them from the cache
:param quant_config: the appropriate QuantizationConfig
"""
def __init__(self, quant_config: QuantizationConfig):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module):
"""
Create "weight" (aka k_scale and v_scale) for an attention layer.
"""
# Initialize the KV cache scales to -1.0, which is an invalid value.
# If the k/v_scale appears in the checkpoint, it will be
# overwritten when loading weights.
layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0),
requires_grad=False)
layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0),
requires_grad=False)
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError(
f"{self.__class__.__name__}.apply should not be called.")
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
if layer.kv_cache_dtype != "auto":
if layer.k_scale > 0.0 and layer.v_scale > 0.0:
# We prefer to use separate k_scale and v_scale if present
k_scale = layer.k_scale.to("cpu").tolist()
v_scale = layer.v_scale.to("cpu").tolist()
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
# If no scales were loaded (both scales are invalid negative
# values), use the default value of 1.0
k_scale = torch.nn.Parameter(torch.tensor(1.0),
requires_grad=False)
v_scale = torch.nn.Parameter(torch.tensor(1.0),
requires_grad=False)
else:
# If we find a single kv_scale in the checkpoint, we remap
# kv_scale to k_scale during weight loading, and duplicate
# k_scale to v_scale here
assert layer.k_scale > 0.0
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
k_scale = scale_to_duplicate.to("cpu").tolist()
v_scale = scale_to_duplicate.to("cpu").tolist()
if not isinstance(k_scale, float) or not isinstance(
v_scale, float):
raise ValueError("Only support per-tensor scaling factor "
"for fp8 KV cache")
# These are used in the final Attention.forward()
layer._k_scale = k_scale
layer._v_scale = v_scale
if (layer._k_scale == 1.0 and layer._v_scale == 1.0
and "e5m2" not in layer.kv_cache_dtype):
print_warning_once(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
"may cause accuracy issues. Please make sure k/v_scale "
"scaling factors are available in the fp8 checkpoint.")
del layer.k_scale
del layer.v_scale
......@@ -100,8 +100,8 @@ class MarlinConfig(QuantizationConfig):
return None
def get_quant_method(
self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["MarlinLinearMethod"]:
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
return MarlinLinearMethod(self)
......
......@@ -52,8 +52,8 @@ class SqueezeLLMConfig(QuantizationConfig):
weight_bits = cls.get_from_keys(config, ["wbits"])
return cls(weight_bits)
def get_quant_method(
self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]:
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional[QuantizeMethodBase]:
if isinstance(layer, LinearBase):
return SqueezeLLMLinearMethod(self)
return None
......
from typing import List, Optional, Tuple
import numpy
import torch
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from .quant_utils import pack_cols, unpack_cols
GPTQ_MARLIN_TILE = 16
GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128
GPTQ_MARLIN_MAX_PARALLEL = 16
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
GPTQ_MARLIN_SUPPORTED_SYM = [True]
GTPQ_MARLIN_UNSUPPORTED_GROUP_SIZE_ACT_ORDER = [-1]
def check_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
min_capability: int) -> bool:
# If the capability of the device is too low, cannot convert.
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
if device_capability < min_capability:
return False
return (device_capability >= min_capability
and num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
and is_sym in GPTQ_MARLIN_SUPPORTED_SYM)
def verify_marlin_supported(num_bits: int, group_size: Optional[int],
is_sym: bool) -> None:
if num_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
raise ValueError(
f"Marlin does not support weight_bits = {num_bits}. "
f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} "
"are supported.")
if (group_size is None
or group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES):
raise ValueError(
f"Marlin does not support group_size = {group_size}. "
f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} "
"are supported.")
if is_sym not in GPTQ_MARLIN_SUPPORTED_SYM:
raise ValueError(
f"Marlin does not support is_sym = is_sym. "
f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.")
MARLIN_SUPPORTED_NUM_BITS = [4, 8]
MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
def _check_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
min_capability: Optional[int],
has_zp: bool) -> Tuple[bool, Optional[str]]:
if min_capability is not None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
if device_capability < min_capability:
return (False, "Marlin does not support device_capability = {}"
", the min_capability required is {}".format(
device_capability, min_capability))
if num_bits not in MARLIN_SUPPORTED_NUM_BITS:
return (False, "Marlin does not support weight_bits = {}. "
"Only weight_bits = {} are supported.".format(
num_bits, MARLIN_SUPPORTED_NUM_BITS))
if group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
return (False, "Marlin does not support group_size = {}. Only "
"group_sizes = {} are supported.".format(
group_size, MARLIN_SUPPORTED_GROUP_SIZES))
if not has_zp and not is_sym:
return (False,
"Marlin without zero_points must have symmetric quantization")
return True, None
def check_gptq_marlin_supported(num_bits: int, group_size: int, is_sym: bool,
min_capability: int) -> bool:
cond, _ = _check_marlin_supported(num_bits,
group_size,
is_sym,
min_capability,
has_zp=False)
return cond
def check_awq_marlin_supported(num_bits: int, group_size: int, has_zp: bool,
min_capability: int) -> bool:
cond, _ = _check_marlin_supported(num_bits,
group_size,
False,
min_capability,
has_zp=has_zp)
return cond
def verify_gptq_marlin_supported(num_bits: int, group_size: int,
is_sym: bool) -> None:
cond, err_msg = _check_marlin_supported(num_bits,
group_size,
is_sym,
min_capability=None,
has_zp=False)
if not cond:
assert err_msg is not None
raise ValueError("GPTQ" + err_msg)
def verify_awq_marlin_supported(num_bits: int, group_size: int,
has_zp: bool) -> None:
cond, err_msg = _check_marlin_supported(num_bits,
group_size,
False,
min_capability=None,
has_zp=has_zp)
if not cond:
assert err_msg is not None
raise ValueError("AWQ" + err_msg)
def verify_marlin_supports_shape(output_size_per_partition: int,
......@@ -138,6 +176,51 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
return s
def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int,
num_bits: int) -> torch.Tensor:
# Permute zero-points in a similar way to scales, but do not use the
# "single" permutation, since zero-points are applied on every MMA
scale_perm, _ = get_scale_perms()
zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
# Interleave column dim (for the dequantize code) and pack it to int32
if num_bits == 4:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
zp = zp.reshape((-1, size_n)).contiguous()
zp = pack_cols(zp, num_bits, size_k, size_n)
return zp
def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
size_n: int, num_bits: int) -> torch.Tensor:
# AWQ zero-points are quantized and packed on the column dim.
# In addition, the values are permuted based on dequantizer.
# Here we undo both of these, and then apply marlin permutation
# and pack it back.
q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
# Undo interleaving (use argsort(..) to get inverse perm)
if num_bits == 4:
undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
elif num_bits == 8:
undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
q_zp = q_zp.reshape((-1, size_n)).contiguous()
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
return marlin_zp
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def replace_tensor(layer: torch.nn.Module, name: str,
......@@ -149,23 +232,61 @@ def replace_tensor(layer: torch.nn.Module, name: str,
del new_t
def apply_marlin_linear(input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
g_idx: torch.Tensor,
g_idx_sort_indices: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
output_size_per_partition: int,
input_size_per_partition: int,
is_k_full: bool,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def apply_gptq_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_zp: torch.Tensor,
g_idx: torch.Tensor,
g_idx_sort_indices: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
output_size_per_partition: int,
input_size_per_partition: int,
is_k_full: bool,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition, )
output = ops.gptq_marlin_gemm(reshaped_x,
weight,
weight_scale,
weight_zp,
g_idx,
g_idx_sort_indices,
workspace,
num_bits,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
is_k_full=is_k_full,
has_zp=False)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
def apply_awq_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
weight_zp: torch.Tensor,
g_idx: torch.Tensor,
g_idx_sort_indices: torch.Tensor,
workspace: torch.Tensor,
num_bits: int,
output_size_per_partition: int,
input_size_per_partition: int,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (output_size_per_partition, )
output = ops.gptq_marlin_gemm(reshaped_x,
weight,
weight_scale,
weight_zp,
g_idx,
g_idx_sort_indices,
workspace,
......@@ -173,7 +294,8 @@ def apply_marlin_linear(input: torch.Tensor,
size_m=reshaped_x.shape[0],
size_n=output_size_per_partition,
size_k=input_size_per_partition,
is_k_full=is_k_full)
is_k_full=True,
has_zp=True)
if bias is not None:
output.add_(bias) # In-place add
......
......@@ -76,8 +76,14 @@ def prepare_fp8_layer_for_marlin(layer: torch.nn.Module) -> None:
# WEIGHT SCALES
# Currently Marlin doesn't support per-tensor scales, so we
# expand it to channelwise
scales = layer.weight_scale.repeat(1, part_size_n).to(
layer.orig_dtype).to(device)
is_channelwise = (len(layer.weight_scale.shape) > 0
and layer.weight_scale.shape[0] == part_size_n)
if is_channelwise:
scales = layer.weight_scale
else:
scales = layer.weight_scale.repeat(1, part_size_n)
scales = scales.to(layer.orig_dtype).to(device)
# Permute scales
marlin_scales = marlin_permute_scales(s=scales,
size_k=part_size_k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment