Commit a8c9afd5 authored by Casper Hansen's avatar Casper Hansen
Browse files

Refactor view/reshaping into a predefined dict

parent 4517b3f2
...@@ -124,19 +124,28 @@ class QuantAttentionFused(nn.Module): ...@@ -124,19 +124,28 @@ class QuantAttentionFused(nn.Module):
self.start_pos = 0 self.start_pos = 0
self.use_alibi = use_alibi self.use_alibi = use_alibi
self.cache_batch_size = 1 self.cache_batch_size = 1
self.attention_shapes = {
# following fastertransformer definition # following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_local_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (self.cache_batch_size, self.n_local_heads, self.head_dim // 8, max_seq_len, 8,),
"xqkv_view": (-1, self.n_local_heads, self.head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0],
"xk_slice": lambda xqkv: xqkv[:, :, 1],
"xv_slice": lambda xqkv: xqkv[:, :, 2],
"xk_view": (self.n_local_heads, self.head_dim),
"xv_view": (self.n_local_heads, self.head_dim),
"single_xq_view": (self.n_local_heads, self.head_dim),
"single_xk_view": (self.n_local_heads, self.head_dim),
"single_xv_view": (self.n_local_heads, self.head_dim)
}
self.cache_v = ( self.cache_v = (
torch.zeros( torch.zeros(self.attention_shapes["cache_v"]).to(dev).half()
( self.cache_batch_size, self.n_local_heads, max_seq_len, self.head_dim, )
).to(dev).half()
) )
# 8: pack 8 fp16 in FT, if fp32 then use 4
self.cache_k = ( self.cache_k = (
torch.zeros( torch.zeros(self.attention_shapes["cache_k"]).to(dev).half()
( self.cache_batch_size, self.n_local_heads, self.head_dim // 8, max_seq_len, 8, )
).to(dev).half()
) )
if use_alibi: if use_alibi:
...@@ -160,15 +169,15 @@ class QuantAttentionFused(nn.Module): ...@@ -160,15 +169,15 @@ class QuantAttentionFused(nn.Module):
): ):
bsz, seqlen, _ = hidden_states.shape bsz, seqlen, _ = hidden_states.shape
xqkv = self.qkv_proj(hidden_states) xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view(bsz, seqlen, -1, self.n_local_heads, self.head_dim) xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
xq = xqkv[:, :, 0] xq = self.attention_shapes["xq_slice"](xqkv)
xk = xqkv[:, :, 1] xk = self.attention_shapes["xk_slice"](xqkv)
xv = xqkv[:, :, 2] xv = self.attention_shapes["xv_slice"](xqkv)
if seqlen > 1: if seqlen > 1:
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim) xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim) xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
if not self.use_alibi: if not self.use_alibi:
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen]) xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen])
...@@ -205,9 +214,13 @@ class QuantAttentionFused(nn.Module): ...@@ -205,9 +214,13 @@ class QuantAttentionFused(nn.Module):
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
else: else:
xq = xq[:, 0, :, :] # xq = xq[:, 0, :, :]
xk = xk[:, 0, :, :] # xk = xk[:, 0, :, :]
xv = xv[:, 0, :, :] # xv = xv[:, 0, :, :]
xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"])
xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])
past_key_value = (xk, xv) if use_cache else None past_key_value = (xk, xv) if use_cache else None
attention_weight = awq_inference_engine.single_query_attention( attention_weight = awq_inference_engine.single_query_attention(
xq, # query xq, # query
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment