"src/libtorchaudio/sox/utils.cpp" did not exist on "d850ff61643c00b6517b00011f4d52e1bc3897d2"
Commit 202c6d6a authored by Casper Hansen's avatar Casper Hansen
Browse files

Automatically reset/increase cache

parent 8eb26eb2
import os
import math
import torch
import logging
import torch.nn as nn
import awq_inference_engine
from torch.nn import functional as F
......@@ -80,12 +81,40 @@ class QuantAttentionFused(nn.Module):
self.start_pos = 0
self.use_alibi = use_alibi
self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
self.max_seq_len = max_seq_len
self.attention_shapes = self._get_attention_shapes(attention_shapes, max_seq_len)
self._initialize_cache(dev)
if use_alibi:
alibi_slopes, alibi_bias = build_alibi_bias(self.n_heads, max_seq_len)
self.alibi_slopes = alibi_slopes.float().to(dev)
self.alibi_bias = alibi_bias.float().to(dev)
self.rotary_dim = 0
self.is_neox = False
else:
self.freqs_cis = precompute_freqs_cis(
hidden_size // n_heads,
max_seq_len * 2,
).to(dev)
self.rotary_dim = self.head_dim
self.alibi_slopes = None
self.is_neox = True
def _initialize_cache(self, dev):
self.cache_v = (
torch.zeros(self.attention_shapes["cache_v"]).to(dev).half()
)
self.cache_k = (
torch.zeros(self.attention_shapes["cache_k"]).to(dev).half()
)
def _get_attention_shapes(self, attention_shapes, max_seq_len):
if attention_shapes is not None:
self.attention_shapes = attention_shapes
attention_shapes = attention_shapes
elif self.n_kv_heads == 0:
self.attention_shapes = {
attention_shapes = {
# following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
......@@ -104,7 +133,7 @@ class QuantAttentionFused(nn.Module):
}
else:
self.attention_shapes = {
attention_shapes = {
# following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_kv_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
......@@ -121,33 +150,12 @@ class QuantAttentionFused(nn.Module):
"single_xk_view": (self.n_kv_heads, self.head_dim),
"single_xv_view": (self.n_kv_heads, self.head_dim)
}
self.cache_v = (
torch.zeros(self.attention_shapes["cache_v"]).to(dev).half()
)
self.cache_k = (
torch.zeros(self.attention_shapes["cache_k"]).to(dev).half()
)
if use_alibi:
alibi_slopes, alibi_bias = build_alibi_bias(self.n_heads, max_seq_len)
self.alibi_slopes = alibi_slopes.float().to(dev)
self.alibi_bias = alibi_bias.float().to(dev)
self.rotary_dim = 0
self.is_neox = False
else:
self.freqs_cis = precompute_freqs_cis(
hidden_size // n_heads,
max_seq_len * 2,
).to(dev)
self.rotary_dim = self.head_dim
self.alibi_slopes = None
self.is_neox = True
return attention_shapes
def forward(
self,
hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False
hidden_states:torch.Tensor, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False
):
bsz, seqlen, _ = hidden_states.shape
if bsz != self.cache_batch_size:
......@@ -155,6 +163,18 @@ class QuantAttentionFused(nn.Module):
f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. "
f"Use: AutoAWQForCausalLM.from_quantized(batch_size={bsz})"
)
if self.start_pos > self.max_seq_len:
logging.warning('You have exceeded max_new_tokens, resetting cache...')
self._initialize_cache(hidden_states.device)
self.start_pos = 0
elif seqlen > self.max_seq_len:
logging.warning('Sequence length > max_seq_len, increasing and resetting cache...')
self.max_seq_len *= 2
self._initialize_cache(hidden_states.device)
self.start_pos = 0
xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment