Commit b9e12416 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.4.3

parents e5d707db e9d3aa04
from enum import IntEnum
import torch
import torch.nn as nn
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
PoolingTensors)
from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput
class PoolingType(IntEnum):
"""Enumeration for different types of pooling methods."""
LAST = 0
class Pooler(nn.Module):
"""A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `PoolerOutput`.
Attributes:
pooling_type: The type of pooling to use (LAST, AVERAGE, MAX).
normalize: Whether to normalize the pooled data.
"""
def __init__(self, pooling_type: PoolingType, normalize: bool):
super().__init__()
self.pooling_type = pooling_type
self.normalize = normalize
def forward(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> PoolerOutput:
"""Pools specific information from hidden states based on metadata."""
prompt_lens = PoolingTensors.from_pooling_metadata(
pooling_metadata, hidden_states.device).prompt_lens
if self.pooling_type == PoolingType.LAST:
last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1
pooled_data = hidden_states[last_token_flat_indices]
else:
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
if self.normalize:
pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
pooled_outputs = [
EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data
]
return PoolerOutput(outputs=pooled_outputs)
...@@ -4,21 +4,32 @@ from vllm.model_executor.layers.quantization.aqlm import AQLMConfig ...@@ -4,21 +4,32 @@ from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsConfig)
from vllm.model_executor.layers.quantization.deepspeedfp import (
DeepSpeedFPConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import ( from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig) GPTQMarlinConfig)
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
GPTQMarlin24Config)
from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"aqlm": AQLMConfig, "aqlm": AQLMConfig,
"awq": AWQConfig, "awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig,
"fp8": Fp8Config, "fp8": Fp8Config,
# The order of gptq methods is important for config.py iteration over
# override_quantization_method(..)
"marlin": MarlinConfig,
"gptq_marlin_24": GPTQMarlin24Config,
"gptq_marlin": GPTQMarlinConfig,
"gptq": GPTQConfig, "gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig, "squeezellm": SqueezeLLMConfig,
"gptq_marlin": GPTQMarlinConfig, "sparseml": CompressedTensorsConfig,
"marlin": MarlinConfig,
} }
......
...@@ -192,7 +192,7 @@ class AQLMConfig(QuantizationConfig): ...@@ -192,7 +192,7 @@ class AQLMConfig(QuantizationConfig):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
return 70 return 60
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> List[str]:
......
...@@ -66,6 +66,17 @@ class QuantizationConfig(ABC): ...@@ -66,6 +66,17 @@ class QuantizationConfig(ABC):
"""Create a config class from the model's quantization config.""" """Create a config class from the model's quantization config."""
raise NotImplementedError raise NotImplementedError
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
"""
Detects if this quantization method can support a given checkpoint
format by overriding the user specified quantization method --
this method should only be overwritten by subclasses in exceptional
circumstances
"""
return None
@staticmethod @staticmethod
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
"""Get a value from the model's quantization config.""" """Get a value from the model's quantization config."""
......
from typing import Any, Dict, List, Optional
import torch
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsW8A8StaticTensor)
class CompressedTensorsConfig(QuantizationConfig):
def __init__(self, layer_quant_details: Dict[str, Any], ignore: List[str]):
self.ignore = ignore
self.layer_quant_details = layer_quant_details
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
return CompressedTensorsLinearMethod(self)
def get_scaled_act_names(self) -> List[str]:
return []
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16]
# Need to figure it out
def get_min_capability(self) -> int:
return 60
def get_name(self) -> str:
return "compressed_tensors"
def get_quant_method(
self, layer: torch.nn.Module
) -> Optional["CompressedTensorsLinearMethod"]:
if isinstance(layer, LinearBase):
return CompressedTensorsLinearMethod(self)
return None
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
layer_quant_details: Dict[str, Any] = dict()
ignore: List[str] = config.get("ignore", None)
for key, quant_config in config["config_groups"].items():
targets = quant_config.get("targets")
for target in targets:
layer_quant_details[target] = {}
layer_quant_details[target]["weight"] = quant_config.get(
"weights")
layer_quant_details[target]["input"] = quant_config.get(
"input_activations")
return cls(layer_quant_details=layer_quant_details, ignore=ignore)
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
def _get_schema(self, weight_quant: Dict, input_quant: Dict):
# TODO: Refactor as additional cases are supported
weight_bit = weight_quant.get("num_bits")
input_bit = input_quant.get("num_bits")
weight_strategy = weight_quant.get("strategy")
input_strategy = input_quant.get("strategy")
weight_symmetric = weight_quant.get("symmetric")
input_symmetric = input_quant.get("symmetric")
is_8_bits = weight_bit == input_bit == 8
is_tensor = weight_strategy == input_strategy == "tensor"
is_symmetric = weight_symmetric and input_symmetric
if is_8_bits and is_tensor and is_symmetric and \
torch.cuda.is_available():
# CompressedTensorsW8A8StaticTensor only supports CUDA path for
# now.
return CompressedTensorsW8A8StaticTensor()
raise NotImplementedError(
"Scheme not supported. Only CUDA, 8-bit static symmtetric "
"per tensor quantization is currently supported")
def get_scheme(self, layer: torch.nn.Module) -> "CompressedTensorsScheme":
# TODO: update with matching function from `compressed_tensors`
layer_type_name = None
layer_name_class = type(layer).__name__.lower()
for target in self.layer_quant_details:
if target.lower() in layer_name_class:
layer_type_name = target
break
if layer_type_name is None:
raise ValueError(f"Could not matching target for layer {layer}")
layer_quant_details: Dict[str, Any] = self.layer_quant_details.get(
layer_type_name, None)
if layer_quant_details is None:
raise ValueError(
f"Could not find quantization details for {layer}.")
return self._get_schema(weight_quant=layer_quant_details["weight"],
input_quant=layer_quant_details["input"])
class CompressedTensorsLinearMethod(LinearMethodBase):
def __init__(self, quantization_config: CompressedTensorsConfig):
self.quantization_config = quantization_config
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""
Use the CompressedTensorsScheme associated with each layer to create
the necessary parameters for the layer.
"""
weight_loader = extra_weight_attrs.get("weight_loader")
scheme = self.quantization_config.get_scheme(layer=layer)
scheme.create_weights(
layer=layer,
input_size_per_partition=input_size_per_partition,
output_partition_sizes=output_partition_sizes,
output_size=output_size,
params_dtype=params_dtype,
weight_loader=weight_loader)
layer.scheme = scheme
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None):
"""
Use the output of create_weights and the CompressedTensorsScheme
associated with the layer to apply the forward pass with the
layer input.
"""
if bias is not None:
raise ValueError("bias is not supported for this linear method")
scheme = layer.scheme
if scheme is None:
raise ValueError("A scheme must be defined for each layer")
return scheme.apply_weights(layer, x)
from .compressed_tensors_scheme import CompressedTensorsScheme # noqa: F401
from .compressed_tensors_unquantized import ( # noqa: F401
CompressedTensorsUnquantized)
from .compressed_tensors_w8a8_statictensor import ( # noqa: F401, E501
CompressedTensorsW8A8StaticTensor)
from abc import ABC, abstractmethod
import torch
__all__ = ["CompressedTensorsScheme"]
class CompressedTensorsScheme(ABC):
"""
Abstract class used to describe the weight creation and forward pass
of different quantization schemes supported by CompressedTensors.
"""
@abstractmethod
def create_weights(self, *args, **kwargs):
"""
Weight creation for the particular scheme. Inputs to this function
"""
raise NotImplementedError
@abstractmethod
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
"""
Run the forward pass for the particular scheme. This is where
scheme-specific dequant/quant steps/kernels should be applied.
:param layer: toch.nn.Module with the registered weights and
other parameters relevant to the particular scheme.
:param x: input to the layer
"""
raise NotImplementedError
from typing import Callable, List
import torch
import torch.nn.functional as F
from torch.nn import Parameter
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.utils import set_weight_attrs
__all__ = ["CompressedTensorsUnquantized"]
class CompressedTensorsUnquantized(CompressedTensorsScheme):
"""
Implements the scheme for all layers which are ignored
in the CompressedTensors config. The input and loaded weight are used
in a linear transformation.
"""
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
device="cuda",
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {"weight_loader": weight_loader})
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
weight = layer.weight
return F.linear(x, weight)
from typing import Callable, List, Tuple, Union
import torch
from torch.nn import Parameter
from vllm import _custom_ops as custom_ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme)
from vllm.model_executor.utils import set_weight_attrs
__all__ = ["CompressedTensorsW8A8StaticTensor"]
class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
def _shard_id_as_int(self, shard_id: Union[str, int]) -> int:
if isinstance(shard_id, int):
return shard_id
assert isinstance(shard_id, str)
qkv_idxs = {"q": 0, "k": 1, "v": 2}
assert shard_id in qkv_idxs
return qkv_idxs[shard_id]
def scales_shard_splitter(
self, param: torch.Tensor, loaded_weight: torch.Tensor,
shard_id: Union[str, int],
logical_widths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
shard_id = self._shard_id_as_int(shard_id)
offset = sum(logical_widths[:shard_id])
size = logical_widths[shard_id]
# update loaded weight with copies for broadcast.
loaded_weight = loaded_weight.repeat(size)
return param[offset:offset + size], loaded_weight
def create_weights(self, layer: torch.nn.Module,
output_partition_sizes: List[int],
input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable,
**kwargs):
# TODO: remove zero_point parameters once the configs given remove them
# Note on input/weight scales and zero_points
#
# When the scales have a single value, it is required that they be
# on the CPU for 2 reasons,
# 1. Performance:
# When the scales (input_scale/weight_scales) have only a single
# value, we perform a scalar broadcast of that value during the
# quant/dequant operations. The "quant" and the "gemm+dequant"
# kernels accept the Scalar by-value. These tensors are allocated
# on the CPU in order to avoid the GPU-to-CPU copy when passing
# by-value.
#
# 2. CUDA Graphs:
# CUDA Graphs don't support GPU-to-CPU copy operations during
# stream capture.
#
# TODO: zero-points are not supported yet. But we expect a similar
# pattern.
is_tensor_partitioned = len(output_partition_sizes) != 1
weight_scale_dim = sum(
output_partition_sizes) if is_tensor_partitioned else 1
weight_scale_device = "cpu" if weight_scale_dim == 1 else "cuda"
input_scale = Parameter(torch.empty(1,
device="cpu",
dtype=torch.float32),
requires_grad=False)
input_zero_point = Parameter(torch.empty(1,
device="cpu",
dtype=torch.int8),
requires_grad=False)
weight_scale = Parameter(torch.empty(weight_scale_dim,
device=weight_scale_device,
dtype=torch.float32),
requires_grad=False)
weight_zero_point = Parameter(torch.empty(1,
device="cpu",
dtype=torch.int8),
requires_grad=False)
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=torch.int8),
requires_grad=False)
layer.register_parameter("weight", weight)
set_weight_attrs(weight, {
"weight_loader": weight_loader,
"input_dim": 1,
"output_dim": 0,
})
layer.register_parameter("input_scale", input_scale)
set_weight_attrs(input_scale, {
"weight_loader": weight_loader,
"ignore_warning": True,
})
layer.register_parameter("input_zero_point", input_zero_point)
set_weight_attrs(input_zero_point, {
"weight_loader": weight_loader,
"ignore_warning": True,
})
layer.register_parameter("weight_scale", weight_scale)
set_weight_attrs(
weight_scale, {
"weight_loader": weight_loader,
"shard_splitter": self.scales_shard_splitter,
"logical_widths": output_partition_sizes,
"ignore_warning": True,
})
layer.register_parameter("weight_zero_point", weight_zero_point)
set_weight_attrs(weight_zero_point, {
"weight_loader": weight_loader,
"ignore_warning": True
})
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor):
weight = layer.weight
weight_scale = layer.weight_scale
act_scale = layer.input_scale
# Input quantize
x_q = custom_ops.static_scaled_int8_quant(x, act_scale[0].item())
return custom_ops.cutlass_scaled_mm_dq(x_q, weight.t(), act_scale,
weight_scale, x.dtype)
from typing import Any, Dict, List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
class DeepSpeedFPConfig(QuantizationConfig):
"""Config for DeepSpeed FP quantizer. It supports fp6 and fp8.
Args:
weight_bits: the target quantization bits, 6 or 8.
group_size: group size for quantizaiton, default to 128.
"""
def __init__(
self,
weight_bits: int = 8,
group_size: int = 512,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.valid_types = [torch.bfloat16, torch.float16]
if self.weight_bits not in (6, 8):
raise ValueError(
"Currently, only 6-bit or 8-bit weight quantization are "
f"supported for DeepSpeed FP quantizaiton, but got "
f"{self.weight_bits} bits.")
def __repr__(self) -> str:
return (f"DeepSpeedFPConfig(weight_bits={self.weight_bits}), "
f"group_size={self.group_size}")
@classmethod
def get_name(cls) -> str:
return "DeepSpeedFP"
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
return cls(weight_bits=weight_bits, group_size=group_size)
def get_linear_method(self) -> "DeepSpeedFPLinearMethod":
return DeepSpeedFPLinearMethod(self)
def get_scaled_act_names(self) -> List[str]:
return []
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 60
@staticmethod
def get_config_filenames() -> List[str]:
return [
"quant_config.json",
"quantize_config.json",
]
def get_quant_method(
self,
layer: torch.nn.Module) -> Optional["DeepSpeedFPLinearMethod"]:
if isinstance(layer, LinearBase):
return DeepSpeedFPLinearMethod(self)
return None
class DeepSpeedFPLinearMethod(LinearMethodBase):
"""Linear method for DeepSpeedFP quantizer.
Args:
quant_config: the DeepSpeedFP quantization config.
"""
def __init__(self, quant_config: DeepSpeedFPConfig):
self.quant_config = quant_config
self.weight = None
def create_weights(self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
weight_loader=None,
**extra_weight_attrs):
del output_size
del input_size
output_size_per_partition = sum(output_partition_sizes)
weight = DeepSpeedFPParameter(
torch.Size((output_size_per_partition, input_size_per_partition)),
params_dtype=params_dtype,
quant_config=self.quant_config,
)
set_weight_attrs(weight, {
"input_dim": 1,
"output_dim": 0,
})
layer.register_parameter("weight", weight)
def quant_weight_loader(param, loaded_weight, *args, **kwargs):
# Calls the original weight loader (if any), quantizes the result,
# and then loads the quantized parameter.
if weight_loader is not None:
orig_param_data = param.data
param.data = param.ds_dequantize()
weight_loader(param, loaded_weight, *args, **kwargs)
param.data, loaded_weight = orig_param_data, param.data
param.ds_quantize_(loaded_weight.cuda())
extra_weight_attrs["weight_loader"] = quant_weight_loader
set_weight_attrs(weight, extra_weight_attrs)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = layer.weight
y = weight.ds_dequantize()
return F.linear(x, y, bias)
class DeepSpeedFPParameter(nn.Parameter):
"""
DeepSpeedFP quantized parameter class that implements fp8/fp6
quantization deepspeed. Weights are stored in quantized form on
GPUs, and can be dequantized on-the-fly when needed by the model.
"""
def __new__(cls, orig_shape: torch.Size, params_dtype: torch.dtype,
quant_config: DeepSpeedFPConfig):
try:
import deepspeed
if deepspeed.__version__ < "0.14.2":
raise ImportError("deepspeed version is wrong. Please "
"install deepspeed>=0.14.2.")
from deepspeed.ops.fp_quantizer import FP_Quantize
except ImportError as err:
raise ImportError("Please install deepspeed>=0.14.2 via "
"`pip install deepspeed>=0.14.2` to use "
"deepspeedfp quantizer.") from err
data = torch.empty((
orig_shape.numel() // quant_config.group_size,
quant_config.group_size * quant_config.weight_bits // 8 + 4,
),
dtype=torch.int8)
self = torch.Tensor._make_subclass(cls, data, data.requires_grad)
self.orig_shape = orig_shape
self.quant_config = quant_config
self.fp_quantizer = FP_Quantize(group_size=quant_config.group_size)
self.fp_quantizer.orig_shape = orig_shape
self.fp_quantizer.orig_dtype = params_dtype
return self
def ds_quantize_(self, tensor: torch.Tensor):
assert tensor.device.type == "cuda" and tensor.dtype != torch.int8
return self.data.copy_(
self.fp_quantizer.quantize(
tensor.data,
q_bits=self.quant_config.weight_bits,
))
def ds_dequantize(self, fp_out=None) -> torch.Tensor:
"""
Return a tensor containing the dequantized weights of this parameter.
"""
assert self.data.device.type == "cuda" and self.data.dtype == torch.int8
return self.fp_quantizer.dequantize(
self.data, fp_out=fp_out, q_bits=self.quant_config.weight_bits)
def ds_selective_dequantize(self, indices, fp_out=None) -> torch.Tensor:
"""
Return a tensor where only the weights at `indices` are dequantized
(to save HBM -> SRAM bandwidth).
"""
assert self.data.device.type == "cuda" and self.data.dtype == torch.int8
return self.fp_quantizer.selective_dequantize(
self.data,
indices,
fp_out=fp_out,
q_bits=self.quant_config.weight_bits)
...@@ -8,8 +8,9 @@ from vllm import _custom_ops as ops ...@@ -8,8 +8,9 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import print_warning_once
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
...@@ -58,9 +59,13 @@ class Fp8Config(QuantizationConfig): ...@@ -58,9 +59,13 @@ class Fp8Config(QuantizationConfig):
activation_scheme=activation_scheme) activation_scheme=activation_scheme)
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module) -> Optional["Fp8LinearMethod"]: self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return Fp8LinearMethod(self) return Fp8LinearMethod(self)
if isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
...@@ -231,9 +236,14 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -231,9 +236,14 @@ class Fp8LinearMethod(LinearMethodBase):
# ops.scaled_fp8_quant supports both dynamic and static quant. # ops.scaled_fp8_quant supports both dynamic and static quant.
# If dynamic, layer.act_scale is None and x_scale computed from x. # If dynamic, layer.act_scale is None and x_scale computed from x.
# If static, layer.act_scale is scalar and x_scale set to act_scale. # If static, layer.act_scale is scalar and x_scale set to act_scale.
qinput, x_scale = ops.scaled_fp8_quant(x, layer.act_scale) qinput, x_scale = ops.scaled_fp8_quant(x,
layer.act_scale,
# Fused GEMM_DQ batch_dim_padding=17)
# Fused GEMM_DQ -- note we padded the input above because
# torch._scaled_mm is more performant for matrices with
# batch dimension > 16. Note that this could change
# in the future.
output, _ = torch._scaled_mm( output, _ = torch._scaled_mm(
qinput, qinput,
layer.weight, layer.weight,
...@@ -243,7 +253,45 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -243,7 +253,45 @@ class Fp8LinearMethod(LinearMethodBase):
bias=bias, bias=bias,
) )
return output return torch.narrow(output, 0, 0, x.shape[0])
class Fp8KVCacheMethod(QuantizeMethodBase):
"""Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module):
"""Create "weight" (aka kv_scale) for an attention layer.
Args:
layer: The layer that is using the QuantizeMethodBase factory.
"""
# Initialize the KV cache scale to 1.0 as the default value.
# If the kv_scale appears in the checkpoint, it will be
# overwritten when loading weights.
layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False)
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")
def process_weights_after_loading(self, layer: Module) -> None:
# If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
if layer.kv_cache_dtype != "auto":
kv_scale = layer.kv_scale.to("cpu").tolist()
if not isinstance(kv_scale, float):
raise ValueError("Only support per-tensor scaling factor "
"for fp8 KV cache")
layer._kv_scale = kv_scale
if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
print_warning_once(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
"cause accuracy issues. Please make sure kv-cache scaling "
"factor is available in the fp8 checkpoint.")
del layer.kv_scale
def all_close_1d(x: torch.Tensor) -> bool: def all_close_1d(x: torch.Tensor) -> bool:
......
...@@ -6,11 +6,14 @@ import torch ...@@ -6,11 +6,14 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
logger = init_logger(__name__)
GPTQ_MARLIN_TILE = 16 GPTQ_MARLIN_TILE = 16
GPTQ_MARLIN_MIN_THREAD_N = 64 GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128 GPTQ_MARLIN_MIN_THREAD_K = 128
...@@ -99,7 +102,7 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -99,7 +102,7 @@ class GPTQMarlinConfig(QuantizationConfig):
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half] return [torch.half, torch.bfloat16]
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
...@@ -117,6 +120,26 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -117,6 +120,26 @@ class GPTQMarlinConfig(QuantizationConfig):
is_sym = cls.get_from_keys(config, ["sym"]) is_sym = cls.get_from_keys(config, ["sym"])
return cls(weight_bits, group_size, desc_act, is_sym) return cls(weight_bits, group_size, desc_act, is_sym)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
can_convert = cls.is_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "marlin")
if can_convert and is_valid_user_quant:
msg = ("The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
if can_convert and user_quant == "gptq":
logger.info("Detected that the model can run with gptq_marlin"
", however you specified quantization=gptq explicitly,"
" so forcing gptq. Use quantization=gptq_marlin for"
" faster inference")
return None
def get_quant_method( def get_quant_method(
self, self,
layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]: layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
...@@ -186,9 +209,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -186,9 +209,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
group_size = input_size group_size = input_size
# Validate dtype # Validate dtype
if params_dtype != torch.float16: if params_dtype not in [torch.float16, torch.bfloat16]:
raise ValueError( raise ValueError(f"The params dtype must be float16 "
f"The params dtype must be float16, but got {params_dtype}") f"or bfloat16, but got {params_dtype}")
# Validate output_size_per_partition # Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
...@@ -275,14 +298,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -275,14 +298,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
}, },
) )
g_idx_sort_indices = Parameter( g_idx_sort_indices = torch.empty(
torch.empty( g_idx.shape,
g_idx.shape, dtype=torch.int32,
dtype=torch.int32,
),
requires_grad=False,
) )
set_weight_attrs(g_idx_sort_indices, extra_weight_attrs)
# Scales # Scales
scales = Parameter( scales = Parameter(
...@@ -333,9 +352,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -333,9 +352,9 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
layer.register_parameter("qweight", qweight) layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx) layer.register_parameter("g_idx", g_idx)
layer.register_parameter("g_idx_sort_indices", g_idx_sort_indices)
layer.register_parameter("scales", scales) layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros) layer.register_parameter("qzeros", qzeros)
layer.g_idx_sort_indices = g_idx_sort_indices
layer.workspace = workspace layer.workspace = workspace
layer.input_size_per_partition = input_size_per_partition layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition layer.output_size_per_partition = output_size_per_partition
......
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
GPTQ_MARLIN_24_TILE = 16
GPTQ_MARLIN_24_MIN_THREAD_N = 128
GPTQ_MARLIN_24_MIN_THREAD_K = 128
GPTQ_MARLIN_24_MAX_PARALLEL = 64
GPTQ_MARLIN_24_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
GPTQ_MARLIN_24_SUPPORTED_SYM = [True]
class GPTQMarlin24Config(QuantizationConfig):
"""Config class for Marlin24.
"""
def __init__(
self,
weight_bits: int,
group_size: int,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
# Verify
if self.weight_bits not in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS:
raise ValueError(
f"Marlin_24 does not support weight_bits = {self.weight_bits}. "
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_NUM_BITS} "
"are supported.")
if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
raise ValueError(
f"Marlin_24 does not support group_size = {self.group_size}. "
f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
"are supported.")
# 4 Bits packed into 32 bit datatype.
self.pack_factor = 32 // self.weight_bits
# Tile size used by marlin kernels.
self.tile_size = 16
# Min out_features dim
self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N
# Min in_features dim
self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K
# Max parallel problems to solve at once (improves large
# batch performance)
self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL
# Permutation length used by the marlin kernels.
self.perm_len = 1024
def __repr__(self) -> str:
return "Marlin24Config(weight_bits={}, group_size={})".format(
self.weight_bits, self.group_size)
@classmethod
def get_name(cls) -> str:
return "gptq_marlin_24"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlin24Config":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
return cls(weight_bits, group_size)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
is_marlin_24_format = (
hf_quant_cfg.get("checkpoint_format") == "marlin_24")
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
or user_quant == "gptq_marlin_24")
if is_marlin_24_format and is_valid_user_quant:
msg = ("The model is serialized in {} format. "
"Using {} kernel.".format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
return None
def get_quant_method(
self,
layer: torch.nn.Module) -> Optional["GPTQMarlin24LinearMethod"]:
if isinstance(layer, LinearBase):
return GPTQMarlin24LinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class GPTQMarlin24LinearMethod(LinearMethodBase):
"""Linear method for Marlin24.
Args:
quant_config: The Marlin24 quantization config.
"""
def __init__(self, quant_config: GPTQMarlin24Config):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del output_size # Unused.
if params_dtype != torch.float16:
raise ValueError(
f"The params dtype must be float16, but got {params_dtype}")
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.min_n_threads != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f"min_n_threads = {self.quant_config.min_n_threads}.")
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f"pack_factor = {self.quant_config.pack_factor}.")
# Validate input_size_per_partition
if input_size_per_partition % self.quant_config.min_k_threads != 0:
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"min_k_threads = {self.quant_config.min_k_threads}.")
if (self.quant_config.group_size != -1 and
input_size_per_partition % self.quant_config.group_size != 0):
raise ValueError(f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"group_size = {self.quant_config.group_size}.")
# Check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm = self.quant_config.perm_len // (
self.quant_config.tile_size**2)
if output_size_per_partition % num_tiles_per_perm != 0:
raise ValueError(
"Each permutation group must reside on the same gpu")
# Quantized 4Bit weights packed into Int32.
qweight = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.tile_size // 2,
output_size_per_partition * self.quant_config.tile_size //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
"marlin_tile_size": self.quant_config.tile_size,
},
)
# Meta
meta = Parameter(
torch.empty(
input_size_per_partition // 8 // 2 // 2,
output_size_per_partition * 2,
device="cuda",
dtype=torch.int16,
),
requires_grad=False,
)
set_weight_attrs(
meta,
{
"input_dim": 0,
"packed_dim": 1,
"pack_factor": 1,
"output_dim": 1,
"marlin_tile_size": 2,
},
)
# Determine if channelwise or not
input_groups = (1 if self.quant_config.group_size == -1 else
input_size_per_partition //
self.quant_config.group_size)
scales = Parameter(
torch.empty(
input_groups,
output_size_per_partition,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
"input_dim": None if input_groups == 1 else 0,
"output_dim": 1,
},
)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size = (
output_size_per_partition //
self.quant_config.min_n_threads) * self.quant_config.max_parallel
workspace = Parameter(torch.zeros(max_workspace_size,
device="cuda",
dtype=torch.int),
requires_grad=False)
layer.register_parameter("B_24", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("B_meta", meta)
set_weight_attrs(meta, extra_weight_attrs)
layer.register_parameter("s", scales)
set_weight_attrs(scales, extra_weight_attrs)
layer.register_parameter("workspace", workspace)
set_weight_attrs(workspace, extra_weight_attrs)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qweight = layer.B_24
meta = layer.B_meta
scales = layer.s
workspace = layer.workspace
x_2d = x.view(-1, x.shape[-1])
size_m = x_2d.shape[0]
size_k = x_2d.shape[1]
size_n = scales.shape[1]
output_2d = ops.gptq_marlin_24_gemm(x_2d, qweight, meta, scales,
workspace,
self.quant_config.weight_bits,
size_m, size_n, size_k)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
if bias is not None:
output.add_(bias) # In-place add
return output
...@@ -4,11 +4,14 @@ import torch ...@@ -4,11 +4,14 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
class MarlinConfig(QuantizationConfig): class MarlinConfig(QuantizationConfig):
"""Config class for Marlin. """Config class for Marlin.
...@@ -72,6 +75,25 @@ class MarlinConfig(QuantizationConfig): ...@@ -72,6 +75,25 @@ class MarlinConfig(QuantizationConfig):
group_size = cls.get_from_keys(config, ["group_size"]) group_size = cls.get_from_keys(config, ["group_size"])
return cls(group_size) return cls(group_size)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_marlin_format: bool
is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin"
or hf_quant_cfg.get("is_marlin_format", False))
is_valid_user_quant = (user_quant is None or user_quant == "gptq"
or user_quant == "marlin")
if is_marlin_format and is_valid_user_quant:
msg = ("The model is serialized in {} format. Using {} kernel.".
format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
return None
def get_quant_method( def get_quant_method(
self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]: self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]:
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
......
#
# Modified by Roberto Lopez Castro (roberto.lopez.castro@udc.es).
#
import torch
# This is PyTorch implementation of main part of reorder_meta()
# function, from tools/util/include/cutlass/util/host_reorder.h file
# of CUTLASS source tree. Furthermore, CUTLASS template for sparse
# GEMM decides upon layout of this matrix, and at the moment for the
# sparse GEMM executed on tensor cores, this is layout described by
# ColumnMajorInterleaved<2> data structure, in
# include/cutlass/layout/matrix.h of CUTLASS source tree. The
# reordering of meta matrix into meta_reordered matrix calculated
# according to these segments of CUTLASS code is re-implemented here.
# Note that this calculation produces offsets for scattering metadata
# matrix elements into reordered metadata matrix elements (or,
# equivalently, for gathering reordered metadata matrix element back
# into metadata matrix elements).
def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype,
device):
dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
# Reorder the rows, then swizzle the 2x2 blocks.
group_x = 64
group_y = 32 if meta_dtype.itemsize == 2 else 16
dst_rows = (dst_rows // group_x * group_x + (dst_rows % 2) * 2 +
(dst_rows % 8) // 4 + ((dst_rows % group_y) % 4) // 2 * 32 +
((dst_rows % group_x) // 8) * 4)
topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
dst_rows += topright - bottomleft
dst_cols -= topright - bottomleft
# Assumed that meta tensor is to be stored in CUTLASS
# InterleavedColumnMajor layout, and reverse engineered
# corresponding code to store values into this tensor.
interleave = 2
cols_maj = dst_cols // interleave
cols_min = dst_cols % interleave
return (cols_maj * m * interleave + dst_rows * interleave +
cols_min).view(-1)
# This function converts dense matrix into sparse semi-structured
# representation, producing "compressed" matrix, in the layout used by
# CUTLASS backend, and corresponding metadata matrix.
def sparse_semi_structured_from_dense_cutlass(dense):
if dense.dim() != 2:
raise RuntimeError(
f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501
)
m, k = dense.shape
device = dense.device
meta_dtype = torch.int8
if dense.dtype == torch.int8:
meta_dtype = torch.int32
elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
meta_dtype = torch.int16
else:
raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
if quadbits_per_meta_elem not in (4, 8):
raise RuntimeError(
"Invalid number of elements per meta element calculated")
if meta_dtype == torch.int32:
if m % 16 != 0:
raise RuntimeError(
f"Number of rows of dense matrix {m} must be divisible by 16")
else:
if m % 32 != 0:
raise RuntimeError(
f"Number of rows of dense matrix {m} must be divisible by 32")
if k % (4 * quadbits_per_meta_elem) != 0:
raise RuntimeError(
f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501
)
if dense.dtype != torch.float:
ksparse = 4
dense_4 = dense.view(-1, k // ksparse, ksparse)
m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
else:
ksparse = 2
dense_2 = dense.view(-1, k // ksparse, ksparse)
m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
meta_ncols = k // (ksparse * quadbits_per_meta_elem)
# Encoding quadruples of True/False values as follows:
# [True, True, False, False] -> 0b0100
# [True, False, True, False] -> 0b1000
# [False, True, True, False] -> 0b1001
# [True, False, False, True ] -> 0b1100
# [False, True, False, True ] -> 0b1101
# [False, False, True, True ] -> 0b1110
# Thus, lower two bits in the encoding are index of the True value
# at the lowest index in the quadruple, and the higher two bits in
# the encoding are index of the other True value in the quadruple.
# In case there are less than two True values, than False value or
# values at some index or indices are considered True for the
# encoding. In case there are more than two True values, then the
# excess True value(s) at some indices are considered False for
# the encoding. The exact encodings used for these cases are as
# follows:
# [False, False, False, False] -> 0b1110
# [False, False, False, True ] -> 0b1110
# [False, False, True, False] -> 0b1110
# [False, True, False, False] -> 0b1001
# [False, True, True, True ] -> 0b1101
# [True, False, False, False] -> 0b1000
# [True, False, True, True ] -> 0b1100
# [True, True, False, True ] -> 0b0100
# [True, True, True, False] -> 0b0100
# [True, True, True, True ] -> 0b0100
# These particular encodings are chosen, with the help of Espresso
# logic minimizer software, for the purpose of minimization of
# corresponding Boolean functions, that translate non-zero flags
# into encoding bits. Note also possible choices for the first
# and last of these encodings were limited only to (0b0100,
# 0b1110), in order to produce valid encodings for 1:2 sparsity
# case.
expr0 = m0 & m1
expr1 = ~m0 & m1
expr2 = ~m0 & ~m1
bit0 = expr1
bit1 = expr2
bit2 = expr0 | expr2 | m3
bit3 = expr1 | ~m1
idxs0 = bit0 | (bit1.to(torch.int64) << 1)
idxs1 = bit2 | (bit3.to(torch.int64) << 1)
if dense.dtype != torch.float:
sparse0 = dense_4.gather(
-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined]
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
else:
sparse = dense_2.gather(-1,
idxs0.unsqueeze(-1) // 2).view(
m,
k // 2) # type: ignore[possibly-undefined]
meta_4 = idxs0 | (idxs1 << 2)
meta_n = meta_4.view(
(-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
if quadbits_per_meta_elem == 4:
meta = (meta_n[:, :, 0]
| (meta_n[:, :, 1] << 4)
| (meta_n[:, :, 2] << 8)
| (meta_n[:, :, 3] << 12))
elif quadbits_per_meta_elem == 8:
meta = (meta_n[:, :, 0]
| (meta_n[:, :, 1] << 4)
| (meta_n[:, :, 2] << 8)
| (meta_n[:, :, 3] << 12)
| (meta_n[:, :, 4] << 16)
| (meta_n[:, :, 5] << 20)
| (meta_n[:, :, 6] << 24)
| (meta_n[:, :, 7] << 28))
# Reorder meta tensor elements.
meta_reordered = meta.new_empty(
(m * meta_ncols, )) # type: ignore[possibly-undefined]
meta_offsets = _calculate_meta_reordering_scatter_offsets(
m, meta_ncols, meta_dtype, device)
meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
return (sparse, meta_reordered.view(m, meta_ncols))
# This function performs reverse of the function above - it
# reconstructs dense matrix from a pair of "compressed" matrix, given
# in the layout used by CUTLASS backend, and accompanying metadata
# matrix.
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
if sparse.dim() != 2:
raise RuntimeError(
f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501
)
m, k = sparse.shape
device = sparse.device
if meta_reordered.dim() != 2:
raise RuntimeError(
f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501
)
if meta_reordered.device != device:
raise RuntimeError(
f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501
)
meta_dtype = meta_reordered.dtype
if meta_dtype not in (torch.int16, torch.int32):
raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
ksparse = 4 if sparse.dtype != torch.float else 2
meta_nrows, meta_ncols = meta_reordered.shape
if meta_nrows != m:
raise RuntimeError(
f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501
)
if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
raise RuntimeError(
f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501
"expected according to the number of columns of meta matrix")
# Undo meta tensor elements reordering.
meta_offsets = _calculate_meta_reordering_scatter_offsets(
m, meta_ncols, meta_dtype, device)
meta = torch.gather(meta_reordered.view(-1), 0,
meta_offsets).view(m, meta_ncols)
# Unpack sparse tensor back to original dense tensor, using
# information provided by meta tensor. Note that torch.float
# datatype is handled pretty much the same as
# torch.half/torch.bfloat16, as metadata for a pair of torch.float
# value is encoded as if underlying 8 bytes contain four
# torch.half/torch.bfloat16 values, where either first two or last
# two are zeros.
meta_2 = torch.empty(
(m, meta_ncols, 2 * quadbits_per_meta_elem),
dtype=meta_dtype,
device=device,
)
if quadbits_per_meta_elem == 4:
meta_2[:, :, 0] = meta & 0b11
meta_2[:, :, 1] = (meta >> 2) & 0b11
meta_2[:, :, 2] = (meta >> 4) & 0b11
meta_2[:, :, 3] = (meta >> 6) & 0b11
meta_2[:, :, 4] = (meta >> 8) & 0b11
meta_2[:, :, 5] = (meta >> 10) & 0b11
meta_2[:, :, 6] = (meta >> 12) & 0b11
meta_2[:, :, 7] = (meta >> 14) & 0b11
elif quadbits_per_meta_elem == 8:
meta_2[:, :, 0] = meta & 0b11
meta_2[:, :, 1] = (meta >> 2) & 0b11
meta_2[:, :, 2] = (meta >> 4) & 0b11
meta_2[:, :, 3] = (meta >> 6) & 0b11
meta_2[:, :, 4] = (meta >> 8) & 0b11
meta_2[:, :, 5] = (meta >> 10) & 0b11
meta_2[:, :, 6] = (meta >> 12) & 0b11
meta_2[:, :, 7] = (meta >> 14) & 0b11
meta_2[:, :, 8] = (meta >> 16) & 0b11
meta_2[:, :, 9] = (meta >> 18) & 0b11
meta_2[:, :, 10] = (meta >> 20) & 0b11
meta_2[:, :, 11] = (meta >> 22) & 0b11
meta_2[:, :, 12] = (meta >> 24) & 0b11
meta_2[:, :, 13] = (meta >> 26) & 0b11
meta_2[:, :, 14] = (meta >> 28) & 0b11
meta_2[:, :, 15] = (meta >> 30) & 0b11
dense_offsets = meta_2.view(-1) + (
torch.arange(0, 2 * m * k // ksparse, device=device) * 4).view(
-1, 1).repeat(1, 2).view(-1)
dense = torch.zeros((m * 2 * k, ), dtype=sparse.dtype, device=device)
if sparse.dtype != torch.float:
# dense.scatter_(0, dense_offsets, sparse.view(-1))
dense.scatter_(0, dense_offsets, sparse.reshape(-1))
else:
dense.view(torch.half).scatter_(0, dense_offsets,
sparse.view(torch.half).view(-1))
return dense.view(m, 2 * k)
def mask_creator(tensor):
"""
Class for creating N:M sparsity masks.
Masks will be created using the N:M ratio, where for every block of
M weights, N will be pruned based on ranked weight value. Each mask
will correspond to the given tensor.
:param N: The number of weights in a group to keep
:param M: The size of a weight group
"""
N = 2
M = 4
mask = None
# for i, tensor in enumerate(tensors):
if tensor.numel() % M != 0:
raise ValueError(
f"Tensor of size {tensor.shape} can't be evenly divided into "
f"{M} groups")
num_groups = tensor.numel() // M
# N:M sparsity for linear layers
tensor_temp = tensor.detach().abs().reshape(num_groups, M)
index = torch.argsort(tensor_temp, dim=1)[:, :int(M - N)]
w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
return mask
"""This file is used for /tests and /benchmarks"""
import numpy
import torch
# Precompute permutations for Marlin24 weight and scale shuffling # noqa: E501
#
# Marlin works on [16*2,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
# (without the need to use ldmatrix instructions) # noqa: E501
def get_perms_24(num_bits):
perm_list = []
for i in range(32):
perm1 = []
col = i // 4
col_o = col // 2
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col_o * 256 + 8 * (col % 2) +
4 * block)
for j in range(4):
perm_list.extend([p + 1 * j for p in perm1])
perm = numpy.array(perm_list)
if num_bits == 4:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3])
else:
raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
perm = torch.from_numpy(perm)
scale_perm = []
for i in range(8):
scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
scale_perm_single = []
for i in range(8):
scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
return perm, scale_perm, scale_perm_single
marlin_24_perm = {}
marlin_24_scale_perm = {}
marlin_24_scale_perm_single = {}
for num_bits in [4, 8]:
perm_24, scale_perm_24, scale_perm_single_24 = get_perms_24(num_bits)
marlin_24_perm[num_bits] = perm_24
marlin_24_scale_perm[num_bits] = scale_perm_24
marlin_24_scale_perm_single[num_bits] = scale_perm_single_24
"""This file is used for /tests and /benchmarks"""
import numpy
import torch
# Precompute permutations for Marlin weight and scale shuffling # noqa: E501
#
# Marlin works on [16,64] tiles. The goal of the permutations is to reorder the weight data so that it is compatible noqa: # noqa: E501
# with the tensor-core format that is described here:
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501
#
# As a result of this reordering, the vector loads inside the kernel will get the data as it is needed for tensor-core # noqa: E501
# (without the need to use ldmatrix instructions) # noqa: E501
def get_perms(num_bits):
perm_list = []
for i in range(32):
perm1 = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm_list.extend([p + 256 * j for p in perm1])
perm = numpy.array(perm_list)
if num_bits == 4:
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
elif num_bits == 8:
interleave = numpy.array([0, 2, 1, 3])
else:
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
perm = torch.from_numpy(perm)
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return perm, scale_perm, scale_perm_single
marlin_perm = {}
marlin_scale_perm = {}
marlin_scale_perm_single = {}
for num_bits in [4, 8]:
perm, scale_perm, scale_perm_single = get_perms(num_bits)
marlin_perm[num_bits] = perm
marlin_scale_perm[num_bits] = scale_perm
marlin_scale_perm_single[num_bits] = scale_perm_single
"""This file is used for /tests and /benchmarks"""
import random
import numpy
import torch
from vllm.model_executor.layers.quantization.utils.format_24 import (
mask_creator, sparse_semi_structured_from_dense_cutlass)
from vllm.model_executor.layers.quantization.utils.marlin_24_perms import (
marlin_24_perm, marlin_24_scale_perm, marlin_24_scale_perm_single)
from vllm.model_executor.layers.quantization.utils.marlin_perms import (
marlin_perm, marlin_scale_perm, marlin_scale_perm_single)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_pack_factor, quantize_weights, sort_weights)
__cuda_arch = torch.cuda.get_device_capability()
MARLIN_TILE = 16
def is_marlin_supported():
return __cuda_arch[0] >= 8
def marlin_permute_weights(q_w, size_k, size_n, perm, tile=MARLIN_TILE):
assert q_w.shape == (size_k, size_n)
assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}"
assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}"
# Permute weights to 16x64 marlin tiles
q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile))
q_w = q_w.permute((0, 2, 1, 3))
q_w = q_w.reshape((size_k // tile, size_n * tile))
q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape)
return q_w
def marlin_weights(q_w, size_k, size_n, num_bits, perm):
# Permute
q_w = marlin_permute_weights(q_w, size_k, size_n, perm)
# Pack
pack_factor = get_pack_factor(num_bits)
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(numpy.uint32)
q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor),
dtype=numpy.uint32)
for i in range(pack_factor):
q_packed |= q_w[:, i::pack_factor] << num_bits * i
q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device)
return q_packed
def marlin_permute_scales(s, size_k, size_n, group_size, scale_perm,
scale_perm_single):
if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
s = s.reshape((-1, size_n)).contiguous()
return s
def marlin_quantize(
w: torch.Tensor,
num_bits: int,
group_size: int,
act_order: bool,
):
size_k, size_n = w.shape
# Normalize group_size
if group_size == -1:
group_size = size_k
assert group_size <= size_k
# Quantize (and apply act_order if provided)
w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,
act_order)
# For act_order, sort the "weights" and "g_idx" so that group ids are
# increasing
sort_indices = torch.empty(0, dtype=torch.int, device=w.device)
if act_order:
q_w, g_idx, sort_indices = sort_weights(q_w, g_idx)
# Reformat to marlin
marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits,
marlin_perm[num_bits])
marlin_s = marlin_permute_scales(s, size_k, size_n, group_size,
marlin_scale_perm[num_bits],
marlin_scale_perm_single[num_bits])
# Create result
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
for i in range(len(res_list)):
res_list[i] = res_list[i].to(w.device)
return res_list
def inject_24(w, size_k, size_n):
assert w.shape == (size_k, size_n)
mask = mask_creator(w.t()).t().cuda().bool()
return (mask * w).contiguous(), mask.contiguous()
def check_24(w, num_rows_to_sample=50, _verbose=False):
BLOCK_SIZE = 4
MAX_NON_ZEROS = 2
w = w.t().contiguous()
print("check_24: w.shape = {}".format(w.shape))
num_rows, num_cols = w.shape
sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)
if _verbose:
print(f"Sampled row idxs = {sampled_row_idxs}")
total_segments = 0
non_24_segments = 0
for i in sampled_row_idxs:
for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
total_segments += 1
block = w[i, j:j + BLOCK_SIZE]
num_nonzero = torch.count_nonzero(block)
if num_nonzero > MAX_NON_ZEROS:
print("i = {} j = {} block = {}".format(i, j, block))
non_24_segments += 1
print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
def compress_quantized_24_weight(q_24, size_k, size_n, num_bits):
assert q_24.shape == (size_k, size_n)
# Remove zp to normalize over 0
max_q_val = (1 << num_bits) - 1
zp = (max_q_val + 1) // 2
q_24_no_zp = q_24 - zp
# Compress
q_24_no_zp = q_24_no_zp.t().contiguous()
q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(
q_24_no_zp)
q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
# Restore zp
q_24_comp = q_24_no_zp_comp + zp
# Resize meta to its actual shape (without moving any data)
meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
return q_24_comp, meta
def marlin_24_quantize(
w: torch.Tensor,
num_bits: int,
group_size: int,
):
size_k, size_n = w.shape
# Normalize group_size
if group_size == -1:
group_size = size_k
assert group_size <= size_k
# Inject 2:4 sparsity
w_24, mask_24 = inject_24(w, size_k, size_n)
# Quantize
w_24_ref, q_w_24, s, g_idx, rand_perm = quantize_weights(w_24,
num_bits,
group_size,
act_order=False)
# Compress quantized weight
q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n,
num_bits)
size_k_comp = size_k // 2
# Reformat to marlin
marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n,
num_bits, marlin_24_perm[num_bits])
marlin_24_s = marlin_permute_scales(s, size_k, size_n, group_size,
marlin_24_scale_perm[num_bits],
marlin_24_scale_perm_single[num_bits])
# Create result
res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]
for i in range(len(res_list)):
res_list[i] = res_list[i].to(w.device)
return res_list
def compute_max_diff(output, output_ref):
return torch.mean(torch.abs(output - output_ref)) / torch.mean(
torch.abs(output_ref))
class MarlinWorkspace:
def __init__(self, out_features, min_thread_n, max_parallel):
assert (out_features % min_thread_n == 0), (
"out_features = {} is undivisible by min_thread_n = {}".format(
out_features, min_thread_n))
max_workspace_size = ((out_features // min_thread_n) * max_parallel)
self.scratch = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment