Unverified Commit d583f131 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Quantized KV Cache (#30483)



* clean-up

* Update src/transformers/cache_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/cache_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/cache_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fixup

* Update tests/quantization/quanto_integration/test_quanto.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/generation/configuration_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* more suggestions

* mapping if torch available

* run tests & add 'support_quantized' flag

* fix jamba test

* revert, will be fixed by another PR

* codestyle

* HQQ and versatile cache classes

* final update

* typo

* make tests happy

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent e05baad8
......@@ -51,6 +51,9 @@ RUN python3 -m pip install --no-cache-dir gguf
# Some slow tests require bnb
RUN python3 -m pip install --no-cache-dir bitsandbytes
# Some tests require quanto
RUN python3 -m pip install --no-cache-dir quanto
# For `dinat` model
# The `XXX` part in `torchXXX` needs to match `PYTORCH` (to some extent)
RUN python3 -m pip install --no-cache-dir natten==0.15.1+torch220$CUDA -f https://shi-labs.com/natten/wheels
......
......@@ -174,6 +174,43 @@ An increasing sequence: one, two, three, four, five, six, seven, eight, nine, te
```
## KV Cache Quantization
The `generate()` method supports caching keys and values to enhance efficiency and avoid re-computations. However the key and value
cache can occupy a large portion of memory, becoming a bottleneck for long-context generation, especially for Large Language Models.
Quantizing the cache when using `generate()` can significantly reduce memory requirements at the cost of speed.
KV Cache quantization in `transformers` is largely inspired by the paper [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache]
(https://arxiv.org/abs/2402.02750) and currently supports `quanto` and `HQQ` as backends. For more information on the inner workings see the paper.
To enable quantization of the key-value cache, one needs to indicate `cache_implementation="quantized"` in the `generation_config`.
Quantization related arguments should be passed to the `generation_config` either as a `dict` or an instance of a [`QuantizedCacheConfig`] class.
One has to indicate which quantization backend to use in the [`QuantizedCacheConfig`], the default is `quanto`.
<Tip warning={true}>
Cache quantization can be detrimental if the context length is short and there is enough GPU VRAM available to run without cache quantization.
</Tip>
```python
>>> import torch
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to("cuda:0")
>>> inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device)
>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="quantized", cache_config={"nbits": 4, "backend": "quanto"})
>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])
I like rock music because it's loud and energetic. It's a great way to express myself and rel
>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20)
>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])
I like rock music because it's loud and energetic. I like to listen to it when I'm feeling
```
## Watermarking
The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green".
......
......@@ -360,6 +360,12 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] Cache
- update
[[autodoc]] CacheConfig
- update
[[autodoc]] QuantizedCacheConfig
- validate
[[autodoc]] DynamicCache
- update
- get_seq_length
......@@ -367,6 +373,14 @@ A [`Constraint`] can be used to force the generation to include specific tokens
- to_legacy_cache
- from_legacy_cache
[[autodoc]] QuantizedCache
- update
- get_seq_length
[[autodoc]] QuantoQuantizedCache
[[autodoc]] HQQQuantizedCache
[[autodoc]] SinkCache
- update
- get_seq_length
......@@ -375,7 +389,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] StaticCache
- update
- get_seq_length
- reorder_cache
- reset
## Watermark Utils
......
......@@ -1182,7 +1182,17 @@ else:
_import_structure["activations"] = []
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
_import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache", "StaticCache"]
_import_structure["cache_utils"] = [
"Cache",
"CacheConfig",
"DynamicCache",
"HQQQuantizedCache",
"QuantizedCache",
"QuantizedCacheConfig",
"QuantoQuantizedCache",
"SinkCache",
"StaticCache",
]
_import_structure["data.datasets"] = [
"GlueDataset",
"GlueDataTrainingArguments",
......@@ -5792,7 +5802,17 @@ if TYPE_CHECKING:
# Benchmarks
from .benchmark.benchmark import PyTorchBenchmark
from .benchmark.benchmark_args import PyTorchBenchmarkArguments
from .cache_utils import Cache, DynamicCache, SinkCache, StaticCache
from .cache_utils import (
Cache,
CacheConfig,
DynamicCache,
HQQQuantizedCache,
QuantizedCache,
QuantizedCacheConfig,
QuantoQuantizedCache,
SinkCache,
StaticCache,
)
from .data.datasets import (
GlueDataset,
GlueDataTrainingArguments,
......
import copy
import json
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from .configuration_utils import PretrainedConfig
from .utils import logging
from .utils import is_hqq_available, is_quanto_available, logging
if is_quanto_available():
from quanto import QBitsTensor, qint2, qint4
if is_hqq_available():
from hqq.core.quantize import Quantizer as HQQQuantizer
logger = logging.get_logger(__name__)
......@@ -82,6 +91,201 @@ class Cache:
return None
@dataclass
class CacheConfig:
"""
Base class for cache configs
"""
cache_implementation: None
@classmethod
def from_dict(cls, config_dict, **kwargs):
"""
Constructs a CacheConfig instance from a dictionary of parameters.
Args:
config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
**kwargs: Additional keyword arguments to override dictionary values.
Returns:
CacheConfig: Instance of CacheConfig constructed from the dictionary.
"""
config = cls(**config_dict)
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
return config
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
"""
Save this instance to a JSON file.
Args:
json_file_path (`str` or `os.PathLike`):
Path to the JSON file in which this configuration instance's parameters will be saved.
use_diff (`bool`, *optional*, defaults to `True`):
If set to `True`, only the difference between the config instance and the default
`QuantizationConfig()` is serialized to JSON file.
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
config_dict = self.to_dict()
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
writer.write(json_string)
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
return copy.deepcopy(self.__dict__)
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
def __iter__(self):
"""allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
for attr, value in copy.deepcopy(self.__dict__).items():
yield attr, value
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
def to_json_string(self):
"""
Serializes this instance to a JSON formatted string.
Returns:
str: JSON formatted string representing the configuration instance.
"""
return json.dumps(self.__dict__, indent=2) + "\n"
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update
def update(self, **kwargs):
"""
Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes,
returning all the unused kwargs.
Args:
kwargs (`Dict[str, Any]`):
Dictionary of attributes to tentatively update this class.
Returns:
`Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
"""
to_remove = []
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
to_remove.append(key)
# Remove all the attributes that were updated, without modifying the input dict
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
return unused_kwargs
@dataclass
class QuantizedCacheConfig(CacheConfig):
"""
Configuration class for quantized cache settings.
Attributes:
backend (`str`, *optional*, defaults to `"quanto"`):
Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`]
nbits (`Optional[int]`, *optional*, defaults to 4):
Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2.
axis_key (`int`, *optional*, defaults to 0):
Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend.
axis_value (`int`, *optional*, defaults to 0):
Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend.
q_group_size (`Optional[int]`, *optional*, defaults to 64):
Size of the quantization group, should be a divisor of the model's hidden dimension.
Defaults to 64.
residual_length (`Optional[int]`, *optional*, defaults to 128):
Length of the residual cache which will always be stored in original presicion.
Defaults to 128.
compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
The defualt dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization.
device (`str`, *optional*, defaults to `"cpu"`):
Device on which to peform computations, should be same as the model's device.
"""
def __init__(
self,
backend: str = "quanto",
nbits: Optional[int] = 4,
axis_key: Optional[int] = 0,
axis_value: Optional[int] = 0,
q_group_size: Optional[int] = 64,
residual_length: Optional[int] = 128,
compute_dtype: Optional[torch.dtype] = torch.float16,
device: Optional[str] = "cpu",
):
self.backend = backend
self.nbits = nbits
self.axis_key = axis_key
self.axis_value = axis_value
self.q_group_size = q_group_size
self.residual_length = residual_length
self.compute_dtype = compute_dtype
self.device = device
def validate(self):
"""Validates if the arguments passed are correct"""
incorrect_arg_msg = (
"Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` "
"but found {found_value}"
)
# Check that the values are reasonable in general (nbits, axis)
# Later in QuantizedCache init we check if they are supported for that particular backend
if self.nbits not in [1, 2, 3, 4, 8]:
raise ValueError(
incorrect_arg_msg.format(
key="nbits",
correct_value="2 or 4 or 8",
found_value=self.nbits,
),
)
if self.q_group_size <= 0:
raise ValueError(
incorrect_arg_msg.format(
key="q_group_size",
correct_value="a positive integer",
found_value=self.q_group_size,
),
)
if self.residual_length < 0:
raise ValueError(
incorrect_arg_msg.format(
key="residual_length",
correct_value="a positive integer",
found_value=self.residual_length,
),
)
if self.axis_key not in [0, 1, -1]:
raise ValueError(
incorrect_arg_msg.format(
key="axis_key",
correct_value="`1` or `0`, `-1`",
found_value=self.axis_key,
),
)
if self.axis_value not in [0, 1, -1]:
raise ValueError(
incorrect_arg_msg.format(
key="axis_value",
correct_value="`1` or `0` or `-1`",
found_value=self.axis_value,
),
)
class DynamicCache(Cache):
"""
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
......@@ -186,6 +390,168 @@ class DynamicCache(Cache):
return cache
class QuantizedCache(DynamicCache):
"""
A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750).
It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.
The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the
original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The
quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper.
It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and
Value in original precision states as a list of tensors, one for each layer. The size of each tensor
is `[batch_size, num_heads, seq_len - residual_length, head_dim]`
"""
def __init__(self, cache_config: QuantizedCacheConfig) -> None:
self._quantized_key_cache: List[torch.Tensor] = []
self._quantized_value_cache: List[torch.Tensor] = []
self.nbits = cache_config.nbits
self.residual_length = cache_config.residual_length
self.q_group_size = cache_config.q_group_size
self.axis_key = cache_config.axis_key
self.axis_value = cache_config.axis_value
self.compute_dtype = cache_config.compute_dtype
self.device = cache_config.device
super().__init__()
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
if len(self.key_cache) <= layer_idx:
self._quantized_key_cache.append(self._quantize(key_states.contiguous(), axis=self.axis_key))
self._quantized_value_cache.append(self._quantize(value_states.contiguous(), axis=self.axis_value))
self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
keys_to_return, values_to_return = key_states, value_states
else:
dequant_key = self._dequantize(self._quantized_key_cache[layer_idx])
dequant_value = self._dequantize(self._quantized_value_cache[layer_idx])
keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states]
values_to_return = [dequant_value, self.value_cache[layer_idx], value_states]
keys_to_return = torch.cat(keys_to_return, dim=-2)
values_to_return = torch.cat(values_to_return, dim=-2)
if (
self.key_cache[layer_idx].dim() == 4
and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length
):
self._quantized_key_cache[layer_idx] = self._quantize(keys_to_return.contiguous(), axis=self.axis_key)
self._quantized_value_cache[layer_idx] = self._quantize(
values_to_return.contiguous(), axis=self.axis_value
)
self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
return keys_to_return, values_to_return
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.key_cache) <= layer_idx:
return 0
# since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is
# updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx
# this part of code otherwise fails when used to verify attn_weight shape in some models
return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1
def _quantize(self, tensor, axis):
"""Quantizes a key/value using a defined quantization method."""
raise NotImplementedError("Make sure to implement `_quantize` in a subclass.")
def _dequantize(self, q_tensor):
"""Dequantizes back the tensor that was quantized by `self._quantize()`"""
raise NotImplementedError("Make sure to implement `_dequantize` in a subclass.")
class QuantoQuantizedCache(QuantizedCache):
"""
Quantized Cache class that uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only.
Parameters:
cache_config (`QuantizedCacheConfig`,):
A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
"""
def __init__(self, cache_config: CacheConfig) -> None:
super().__init__(cache_config)
if self.nbits not in [2, 4]:
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
if self.axis_key not in [0, -1]:
raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}")
if self.axis_value not in [0, -1]:
raise ValueError(
f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}"
)
self.qtype = qint4 if self.nbits == 4 else qint2
def _quantize(self, tensor, axis):
qtensor = QBitsTensor.quantize(tensor, axis=axis, qtype=self.qtype, group_size=self.q_group_size)
return qtensor
def _dequantize(self, qtensor):
return qtensor.dequantize()
class HQQQuantizedCache(QuantizedCache):
"""
Quantized Cache class that uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes.
Parameters:
cache_config (`QuantizedCacheConfig`,):
A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
"""
def __init__(self, cache_config: CacheConfig) -> None:
super().__init__(cache_config)
if self.nbits not in [1, 2, 3, 4, 8]:
raise ValueError(
f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}"
)
if self.axis_key not in [0, 1]:
raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}")
if self.axis_value not in [0, 1]:
raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}")
self.quantizer = HQQQuantizer
def _quantize(self, tensor, axis):
qtensor, meta = self.quantizer.quantize(
tensor,
axis=axis,
device=self.device,
compute_dtype=self.compute_dtype,
nbits=self.nbits,
group_size=self.q_group_size,
)
meta["compute_dtype"] = self.compute_dtype
self.quantizer.cuda(qtensor, meta=meta, device=self.device) # Move to device and cast to dtype
return qtensor, meta
def _dequantize(self, qtensor):
quant_tensor, meta = qtensor
tensor = self.quantizer.dequantize(quant_tensor, meta)
return tensor
class SinkCache(Cache):
"""
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
......
......@@ -31,6 +31,7 @@ from ..utils import (
download_url,
extract_commit_hash,
is_remote_url,
is_torch_available,
logging,
)
......@@ -41,6 +42,12 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__)
METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version")
NEEDS_CACHE_CONFIG = {}
if is_torch_available():
from ..cache_utils import QuantizedCacheConfig
NEEDS_CACHE_CONFIG["quantized"] = QuantizedCacheConfig
class GenerationMode(ExplicitEnum):
......@@ -299,6 +306,10 @@ class GenerationConfig(PushToHubMixin):
cache_implementation (`str`, *optional*, default to `None`):
Cache class that should be used when generating.
cache_config (`Union[CacheConfig, dict]`, *optional*, default to `None`):
Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
it will be converted to its repsective `CacheConfig` internally.
Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`.
> Wild card
......@@ -382,6 +393,13 @@ class GenerationConfig(PushToHubMixin):
# Cache implementation
self.cache_implementation = kwargs.pop("cache_implementation", None)
self.cache_config = kwargs.pop("cache_config", None)
if self.cache_implementation is not None:
cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation]
if self.cache_config is None:
self.cache_config = cache_config_class()
elif isinstance(self.cache_config, dict):
self.cache_config = cache_config_class.from_dict(self.cache_config)
# Prompt lookup decoding
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
......@@ -638,13 +656,26 @@ class GenerationConfig(PushToHubMixin):
f"({self.num_beams})."
)
# check watermarking arguments
# 5. check `cache_config`
if self.cache_config is not None:
cache_class = NEEDS_CACHE_CONFIG.get(self.cache_implementation)
if cache_class is None:
raise ValueError(
"You provided a `cache_config` but the cache implementation you are using "
f"({self.cache_implementation}) does not require any config. Make sure to use the "
"correct cache implementation matching your cache config."
)
if not isinstance(self.cache_config, cache_class):
self.cache_config = cache_class.from_dict(self.cache_config)
self.cache_config.validate()
# 6. check watermarking arguments
if self.watermarking_config is not None:
if not isinstance(self.watermarking_config, WatermarkingConfig):
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
self.watermarking_config.validate()
# 5. check common issue: passing `generate` arguments inside the generation config
# 7. check common issue: passing `generate` arguments inside the generation config
generate_arguments = (
"logits_processor",
"stopping_criteria",
......
......@@ -24,7 +24,15 @@ import torch
import torch.distributed as dist
from torch import nn
from ..cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
from ..cache_utils import (
Cache,
DynamicCache,
HQQQuantizedCache,
QuantizedCacheConfig,
QuantoQuantizedCache,
SlidingWindowCache,
StaticCache,
)
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import (
......@@ -34,7 +42,14 @@ from ..models.auto import (
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from ..utils import ModelOutput, is_accelerate_available, is_torchdynamo_compiling, logging
from ..utils import (
ModelOutput,
is_accelerate_available,
is_hqq_available,
is_quanto_available,
is_torchdynamo_compiling,
logging,
)
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .candidate_generator import (
......@@ -97,6 +112,7 @@ if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache}
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
@dataclass
......@@ -1658,20 +1674,43 @@ class GenerationMixin:
"Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a "
"Cache object) is unsupported. Please use only one of the two."
)
elif generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if not self._supports_cache_class:
raise ValueError(
"This model does not support the `cache_implementation` argument. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981."
elif generation_config.cache_implementation is not None:
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
if generation_config.cache_implementation == "static" and not self._supports_static_cache:
raise ValueError(
"This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs["past_key_values"] = self._get_cache(
generation_config.cache_implementation, batch_size, generation_config.max_length
)
if generation_config.cache_implementation == "static" and not self._supports_static_cache:
raise ValueError(
"This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981"
elif generation_config.cache_implementation == "quantized":
if not self._supports_quantized_cache:
raise ValueError(
"This model does not support the quantized cache. If you want your model to support quantized "
"cache, please open an issue."
)
cache_config = (
generation_config.cache_config
if generation_config.cache_config is not None
else QuantizedCacheConfig()
)
model_kwargs["past_key_values"] = self._get_cache(
generation_config.cache_implementation, batch_size, generation_config.max_length
)
cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]
if cache_config.backend == "quanto" and not is_quanto_available():
raise ImportError(
"You need to install `quanto` in order to use KV cache quantization with quanto backend. "
"Please install it via with `pip install quanto`"
)
elif cache_config.backend == "HQQ" and not is_hqq_available():
raise ImportError(
"You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
"Please install it via with `pip install hqq`"
)
model_kwargs["past_key_values"] = cache_class(cache_config)
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
# 7. determine generation mode
......
......@@ -1284,6 +1284,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
_supports_cache_class = False
_supports_static_cache = False
# Has support for a `QuantoQuantizedCache` instance as `past_key_values`
_supports_quantized_cache = False
@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
"""
......
......@@ -712,6 +712,7 @@ class CoherePreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
......
......@@ -937,6 +937,7 @@ class DbrxPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module: nn.Module):
......
......@@ -698,6 +698,7 @@ class GemmaPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
......
......@@ -767,6 +767,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
......
......@@ -745,6 +745,7 @@ class OlmoPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
def _init_weights(self, module):
......
......@@ -539,6 +539,8 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = ["cache"]
_supports_flash_attn_2 = False
_supports_sdpa = False # we can't compare with eager for now
_supports_cache_class = True
_supports_quantized_cache = True
def _init_weights(self, module):
std = math.sqrt(self.config.w_init_variance_scale / self.config.conv1d_width)
......
......@@ -832,6 +832,7 @@ class StableLmPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True
_supports_cache_class = True
_supports_sdpa = True
_supports_quantized_cache = True
def _init_weights(self, module):
std = self.config.initializer_range
......
......@@ -784,11 +784,11 @@ class WhisperGenerationMixin:
del generate_kwargs[key]
seek_outputs = super().generate(
segment_input,
generation_config,
logits_processor,
stopping_criteria,
prefix_allowed_tokens_fn,
synced_gpus,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
decoder_input_ids=decoder_input_ids,
**generate_kwargs,
)
......
......@@ -23,6 +23,13 @@ class Cache(metaclass=DummyObject):
requires_backends(self, ["torch"])
class CacheConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DynamicCache(metaclass=DummyObject):
_backends = ["torch"]
......@@ -30,6 +37,34 @@ class DynamicCache(metaclass=DummyObject):
requires_backends(self, ["torch"])
class HQQQuantizedCache(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class QuantizedCache(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class QuantizedCacheConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class QuantoQuantizedCache(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SinkCache(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -27,6 +27,7 @@ from transformers import is_torch_available, pipeline, set_seed
from transformers.testing_utils import (
is_flaky,
require_accelerate,
require_quanto,
require_torch,
require_torch_multi_accelerator,
slow,
......@@ -55,7 +56,7 @@ if is_torch_available():
ImageGPTForCausalImageModeling,
SpeechEncoderDecoderModel,
)
from transformers.cache_utils import DynamicCache
from transformers.cache_utils import DynamicCache, QuantoQuantizedCache
from transformers.generation import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
......@@ -1654,6 +1655,39 @@ class GenerationTesterMixin:
)
)
@require_quanto
def test_generate_with_quant_cache(self):
for model_class in self.all_generative_model_classes:
if not model_class._supports_quantized_cache:
self.skipTest("This model does not support the quantized cache format")
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
"max_new_tokens": 5,
"cache_implementation": "quantized",
# careful with group size, should be divisor of model's hidden size
"cache_config": {"backend": "quanto", "nbits": 2, "q_group_size": 8, "residual_length": 128},
"return_dict_in_generate": True, # Required to return `past_key_values`
}
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
self.assertTrue(isinstance(results.past_key_values, QuantoQuantizedCache))
# passing past key values of different type should raise Error
with self.assertRaises(ValueError):
model.generate(
input_ids, attention_mask=attention_mask, past_key_valyes=DynamicCache(), **generation_kwargs
)
# setting incorrect cache_config args should raise an Error, i.e. nbits=60 does not make sense
generation_kwargs["cache_config"] = {"nbits": 60, "q_group_size": 8, "residual_length": 128}
with self.assertRaises(ValueError):
model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, seq_length = input_ids.shape
num_sequences_in_output = batch_size * num_return_sequences
......
......@@ -17,13 +17,22 @@ import tempfile
import unittest
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, QuantoConfig
from transformers.testing_utils import require_accelerate, require_quanto, require_torch_gpu, slow
from transformers.testing_utils import (
require_accelerate,
require_quanto,
require_read_token,
require_torch_gpu,
slow,
torch_device,
)
from transformers.utils import is_accelerate_available, is_quanto_available, is_torch_available
if is_torch_available():
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
if is_accelerate_available():
from accelerate import init_empty_weights
......@@ -429,3 +438,28 @@ class QuantoQuantizationActivationTest(unittest.TestCase):
with self.assertRaises(ValueError) as e:
AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", quantization_config=quantization_config)
self.assertIn("We don't support quantizing the activations with transformers library", str(e.exception))
@require_torch_gpu
class QuantoKVCacheQuantizationTest(unittest.TestCase):
@slow
@require_read_token
def test_quantized_cache(self):
EXPECTED_TEXT_COMPLETION = [
"Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory of relativity",
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my burgers, my hot dogs, my sandwiches, my chicken, my pizza, my sal",
]
prompts = [
"Simply put, the theory of relativity states that ",
"My favorite all time favorite condiment is ketchup.",
]
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="left")
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", device_map="sequential", torch_dtype=torch.float16
)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(torch_device)
generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False, cache_implementation="quantized")
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment