Commit 0392a823 authored by Casper Hansen's avatar Casper Hansen
Browse files

Fix edge case

parent 3bb4a9f6
......@@ -5,6 +5,7 @@ import logging
import torch.nn as nn
import awq_inference_engine
from torch.nn import functional as F
from awq.utils.utils import compute_memory_used_pct
try:
import ft_inference_engine
......@@ -164,16 +165,22 @@ class QuantAttentionFused(nn.Module):
f"Use: AutoAWQForCausalLM.from_quantized(batch_size={bsz})"
)
if self.start_pos > self.max_seq_len:
if self.start_pos > self.max_seq_len or self.start_pos + seqlen > self.max_seq_len:
logging.warning('You have exceeded max_new_tokens, resetting cache...')
self._initialize_cache(hidden_states.device)
self.start_pos = 0
elif seqlen > self.max_seq_len:
logging.warning('Sequence length > max_seq_len, increasing and resetting cache...')
self.max_seq_len *= 2
self._initialize_cache(hidden_states.device)
self.start_pos = 0
memory_used = compute_memory_used_pct(hidden_states.device)
if memory_used <= 80:
logging.warning('Input sequence length > max_seq_len, increasing and resetting cache...')
self.max_seq_len += self.max_seq_len
self.attention_shapes = self._get_attention_shapes(None, self.max_seq_len)
self._initialize_cache(hidden_states.device)
self.start_pos = 0
else:
logging.error('Input sequence length > max_seq_len, memory is filled, exiting...')
xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
......@@ -199,9 +206,16 @@ class QuantAttentionFused(nn.Module):
.permute(0, 2, 3, 1, 4)
.contiguous()
)
self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store
self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store
try:
self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store
self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store
except Exception as ex:
print(seqlen, self.max_seq_len)
print(self.cache_v.shape, self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :].shape, values_store.shape)
print(self.cache_k.shape, self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :].shape, keys_store.shape)
print(ex)
exit(0)
if seqlen == 1:
xv = self.cache_v[:bsz, :, : self.start_pos + seqlen, :].transpose(1, 2).contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment