Unverified Commit 9719202d authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: fix `SinkCache` on Llama models (#30581)

parent 66abe139
...@@ -207,7 +207,9 @@ class SinkCache(Cache): ...@@ -207,7 +207,9 @@ class SinkCache(Cache):
self.value_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = []
self.window_length = window_length self.window_length = window_length
self.num_sink_tokens = num_sink_tokens self.num_sink_tokens = num_sink_tokens
self.cos_sin_cache = {} self.cos_sin_rerotation_cache = {}
self._cos_cache = None
self._sin_cache = None
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
@staticmethod @staticmethod
...@@ -225,7 +227,7 @@ class SinkCache(Cache): ...@@ -225,7 +227,7 @@ class SinkCache(Cache):
def _get_rerotation_cos_sin( def _get_rerotation_cos_sin(
self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if key_states.shape[-2] not in self.cos_sin_cache: if key_states.shape[-2] not in self.cos_sin_rerotation_cache:
# Upcast to float32 temporarily for better accuracy # Upcast to float32 temporarily for better accuracy
cos = cos.to(torch.float32) cos = cos.to(torch.float32)
sin = sin.to(torch.float32) sin = sin.to(torch.float32)
...@@ -238,11 +240,11 @@ class SinkCache(Cache): ...@@ -238,11 +240,11 @@ class SinkCache(Cache):
rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
self.cos_sin_cache[key_states.shape[-2]] = ( self.cos_sin_rerotation_cache[key_states.shape[-2]] = (
rerotation_cos.to(key_states.dtype).unsqueeze(0), rerotation_cos.to(key_states.dtype).unsqueeze(0),
rerotation_sin.to(key_states.dtype).unsqueeze(0), rerotation_sin.to(key_states.dtype).unsqueeze(0),
) )
return self.cos_sin_cache[key_states.shape[-2]] return self.cos_sin_rerotation_cache[key_states.shape[-2]]
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" """Returns the sequence length of the cached states. A layer index can be optionally passed."""
...@@ -292,6 +294,21 @@ class SinkCache(Cache): ...@@ -292,6 +294,21 @@ class SinkCache(Cache):
if layer_idx == 0: if layer_idx == 0:
self._seen_tokens += key_states.shape[-2] self._seen_tokens += key_states.shape[-2]
# Update the sin/cos cache, which holds sin/cos values for all possible positions
if using_rope and layer_idx == 0:
# BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
# after all RoPE models have a llama-like cache utilization.
if cos.dim() == 2:
self._cos_cache = cos
self._sin_cache = sin
else:
if self._cos_cache is None:
self._cos_cache = cos[0, ...]
self._sin_cache = sin[0, ...]
elif self._cos_cache.shape[0] < self.window_length:
self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0)
self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0)
# [bsz, num_heads, seq_len, head_dim] # [bsz, num_heads, seq_len, head_dim]
if len(self.key_cache) <= layer_idx: if len(self.key_cache) <= layer_idx:
# Empty cache # Empty cache
...@@ -312,7 +329,7 @@ class SinkCache(Cache): ...@@ -312,7 +329,7 @@ class SinkCache(Cache):
# On RoPE models, we need to recompute the Key rotation as the tokens are shifted # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
if using_rope: if using_rope:
rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin( rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
key_states, cos[: self.window_length], sin[: self.window_length] key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length]
) )
if partial_rotation_size is not None: if partial_rotation_size is not None:
keys_to_keep, keys_pass = ( keys_to_keep, keys_pass = (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment