Commit 69733d2c authored by Casper Hansen's avatar Casper Hansen
Browse files

Create RoPE module

parent 306de683
...@@ -12,31 +12,45 @@ try: ...@@ -12,31 +12,45 @@ try:
except: except:
FT_INSTALLED = False FT_INSTALLED = False
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): class RoPE(nn.Module):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) def __init__(self, hidden_size, n_heads, max_seq_len, device):
t = torch.arange(end, device=freqs.device) # type: ignore super(RoPE, self).__init__()
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 self.freqs_cis = nn.Parameter(
return freqs_cis self.precompute_freqs_cis(hidden_size // n_heads, max_seq_len * 2).to(device),
requires_grad=False
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): )
ndim = x.ndim
assert 0 <= 1 < ndim @staticmethod
assert freqs_cis.shape == (x.shape[1], x.shape[-1]) def precompute_freqs_cis(dim: int, end: int, theta=10000.0):
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
return freqs_cis.view(*shape) t = torch.arange(end)
freqs = torch.outer(t, freqs).float()
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor): freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
xq_ = torch.view_as_complex( return freqs_cis
xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
) @staticmethod
xk_ = torch.view_as_complex( def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous() ndim = x.ndim
) assert 0 <= 1 < ndim
freqs_cis = reshape_for_broadcast(freqs_cis, xq_).to(xq_.device) assert freqs_cis.shape == (x.shape[1], x.shape[-1])
xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3) return freqs_cis.view(*shape)
return xq_out.type_as(xq), xk_out.type_as(xk)
def forward(self, xq: torch.Tensor, xk: torch.Tensor, start_pos: int, seqlen: int):
xq_ = torch.view_as_complex(
xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
xk_ = torch.view_as_complex(
xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_).to(xq_.device)
xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class ALiBi(nn.Module): class ALiBi(nn.Module):
def __init__(self, n_heads, max_seq_len, device, alibi_bias_max=8): def __init__(self, n_heads, max_seq_len, device, alibi_bias_max=8):
...@@ -101,12 +115,9 @@ class QuantAttentionFused(nn.Module): ...@@ -101,12 +115,9 @@ class QuantAttentionFused(nn.Module):
self.rotary_dim = 0 self.rotary_dim = 0
self.is_neox = False self.is_neox = False
else: else:
self.freqs_cis = precompute_freqs_cis( self.alibi = None
hidden_size // n_heads, self.rope = RoPE(hidden_size, n_heads, max_seq_len, dev)
max_seq_len * 2,
).to(dev)
self.rotary_dim = self.head_dim self.rotary_dim = self.head_dim
self.alibi_slopes = None
self.is_neox = True self.is_neox = True
def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs): def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs):
...@@ -134,7 +145,7 @@ class QuantAttentionFused(nn.Module): ...@@ -134,7 +145,7 @@ class QuantAttentionFused(nn.Module):
xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"]) xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
if not self.use_alibi: if not self.use_alibi:
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen]) xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen)
self.cache.to(xq) self.cache.to(xq)
...@@ -176,6 +187,7 @@ class QuantAttentionFused(nn.Module): ...@@ -176,6 +187,7 @@ class QuantAttentionFused(nn.Module):
xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"]) xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"]) xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])
alibi_slopes = self.alibi.slopes if self.alibi is not None else None
attention_weight = ft_inference_engine.single_query_attention( attention_weight = ft_inference_engine.single_query_attention(
xq, # query xq, # query
xk, # key xk, # key
...@@ -183,7 +195,7 @@ class QuantAttentionFused(nn.Module): ...@@ -183,7 +195,7 @@ class QuantAttentionFused(nn.Module):
self.cache.k, # key cache self.cache.k, # key cache
self.cache.v, # value cache self.cache.v, # value cache
None, # length per sample None, # length per sample
self.alibi.slopes, # alibi slopes alibi_slopes, # alibi slopes
self.start_pos, # timestep self.start_pos, # timestep
self.rotary_dim, # rotary embedding dimension self.rotary_dim, # rotary embedding dimension
10000, # rotary embedding base 10000, # rotary embedding base
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment