import math import torch import torch.nn as nn import awq_inference_engine from torch.nn import functional as F class QuantLlamaRotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / ( self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) ) self.register_buffer("inv_freq", inv_freq) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype(), ) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange( self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype ) freqs = torch.einsum("i,j->ij", t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) cos = freqs.cos() sin = freqs.sin() cache = torch.cat((cos, sin), dim=-1) # self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) # self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("cos_sin_cache", cache.half(), persistent=False) def forward( self, query: torch.Tensor, key: torch.Tensor, positions: torch.Tensor, ): # Apply rotary embedding to the query and key before passing them # to the attention op. # print(positions.shape, query.shape, key.shape, self.cos_sin_cache.shape) query = query.contiguous() key = key.contiguous() awq_inference_engine.rotary_embedding_neox( positions, query, key, self.dim, self.cos_sin_cache ) return query, key class QuantLlamaAttentionFused(nn.Module): def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, max_position_embeddings): super().__init__() self.hidden_size = hidden_size self.n_local_heads = num_heads self.head_dim = self.hidden_size // num_heads self.qkv_proj = qkv_layer self.o_proj = o_proj self.start_pos = 0 # following fastertransformer definition self.cache_v = ( torch.zeros( ( 1, self.n_local_heads, max_position_embeddings, self.head_dim, ) ).to(dev).half() ) # 8: pack 8 fp16 in FT, if fp32 then use 4 self.cache_k = ( torch.zeros( ( 1, self.n_local_heads, self.head_dim // 8, max_position_embeddings, 8, ) ).to(dev).half() ) self.rotary_emb = QuantLlamaRotaryEmbedding( dim=hidden_size // num_heads, max_position_embeddings=max_position_embeddings, device=dev ) def forward( self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False ): bsz, seqlen, _ = hidden_states.shape xqkv = self.qkv_proj(hidden_states) xqkv = xqkv.view(bsz, seqlen, -1, self.n_local_heads, self.head_dim) xq = xqkv[:, :, 0] xk = xqkv[:, :, 1] xv = xqkv[:, :, 2] if seqlen > 1: xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim) xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim) xq, xk = self.rotary_emb(xq, xk, position_ids) self.cache_k = self.cache_k.to(xq) self.cache_v = self.cache_v.to(xq) values_store = xv.transpose(2, 1) keys_store = ( xk.reshape(bsz, seqlen, self.n_local_heads, self.head_dim // 8, 8) .permute(0, 2, 3, 1, 4) .contiguous() ) self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store keys = xk values = xv past_key_value = (xk, xv) if use_cache else None xq = xq.transpose(1, 2) keys = keys.transpose(1, 2) values = values.transpose(1, 2) scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen) scores = F.softmax(scores.float(), dim=-1).type_as(xq) output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) else: xq = xq[:, 0, :, :] xk = xk[:, 0, :, :] xv = xv[:, 0, :, :] past_key_value = (xk, xv) if use_cache else None output = awq_inference_engine.single_query_attention( xq, xk, xv, self.cache_k, self.cache_v, None, None, self.start_pos, self.head_dim, 10000, True, ) output = output.reshape(bsz, 1, -1) attn_output = self.o_proj(output) if use_cache: self.start_pos += seqlen else: self.start_pos = 0 return attn_output, None, past_key_value