Commit a8c9afd5 authored by Casper Hansen's avatar Casper Hansen
Browse files

Refactor view/reshaping into a predefined dict

parent 4517b3f2
......@@ -124,19 +124,28 @@ class QuantAttentionFused(nn.Module):
self.start_pos = 0
self.use_alibi = use_alibi
self.cache_batch_size = 1
# following fastertransformer definition
self.attention_shapes = {
# following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_local_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (self.cache_batch_size, self.n_local_heads, self.head_dim // 8, max_seq_len, 8,),
"xqkv_view": (-1, self.n_local_heads, self.head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0],
"xk_slice": lambda xqkv: xqkv[:, :, 1],
"xv_slice": lambda xqkv: xqkv[:, :, 2],
"xk_view": (self.n_local_heads, self.head_dim),
"xv_view": (self.n_local_heads, self.head_dim),
"single_xq_view": (self.n_local_heads, self.head_dim),
"single_xk_view": (self.n_local_heads, self.head_dim),
"single_xv_view": (self.n_local_heads, self.head_dim)
}
self.cache_v = (
torch.zeros(
( self.cache_batch_size, self.n_local_heads, max_seq_len, self.head_dim, )
).to(dev).half()
torch.zeros(self.attention_shapes["cache_v"]).to(dev).half()
)
# 8: pack 8 fp16 in FT, if fp32 then use 4
self.cache_k = (
torch.zeros(
( self.cache_batch_size, self.n_local_heads, self.head_dim // 8, max_seq_len, 8, )
).to(dev).half()
torch.zeros(self.attention_shapes["cache_k"]).to(dev).half()
)
if use_alibi:
......@@ -160,15 +169,15 @@ class QuantAttentionFused(nn.Module):
):
bsz, seqlen, _ = hidden_states.shape
xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view(bsz, seqlen, -1, self.n_local_heads, self.head_dim)
xq = xqkv[:, :, 0]
xk = xqkv[:, :, 1]
xv = xqkv[:, :, 2]
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
xq = self.attention_shapes["xq_slice"](xqkv)
xk = self.attention_shapes["xk_slice"](xqkv)
xv = self.attention_shapes["xv_slice"](xqkv)
if seqlen > 1:
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
if not self.use_alibi:
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen])
......@@ -205,9 +214,13 @@ class QuantAttentionFused(nn.Module):
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
else:
xq = xq[:, 0, :, :]
xk = xk[:, 0, :, :]
xv = xv[:, 0, :, :]
# xq = xq[:, 0, :, :]
# xk = xk[:, 0, :, :]
# xv = xv[:, 0, :, :]
xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"])
xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])
past_key_value = (xk, xv) if use_cache else None
attention_weight = awq_inference_engine.single_query_attention(
xq, # query
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment