Unverified Commit df909e83 authored by Casper's avatar Casper Committed by GitHub
Browse files

Reset cache on new generation (#178)

parent 5db86ec5
...@@ -128,15 +128,6 @@ class QuantAttentionFused(nn.Module): ...@@ -128,15 +128,6 @@ class QuantAttentionFused(nn.Module):
f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. " f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. "
f"Use: AutoAWQForCausalLM.from_quantized(batch_size={bsz})" f"Use: AutoAWQForCausalLM.from_quantized(batch_size={bsz})"
) )
will_cache_be_exceeded = self.start_pos + seqlen > self.max_seq_len
# Reset and avoid retaining state when processing context
if will_cache_be_exceeded and seqlen > 1:
self.start_pos = self.cache.roll_kv_n_steps(self.start_pos, n=self.start_pos)
# Slowly roll out old tokens without performance hit if exceeded during decoding
elif will_cache_be_exceeded and seqlen == 1:
self.start_pos = self.cache.roll_kv_n_steps(self.start_pos, n=100)
xqkv = self.qkv_proj(hidden_states) xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"]) xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
......
...@@ -2,8 +2,8 @@ import torch ...@@ -2,8 +2,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from typing import List from typing import List
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from awq.utils.fused_utils import prepare_attention_mask, prepare_input_ids
from awq.modules.fused.block import MPTBlock, FalconDecoderLayer, LlamaLikeBlock from awq.modules.fused.block import MPTBlock, FalconDecoderLayer, LlamaLikeBlock
from awq.utils.fused_utils import prepare_attention_mask, prepare_input_ids, prepare_cache
class LlamaLikeModel(nn.Module): class LlamaLikeModel(nn.Module):
""" """
...@@ -24,8 +24,10 @@ class LlamaLikeModel(nn.Module): ...@@ -24,8 +24,10 @@ class LlamaLikeModel(nn.Module):
input_ids, input_ids,
self.last_forward_num_tokens self.last_forward_num_tokens
) )
_bsz, seqlen = input_ids.shape _bsz, seqlen = input_ids.shape
prepare_cache(self.blocks, seqlen)
h = self.embedding(input_ids) h = self.embedding(input_ids)
mask = prepare_attention_mask( mask = prepare_attention_mask(
...@@ -58,8 +60,10 @@ class MPTModel(nn.Module): ...@@ -58,8 +60,10 @@ class MPTModel(nn.Module):
input_ids, input_ids,
self.last_forward_num_tokens self.last_forward_num_tokens
) )
_bsz, seqlen = input_ids.shape _bsz, seqlen = input_ids.shape
prepare_cache(self.blocks, seqlen)
h = self.wte(input_ids) h = self.wte(input_ids)
mask = prepare_attention_mask( mask = prepare_attention_mask(
...@@ -92,8 +96,10 @@ class FalconModel(nn.Module): ...@@ -92,8 +96,10 @@ class FalconModel(nn.Module):
input_ids, input_ids,
self.last_forward_num_tokens self.last_forward_num_tokens
) )
_bsz, seqlen = input_ids.shape _bsz, seqlen = input_ids.shape
prepare_cache(self.blocks, seqlen)
h = self.word_embeddings(input_ids) h = self.word_embeddings(input_ids)
mask = prepare_attention_mask( mask = prepare_attention_mask(
......
import torch import torch
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
def prepare_cache(blocks, seqlen: int) -> int:
for block in blocks:
start_pos = block.attn.start_pos
will_cache_be_exceeded = start_pos + seqlen > block.attn.max_seq_len
# Reset and avoid retaining state when processing context
if seqlen > 1 and (will_cache_be_exceeded or seqlen > 1):
block.attn.start_pos = block.attn.cache.roll_kv_n_steps(start_pos, n=start_pos)
# Slowly roll out old tokens without performance hit if exceeded during decoding
elif seqlen == 1 and will_cache_be_exceeded:
block.attn.start_pos = block.attn.cache.roll_kv_n_steps(start_pos, n=100)
def prepare_input_ids(input_ids: torch.Tensor, last_forward_num_tokens: int): def prepare_input_ids(input_ids: torch.Tensor, last_forward_num_tokens: int):
# NOTE: from transformers 4.35.0, input_ids includes full context during decoding # NOTE: from transformers 4.35.0, input_ids includes full context during decoding
num_input_tokens = input_ids.shape[-1] num_input_tokens = input_ids.shape[-1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment