Commit c2cd6535 authored by Casper Hansen's avatar Casper Hansen
Browse files

Roll cache efficiently and correctly

parent 0392a823
...@@ -166,21 +166,14 @@ class QuantAttentionFused(nn.Module): ...@@ -166,21 +166,14 @@ class QuantAttentionFused(nn.Module):
) )
if self.start_pos > self.max_seq_len or self.start_pos + seqlen > self.max_seq_len: if self.start_pos > self.max_seq_len or self.start_pos + seqlen > self.max_seq_len:
logging.warning('You have exceeded max_new_tokens, resetting cache...') # Roll cache to the left
self._initialize_cache(hidden_states.device) roll_len = self.start_pos
self.cache_v = torch.roll(self.cache_v, shifts=-roll_len, dims=2)
self.cache_k = torch.roll(self.cache_k, shifts=-roll_len, dims=3)
# Zero out the new part
self.cache_v[:, :, -roll_len:, :] = 0
self.cache_k[:, :, :, -roll_len:, :] = 0
self.start_pos = 0 self.start_pos = 0
elif seqlen > self.max_seq_len:
memory_used = compute_memory_used_pct(hidden_states.device)
if memory_used <= 80:
logging.warning('Input sequence length > max_seq_len, increasing and resetting cache...')
self.max_seq_len += self.max_seq_len
self.attention_shapes = self._get_attention_shapes(None, self.max_seq_len)
self._initialize_cache(hidden_states.device)
self.start_pos = 0
else:
logging.error('Input sequence length > max_seq_len, memory is filled, exiting...')
xqkv = self.qkv_proj(hidden_states) xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"]) xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment