Unverified Commit 1b68975a authored by Casper's avatar Casper Committed by GitHub
Browse files

Merge pull request #95 from casper-hansen/cache_refactor

Refactor cache and embedding modules
parents c9e45270 b13e2a85
...@@ -2,8 +2,9 @@ import os ...@@ -2,8 +2,9 @@ import os
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import awq_inference_engine
from torch.nn import functional as F from torch.nn import functional as F
from awq.modules.fused.cache import WindowedCache
from awq.utils.fused_utils import get_attention_shapes
try: try:
import ft_inference_engine import ft_inference_engine
...@@ -11,60 +12,78 @@ try: ...@@ -11,60 +12,78 @@ try:
except: except:
FT_INSTALLED = False FT_INSTALLED = False
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): class RoPE(nn.Module):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) def __init__(self, hidden_size, n_heads, max_seq_len, device):
t = torch.arange(end, device=freqs.device) # type: ignore super(RoPE, self).__init__()
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 self.freqs_cis = nn.Parameter(
return freqs_cis self.precompute_freqs_cis(hidden_size // n_heads, max_seq_len * 2).to(device),
requires_grad=False
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): )
ndim = x.ndim
assert 0 <= 1 < ndim @staticmethod
assert freqs_cis.shape == (x.shape[1], x.shape[-1]) def precompute_freqs_cis(dim: int, end: int, theta=10000.0):
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
return freqs_cis.view(*shape) t = torch.arange(end)
freqs = torch.outer(t, freqs).float()
def apply_rotary_emb( freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
xq: torch.Tensor, return freqs_cis
xk: torch.Tensor,
freqs_cis: torch.Tensor, @staticmethod
): def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
xq_ = torch.view_as_complex( ndim = x.ndim
xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous() assert 0 <= 1 < ndim
) assert freqs_cis.shape == (x.shape[1], x.shape[-1])
xk_ = torch.view_as_complex( shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous() return freqs_cis.view(*shape)
)
freqs_cis = reshape_for_broadcast(freqs_cis, xq_).to(xq_.device) def forward(self, xq: torch.Tensor, xk: torch.Tensor, start_pos: int, seqlen: int):
xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3) xq_ = torch.view_as_complex(
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3) xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
return xq_out.type_as(xq), xk_out.type_as(xk)
def gen_slopes(n_heads, alibi_bias_max=8):
_n_heads = 2 ** math.ceil(math.log2(n_heads))
m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
m = m.mul(alibi_bias_max / _n_heads)
slopes = 1.0 / torch.pow(2, m)
if _n_heads != n_heads:
slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
return slopes.view(1, n_heads, 1, 1)
def build_alibi_bias(
n_heads, seq_len, full=False, alibi_bias_max=8, dtype=torch.float32
):
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len)
if full:
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32).view(
1, 1, seq_len, 1
) )
alibi_bias = alibi_bias.abs().mul(-1) xk_ = torch.view_as_complex(
slopes = gen_slopes(n_heads, alibi_bias_max) xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
alibi_bias = alibi_bias * slopes )
slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1) freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype) freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_).to(xq_.device)
xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class ALiBi(nn.Module):
def __init__(self, n_heads, max_seq_len, device, alibi_bias_max=8):
super(ALiBi, self).__init__()
# Initialize ALiBi slopes and bias
slopes, bias = self.build_alibi_bias(n_heads, max_seq_len, alibi_bias_max=alibi_bias_max)
self.slopes = nn.Parameter(slopes.float().to(device), requires_grad=False)
self.bias = nn.Parameter(bias.float().to(device), requires_grad=False)
@staticmethod
def gen_slopes(n_heads, alibi_bias_max=8):
_n_heads = 2 ** math.ceil(math.log2(n_heads))
m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
m = m.mul(alibi_bias_max / _n_heads)
slopes = 1.0 / torch.pow(2, m)
if _n_heads != n_heads:
slopes = torch.cat([slopes[1::2], slopes[::2]])[:n_heads]
return slopes.view(1, n_heads, 1, 1)
@staticmethod
def build_alibi_bias(n_heads, seq_len, alibi_bias_max=8, dtype=torch.float32):
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len)
slopes = ALiBi.gen_slopes(n_heads, alibi_bias_max)
alibi_bias = alibi_bias * slopes
slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)
def forward(self, scores, seqlen):
scores += self.bias[..., :seqlen]
return scores
class QuantAttentionFused(nn.Module): class QuantAttentionFused(nn.Module):
def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len, def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len,
...@@ -81,74 +100,27 @@ class QuantAttentionFused(nn.Module): ...@@ -81,74 +100,27 @@ class QuantAttentionFused(nn.Module):
self.use_alibi = use_alibi self.use_alibi = use_alibi
self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1")) self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.attention_shapes = self._get_attention_shapes(attention_shapes, max_seq_len)
self.cache_v = ( torch.zeros(self.attention_shapes["cache_v"]).to(dev).half() ) # attention shapes for self attention
self.cache_k = ( torch.zeros(self.attention_shapes["cache_k"]).to(dev).half() ) self.attention_shapes = get_attention_shapes(
attention_shapes, max_seq_len, self.cache_batch_size, n_heads, n_kv_heads, self.head_dim
)
# cache store that rolls cache
self.cache = WindowedCache(
self.attention_shapes["cache_v"], self.attention_shapes["cache_k"], dev
)
if use_alibi: if use_alibi:
alibi_slopes, alibi_bias = build_alibi_bias(self.n_heads, max_seq_len) self.alibi = ALiBi(n_heads, max_seq_len, dev)
self.alibi_slopes = alibi_slopes.float().to(dev)
self.alibi_bias = alibi_bias.float().to(dev)
self.rotary_dim = 0 self.rotary_dim = 0
self.is_neox = False self.is_neox = False
else: else:
self.freqs_cis = precompute_freqs_cis( self.alibi = None
hidden_size // n_heads, self.rope = RoPE(hidden_size, n_heads, max_seq_len, dev)
max_seq_len * 2,
).to(dev)
self.rotary_dim = self.head_dim self.rotary_dim = self.head_dim
self.alibi_slopes = None
self.is_neox = True self.is_neox = True
def _get_attention_shapes(self, attention_shapes, max_seq_len): def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs):
if attention_shapes is not None:
attention_shapes = attention_shapes
elif self.n_kv_heads == 0:
attention_shapes = {
# following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (self.cache_batch_size, self.n_heads, self.head_dim // 8, max_seq_len, 8,),
"xqkv_view": (-1, self.n_heads, self.head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0],
"xk_slice": lambda xqkv: xqkv[:, :, 1],
"xv_slice": lambda xqkv: xqkv[:, :, 2],
"xq_view": (self.n_heads, self.head_dim),
"xk_view": (self.n_heads, self.head_dim),
"xv_view": (self.n_heads, self.head_dim),
"xk_reshape": (self.n_heads, self.head_dim // 8, 8),
"single_xq_view": (self.n_heads, self.head_dim),
"single_xk_view": (self.n_heads, self.head_dim),
"single_xv_view": (self.n_heads, self.head_dim)
}
else:
attention_shapes = {
# following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_kv_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (self.cache_batch_size, self.n_kv_heads, self.head_dim // 8, max_seq_len, 8,),
"xqkv_view": (self.n_heads + self.n_kv_heads * 2, self.head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0 : self.n_heads],
"xk_slice": lambda xqkv: xqkv[:, :, self.n_heads : (self.n_heads + self.n_kv_heads)],
"xv_slice": lambda xqkv: xqkv[:, :, -self.n_kv_heads :],
"xq_view": (self.n_heads, self.head_dim),
"xk_view": (self.n_kv_heads, self.head_dim),
"xv_view": (self.n_kv_heads, self.head_dim),
"xk_reshape": (self.n_kv_heads, self.head_dim // 8, 8),
"single_xq_view": (self.n_heads, self.head_dim),
"single_xk_view": (self.n_kv_heads, self.head_dim),
"single_xv_view": (self.n_kv_heads, self.head_dim)
}
return attention_shapes
def forward(
self,
hidden_states:torch.Tensor, past_key_value=None, attention_mask=None, position_ids=None,
output_attentions=False, use_cache=False, *args, **kwargs
):
bsz, seqlen, _ = hidden_states.shape bsz, seqlen, _ = hidden_states.shape
if bsz != self.cache_batch_size: if bsz != self.cache_batch_size:
raise RuntimeError( raise RuntimeError(
...@@ -157,14 +129,8 @@ class QuantAttentionFused(nn.Module): ...@@ -157,14 +129,8 @@ class QuantAttentionFused(nn.Module):
) )
if self.start_pos > self.max_seq_len or self.start_pos + seqlen > self.max_seq_len: if self.start_pos > self.max_seq_len or self.start_pos + seqlen > self.max_seq_len:
# Roll cache to the left excess_length = self.start_pos + seqlen - self.max_seq_len
roll_len = self.start_pos self.start_pos = self.cache.roll_kv(excess_length, self.start_pos)
self.cache_v = torch.roll(self.cache_v, shifts=-roll_len, dims=2)
self.cache_k = torch.roll(self.cache_k, shifts=-roll_len, dims=3)
# Zero out the new part
self.cache_v[:, :, -roll_len:, :] = 0
self.cache_k[:, :, :, -roll_len:, :] = 0
self.start_pos = 0
xqkv = self.qkv_proj(hidden_states) xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"]) xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
...@@ -179,10 +145,9 @@ class QuantAttentionFused(nn.Module): ...@@ -179,10 +145,9 @@ class QuantAttentionFused(nn.Module):
xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"]) xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
if not self.use_alibi: if not self.use_alibi:
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen]) xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen)
self.cache_k = self.cache_k.to(xq) self.cache.to(xq)
self.cache_v = self.cache_v.to(xq)
values_store = xv.transpose(2, 1) values_store = xv.transpose(2, 1)
keys_store = ( keys_store = (
...@@ -191,13 +156,10 @@ class QuantAttentionFused(nn.Module): ...@@ -191,13 +156,10 @@ class QuantAttentionFused(nn.Module):
.contiguous() .contiguous()
) )
self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen)
self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store
if seqlen == 1: if seqlen == 1:
xv = self.cache_v[:bsz, :, : self.start_pos + seqlen, :].transpose(1, 2).contiguous() xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim)
xk = self.cache_k[:bsz, :, :, : self.start_pos + seqlen, :].transpose(2, 3).contiguous()
xk = xk.reshape(xk.shape[:-2] + (self.head_dim,)).transpose(1, 2).contiguous()
keys = xk keys = xk
values = xv values = xv
...@@ -212,7 +174,7 @@ class QuantAttentionFused(nn.Module): ...@@ -212,7 +174,7 @@ class QuantAttentionFused(nn.Module):
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if self.use_alibi: if self.use_alibi:
scores += self.alibi_bias[..., :seqlen] scores = self.alibi.forward(scores, seqlen)
if attention_mask is not None: if attention_mask is not None:
scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen) scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen)
...@@ -225,14 +187,15 @@ class QuantAttentionFused(nn.Module): ...@@ -225,14 +187,15 @@ class QuantAttentionFused(nn.Module):
xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"]) xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"]) xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])
alibi_slopes = self.alibi.slopes if self.alibi is not None else None
attention_weight = ft_inference_engine.single_query_attention( attention_weight = ft_inference_engine.single_query_attention(
xq, # query xq, # query
xk, # key xk, # key
xv, # value xv, # value
self.cache_k, # key cache self.cache.k, # key cache
self.cache_v, # value cache self.cache.v, # value cache
None, # length per sample None, # length per sample
self.alibi_slopes, # alibi slopes alibi_slopes, # alibi slopes
self.start_pos, # timestep self.start_pos, # timestep
self.rotary_dim, # rotary embedding dimension self.rotary_dim, # rotary embedding dimension
10000, # rotary embedding base 10000, # rotary embedding base
...@@ -241,11 +204,8 @@ class QuantAttentionFused(nn.Module): ...@@ -241,11 +204,8 @@ class QuantAttentionFused(nn.Module):
attention_weight = attention_weight.reshape(bsz, 1, -1) attention_weight = attention_weight.reshape(bsz, 1, -1)
attn_output = self.o_proj(attention_weight) attn_output = self.o_proj(attention_weight)
self.start_pos += seqlen
if use_cache:
self.start_pos += seqlen
else:
self.start_pos = 0
# past_key_value is replaced with cache_v, cache_k, returning None # past_key_value is replaced with cache_v, cache_k, returning empty data
return attn_output, attention_weight, None past_key_value = [torch.Tensor([ [ [[0]], [[0]], [[0]] ] ])]
return attn_output, attention_weight, past_key_value
\ No newline at end of file
import torch
class WindowedCache:
def __init__(self, cache_v_shape, cache_k_shape, device):
"""
The window size is the same as the max_new_tokens. The window will
automatically roll once max_new_tokens is exceeded.
"""
# [batch_size, n_kv_heads, max_seq_len, head_dim]
self.v = torch.zeros(cache_v_shape).to(device).half()
# [batch_size, n_kv_heads, head_dim // pack_factor, max_seq_len, pack_factor]
self.k = torch.zeros(cache_k_shape).to(device).half()
def get_kv(self, batch_size, start_pos, seqlen, head_dim):
xv = self.v[:batch_size, :, : start_pos + seqlen, :].transpose(1, 2).contiguous()
xk = self.k[:batch_size, :, :, : start_pos + seqlen, :].transpose(2, 3).contiguous()
xk = xk.reshape(xk.shape[:-2] + (head_dim,)).transpose(1, 2).contiguous()
return xv, xk
def update_kv(self, values_store, keys_store, batch_size, start_pos, seqlen):
self.v[:batch_size, :, start_pos : start_pos + seqlen, :] = values_store
self.k[:batch_size, :, :, start_pos : start_pos + seqlen, :] = keys_store
def roll_kv(self, roll_len, start_pos):
# Roll only the necessary part of the cache to the left
self.v[:, :, :-roll_len, :] = self.v[:, :, roll_len:, :]
self.k[:, :, :, :-roll_len, :] = self.k[:, :, :, roll_len:, :]
# Zero out the new part
self.v[:, :, -roll_len:, :] = 0
self.k[:, :, :, -roll_len:, :] = 0
return start_pos - roll_len
def to(self, device):
self.k = self.k.to(device)
self.v = self.v.to(device)
\ No newline at end of file
def get_attention_shapes(attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim):
if attention_shapes is not None:
attention_shapes = attention_shapes
elif n_kv_heads == 0:
attention_shapes = {
# following fastertransformer definition
"cache_v": (cache_batch_size, n_heads, max_seq_len, head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (cache_batch_size, n_heads, head_dim // 8, max_seq_len, 8,),
"xqkv_view": (-1, n_heads, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0],
"xk_slice": lambda xqkv: xqkv[:, :, 1],
"xv_slice": lambda xqkv: xqkv[:, :, 2],
"xq_view": (n_heads, head_dim),
"xk_view": (n_heads, head_dim),
"xv_view": (n_heads, head_dim),
"xk_reshape": (n_heads, head_dim // 8, 8),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (n_heads, head_dim),
"single_xv_view": (n_heads, head_dim)
}
else:
attention_shapes = {
# following fastertransformer definition
"cache_v": (cache_batch_size, n_kv_heads, max_seq_len, head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (cache_batch_size, n_kv_heads, head_dim // 8, max_seq_len, 8,),
"xqkv_view": (n_heads + n_kv_heads * 2, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0 : n_heads],
"xk_slice": lambda xqkv: xqkv[:, :, n_heads : (n_heads + n_kv_heads)],
"xv_slice": lambda xqkv: xqkv[:, :, -n_kv_heads :],
"xq_view": (n_heads, head_dim),
"xk_view": (n_kv_heads, head_dim),
"xv_view": (n_kv_heads, head_dim),
"xk_reshape": (n_kv_heads, head_dim // 8, 8),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (n_kv_heads, head_dim),
"single_xv_view": (n_kv_heads, head_dim)
}
return attention_shapes
\ No newline at end of file
...@@ -4,7 +4,7 @@ from transformers import AutoTokenizer, TextStreamer ...@@ -4,7 +4,7 @@ from transformers import AutoTokenizer, TextStreamer
quant_path = "TheBloke/Mistral-7B-OpenOrca-AWQ" quant_path = "TheBloke/Mistral-7B-OpenOrca-AWQ"
# Load model # Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=False, safetensors=True) model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True, safetensors=True)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
...@@ -16,8 +16,12 @@ You are MistralOrca, a large language model trained by Alignment Lab AI. Write o ...@@ -16,8 +16,12 @@ You are MistralOrca, a large language model trained by Alignment Lab AI. Write o
{prompt}<|im_end|> {prompt}<|im_end|>
<|im_start|>assistant""" <|im_start|>assistant"""
prompt = "You're standing on the surface of the Earth. "\
"You walk one mile south, one mile west and one mile north. "\
"You end up exactly where you started. Where are you?"
tokens = tokenizer( tokens = tokenizer(
prompt_template.format(prompt="Why is ice cream so good, yes so good?"), prompt_template.format(prompt=prompt),
return_tensors='pt' return_tensors='pt'
).input_ids.cuda() ).input_ids.cuda()
...@@ -26,4 +30,4 @@ generation_output = model.generate( ...@@ -26,4 +30,4 @@ generation_output = model.generate(
tokens, tokens,
streamer=streamer, streamer=streamer,
max_new_tokens=512 max_new_tokens=512
) )
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment