Commit e46703d8 authored by Casper Hansen's avatar Casper Hansen
Browse files

Use new WindowedCache

parent 66b2e233
...@@ -2,8 +2,8 @@ import os ...@@ -2,8 +2,8 @@ import os
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import awq_inference_engine
from torch.nn import functional as F from torch.nn import functional as F
from awq.modules.fused.cache import WindowedCache
try: try:
import ft_inference_engine import ft_inference_engine
...@@ -25,11 +25,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): ...@@ -25,11 +25,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape) return freqs_cis.view(*shape)
def apply_rotary_emb( def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
):
xq_ = torch.view_as_complex( xq_ = torch.view_as_complex(
xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous() xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
) )
...@@ -65,6 +61,49 @@ def build_alibi_bias( ...@@ -65,6 +61,49 @@ def build_alibi_bias(
slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1) slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype) return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)
def get_attention_shapes(attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim):
if attention_shapes is not None:
attention_shapes = attention_shapes
elif n_kv_heads == 0:
attention_shapes = {
# following fastertransformer definition
"cache_v": (cache_batch_size, n_heads, max_seq_len, head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (cache_batch_size, n_heads, head_dim // 8, max_seq_len, 8,),
"xqkv_view": (-1, n_heads, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0],
"xk_slice": lambda xqkv: xqkv[:, :, 1],
"xv_slice": lambda xqkv: xqkv[:, :, 2],
"xq_view": (n_heads, head_dim),
"xk_view": (n_heads, head_dim),
"xv_view": (n_heads, head_dim),
"xk_reshape": (n_heads, head_dim // 8, 8),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (n_heads, head_dim),
"single_xv_view": (n_heads, head_dim)
}
else:
attention_shapes = {
# following fastertransformer definition
"cache_v": (cache_batch_size, n_kv_heads, max_seq_len, head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (cache_batch_size, n_kv_heads, head_dim // 8, max_seq_len, 8,),
"xqkv_view": (n_heads + n_kv_heads * 2, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0 : n_heads],
"xk_slice": lambda xqkv: xqkv[:, :, n_heads : (n_heads + n_kv_heads)],
"xv_slice": lambda xqkv: xqkv[:, :, -n_kv_heads :],
"xq_view": (n_heads, head_dim),
"xk_view": (n_kv_heads, head_dim),
"xv_view": (n_kv_heads, head_dim),
"xk_reshape": (n_kv_heads, head_dim // 8, 8),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (n_kv_heads, head_dim),
"single_xv_view": (n_kv_heads, head_dim)
}
return attention_shapes
class QuantAttentionFused(nn.Module): class QuantAttentionFused(nn.Module):
def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len, def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len,
...@@ -81,9 +120,15 @@ class QuantAttentionFused(nn.Module): ...@@ -81,9 +120,15 @@ class QuantAttentionFused(nn.Module):
self.use_alibi = use_alibi self.use_alibi = use_alibi
self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1")) self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.attention_shapes = self._get_attention_shapes(attention_shapes, max_seq_len)
self.cache_v = ( torch.zeros(self.attention_shapes["cache_v"]).to(dev).half() ) # attention shapes for self attention
self.cache_k = ( torch.zeros(self.attention_shapes["cache_k"]).to(dev).half() ) self.attention_shapes = get_attention_shapes(
attention_shapes, max_seq_len, self.cache_batch_size, n_heads, n_kv_heads, self.head_dim
)
# cache store that rolls cache
self.cache = WindowedCache(
self.attention_shapes["cache_v"], self.attention_shapes["cache_k"], dev
)
if use_alibi: if use_alibi:
alibi_slopes, alibi_bias = build_alibi_bias(self.n_heads, max_seq_len) alibi_slopes, alibi_bias = build_alibi_bias(self.n_heads, max_seq_len)
...@@ -100,55 +145,7 @@ class QuantAttentionFused(nn.Module): ...@@ -100,55 +145,7 @@ class QuantAttentionFused(nn.Module):
self.alibi_slopes = None self.alibi_slopes = None
self.is_neox = True self.is_neox = True
def _get_attention_shapes(self, attention_shapes, max_seq_len): def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs):
if attention_shapes is not None:
attention_shapes = attention_shapes
elif self.n_kv_heads == 0:
attention_shapes = {
# following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (self.cache_batch_size, self.n_heads, self.head_dim // 8, max_seq_len, 8,),
"xqkv_view": (-1, self.n_heads, self.head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0],
"xk_slice": lambda xqkv: xqkv[:, :, 1],
"xv_slice": lambda xqkv: xqkv[:, :, 2],
"xq_view": (self.n_heads, self.head_dim),
"xk_view": (self.n_heads, self.head_dim),
"xv_view": (self.n_heads, self.head_dim),
"xk_reshape": (self.n_heads, self.head_dim // 8, 8),
"single_xq_view": (self.n_heads, self.head_dim),
"single_xk_view": (self.n_heads, self.head_dim),
"single_xv_view": (self.n_heads, self.head_dim)
}
else:
attention_shapes = {
# following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_kv_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (self.cache_batch_size, self.n_kv_heads, self.head_dim // 8, max_seq_len, 8,),
"xqkv_view": (self.n_heads + self.n_kv_heads * 2, self.head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0 : self.n_heads],
"xk_slice": lambda xqkv: xqkv[:, :, self.n_heads : (self.n_heads + self.n_kv_heads)],
"xv_slice": lambda xqkv: xqkv[:, :, -self.n_kv_heads :],
"xq_view": (self.n_heads, self.head_dim),
"xk_view": (self.n_kv_heads, self.head_dim),
"xv_view": (self.n_kv_heads, self.head_dim),
"xk_reshape": (self.n_kv_heads, self.head_dim // 8, 8),
"single_xq_view": (self.n_heads, self.head_dim),
"single_xk_view": (self.n_kv_heads, self.head_dim),
"single_xv_view": (self.n_kv_heads, self.head_dim)
}
return attention_shapes
def forward(
self,
hidden_states:torch.Tensor, past_key_value=None, attention_mask=None, position_ids=None,
output_attentions=False, use_cache=False, *args, **kwargs
):
bsz, seqlen, _ = hidden_states.shape bsz, seqlen, _ = hidden_states.shape
if bsz != self.cache_batch_size: if bsz != self.cache_batch_size:
raise RuntimeError( raise RuntimeError(
...@@ -157,14 +154,8 @@ class QuantAttentionFused(nn.Module): ...@@ -157,14 +154,8 @@ class QuantAttentionFused(nn.Module):
) )
if self.start_pos > self.max_seq_len or self.start_pos + seqlen > self.max_seq_len: if self.start_pos > self.max_seq_len or self.start_pos + seqlen > self.max_seq_len:
# Roll cache to the left excess_length = self.start_pos + seqlen - self.max_seq_len
roll_len = self.start_pos self.start_pos = self.cache.roll_kv(excess_length, self.start_pos)
self.cache_v = torch.roll(self.cache_v, shifts=-roll_len, dims=2)
self.cache_k = torch.roll(self.cache_k, shifts=-roll_len, dims=3)
# Zero out the new part
self.cache_v[:, :, -roll_len:, :] = 0
self.cache_k[:, :, :, -roll_len:, :] = 0
self.start_pos = 0
xqkv = self.qkv_proj(hidden_states) xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"]) xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
...@@ -181,8 +172,7 @@ class QuantAttentionFused(nn.Module): ...@@ -181,8 +172,7 @@ class QuantAttentionFused(nn.Module):
if not self.use_alibi: if not self.use_alibi:
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen]) xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen])
self.cache_k = self.cache_k.to(xq) self.cache.to(xq)
self.cache_v = self.cache_v.to(xq)
values_store = xv.transpose(2, 1) values_store = xv.transpose(2, 1)
keys_store = ( keys_store = (
...@@ -191,13 +181,10 @@ class QuantAttentionFused(nn.Module): ...@@ -191,13 +181,10 @@ class QuantAttentionFused(nn.Module):
.contiguous() .contiguous()
) )
self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen)
self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store
if seqlen == 1: if seqlen == 1:
xv = self.cache_v[:bsz, :, : self.start_pos + seqlen, :].transpose(1, 2).contiguous() xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim)
xk = self.cache_k[:bsz, :, :, : self.start_pos + seqlen, :].transpose(2, 3).contiguous()
xk = xk.reshape(xk.shape[:-2] + (self.head_dim,)).transpose(1, 2).contiguous()
keys = xk keys = xk
values = xv values = xv
...@@ -229,8 +216,8 @@ class QuantAttentionFused(nn.Module): ...@@ -229,8 +216,8 @@ class QuantAttentionFused(nn.Module):
xq, # query xq, # query
xk, # key xk, # key
xv, # value xv, # value
self.cache_k, # key cache self.cache.k, # key cache
self.cache_v, # value cache self.cache.v, # value cache
None, # length per sample None, # length per sample
self.alibi_slopes, # alibi slopes self.alibi_slopes, # alibi slopes
self.start_pos, # timestep self.start_pos, # timestep
...@@ -241,11 +228,8 @@ class QuantAttentionFused(nn.Module): ...@@ -241,11 +228,8 @@ class QuantAttentionFused(nn.Module):
attention_weight = attention_weight.reshape(bsz, 1, -1) attention_weight = attention_weight.reshape(bsz, 1, -1)
attn_output = self.o_proj(attention_weight) attn_output = self.o_proj(attention_weight)
self.start_pos += seqlen
if use_cache:
self.start_pos += seqlen
else:
self.start_pos = 0
# past_key_value is replaced with cache_v, cache_k, returning None # past_key_value is replaced with cache_v, cache_k, returning empty data
return attn_output, attention_weight, None past_key_value = [torch.Tensor([ [ [[0]], [[0]], [[0]] ] ])]
return attn_output, attention_weight, past_key_value
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment