Commit 69733d2c authored by Casper Hansen's avatar Casper Hansen
Browse files

Create RoPE module

parent 306de683
...@@ -12,30 +12,44 @@ try: ...@@ -12,30 +12,44 @@ try:
except: except:
FT_INSTALLED = False FT_INSTALLED = False
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): class RoPE(nn.Module):
def __init__(self, hidden_size, n_heads, max_seq_len, device):
super(RoPE, self).__init__()
self.freqs_cis = nn.Parameter(
self.precompute_freqs_cis(hidden_size // n_heads, max_seq_len * 2).to(device),
requires_grad=False
)
@staticmethod
def precompute_freqs_cis(dim: int, end: int, theta=10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore t = torch.arange(end)
freqs = torch.outer(t, freqs).float() # type: ignore freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): @staticmethod
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim ndim = x.ndim
assert 0 <= 1 < ndim assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1]) assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape) return freqs_cis.view(*shape)
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor): def forward(self, xq: torch.Tensor, xk: torch.Tensor, start_pos: int, seqlen: int):
xq_ = torch.view_as_complex( xq_ = torch.view_as_complex(
xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous() xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
) )
xk_ = torch.view_as_complex( xk_ = torch.view_as_complex(
xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous() xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
) )
freqs_cis = reshape_for_broadcast(freqs_cis, xq_).to(xq_.device) freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_).to(xq_.device)
xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3) xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk) return xq_out.type_as(xq), xk_out.type_as(xk)
class ALiBi(nn.Module): class ALiBi(nn.Module):
...@@ -101,12 +115,9 @@ class QuantAttentionFused(nn.Module): ...@@ -101,12 +115,9 @@ class QuantAttentionFused(nn.Module):
self.rotary_dim = 0 self.rotary_dim = 0
self.is_neox = False self.is_neox = False
else: else:
self.freqs_cis = precompute_freqs_cis( self.alibi = None
hidden_size // n_heads, self.rope = RoPE(hidden_size, n_heads, max_seq_len, dev)
max_seq_len * 2,
).to(dev)
self.rotary_dim = self.head_dim self.rotary_dim = self.head_dim
self.alibi_slopes = None
self.is_neox = True self.is_neox = True
def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs): def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs):
...@@ -134,7 +145,7 @@ class QuantAttentionFused(nn.Module): ...@@ -134,7 +145,7 @@ class QuantAttentionFused(nn.Module):
xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"]) xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
if not self.use_alibi: if not self.use_alibi:
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen]) xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen)
self.cache.to(xq) self.cache.to(xq)
...@@ -176,6 +187,7 @@ class QuantAttentionFused(nn.Module): ...@@ -176,6 +187,7 @@ class QuantAttentionFused(nn.Module):
xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"]) xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"]) xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])
alibi_slopes = self.alibi.slopes if self.alibi is not None else None
attention_weight = ft_inference_engine.single_query_attention( attention_weight = ft_inference_engine.single_query_attention(
xq, # query xq, # query
xk, # key xk, # key
...@@ -183,7 +195,7 @@ class QuantAttentionFused(nn.Module): ...@@ -183,7 +195,7 @@ class QuantAttentionFused(nn.Module):
self.cache.k, # key cache self.cache.k, # key cache
self.cache.v, # value cache self.cache.v, # value cache
None, # length per sample None, # length per sample
self.alibi.slopes, # alibi slopes alibi_slopes, # alibi slopes
self.start_pos, # timestep self.start_pos, # timestep
self.rotary_dim, # rotary embedding dimension self.rotary_dim, # rotary embedding dimension
10000, # rotary embedding base 10000, # rotary embedding base
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment