Unverified Commit c7cb5c33 authored by Kyle Sayers's avatar Kyle Sayers Committed by GitHub
Browse files

[Misc] GPTQ Activation Ordering (#8135)

parent f9b4a2d4
...@@ -21,6 +21,7 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main ...@@ -21,6 +21,7 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main
compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main
compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main
awq, casperhansen/mixtral-instruct-awq, main awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
......
...@@ -232,7 +232,8 @@ class CompressedTensorsConfig(QuantizationConfig): ...@@ -232,7 +232,8 @@ class CompressedTensorsConfig(QuantizationConfig):
return CompressedTensorsWNA16( return CompressedTensorsWNA16(
num_bits=weight_quant.num_bits, num_bits=weight_quant.num_bits,
strategy=weight_quant.strategy, strategy=weight_quant.strategy,
group_size=weight_quant.group_size) group_size=weight_quant.group_size,
actorder=weight_quant.actorder)
# Detect If Activation Quantization. # Detect If Activation Quantization.
# TODO @dsikka: clean-up conditions # TODO @dsikka: clean-up conditions
......
...@@ -5,14 +5,18 @@ import torch ...@@ -5,14 +5,18 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme) CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
ActivationOrdering)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace,
marlin_permute_scales, replace_tensor, verify_marlin_supported, marlin_permute_scales, marlin_repeat_scales_on_all_ranks,
marlin_sort_g_idx, replace_tensor, verify_marlin_supported,
verify_marlin_supports_shape) verify_marlin_supports_shape)
from vllm.model_executor.parameter import (BasevLLMParameter, from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter, ChannelQuantScaleParameter,
GroupQuantScaleParameter, GroupQuantScaleParameter,
PackedvLLMParameter) PackedvLLMParameter,
RowvLLMParameter)
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
__all__ = ["CompressedTensorsWNA16"] __all__ = ["CompressedTensorsWNA16"]
...@@ -28,11 +32,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -28,11 +32,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
def __init__(self, def __init__(self,
strategy: str, strategy: str,
num_bits: int, num_bits: int,
group_size: Optional[int] = None): group_size: Optional[int] = None,
actorder: Optional[ActivationOrdering] = None):
self.pack_factor = 32 // num_bits self.pack_factor = 32 // num_bits
self.strategy = strategy self.strategy = strategy
self.group_size = -1 if group_size is None else group_size self.group_size = -1 if group_size is None else group_size
self.has_g_idx = actorder == ActivationOrdering.GROUP
if self.group_size == -1 and self.strategy != "channel": if self.group_size == -1 and self.strategy != "channel":
raise ValueError("Marlin kernels require group quantization or " raise ValueError("Marlin kernels require group quantization or "
...@@ -64,12 +70,10 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -64,12 +70,10 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
# If group_size is -1, we are in channelwise case. # If group_size is -1, we are in channelwise case.
channelwise = (self.group_size == -1)
group_size = self.group_size if self.group_size != -1 else input_size group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = (input_size != input_size_per_partition) row_parallel = (input_size != input_size_per_partition)
# In the case of channelwise quantization, we need to replicate the partition_scales = not marlin_repeat_scales_on_all_ranks(
# scales across all gpus. self.has_g_idx, self.group_size, row_parallel)
partition_scales = (row_parallel and not channelwise)
verify_marlin_supports_shape( verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition, output_size_per_partition=output_size_per_partition,
...@@ -123,6 +127,16 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -123,6 +127,16 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight_shape", weight_shape) layer.register_parameter("weight_shape", weight_shape)
# group index (for activation reordering)
if self.has_g_idx:
weight_g_idx = RowvLLMParameter(data=torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
layer.register_parameter("weight_g_idx", weight_g_idx)
layer.input_size_per_partition = input_size_per_partition layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size layer.input_size = input_size
...@@ -137,8 +151,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -137,8 +151,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
layer.workspace = marlin_make_workspace( layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device) layer.output_size_per_partition, device)
# Act-order not supported in compressed-tensors yet, so set to empty. # Handle sorting for activation reordering if needed.
layer.g_idx = marlin_make_empty_g_idx(device) if self.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.weight_g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
replace_tensor(layer, "weight_g_idx", g_idx)
else:
layer.weight_g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
# No zero-point # No zero-point
...@@ -159,9 +178,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -159,9 +178,11 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
replace_tensor(layer, "weight_packed", marlin_qweight) replace_tensor(layer, "weight_packed", marlin_qweight)
# Permute scales from compressed-tensors format to marlin format. # Permute scales from compressed-tensors format to marlin format.
# scale is required on all partitions if activation reordering
marlin_scales = marlin_permute_scales( marlin_scales = marlin_permute_scales(
layer.weight_scale, layer.weight_scale,
size_k=layer.input_size_per_partition, size_k=(layer.input_size
if self.has_g_idx else layer.input_size_per_partition),
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
group_size=layer.group_size) group_size=layer.group_size)
replace_tensor(layer, "weight_scale", marlin_scales) replace_tensor(layer, "weight_scale", marlin_scales)
...@@ -174,7 +195,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -174,7 +195,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
weight=layer.weight_packed, weight=layer.weight_packed,
weight_scale=layer.weight_scale, weight_scale=layer.weight_scale,
weight_zp=layer.weight_zp, weight_zp=layer.weight_zp,
g_idx=layer.g_idx, g_idx=layer.weight_g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices, g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace, workspace=layer.workspace,
wtype=self.quant_type, wtype=self.quant_type,
......
import re import re
from enum import Enum from enum import Enum
from typing import Any, Dict, Iterable, Optional from typing import Any, Dict, Iterable, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, field_validator
from torch.nn import Module from torch.nn import Module
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
...@@ -40,6 +40,19 @@ class QuantizationStrategy(str, Enum): ...@@ -40,6 +40,19 @@ class QuantizationStrategy(str, Enum):
TOKEN = "token" TOKEN = "token"
class ActivationOrdering(str, Enum):
"""
Enum storing strategies for activation ordering
Group: reorder groups and weight\n
Weight: only reorder weight, not groups. Slightly lower latency and
accuracy compared to group actorder\n
"""
GROUP = "group"
WEIGHT = "weight"
class QuantizationArgs(BaseModel): class QuantizationArgs(BaseModel):
""" """
User facing arguments used to define a quantization config User facing arguments used to define a quantization config
...@@ -58,6 +71,8 @@ class QuantizationArgs(BaseModel): ...@@ -58,6 +71,8 @@ class QuantizationArgs(BaseModel):
observed with every sample. Defaults to False for static observed with every sample. Defaults to False for static
quantization. Note that enabling dynamic quantization quantization. Note that enabling dynamic quantization
will change the default observer to a memoryless one will change the default observer to a memoryless one
:param actorder: whether to apply group quantization in decreasing order of
activation. Defaults to None for arbitrary ordering
""" """
num_bits: int = 8 num_bits: int = 8
...@@ -67,6 +82,7 @@ class QuantizationArgs(BaseModel): ...@@ -67,6 +82,7 @@ class QuantizationArgs(BaseModel):
strategy: Optional[QuantizationStrategy] = None strategy: Optional[QuantizationStrategy] = None
block_structure: Optional[str] = None block_structure: Optional[str] = None
dynamic: bool = False dynamic: bool = False
actorder: Union[ActivationOrdering, bool, None] = None
observer: str = Field( observer: str = Field(
default="minmax", default="minmax",
description=("The class to use to compute the quantization param - " description=("The class to use to compute the quantization param - "
...@@ -79,6 +95,16 @@ class QuantizationArgs(BaseModel): ...@@ -79,6 +95,16 @@ class QuantizationArgs(BaseModel):
"Observers constructor excluding quantization range or symmetry"), "Observers constructor excluding quantization range or symmetry"),
) )
@field_validator("actorder", mode="before")
def validate_actorder(cls, value) -> Optional[ActivationOrdering]:
if isinstance(value, bool):
return ActivationOrdering.GROUP if value else None
if isinstance(value, str):
return ActivationOrdering(value.lower())
return value
def is_activation_quantization_format(format: str) -> bool: def is_activation_quantization_format(format: str) -> bool:
_ACTIVATION_QUANTIZATION_FORMATS = [ _ACTIVATION_QUANTIZATION_FORMATS = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment