Unverified Commit d521ba57 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Quantized KV cache: update quanto (#31052)



* quanto latest version was refactored

* add error msg

* incorrect compare sign

* Update src/transformers/cache_utils.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent a564d10a
import copy import copy
import importlib.metadata
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from packaging import version
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .utils import is_hqq_available, is_quanto_available, logging from .utils import is_hqq_available, is_quanto_available, logging
if is_quanto_available(): if is_quanto_available():
from quanto import QBitsTensor, qint2, qint4 quanto_version = version.parse(importlib.metadata.version("quanto"))
if quanto_version >= version.parse("0.2.0"):
from quanto import AffineQuantizer, MaxOptimizer, qint2, qint4
if is_hqq_available(): if is_hqq_available():
from hqq.core.quantize import Quantizer as HQQQuantizer from hqq.core.quantize import Quantizer as HQQQuantizer
...@@ -488,6 +492,13 @@ class QuantoQuantizedCache(QuantizedCache): ...@@ -488,6 +492,13 @@ class QuantoQuantizedCache(QuantizedCache):
def __init__(self, cache_config: CacheConfig) -> None: def __init__(self, cache_config: CacheConfig) -> None:
super().__init__(cache_config) super().__init__(cache_config)
quanto_version = version.parse(importlib.metadata.version("quanto"))
if quanto_version < version.parse("0.2.0"):
raise ImportError(
f"You need quanto package version to be greater or equal than 0.2.0 to use `QuantoQuantizedCache`. Detected version {quanto_version}. "
f"Please upgrade quanto with `pip install -U quanto`"
)
if self.nbits not in [2, 4]: if self.nbits not in [2, 4]:
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}") raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
...@@ -500,9 +511,11 @@ class QuantoQuantizedCache(QuantizedCache): ...@@ -500,9 +511,11 @@ class QuantoQuantizedCache(QuantizedCache):
) )
self.qtype = qint4 if self.nbits == 4 else qint2 self.qtype = qint4 if self.nbits == 4 else qint2
self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization
def _quantize(self, tensor, axis): def _quantize(self, tensor, axis):
qtensor = QBitsTensor.quantize(tensor, axis=axis, qtype=self.qtype, group_size=self.q_group_size) scale, zeropoint = self.optimizer(tensor, self.qtype.bits, axis, self.q_group_size)
qtensor = AffineQuantizer.apply(tensor, self.qtype, axis, self.q_group_size, scale, zeropoint)
return qtensor return qtensor
def _dequantize(self, qtensor): def _dequantize(self, qtensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment