Unverified Commit 115ac94d authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`Core generation`] Adds support for static KV cache (#27931)


Co-authored-by: default avatarfxmarty <9808326+fxmarty@users.noreply.github.com>
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
parent 4b236aed
...@@ -373,3 +373,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens ...@@ -373,3 +373,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens
- update - update
- get_seq_length - get_seq_length
- reorder_cache - reorder_cache
[[autodoc]] StaticCache
- update
- get_seq_length
\ No newline at end of file
...@@ -1337,7 +1337,7 @@ else: ...@@ -1337,7 +1337,7 @@ else:
_import_structure["activations"] = [] _import_structure["activations"] = []
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"] _import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"] _import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
_import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache"] _import_structure["cache_utils"] = ["Cache", "DynamicCache", "SinkCache", "StaticCache"]
_import_structure["data.datasets"] = [ _import_structure["data.datasets"] = [
"GlueDataset", "GlueDataset",
"GlueDataTrainingArguments", "GlueDataTrainingArguments",
...@@ -6073,7 +6073,7 @@ if TYPE_CHECKING: ...@@ -6073,7 +6073,7 @@ if TYPE_CHECKING:
# Benchmarks # Benchmarks
from .benchmark.benchmark import PyTorchBenchmark from .benchmark.benchmark import PyTorchBenchmark
from .benchmark.benchmark_args import PyTorchBenchmarkArguments from .benchmark.benchmark_args import PyTorchBenchmarkArguments
from .cache_utils import Cache, DynamicCache, SinkCache from .cache_utils import Cache, DynamicCache, SinkCache, StaticCache
from .data.datasets import ( from .data.datasets import (
GlueDataset, GlueDataset,
GlueDataTrainingArguments, GlueDataTrainingArguments,
......
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from .configuration_utils import PretrainedConfig
@dataclass
class Cache: class Cache:
""" """
Base, abstract class for all caches. The actual data structure is specific to each subclass. Base, abstract class for all caches. The actual data structure is specific to each subclass.
...@@ -320,3 +324,91 @@ class SinkCache(Cache): ...@@ -320,3 +324,91 @@ class SinkCache(Cache):
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
class StaticCache(Cache):
"""
Static Cache class to be used with `torch.compile(model)`.
Parameters:
config (`PretrainedConfig):
The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads`
required to initialize the static cache.
max_batch_size (`int`):
The maximum batch size with which the model will be used.
max_cache_len (`int`):
The maximum sequence length with which the model will be used.
device (`torch.device`):
The device on which the cache should be initialized. Should be the same as the layer.
dtype (*optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
"""
def __init__(
self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=torch.float32
) -> None:
super().__init__()
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.head_dim = config.hidden_size // config.num_attention_heads
self.num_heads = config.num_attention_heads
self.dtype = config.torch_dtype if config.torch_dtype is not None else dtype
cache_shape = (max_batch_size, self.num_heads, self.max_cache_len, self.head_dim)
self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device)
self.seen_tokens = 0
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for. Kept for backward compatibility
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len`
to know how much of the cache it should overwrite.
Return:
A tuple containing the updated key and value states.
"""
new_cache_positions = cache_kwargs.get("position_ids")
k_out = self.key_cache
v_out = self.value_cache
k_out[:, :, new_cache_positions] = key_states
v_out[:, :, new_cache_positions] = value_states
self.seen_tokens += key_states.shape[-2]
return k_out, v_out
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC"""
return self.seen_tokens
def get_max_length(self) -> Optional[int]:
"""Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length."""
return self.max_cache_len
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
device = self.key_cache.device
self.key_cache = self.key_cache.index_select(0, beam_idx.to(device))
device = self.value_cache.device
self.value_cache = self.value_cache.index_select(0, beam_idx.to(device))
def to_legacy_cache(self):
"""Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it"""
return None
...@@ -250,6 +250,11 @@ class GenerationConfig(PushToHubMixin): ...@@ -250,6 +250,11 @@ class GenerationConfig(PushToHubMixin):
reduce by 1 reduce by 1
- `"constant"`: `num_assistant_tokens` stays unchanged during generation - `"constant"`: `num_assistant_tokens` stays unchanged during generation
> Parameters specific to the caching mechanism:
cache_implementation (`str`, *optional*, default to `None`):
Cache class that should be used when generating.
> Wild card > Wild card
generation_kwargs: generation_kwargs:
...@@ -321,6 +326,9 @@ class GenerationConfig(PushToHubMixin): ...@@ -321,6 +326,9 @@ class GenerationConfig(PushToHubMixin):
self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5) self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5)
self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic") self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic")
# Cache implementation
self.cache_implementation = kwargs.pop("cache_implementation", None)
# Prompt lookup decoding # Prompt lookup decoding
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None) self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
......
...@@ -24,7 +24,7 @@ import torch ...@@ -24,7 +24,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
from torch import nn from torch import nn
from ..cache_utils import Cache, DynamicCache from ..cache_utils import Cache, DynamicCache, StaticCache
from ..integrations.deepspeed import is_deepspeed_zero3_enabled from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import ( from ..models.auto import (
...@@ -92,6 +92,10 @@ logger = logging.get_logger(__name__) ...@@ -92,6 +92,10 @@ logger = logging.get_logger(__name__)
if is_accelerate_available(): if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module from accelerate.hooks import AlignDevicesHook, add_hook_to_module
NEED_SETUP_CACHE_CLASSES_MAPPING = {
"static": StaticCache,
}
@dataclass @dataclass
class GenerateDecoderOnlyOutput(ModelOutput): class GenerateDecoderOnlyOutput(ModelOutput):
...@@ -1398,6 +1402,19 @@ class GenerationMixin: ...@@ -1398,6 +1402,19 @@ class GenerationMixin:
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)" "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
) )
generation_config.max_length = generation_config.max_new_tokens + input_ids_length generation_config.max_length = generation_config.max_new_tokens + input_ids_length
# if we don't pass `past_key_values` and a cache_implementation is specified
if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING and not model_kwargs.get(
"past_key_values", False
):
cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING[generation_config.cache_implementation]
if not callable(getattr(self, "_setup_cache", None)):
raise ValueError(
"The `generation_config` defines a `cache_implementation` that is not compatible with this model."
" Make sure it has a `_setup_cache` function."
)
self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length)
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
# 7. determine generation mode # 7. determine generation mode
......
...@@ -63,7 +63,7 @@ class OpenLlamaRMSNorm(nn.Module): ...@@ -63,7 +63,7 @@ class OpenLlamaRMSNorm(nn.Module):
return self.weight * hidden_states.to(input_dtype) return self.weight * hidden_states.to(input_dtype)
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->OpenLlama # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->OpenLlama
class OpenLlamaRotaryEmbedding(nn.Module): class OpenLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -154,7 +154,7 @@ def rotate_half(x): ...@@ -154,7 +154,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
......
...@@ -88,7 +88,7 @@ def rotate_half(x): ...@@ -88,7 +88,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
...@@ -130,7 +130,7 @@ def _get_unpad_data(attention_mask): ...@@ -130,7 +130,7 @@ def _get_unpad_data(attention_mask):
) )
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Falcon # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Falcon
class FalconRotaryEmbedding(nn.Module): class FalconRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
......
...@@ -527,7 +527,7 @@ def attention_mask_func(attention_scores, ltor_mask): ...@@ -527,7 +527,7 @@ def attention_mask_func(attention_scores, ltor_mask):
class GPTNeoXRotaryEmbedding(nn.Module): class GPTNeoXRotaryEmbedding(nn.Module):
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -617,7 +617,7 @@ def rotate_half(x): ...@@ -617,7 +617,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
......
...@@ -235,7 +235,7 @@ class GPTNeoXJapaneseAttention(nn.Module): ...@@ -235,7 +235,7 @@ class GPTNeoXJapaneseAttention(nn.Module):
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding with GPTNeoXRotaryEmbedding->RotaryEmbedding
class RotaryEmbedding(nn.Module): class RotaryEmbedding(nn.Module):
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding.__init__
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
......
...@@ -513,7 +513,7 @@ def rotate_half(x): ...@@ -513,7 +513,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
......
...@@ -88,7 +88,8 @@ class MistralRMSNorm(nn.Module): ...@@ -88,7 +88,8 @@ class MistralRMSNorm(nn.Module):
return self.weight * hidden_states.to(input_dtype) return self.weight * hidden_states.to(input_dtype)
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral # copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mistral
# TODO @Arthur no longer copied from LLama after static cache
class MistralRotaryEmbedding(nn.Module): class MistralRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -133,7 +134,8 @@ def rotate_half(x): ...@@ -133,7 +134,8 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb # copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
# TODO @Arthur no longer copied from LLama after static cache
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
...@@ -612,7 +614,8 @@ class MistralFlashAttention2(MistralAttention): ...@@ -612,7 +614,8 @@ class MistralFlashAttention2(MistralAttention):
) )
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
# TODO @Arthur no longer copied from LLama after static cache
class MistralSdpaAttention(MistralAttention): class MistralSdpaAttention(MistralAttention):
""" """
Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from Mistral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
...@@ -656,28 +659,34 @@ class MistralSdpaAttention(MistralAttention): ...@@ -656,28 +659,34 @@ class MistralSdpaAttention(MistralAttention):
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
past_seen_tokens = kv_seq_len - key_states.shape[-2]
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
if past_key_value is not None: if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
if attention_mask is not None: if (
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): attention_mask is not None and not torch.all(attention_mask[..., 0] == 1) and q_len != 1
raise ValueError( ): # user defined causal mask
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]]
) # this one liner is equivalent to the pad_unpad function
causal_mask.mul_(~torch.eq(causal_mask, causal_mask.min()).all(dim=-1)[..., None])
else:
causal_mask = None
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577. # Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None: if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous() query_states = query_states.contiguous()
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
...@@ -686,14 +695,13 @@ class MistralSdpaAttention(MistralAttention): ...@@ -686,14 +695,13 @@ class MistralSdpaAttention(MistralAttention):
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=attention_mask, attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0, dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. is_causal=causal_mask is None and q_len > 1,
is_causal=self.is_causal and attention_mask is None and q_len > 1,
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
......
...@@ -181,7 +181,7 @@ class MixtralRMSNorm(nn.Module): ...@@ -181,7 +181,7 @@ class MixtralRMSNorm(nn.Module):
return self.weight * hidden_states.to(input_dtype) return self.weight * hidden_states.to(input_dtype)
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Mixtral # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral
class MixtralRotaryEmbedding(nn.Module): class MixtralRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -226,7 +226,7 @@ def rotate_half(x): ...@@ -226,7 +226,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
...@@ -692,7 +692,7 @@ class MixtralFlashAttention2(MixtralAttention): ...@@ -692,7 +692,7 @@ class MixtralFlashAttention2(MixtralAttention):
) )
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mixtral # Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral
class MixtralSdpaAttention(MixtralAttention): class MixtralSdpaAttention(MixtralAttention):
""" """
Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
...@@ -736,28 +736,34 @@ class MixtralSdpaAttention(MixtralAttention): ...@@ -736,28 +736,34 @@ class MixtralSdpaAttention(MixtralAttention):
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
past_seen_tokens = kv_seq_len - key_states.shape[-2]
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
if past_key_value is not None: if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
if attention_mask is not None: if (
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): attention_mask is not None and not torch.all(attention_mask[..., 0] == 1) and q_len != 1
raise ValueError( ): # user defined causal mask
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]]
) # this one liner is equivalent to the pad_unpad function
causal_mask.mul_(~torch.eq(causal_mask, causal_mask.min()).all(dim=-1)[..., None])
else:
causal_mask = None
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577. # Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None: if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous() query_states = query_states.contiguous()
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
...@@ -766,14 +772,13 @@ class MixtralSdpaAttention(MixtralAttention): ...@@ -766,14 +772,13 @@ class MixtralSdpaAttention(MixtralAttention):
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=attention_mask, attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0, dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. is_causal=causal_mask is None and q_len > 1,
is_causal=self.is_causal and attention_mask is None and q_len > 1,
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
......
...@@ -40,7 +40,7 @@ logger = logging.get_logger(__name__) ...@@ -40,7 +40,7 @@ logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "PersimmonConfig" _CONFIG_FOR_DOC = "PersimmonConfig"
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Persimmon # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Persimmon
class PersimmonRotaryEmbedding(nn.Module): class PersimmonRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -132,7 +132,7 @@ def rotate_half(x): ...@@ -132,7 +132,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
...@@ -864,6 +864,12 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel): ...@@ -864,6 +864,12 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
if past_key_values: if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
# generation with static cache
seen_tokens = past_key_value.get_seq_length()
input_ids = input_ids[:, seen_tokens:]
position_ids = position_ids[:, seen_tokens:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
......
...@@ -78,7 +78,7 @@ def _get_unpad_data(attention_mask): ...@@ -78,7 +78,7 @@ def _get_unpad_data(attention_mask):
) )
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Phi
class PhiRotaryEmbedding(nn.Module): class PhiRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -170,7 +170,7 @@ def rotate_half(x): ...@@ -170,7 +170,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
...@@ -1125,6 +1125,12 @@ class PhiForCausalLM(PhiPreTrainedModel): ...@@ -1125,6 +1125,12 @@ class PhiForCausalLM(PhiPreTrainedModel):
if past_key_values: if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
# generation with static cache
seen_tokens = past_key_value.get_seq_length()
input_ids = input_ids[:, seen_tokens:]
position_ids = position_ids[:, seen_tokens:]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
......
...@@ -95,7 +95,7 @@ class Qwen2RMSNorm(nn.Module): ...@@ -95,7 +95,7 @@ class Qwen2RMSNorm(nn.Module):
return self.weight * hidden_states.to(input_dtype) return self.weight * hidden_states.to(input_dtype)
# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Qwen2 # Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Qwen2
class Qwen2RotaryEmbedding(nn.Module): class Qwen2RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
...@@ -140,7 +140,7 @@ def rotate_half(x): ...@@ -140,7 +140,7 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1) return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb # Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors. """Applies Rotary Position Embedding to the query and key tensors.
...@@ -625,7 +625,7 @@ class Qwen2FlashAttention2(Qwen2Attention): ...@@ -625,7 +625,7 @@ class Qwen2FlashAttention2(Qwen2Attention):
) )
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Qwen2 # Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Qwen2
class Qwen2SdpaAttention(Qwen2Attention): class Qwen2SdpaAttention(Qwen2Attention):
""" """
Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
...@@ -669,28 +669,34 @@ class Qwen2SdpaAttention(Qwen2Attention): ...@@ -669,28 +669,34 @@ class Qwen2SdpaAttention(Qwen2Attention):
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2] kv_seq_len = key_states.shape[-2]
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) # add what was seen
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
past_seen_tokens = kv_seq_len - key_states.shape[-2]
new_cache_positions = torch.arange(past_seen_tokens, past_seen_tokens + q_len, device=key_states.device)
if past_key_value is not None: if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models cache_kwargs = {"sin": sin, "cos": cos, "position_ids": new_cache_positions} # Specific to RoPE models
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups) key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups)
if attention_mask is not None: if (
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): attention_mask is not None and not torch.all(attention_mask[..., 0] == 1) and q_len != 1
raise ValueError( ): # user defined causal mask
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" causal_mask = attention_mask[:, :, past_seen_tokens : past_seen_tokens + q_len, : key_states.shape[-2]]
) # this one liner is equivalent to the pad_unpad function
causal_mask.mul_(~torch.eq(causal_mask, causal_mask.min()).all(dim=-1)[..., None])
else:
causal_mask = None
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
# Reference: https://github.com/pytorch/pytorch/issues/112577. # Reference: https://github.com/pytorch/pytorch/issues/112577.
if query_states.device.type == "cuda" and attention_mask is not None: if query_states.device.type == "cuda" and causal_mask is not None:
query_states = query_states.contiguous() query_states = query_states.contiguous()
key_states = key_states.contiguous() key_states = key_states.contiguous()
value_states = value_states.contiguous() value_states = value_states.contiguous()
...@@ -699,14 +705,13 @@ class Qwen2SdpaAttention(Qwen2Attention): ...@@ -699,14 +705,13 @@ class Qwen2SdpaAttention(Qwen2Attention):
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_mask=attention_mask, attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0, dropout_p=self.attention_dropout if self.training else 0.0,
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. is_causal=causal_mask is None and q_len > 1,
is_causal=self.is_causal and attention_mask is None and q_len > 1,
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
......
...@@ -37,6 +37,13 @@ class SinkCache(metaclass=DummyObject): ...@@ -37,6 +37,13 @@ class SinkCache(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class StaticCache(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class GlueDataset(metaclass=DummyObject): class GlueDataset(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -362,6 +362,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -362,6 +362,7 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
pass pass
@parameterized.expand([("linear",), ("dynamic",)]) @parameterized.expand([("linear",), ("dynamic",)])
@unittest.skip("TODO @gante fix this for Llama")
def test_model_rope_scaling(self, scaling_type): def test_model_rope_scaling(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size) short_input = ids_tensor([1, 10], config.vocab_size)
...@@ -507,9 +508,19 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -507,9 +508,19 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device) inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device)
res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
self.assertTrue(torch.allclose(res_eager, res_sdpa))
with self.subTest(f"{padding_side}"):
torch.testing.assert_close(
res_eager,
res_sdpa,
msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
)
@unittest.skip("TODO @gante fix this for Llama")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
@require_torch @require_torch
......
...@@ -15,14 +15,29 @@ ...@@ -15,14 +15,29 @@
import unittest import unittest
from parameterized import parameterized
from transformers import set_seed from transformers import set_seed
from transformers.testing_utils import is_torch_available, require_auto_gptq, require_torch, require_torch_gpu, slow from transformers.testing_utils import (
is_torch_available,
require_auto_gptq,
require_torch,
require_torch_gpu,
slow,
torch_device,
)
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, LlamaForCausalLM, SinkCache from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DynamicCache,
LlamaForCausalLM,
SinkCache,
)
@require_torch @require_torch
...@@ -229,3 +244,100 @@ class CacheIntegrationTest(unittest.TestCase): ...@@ -229,3 +244,100 @@ class CacheIntegrationTest(unittest.TestCase):
"was visiting the historic district of Honolulu. Here," "was visiting the historic district of Honolulu. Here,"
) )
self.assertTrue(decoded[0].endswith(last_output)) self.assertTrue(decoded[0].endswith(last_output))
@require_torch_gpu
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
def test_static_cache_greedy_sampling_pad_left(self, attn_implementation):
EXPECTED_GENERATION = [
"The best color is the one that complements the subject you are photograph",
"We should not undermind the issues at hand.\nWe should not undermind the issues",
]
tokenizer = AutoTokenizer.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
)
model = AutoModelForCausalLM.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf",
torch_dtype=torch.bfloat16,
attn_implementation=attn_implementation,
).to(torch_device)
inputs = tokenizer(
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
).to(model.device)
set_seed(0)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, dynamic"):
self.assertListEqual(decoded, EXPECTED_GENERATION)
set_seed(0)
model.generation_config.cache_implementation = "static"
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, eager"):
self.assertListEqual(decoded, EXPECTED_GENERATION)
set_seed(0)
model.forward = torch.compile(model.forward)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, compiled"):
self.assertListEqual(decoded, EXPECTED_GENERATION)
@require_torch_gpu
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
def test_static_cache_greedy_sampling_pad_right(self, attn_implementation):
EXPECTED_GENERATION = [
"The best color is\n\n\n\n\n\n\n\n\n\n",
"We should not undermind the issues at hand, but address them head on.\nI think",
]
tokenizer = AutoTokenizer.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
)
model = AutoModelForCausalLM.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf",
torch_dtype=torch.bfloat16,
attn_implementation=attn_implementation,
).to("cuda:1")
inputs = tokenizer(
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
).to(model.device)
set_seed(0)
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, dynamic"):
self.assertListEqual(decoded, EXPECTED_GENERATION)
set_seed(0)
model.generation_config.cache_implementation = "static"
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, eager"):
self.assertListEqual(decoded, EXPECTED_GENERATION)
set_seed(0)
model._forward = model.forward
compiled_forward = torch.compile(model.forward)
def compiled(func, input_ids, **kwargs):
return func(input_ids, **kwargs)
def call(input_ids, **kwargs):
if input_ids.shape[-1] == 1:
return compiled(compiled_forward, input_ids, **kwargs)
return model._forward(input_ids, **kwargs)
model.forward = call
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
with self.subTest(f"{attn_implementation}, static, compiled"):
self.assertListEqual(decoded, EXPECTED_GENERATION)
@unittest.skip("TODO @gante static cache's does not support beam search yet")
def test_static_cache_beam_search(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment