Unverified Commit d583f131 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Quantized KV Cache (#30483)



* clean-up

* Update src/transformers/cache_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/cache_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/cache_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* fixup

* Update tests/quantization/quanto_integration/test_quanto.py
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/generation/configuration_utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* more suggestions

* mapping if torch available

* run tests & add 'support_quantized' flag

* fix jamba test

* revert, will be fixed by another PR

* codestyle

* HQQ and versatile cache classes

* final update

* typo

* make tests happy

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent e05baad8
...@@ -51,6 +51,9 @@ RUN python3 -m pip install --no-cache-dir gguf ...@@ -51,6 +51,9 @@ RUN python3 -m pip install --no-cache-dir gguf
# Some slow tests require bnb # Some slow tests require bnb
RUN python3 -m pip install --no-cache-dir bitsandbytes RUN python3 -m pip install --no-cache-dir bitsandbytes
# Some tests require quanto
RUN python3 -m pip install --no-cache-dir quanto
# For `dinat` model # For `dinat` model
# The `XXX` part in `torchXXX` needs to match `PYTORCH` (to some extent) # The `XXX` part in `torchXXX` needs to match `PYTORCH` (to some extent)
RUN python3 -m pip install --no-cache-dir natten==0.15.1+torch220$CUDA -f https://shi-labs.com/natten/wheels RUN python3 -m pip install --no-cache-dir natten==0.15.1+torch220$CUDA -f https://shi-labs.com/natten/wheels
......
...@@ -174,6 +174,43 @@ An increasing sequence: one, two, three, four, five, six, seven, eight, nine, te ...@@ -174,6 +174,43 @@ An increasing sequence: one, two, three, four, five, six, seven, eight, nine, te
``` ```
## KV Cache Quantization
The `generate()` method supports caching keys and values to enhance efficiency and avoid re-computations. However the key and value
cache can occupy a large portion of memory, becoming a bottleneck for long-context generation, especially for Large Language Models.
Quantizing the cache when using `generate()` can significantly reduce memory requirements at the cost of speed.
KV Cache quantization in `transformers` is largely inspired by the paper [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache]
(https://arxiv.org/abs/2402.02750) and currently supports `quanto` and `HQQ` as backends. For more information on the inner workings see the paper.
To enable quantization of the key-value cache, one needs to indicate `cache_implementation="quantized"` in the `generation_config`.
Quantization related arguments should be passed to the `generation_config` either as a `dict` or an instance of a [`QuantizedCacheConfig`] class.
One has to indicate which quantization backend to use in the [`QuantizedCacheConfig`], the default is `quanto`.
<Tip warning={true}>
Cache quantization can be detrimental if the context length is short and there is enough GPU VRAM available to run without cache quantization.
</Tip>
```python
>>> import torch
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
>>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.float16).to("cuda:0")
>>> inputs = tokenizer("I like rock music because", return_tensors="pt").to(model.device)
>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20, cache_implementation="quantized", cache_config={"nbits": 4, "backend": "quanto"})
>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])
I like rock music because it's loud and energetic. It's a great way to express myself and rel
>>> out = model.generate(**inputs, do_sample=False, max_new_tokens=20)
>>> print(tokenizer.batch_decode(out, skip_special_tokens=True)[0])
I like rock music because it's loud and energetic. I like to listen to it when I'm feeling
```
## Watermarking ## Watermarking
The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green". The `generate()` supports watermarking the generated text by randomly marking a portion of tokens as "green".
......
...@@ -360,6 +360,12 @@ A [`Constraint`] can be used to force the generation to include specific tokens ...@@ -360,6 +360,12 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] Cache [[autodoc]] Cache
- update - update
[[autodoc]] CacheConfig
- update
[[autodoc]] QuantizedCacheConfig
- validate
[[autodoc]] DynamicCache [[autodoc]] DynamicCache
- update - update
- get_seq_length - get_seq_length
...@@ -367,6 +373,14 @@ A [`Constraint`] can be used to force the generation to include specific tokens ...@@ -367,6 +373,14 @@ A [`Constraint`] can be used to force the generation to include specific tokens
- to_legacy_cache - to_legacy_cache
- from_legacy_cache - from_legacy_cache
[[autodoc]] QuantizedCache
- update
- get_seq_length
[[autodoc]] QuantoQuantizedCache
[[autodoc]] HQQQuantizedCache
[[autodoc]] SinkCache [[autodoc]] SinkCache
- update - update
- get_seq_length - get_seq_length
...@@ -375,7 +389,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens ...@@ -375,7 +389,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] StaticCache [[autodoc]] StaticCache
- update - update
- get_seq_length - get_seq_length
- reorder_cache - reset
## Watermark Utils ## Watermark Utils
......
...@@ -1182,7 +1182,17 @@ else: ...@@ -1182,7 +1182,17 @@ else:
_import_structure["activations"] = [] _import_structure["activations"] = []
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"] _import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"] _import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
_import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache", "StaticCache"] _import_structure["cache_utils"] = [
"Cache",
"CacheConfig",
"DynamicCache",
"HQQQuantizedCache",
"QuantizedCache",
"QuantizedCacheConfig",
"QuantoQuantizedCache",
"SinkCache",
"StaticCache",
]
_import_structure["data.datasets"] = [ _import_structure["data.datasets"] = [
"GlueDataset", "GlueDataset",
"GlueDataTrainingArguments", "GlueDataTrainingArguments",
...@@ -5792,7 +5802,17 @@ if TYPE_CHECKING: ...@@ -5792,7 +5802,17 @@ if TYPE_CHECKING:
# Benchmarks # Benchmarks
from .benchmark.benchmark import PyTorchBenchmark from .benchmark.benchmark import PyTorchBenchmark
from .benchmark.benchmark_args import PyTorchBenchmarkArguments from .benchmark.benchmark_args import PyTorchBenchmarkArguments
from .cache_utils import Cache, DynamicCache, SinkCache, StaticCache from .cache_utils import (
Cache,
CacheConfig,
DynamicCache,
HQQQuantizedCache,
QuantizedCache,
QuantizedCacheConfig,
QuantoQuantizedCache,
SinkCache,
StaticCache,
)
from .data.datasets import ( from .data.datasets import (
GlueDataset, GlueDataset,
GlueDataTrainingArguments, GlueDataTrainingArguments,
......
import copy
import json
import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
from .utils import logging from .utils import is_hqq_available, is_quanto_available, logging
if is_quanto_available():
from quanto import QBitsTensor, qint2, qint4
if is_hqq_available():
from hqq.core.quantize import Quantizer as HQQQuantizer
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -82,6 +91,201 @@ class Cache: ...@@ -82,6 +91,201 @@ class Cache:
return None return None
@dataclass
class CacheConfig:
"""
Base class for cache configs
"""
cache_implementation: None
@classmethod
def from_dict(cls, config_dict, **kwargs):
"""
Constructs a CacheConfig instance from a dictionary of parameters.
Args:
config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
**kwargs: Additional keyword arguments to override dictionary values.
Returns:
CacheConfig: Instance of CacheConfig constructed from the dictionary.
"""
config = cls(**config_dict)
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
return config
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
"""
Save this instance to a JSON file.
Args:
json_file_path (`str` or `os.PathLike`):
Path to the JSON file in which this configuration instance's parameters will be saved.
use_diff (`bool`, *optional*, defaults to `True`):
If set to `True`, only the difference between the config instance and the default
`QuantizationConfig()` is serialized to JSON file.
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
config_dict = self.to_dict()
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
writer.write(json_string)
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
return copy.deepcopy(self.__dict__)
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
def __iter__(self):
"""allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
for attr, value in copy.deepcopy(self.__dict__).items():
yield attr, value
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
def to_json_string(self):
"""
Serializes this instance to a JSON formatted string.
Returns:
str: JSON formatted string representing the configuration instance.
"""
return json.dumps(self.__dict__, indent=2) + "\n"
# Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update
def update(self, **kwargs):
"""
Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes,
returning all the unused kwargs.
Args:
kwargs (`Dict[str, Any]`):
Dictionary of attributes to tentatively update this class.
Returns:
`Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
"""
to_remove = []
for key, value in kwargs.items():
if hasattr(self, key):
setattr(self, key, value)
to_remove.append(key)
# Remove all the attributes that were updated, without modifying the input dict
unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
return unused_kwargs
@dataclass
class QuantizedCacheConfig(CacheConfig):
"""
Configuration class for quantized cache settings.
Attributes:
backend (`str`, *optional*, defaults to `"quanto"`):
Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`]
nbits (`Optional[int]`, *optional*, defaults to 4):
Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2.
axis_key (`int`, *optional*, defaults to 0):
Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend.
axis_value (`int`, *optional*, defaults to 0):
Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend.
q_group_size (`Optional[int]`, *optional*, defaults to 64):
Size of the quantization group, should be a divisor of the model's hidden dimension.
Defaults to 64.
residual_length (`Optional[int]`, *optional*, defaults to 128):
Length of the residual cache which will always be stored in original presicion.
Defaults to 128.
compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
The defualt dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization.
device (`str`, *optional*, defaults to `"cpu"`):
Device on which to peform computations, should be same as the model's device.
"""
def __init__(
self,
backend: str = "quanto",
nbits: Optional[int] = 4,
axis_key: Optional[int] = 0,
axis_value: Optional[int] = 0,
q_group_size: Optional[int] = 64,
residual_length: Optional[int] = 128,
compute_dtype: Optional[torch.dtype] = torch.float16,
device: Optional[str] = "cpu",
):
self.backend = backend
self.nbits = nbits
self.axis_key = axis_key
self.axis_value = axis_value
self.q_group_size = q_group_size
self.residual_length = residual_length
self.compute_dtype = compute_dtype
self.device = device
def validate(self):
"""Validates if the arguments passed are correct"""
incorrect_arg_msg = (
"Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` "
"but found {found_value}"
)
# Check that the values are reasonable in general (nbits, axis)
# Later in QuantizedCache init we check if they are supported for that particular backend
if self.nbits not in [1, 2, 3, 4, 8]:
raise ValueError(
incorrect_arg_msg.format(
key="nbits",
correct_value="2 or 4 or 8",
found_value=self.nbits,
),
)
if self.q_group_size <= 0:
raise ValueError(
incorrect_arg_msg.format(
key="q_group_size",
correct_value="a positive integer",
found_value=self.q_group_size,
),
)
if self.residual_length < 0:
raise ValueError(
incorrect_arg_msg.format(
key="residual_length",
correct_value="a positive integer",
found_value=self.residual_length,
),
)
if self.axis_key not in [0, 1, -1]:
raise ValueError(
incorrect_arg_msg.format(
key="axis_key",
correct_value="`1` or `0`, `-1`",
found_value=self.axis_key,
),
)
if self.axis_value not in [0, 1, -1]:
raise ValueError(
incorrect_arg_msg.format(
key="axis_value",
correct_value="`1` or `0` or `-1`",
found_value=self.axis_value,
),
)
class DynamicCache(Cache): class DynamicCache(Cache):
""" """
A cache that grows dynamically as more tokens are generated. This is the default for generative models. A cache that grows dynamically as more tokens are generated. This is the default for generative models.
...@@ -186,6 +390,168 @@ class DynamicCache(Cache): ...@@ -186,6 +390,168 @@ class DynamicCache(Cache):
return cache return cache
class QuantizedCache(DynamicCache):
"""
A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750).
It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.
The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the
original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The
quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper.
It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and
Value in original precision states as a list of tensors, one for each layer. The size of each tensor
is `[batch_size, num_heads, seq_len - residual_length, head_dim]`
"""
def __init__(self, cache_config: QuantizedCacheConfig) -> None:
self._quantized_key_cache: List[torch.Tensor] = []
self._quantized_value_cache: List[torch.Tensor] = []
self.nbits = cache_config.nbits
self.residual_length = cache_config.residual_length
self.q_group_size = cache_config.q_group_size
self.axis_key = cache_config.axis_key
self.axis_value = cache_config.axis_value
self.compute_dtype = cache_config.compute_dtype
self.device = cache_config.device
super().__init__()
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
if len(self.key_cache) <= layer_idx:
self._quantized_key_cache.append(self._quantize(key_states.contiguous(), axis=self.axis_key))
self._quantized_value_cache.append(self._quantize(value_states.contiguous(), axis=self.axis_value))
self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
keys_to_return, values_to_return = key_states, value_states
else:
dequant_key = self._dequantize(self._quantized_key_cache[layer_idx])
dequant_value = self._dequantize(self._quantized_value_cache[layer_idx])
keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states]
values_to_return = [dequant_value, self.value_cache[layer_idx], value_states]
keys_to_return = torch.cat(keys_to_return, dim=-2)
values_to_return = torch.cat(values_to_return, dim=-2)
if (
self.key_cache[layer_idx].dim() == 4
and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length
):
self._quantized_key_cache[layer_idx] = self._quantize(keys_to_return.contiguous(), axis=self.axis_key)
self._quantized_value_cache[layer_idx] = self._quantize(
values_to_return.contiguous(), axis=self.axis_value
)
self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
return keys_to_return, values_to_return
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.key_cache) <= layer_idx:
return 0
# since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is
# updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx
# this part of code otherwise fails when used to verify attn_weight shape in some models
return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1
def _quantize(self, tensor, axis):
"""Quantizes a key/value using a defined quantization method."""
raise NotImplementedError("Make sure to implement `_quantize` in a subclass.")
def _dequantize(self, q_tensor):
"""Dequantizes back the tensor that was quantized by `self._quantize()`"""
raise NotImplementedError("Make sure to implement `_dequantize` in a subclass.")
class QuantoQuantizedCache(QuantizedCache):
"""
Quantized Cache class that uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only.
Parameters:
cache_config (`QuantizedCacheConfig`,):
A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
"""
def __init__(self, cache_config: CacheConfig) -> None:
super().__init__(cache_config)
if self.nbits not in [2, 4]:
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
if self.axis_key not in [0, -1]:
raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}")
if self.axis_value not in [0, -1]:
raise ValueError(
f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}"
)
self.qtype = qint4 if self.nbits == 4 else qint2
def _quantize(self, tensor, axis):
qtensor = QBitsTensor.quantize(tensor, axis=axis, qtype=self.qtype, group_size=self.q_group_size)
return qtensor
def _dequantize(self, qtensor):
return qtensor.dequantize()
class HQQQuantizedCache(QuantizedCache):
"""
Quantized Cache class that uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes.
Parameters:
cache_config (`QuantizedCacheConfig`,):
A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
"""
def __init__(self, cache_config: CacheConfig) -> None:
super().__init__(cache_config)
if self.nbits not in [1, 2, 3, 4, 8]:
raise ValueError(
f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}"
)
if self.axis_key not in [0, 1]:
raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}")
if self.axis_value not in [0, 1]:
raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}")
self.quantizer = HQQQuantizer
def _quantize(self, tensor, axis):
qtensor, meta = self.quantizer.quantize(
tensor,
axis=axis,
device=self.device,
compute_dtype=self.compute_dtype,
nbits=self.nbits,
group_size=self.q_group_size,
)
meta["compute_dtype"] = self.compute_dtype
self.quantizer.cuda(qtensor, meta=meta, device=self.device) # Move to device and cast to dtype
return qtensor, meta
def _dequantize(self, qtensor):
quant_tensor, meta = qtensor
tensor = self.quantizer.dequantize(quant_tensor, meta)
return tensor
class SinkCache(Cache): class SinkCache(Cache):
""" """
A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
......
...@@ -31,6 +31,7 @@ from ..utils import ( ...@@ -31,6 +31,7 @@ from ..utils import (
download_url, download_url,
extract_commit_hash, extract_commit_hash,
is_remote_url, is_remote_url,
is_torch_available,
logging, logging,
) )
...@@ -41,6 +42,12 @@ if TYPE_CHECKING: ...@@ -41,6 +42,12 @@ if TYPE_CHECKING:
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version") METADATA_FIELDS = ("_from_model_config", "_commit_hash", "_original_object_hash", "transformers_version")
NEEDS_CACHE_CONFIG = {}
if is_torch_available():
from ..cache_utils import QuantizedCacheConfig
NEEDS_CACHE_CONFIG["quantized"] = QuantizedCacheConfig
class GenerationMode(ExplicitEnum): class GenerationMode(ExplicitEnum):
...@@ -299,6 +306,10 @@ class GenerationConfig(PushToHubMixin): ...@@ -299,6 +306,10 @@ class GenerationConfig(PushToHubMixin):
cache_implementation (`str`, *optional*, default to `None`): cache_implementation (`str`, *optional*, default to `None`):
Cache class that should be used when generating. Cache class that should be used when generating.
cache_config (`Union[CacheConfig, dict]`, *optional*, default to `None`):
Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
it will be converted to its repsective `CacheConfig` internally.
Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`.
> Wild card > Wild card
...@@ -382,6 +393,13 @@ class GenerationConfig(PushToHubMixin): ...@@ -382,6 +393,13 @@ class GenerationConfig(PushToHubMixin):
# Cache implementation # Cache implementation
self.cache_implementation = kwargs.pop("cache_implementation", None) self.cache_implementation = kwargs.pop("cache_implementation", None)
self.cache_config = kwargs.pop("cache_config", None)
if self.cache_implementation is not None:
cache_config_class = NEEDS_CACHE_CONFIG[self.cache_implementation]
if self.cache_config is None:
self.cache_config = cache_config_class()
elif isinstance(self.cache_config, dict):
self.cache_config = cache_config_class.from_dict(self.cache_config)
# Prompt lookup decoding # Prompt lookup decoding
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None) self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
...@@ -638,13 +656,26 @@ class GenerationConfig(PushToHubMixin): ...@@ -638,13 +656,26 @@ class GenerationConfig(PushToHubMixin):
f"({self.num_beams})." f"({self.num_beams})."
) )
# check watermarking arguments # 5. check `cache_config`
if self.cache_config is not None:
cache_class = NEEDS_CACHE_CONFIG.get(self.cache_implementation)
if cache_class is None:
raise ValueError(
"You provided a `cache_config` but the cache implementation you are using "
f"({self.cache_implementation}) does not require any config. Make sure to use the "
"correct cache implementation matching your cache config."
)
if not isinstance(self.cache_config, cache_class):
self.cache_config = cache_class.from_dict(self.cache_config)
self.cache_config.validate()
# 6. check watermarking arguments
if self.watermarking_config is not None: if self.watermarking_config is not None:
if not isinstance(self.watermarking_config, WatermarkingConfig): if not isinstance(self.watermarking_config, WatermarkingConfig):
self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config) self.watermarking_config = WatermarkingConfig.from_dict(self.watermarking_config)
self.watermarking_config.validate() self.watermarking_config.validate()
# 5. check common issue: passing `generate` arguments inside the generation config # 7. check common issue: passing `generate` arguments inside the generation config
generate_arguments = ( generate_arguments = (
"logits_processor", "logits_processor",
"stopping_criteria", "stopping_criteria",
......
...@@ -24,7 +24,15 @@ import torch ...@@ -24,7 +24,15 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch import nn from torch import nn
from ..cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache from ..cache_utils import (
Cache,
DynamicCache,
HQQQuantizedCache,
QuantizedCacheConfig,
QuantoQuantizedCache,
SlidingWindowCache,
StaticCache,
)
from ..integrations.deepspeed import is_deepspeed_zero3_enabled from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import ( from ..models.auto import (
...@@ -34,7 +42,14 @@ from ..models.auto import ( ...@@ -34,7 +42,14 @@ from ..models.auto import (
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING, MODEL_FOR_VISION_2_SEQ_MAPPING,
) )
from ..utils import ModelOutput, is_accelerate_available, is_torchdynamo_compiling, logging from ..utils import (
ModelOutput,
is_accelerate_available,
is_hqq_available,
is_quanto_available,
is_torchdynamo_compiling,
logging,
)
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .candidate_generator import ( from .candidate_generator import (
...@@ -97,6 +112,7 @@ if is_accelerate_available(): ...@@ -97,6 +112,7 @@ if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module from accelerate.hooks import AlignDevicesHook, add_hook_to_module
NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache} NEED_SETUP_CACHE_CLASSES_MAPPING = {"static": StaticCache, "sliding_window": SlidingWindowCache}
QUANT_BACKEND_CLASSES_MAPPING = {"quanto": QuantoQuantizedCache, "HQQ": HQQQuantizedCache}
@dataclass @dataclass
...@@ -1658,20 +1674,43 @@ class GenerationMixin: ...@@ -1658,20 +1674,43 @@ class GenerationMixin:
"Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a " "Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a "
"Cache object) is unsupported. Please use only one of the two." "Cache object) is unsupported. Please use only one of the two."
) )
elif generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: elif generation_config.cache_implementation is not None:
if not self._supports_cache_class: if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING:
raise ValueError( if generation_config.cache_implementation == "static" and not self._supports_static_cache:
"This model does not support the `cache_implementation` argument. Please check the following " raise ValueError(
"issue: https://github.com/huggingface/transformers/issues/28981." "This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs["past_key_values"] = self._get_cache(
generation_config.cache_implementation, batch_size, generation_config.max_length
) )
if generation_config.cache_implementation == "static" and not self._supports_static_cache: elif generation_config.cache_implementation == "quantized":
raise ValueError( if not self._supports_quantized_cache:
"This model does not support `cache_implementation='static'`. Please check the following " raise ValueError(
"issue: https://github.com/huggingface/transformers/issues/28981" "This model does not support the quantized cache. If you want your model to support quantized "
"cache, please open an issue."
)
cache_config = (
generation_config.cache_config
if generation_config.cache_config is not None
else QuantizedCacheConfig()
) )
model_kwargs["past_key_values"] = self._get_cache( cache_class = QUANT_BACKEND_CLASSES_MAPPING[cache_config.backend]
generation_config.cache_implementation, batch_size, generation_config.max_length
) if cache_config.backend == "quanto" and not is_quanto_available():
raise ImportError(
"You need to install `quanto` in order to use KV cache quantization with quanto backend. "
"Please install it via with `pip install quanto`"
)
elif cache_config.backend == "HQQ" and not is_hqq_available():
raise ImportError(
"You need to install `HQQ` in order to use KV cache quantization with HQQ backend. "
"Please install it via with `pip install hqq`"
)
model_kwargs["past_key_values"] = cache_class(cache_config)
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
# 7. determine generation mode # 7. determine generation mode
......
...@@ -1284,6 +1284,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1284,6 +1284,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
_supports_cache_class = False _supports_cache_class = False
_supports_static_cache = False _supports_static_cache = False
# Has support for a `QuantoQuantizedCache` instance as `past_key_values`
_supports_quantized_cache = False
@property @property
def dummy_inputs(self) -> Dict[str, torch.Tensor]: def dummy_inputs(self) -> Dict[str, torch.Tensor]:
""" """
......
...@@ -712,6 +712,7 @@ class CoherePreTrainedModel(PreTrainedModel): ...@@ -712,6 +712,7 @@ class CoherePreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True _supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True _supports_static_cache = True
def _init_weights(self, module): def _init_weights(self, module):
......
...@@ -937,6 +937,7 @@ class DbrxPreTrainedModel(PreTrainedModel): ...@@ -937,6 +937,7 @@ class DbrxPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True _supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True _supports_static_cache = True
def _init_weights(self, module: nn.Module): def _init_weights(self, module: nn.Module):
......
...@@ -698,6 +698,7 @@ class GemmaPreTrainedModel(PreTrainedModel): ...@@ -698,6 +698,7 @@ class GemmaPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True _supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True _supports_static_cache = True
def _init_weights(self, module): def _init_weights(self, module):
......
...@@ -767,6 +767,7 @@ class LlamaPreTrainedModel(PreTrainedModel): ...@@ -767,6 +767,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True _supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True _supports_static_cache = True
def _init_weights(self, module): def _init_weights(self, module):
......
...@@ -745,6 +745,7 @@ class OlmoPreTrainedModel(PreTrainedModel): ...@@ -745,6 +745,7 @@ class OlmoPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_sdpa = True _supports_sdpa = True
_supports_cache_class = True _supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True _supports_static_cache = True
def _init_weights(self, module): def _init_weights(self, module):
......
...@@ -539,6 +539,8 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel): ...@@ -539,6 +539,8 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = ["cache"] _skip_keys_device_placement = ["cache"]
_supports_flash_attn_2 = False _supports_flash_attn_2 = False
_supports_sdpa = False # we can't compare with eager for now _supports_sdpa = False # we can't compare with eager for now
_supports_cache_class = True
_supports_quantized_cache = True
def _init_weights(self, module): def _init_weights(self, module):
std = math.sqrt(self.config.w_init_variance_scale / self.config.conv1d_width) std = math.sqrt(self.config.w_init_variance_scale / self.config.conv1d_width)
......
...@@ -832,6 +832,7 @@ class StableLmPreTrainedModel(PreTrainedModel): ...@@ -832,6 +832,7 @@ class StableLmPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2 = True _supports_flash_attn_2 = True
_supports_cache_class = True _supports_cache_class = True
_supports_sdpa = True _supports_sdpa = True
_supports_quantized_cache = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.initializer_range std = self.config.initializer_range
......
...@@ -784,11 +784,11 @@ class WhisperGenerationMixin: ...@@ -784,11 +784,11 @@ class WhisperGenerationMixin:
del generate_kwargs[key] del generate_kwargs[key]
seek_outputs = super().generate( seek_outputs = super().generate(
segment_input, segment_input,
generation_config, generation_config=generation_config,
logits_processor, logits_processor=logits_processor,
stopping_criteria, stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn, prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus, synced_gpus=synced_gpus,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
**generate_kwargs, **generate_kwargs,
) )
......
...@@ -23,6 +23,13 @@ class Cache(metaclass=DummyObject): ...@@ -23,6 +23,13 @@ class Cache(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class CacheConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DynamicCache(metaclass=DummyObject): class DynamicCache(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -30,6 +37,34 @@ class DynamicCache(metaclass=DummyObject): ...@@ -30,6 +37,34 @@ class DynamicCache(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class HQQQuantizedCache(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class QuantizedCache(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class QuantizedCacheConfig(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class QuantoQuantizedCache(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class SinkCache(metaclass=DummyObject): class SinkCache(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -27,6 +27,7 @@ from transformers import is_torch_available, pipeline, set_seed ...@@ -27,6 +27,7 @@ from transformers import is_torch_available, pipeline, set_seed
from transformers.testing_utils import ( from transformers.testing_utils import (
is_flaky, is_flaky,
require_accelerate, require_accelerate,
require_quanto,
require_torch, require_torch,
require_torch_multi_accelerator, require_torch_multi_accelerator,
slow, slow,
...@@ -55,7 +56,7 @@ if is_torch_available(): ...@@ -55,7 +56,7 @@ if is_torch_available():
ImageGPTForCausalImageModeling, ImageGPTForCausalImageModeling,
SpeechEncoderDecoderModel, SpeechEncoderDecoderModel,
) )
from transformers.cache_utils import DynamicCache from transformers.cache_utils import DynamicCache, QuantoQuantizedCache
from transformers.generation import ( from transformers.generation import (
BeamSampleDecoderOnlyOutput, BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput, BeamSampleEncoderDecoderOutput,
...@@ -1654,6 +1655,39 @@ class GenerationTesterMixin: ...@@ -1654,6 +1655,39 @@ class GenerationTesterMixin:
) )
) )
@require_quanto
def test_generate_with_quant_cache(self):
for model_class in self.all_generative_model_classes:
if not model_class._supports_quantized_cache:
self.skipTest("This model does not support the quantized cache format")
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.use_cache = True
config.is_decoder = True
model = model_class(config).to(torch_device).eval()
generation_kwargs = {
"max_new_tokens": 5,
"cache_implementation": "quantized",
# careful with group size, should be divisor of model's hidden size
"cache_config": {"backend": "quanto", "nbits": 2, "q_group_size": 8, "residual_length": 128},
"return_dict_in_generate": True, # Required to return `past_key_values`
}
results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
self.assertTrue(isinstance(results.past_key_values, QuantoQuantizedCache))
# passing past key values of different type should raise Error
with self.assertRaises(ValueError):
model.generate(
input_ids, attention_mask=attention_mask, past_key_valyes=DynamicCache(), **generation_kwargs
)
# setting incorrect cache_config args should raise an Error, i.e. nbits=60 does not make sense
generation_kwargs["cache_config"] = {"nbits": 60, "q_group_size": 8, "residual_length": 128}
with self.assertRaises(ValueError):
model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1): def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
batch_size, seq_length = input_ids.shape batch_size, seq_length = input_ids.shape
num_sequences_in_output = batch_size * num_return_sequences num_sequences_in_output = batch_size * num_return_sequences
......
...@@ -17,13 +17,22 @@ import tempfile ...@@ -17,13 +17,22 @@ import tempfile
import unittest import unittest
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, QuantoConfig from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, QuantoConfig
from transformers.testing_utils import require_accelerate, require_quanto, require_torch_gpu, slow from transformers.testing_utils import (
require_accelerate,
require_quanto,
require_read_token,
require_torch_gpu,
slow,
torch_device,
)
from transformers.utils import is_accelerate_available, is_quanto_available, is_torch_available from transformers.utils import is_accelerate_available, is_quanto_available, is_torch_available
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
if is_accelerate_available(): if is_accelerate_available():
from accelerate import init_empty_weights from accelerate import init_empty_weights
...@@ -429,3 +438,28 @@ class QuantoQuantizationActivationTest(unittest.TestCase): ...@@ -429,3 +438,28 @@ class QuantoQuantizationActivationTest(unittest.TestCase):
with self.assertRaises(ValueError) as e: with self.assertRaises(ValueError) as e:
AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", quantization_config=quantization_config) AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", quantization_config=quantization_config)
self.assertIn("We don't support quantizing the activations with transformers library", str(e.exception)) self.assertIn("We don't support quantizing the activations with transformers library", str(e.exception))
@require_torch_gpu
class QuantoKVCacheQuantizationTest(unittest.TestCase):
@slow
@require_read_token
def test_quantized_cache(self):
EXPECTED_TEXT_COMPLETION = [
"Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory of relativity",
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my burgers, my hot dogs, my sandwiches, my chicken, my pizza, my sal",
]
prompts = [
"Simply put, the theory of relativity states that ",
"My favorite all time favorite condiment is ketchup.",
]
tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token="</s>", padding_side="left")
model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf", device_map="sequential", torch_dtype=torch.float16
)
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(torch_device)
generated_ids = model.generate(**inputs, max_new_tokens=40, do_sample=False, cache_implementation="quantized")
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, text)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment