Unverified Commit 1b68975a authored by Casper's avatar Casper Committed by GitHub
Browse files

Merge pull request #95 from casper-hansen/cache_refactor

Refactor cache and embedding modules
parents c9e45270 b13e2a85
......@@ -2,8 +2,9 @@ import os
import math
import torch
import torch.nn as nn
import awq_inference_engine
from torch.nn import functional as F
from awq.modules.fused.cache import WindowedCache
from awq.utils.fused_utils import get_attention_shapes
try:
import ft_inference_engine
......@@ -11,60 +12,78 @@ try:
except:
FT_INSTALLED = False
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
):
xq_ = torch.view_as_complex(
xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
xk_ = torch.view_as_complex(
xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
freqs_cis = reshape_for_broadcast(freqs_cis, xq_).to(xq_.device)
xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
def gen_slopes(n_heads, alibi_bias_max=8):
_n_heads = 2 ** math.ceil(math.log2(n_heads))
m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
m = m.mul(alibi_bias_max / _n_heads)
slopes = 1.0 / torch.pow(2, m)
if _n_heads != n_heads:
slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
return slopes.view(1, n_heads, 1, 1)
def build_alibi_bias(
n_heads, seq_len, full=False, alibi_bias_max=8, dtype=torch.float32
):
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len)
if full:
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32).view(
1, 1, seq_len, 1
class RoPE(nn.Module):
def __init__(self, hidden_size, n_heads, max_seq_len, device):
super(RoPE, self).__init__()
self.freqs_cis = nn.Parameter(
self.precompute_freqs_cis(hidden_size // n_heads, max_seq_len * 2).to(device),
requires_grad=False
)
@staticmethod
def precompute_freqs_cis(dim: int, end: int, theta=10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
@staticmethod
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def forward(self, xq: torch.Tensor, xk: torch.Tensor, start_pos: int, seqlen: int):
xq_ = torch.view_as_complex(
xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
alibi_bias = alibi_bias.abs().mul(-1)
slopes = gen_slopes(n_heads, alibi_bias_max)
alibi_bias = alibi_bias * slopes
slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)
xk_ = torch.view_as_complex(
xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_).to(xq_.device)
xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class ALiBi(nn.Module):
def __init__(self, n_heads, max_seq_len, device, alibi_bias_max=8):
super(ALiBi, self).__init__()
# Initialize ALiBi slopes and bias
slopes, bias = self.build_alibi_bias(n_heads, max_seq_len, alibi_bias_max=alibi_bias_max)
self.slopes = nn.Parameter(slopes.float().to(device), requires_grad=False)
self.bias = nn.Parameter(bias.float().to(device), requires_grad=False)
@staticmethod
def gen_slopes(n_heads, alibi_bias_max=8):
_n_heads = 2 ** math.ceil(math.log2(n_heads))
m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
m = m.mul(alibi_bias_max / _n_heads)
slopes = 1.0 / torch.pow(2, m)
if _n_heads != n_heads:
slopes = torch.cat([slopes[1::2], slopes[::2]])[:n_heads]
return slopes.view(1, n_heads, 1, 1)
@staticmethod
def build_alibi_bias(n_heads, seq_len, alibi_bias_max=8, dtype=torch.float32):
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len)
slopes = ALiBi.gen_slopes(n_heads, alibi_bias_max)
alibi_bias = alibi_bias * slopes
slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)
def forward(self, scores, seqlen):
scores += self.bias[..., :seqlen]
return scores
class QuantAttentionFused(nn.Module):
def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len,
......@@ -81,74 +100,27 @@ class QuantAttentionFused(nn.Module):
self.use_alibi = use_alibi
self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
self.max_seq_len = max_seq_len
self.attention_shapes = self._get_attention_shapes(attention_shapes, max_seq_len)
self.cache_v = ( torch.zeros(self.attention_shapes["cache_v"]).to(dev).half() )
self.cache_k = ( torch.zeros(self.attention_shapes["cache_k"]).to(dev).half() )
# attention shapes for self attention
self.attention_shapes = get_attention_shapes(
attention_shapes, max_seq_len, self.cache_batch_size, n_heads, n_kv_heads, self.head_dim
)
# cache store that rolls cache
self.cache = WindowedCache(
self.attention_shapes["cache_v"], self.attention_shapes["cache_k"], dev
)
if use_alibi:
alibi_slopes, alibi_bias = build_alibi_bias(self.n_heads, max_seq_len)
self.alibi_slopes = alibi_slopes.float().to(dev)
self.alibi_bias = alibi_bias.float().to(dev)
self.alibi = ALiBi(n_heads, max_seq_len, dev)
self.rotary_dim = 0
self.is_neox = False
else:
self.freqs_cis = precompute_freqs_cis(
hidden_size // n_heads,
max_seq_len * 2,
).to(dev)
self.alibi = None
self.rope = RoPE(hidden_size, n_heads, max_seq_len, dev)
self.rotary_dim = self.head_dim
self.alibi_slopes = None
self.is_neox = True
def _get_attention_shapes(self, attention_shapes, max_seq_len):
if attention_shapes is not None:
attention_shapes = attention_shapes
elif self.n_kv_heads == 0:
attention_shapes = {
# following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (self.cache_batch_size, self.n_heads, self.head_dim // 8, max_seq_len, 8,),
"xqkv_view": (-1, self.n_heads, self.head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0],
"xk_slice": lambda xqkv: xqkv[:, :, 1],
"xv_slice": lambda xqkv: xqkv[:, :, 2],
"xq_view": (self.n_heads, self.head_dim),
"xk_view": (self.n_heads, self.head_dim),
"xv_view": (self.n_heads, self.head_dim),
"xk_reshape": (self.n_heads, self.head_dim // 8, 8),
"single_xq_view": (self.n_heads, self.head_dim),
"single_xk_view": (self.n_heads, self.head_dim),
"single_xv_view": (self.n_heads, self.head_dim)
}
else:
attention_shapes = {
# following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_kv_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (self.cache_batch_size, self.n_kv_heads, self.head_dim // 8, max_seq_len, 8,),
"xqkv_view": (self.n_heads + self.n_kv_heads * 2, self.head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0 : self.n_heads],
"xk_slice": lambda xqkv: xqkv[:, :, self.n_heads : (self.n_heads + self.n_kv_heads)],
"xv_slice": lambda xqkv: xqkv[:, :, -self.n_kv_heads :],
"xq_view": (self.n_heads, self.head_dim),
"xk_view": (self.n_kv_heads, self.head_dim),
"xv_view": (self.n_kv_heads, self.head_dim),
"xk_reshape": (self.n_kv_heads, self.head_dim // 8, 8),
"single_xq_view": (self.n_heads, self.head_dim),
"single_xk_view": (self.n_kv_heads, self.head_dim),
"single_xv_view": (self.n_kv_heads, self.head_dim)
}
return attention_shapes
def forward(
self,
hidden_states:torch.Tensor, past_key_value=None, attention_mask=None, position_ids=None,
output_attentions=False, use_cache=False, *args, **kwargs
):
def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs):
bsz, seqlen, _ = hidden_states.shape
if bsz != self.cache_batch_size:
raise RuntimeError(
......@@ -157,14 +129,8 @@ class QuantAttentionFused(nn.Module):
)
if self.start_pos > self.max_seq_len or self.start_pos + seqlen > self.max_seq_len:
# Roll cache to the left
roll_len = self.start_pos
self.cache_v = torch.roll(self.cache_v, shifts=-roll_len, dims=2)
self.cache_k = torch.roll(self.cache_k, shifts=-roll_len, dims=3)
# Zero out the new part
self.cache_v[:, :, -roll_len:, :] = 0
self.cache_k[:, :, :, -roll_len:, :] = 0
self.start_pos = 0
excess_length = self.start_pos + seqlen - self.max_seq_len
self.start_pos = self.cache.roll_kv(excess_length, self.start_pos)
xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
......@@ -179,10 +145,9 @@ class QuantAttentionFused(nn.Module):
xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
if not self.use_alibi:
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen])
xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen)
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache.to(xq)
values_store = xv.transpose(2, 1)
keys_store = (
......@@ -191,13 +156,10 @@ class QuantAttentionFused(nn.Module):
.contiguous()
)
self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store
self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store
self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen)
if seqlen == 1:
xv = self.cache_v[:bsz, :, : self.start_pos + seqlen, :].transpose(1, 2).contiguous()
xk = self.cache_k[:bsz, :, :, : self.start_pos + seqlen, :].transpose(2, 3).contiguous()
xk = xk.reshape(xk.shape[:-2] + (self.head_dim,)).transpose(1, 2).contiguous()
xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim)
keys = xk
values = xv
......@@ -212,7 +174,7 @@ class QuantAttentionFused(nn.Module):
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if self.use_alibi:
scores += self.alibi_bias[..., :seqlen]
scores = self.alibi.forward(scores, seqlen)
if attention_mask is not None:
scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen)
......@@ -225,14 +187,15 @@ class QuantAttentionFused(nn.Module):
xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])
alibi_slopes = self.alibi.slopes if self.alibi is not None else None
attention_weight = ft_inference_engine.single_query_attention(
xq, # query
xk, # key
xv, # value
self.cache_k, # key cache
self.cache_v, # value cache
self.cache.k, # key cache
self.cache.v, # value cache
None, # length per sample
self.alibi_slopes, # alibi slopes
alibi_slopes, # alibi slopes
self.start_pos, # timestep
self.rotary_dim, # rotary embedding dimension
10000, # rotary embedding base
......@@ -241,11 +204,8 @@ class QuantAttentionFused(nn.Module):
attention_weight = attention_weight.reshape(bsz, 1, -1)
attn_output = self.o_proj(attention_weight)
if use_cache:
self.start_pos += seqlen
else:
self.start_pos = 0
self.start_pos += seqlen
# past_key_value is replaced with cache_v, cache_k, returning None
return attn_output, attention_weight, None
# past_key_value is replaced with cache_v, cache_k, returning empty data
past_key_value = [torch.Tensor([ [ [[0]], [[0]], [[0]] ] ])]
return attn_output, attention_weight, past_key_value
\ No newline at end of file
import torch
class WindowedCache:
def __init__(self, cache_v_shape, cache_k_shape, device):
"""
The window size is the same as the max_new_tokens. The window will
automatically roll once max_new_tokens is exceeded.
"""
# [batch_size, n_kv_heads, max_seq_len, head_dim]
self.v = torch.zeros(cache_v_shape).to(device).half()
# [batch_size, n_kv_heads, head_dim // pack_factor, max_seq_len, pack_factor]
self.k = torch.zeros(cache_k_shape).to(device).half()
def get_kv(self, batch_size, start_pos, seqlen, head_dim):
xv = self.v[:batch_size, :, : start_pos + seqlen, :].transpose(1, 2).contiguous()
xk = self.k[:batch_size, :, :, : start_pos + seqlen, :].transpose(2, 3).contiguous()
xk = xk.reshape(xk.shape[:-2] + (head_dim,)).transpose(1, 2).contiguous()
return xv, xk
def update_kv(self, values_store, keys_store, batch_size, start_pos, seqlen):
self.v[:batch_size, :, start_pos : start_pos + seqlen, :] = values_store
self.k[:batch_size, :, :, start_pos : start_pos + seqlen, :] = keys_store
def roll_kv(self, roll_len, start_pos):
# Roll only the necessary part of the cache to the left
self.v[:, :, :-roll_len, :] = self.v[:, :, roll_len:, :]
self.k[:, :, :, :-roll_len, :] = self.k[:, :, :, roll_len:, :]
# Zero out the new part
self.v[:, :, -roll_len:, :] = 0
self.k[:, :, :, -roll_len:, :] = 0
return start_pos - roll_len
def to(self, device):
self.k = self.k.to(device)
self.v = self.v.to(device)
\ No newline at end of file
def get_attention_shapes(attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim):
if attention_shapes is not None:
attention_shapes = attention_shapes
elif n_kv_heads == 0:
attention_shapes = {
# following fastertransformer definition
"cache_v": (cache_batch_size, n_heads, max_seq_len, head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (cache_batch_size, n_heads, head_dim // 8, max_seq_len, 8,),
"xqkv_view": (-1, n_heads, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0],
"xk_slice": lambda xqkv: xqkv[:, :, 1],
"xv_slice": lambda xqkv: xqkv[:, :, 2],
"xq_view": (n_heads, head_dim),
"xk_view": (n_heads, head_dim),
"xv_view": (n_heads, head_dim),
"xk_reshape": (n_heads, head_dim // 8, 8),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (n_heads, head_dim),
"single_xv_view": (n_heads, head_dim)
}
else:
attention_shapes = {
# following fastertransformer definition
"cache_v": (cache_batch_size, n_kv_heads, max_seq_len, head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (cache_batch_size, n_kv_heads, head_dim // 8, max_seq_len, 8,),
"xqkv_view": (n_heads + n_kv_heads * 2, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0 : n_heads],
"xk_slice": lambda xqkv: xqkv[:, :, n_heads : (n_heads + n_kv_heads)],
"xv_slice": lambda xqkv: xqkv[:, :, -n_kv_heads :],
"xq_view": (n_heads, head_dim),
"xk_view": (n_kv_heads, head_dim),
"xv_view": (n_kv_heads, head_dim),
"xk_reshape": (n_kv_heads, head_dim // 8, 8),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (n_kv_heads, head_dim),
"single_xv_view": (n_kv_heads, head_dim)
}
return attention_shapes
\ No newline at end of file
......@@ -4,7 +4,7 @@ from transformers import AutoTokenizer, TextStreamer
quant_path = "TheBloke/Mistral-7B-OpenOrca-AWQ"
# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=False, safetensors=True)
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True, safetensors=True)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
......@@ -16,8 +16,12 @@ You are MistralOrca, a large language model trained by Alignment Lab AI. Write o
{prompt}<|im_end|>
<|im_start|>assistant"""
prompt = "You're standing on the surface of the Earth. "\
"You walk one mile south, one mile west and one mile north. "\
"You end up exactly where you started. Where are you?"
tokens = tokenizer(
prompt_template.format(prompt="Why is ice cream so good, yes so good?"),
prompt_template.format(prompt=prompt),
return_tensors='pt'
).input_ids.cuda()
......@@ -26,4 +30,4 @@ generation_output = model.generate(
tokens,
streamer=streamer,
max_new_tokens=512
)
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment