Commit 64e6b3e1 authored by Casper Hansen's avatar Casper Hansen
Browse files

Use apply_rotary_emb

parent a11c313a
......@@ -4,6 +4,36 @@ import torch.nn as nn
import awq_inference_engine
from torch.nn import functional as F
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
):
xq_ = torch.view_as_complex(
xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
xk_ = torch.view_as_complex(
xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class QuantLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
......@@ -37,8 +67,6 @@ class QuantLlamaRotaryEmbedding(nn.Module):
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
# self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
# self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
def forward(
......@@ -61,7 +89,6 @@ class QuantLlamaRotaryEmbedding(nn.Module):
)
return query, key
class QuantLlamaAttentionFused(nn.Module):
def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, max_position_embeddings):
super().__init__()
......@@ -85,12 +112,10 @@ class QuantLlamaAttentionFused(nn.Module):
( 1, self.n_local_heads, self.head_dim // 8, max_position_embeddings, 8, )
).to(dev).half()
)
self.rotary_emb = QuantLlamaRotaryEmbedding(
dim=hidden_size // num_heads,
max_position_embeddings=max_position_embeddings,
device=dev
)
self.freqs_cis = precompute_freqs_cis(
hidden_size // num_heads,
max_position_embeddings * 2,
).to(dev)
def forward(
self,
......@@ -108,7 +133,7 @@ class QuantLlamaAttentionFused(nn.Module):
xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xq, xk = self.rotary_emb(xq, xk, position_ids)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen])
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment