Commit 306de683 authored by Casper Hansen's avatar Casper Hansen
Browse files

Move attention shapes to fused_utils

parent 428504e4
......@@ -4,6 +4,7 @@ import torch
import torch.nn as nn
from torch.nn import functional as F
from awq.modules.fused.cache import WindowedCache
from awq.utils.fused_utils import get_attention_shapes
try:
import ft_inference_engine
......@@ -70,50 +71,6 @@ class ALiBi(nn.Module):
scores += self.bias[..., :seqlen]
return scores
def get_attention_shapes(attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim):
if attention_shapes is not None:
attention_shapes = attention_shapes
elif n_kv_heads == 0:
attention_shapes = {
# following fastertransformer definition
"cache_v": (cache_batch_size, n_heads, max_seq_len, head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (cache_batch_size, n_heads, head_dim // 8, max_seq_len, 8,),
"xqkv_view": (-1, n_heads, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0],
"xk_slice": lambda xqkv: xqkv[:, :, 1],
"xv_slice": lambda xqkv: xqkv[:, :, 2],
"xq_view": (n_heads, head_dim),
"xk_view": (n_heads, head_dim),
"xv_view": (n_heads, head_dim),
"xk_reshape": (n_heads, head_dim // 8, 8),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (n_heads, head_dim),
"single_xv_view": (n_heads, head_dim)
}
else:
attention_shapes = {
# following fastertransformer definition
"cache_v": (cache_batch_size, n_kv_heads, max_seq_len, head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (cache_batch_size, n_kv_heads, head_dim // 8, max_seq_len, 8,),
"xqkv_view": (n_heads + n_kv_heads * 2, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0 : n_heads],
"xk_slice": lambda xqkv: xqkv[:, :, n_heads : (n_heads + n_kv_heads)],
"xv_slice": lambda xqkv: xqkv[:, :, -n_kv_heads :],
"xq_view": (n_heads, head_dim),
"xk_view": (n_kv_heads, head_dim),
"xv_view": (n_kv_heads, head_dim),
"xk_reshape": (n_kv_heads, head_dim // 8, 8),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (n_kv_heads, head_dim),
"single_xv_view": (n_kv_heads, head_dim)
}
return attention_shapes
class QuantAttentionFused(nn.Module):
def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len,
use_alibi=False, attention_shapes=None):
......
def get_attention_shapes(attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim):
if attention_shapes is not None:
attention_shapes = attention_shapes
elif n_kv_heads == 0:
attention_shapes = {
# following fastertransformer definition
"cache_v": (cache_batch_size, n_heads, max_seq_len, head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (cache_batch_size, n_heads, head_dim // 8, max_seq_len, 8,),
"xqkv_view": (-1, n_heads, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0],
"xk_slice": lambda xqkv: xqkv[:, :, 1],
"xv_slice": lambda xqkv: xqkv[:, :, 2],
"xq_view": (n_heads, head_dim),
"xk_view": (n_heads, head_dim),
"xv_view": (n_heads, head_dim),
"xk_reshape": (n_heads, head_dim // 8, 8),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (n_heads, head_dim),
"single_xv_view": (n_heads, head_dim)
}
else:
attention_shapes = {
# following fastertransformer definition
"cache_v": (cache_batch_size, n_kv_heads, max_seq_len, head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (cache_batch_size, n_kv_heads, head_dim // 8, max_seq_len, 8,),
"xqkv_view": (n_heads + n_kv_heads * 2, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0 : n_heads],
"xk_slice": lambda xqkv: xqkv[:, :, n_heads : (n_heads + n_kv_heads)],
"xv_slice": lambda xqkv: xqkv[:, :, -n_kv_heads :],
"xq_view": (n_heads, head_dim),
"xk_view": (n_kv_heads, head_dim),
"xv_view": (n_kv_heads, head_dim),
"xk_reshape": (n_kv_heads, head_dim // 8, 8),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (n_kv_heads, head_dim),
"single_xv_view": (n_kv_heads, head_dim)
}
return attention_shapes
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment