Commit 8c80b3e0 authored by Casper Hansen's avatar Casper Hansen
Browse files

Move cache initialization back to init

parent c2cd6535
......@@ -84,7 +84,8 @@ class QuantAttentionFused(nn.Module):
self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
self.max_seq_len = max_seq_len
self.attention_shapes = self._get_attention_shapes(attention_shapes, max_seq_len)
self._initialize_cache(dev)
self.cache_v = ( torch.zeros(self.attention_shapes["cache_v"]).to(dev).half() )
self.cache_k = ( torch.zeros(self.attention_shapes["cache_k"]).to(dev).half() )
if use_alibi:
alibi_slopes, alibi_bias = build_alibi_bias(self.n_heads, max_seq_len)
......@@ -101,15 +102,6 @@ class QuantAttentionFused(nn.Module):
self.alibi_slopes = None
self.is_neox = True
def _initialize_cache(self, dev):
self.cache_v = (
torch.zeros(self.attention_shapes["cache_v"]).to(dev).half()
)
self.cache_k = (
torch.zeros(self.attention_shapes["cache_k"]).to(dev).half()
)
def _get_attention_shapes(self, attention_shapes, max_seq_len):
if attention_shapes is not None:
attention_shapes = attention_shapes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment