Unverified Commit d521ba57 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Quantized KV cache: update quanto (#31052)



* quanto latest version was refactored

* add error msg

* incorrect compare sign

* Update src/transformers/cache_utils.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent a564d10a
import copy
import importlib.metadata
import json
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from packaging import version
from .configuration_utils import PretrainedConfig
from .utils import is_hqq_available, is_quanto_available, logging
if is_quanto_available():
from quanto import QBitsTensor, qint2, qint4
quanto_version = version.parse(importlib.metadata.version("quanto"))
if quanto_version >= version.parse("0.2.0"):
from quanto import AffineQuantizer, MaxOptimizer, qint2, qint4
if is_hqq_available():
from hqq.core.quantize import Quantizer as HQQQuantizer
......@@ -488,6 +492,13 @@ class QuantoQuantizedCache(QuantizedCache):
def __init__(self, cache_config: CacheConfig) -> None:
super().__init__(cache_config)
quanto_version = version.parse(importlib.metadata.version("quanto"))
if quanto_version < version.parse("0.2.0"):
raise ImportError(
f"You need quanto package version to be greater or equal than 0.2.0 to use `QuantoQuantizedCache`. Detected version {quanto_version}. "
f"Please upgrade quanto with `pip install -U quanto`"
)
if self.nbits not in [2, 4]:
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
......@@ -500,9 +511,11 @@ class QuantoQuantizedCache(QuantizedCache):
)
self.qtype = qint4 if self.nbits == 4 else qint2
self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization
def _quantize(self, tensor, axis):
qtensor = QBitsTensor.quantize(tensor, axis=axis, qtype=self.qtype, group_size=self.q_group_size)
scale, zeropoint = self.optimizer(tensor, self.qtype.bits, axis, self.q_group_size)
qtensor = AffineQuantizer.apply(tensor, self.qtype, axis, self.q_group_size, scale, zeropoint)
return qtensor
def _dequantize(self, qtensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment