Commit 73c5e2bf authored by Casper Hansen's avatar Casper Hansen
Browse files

Fix alibi, rotary arg, neox, arg

parent 06bea896
......@@ -144,13 +144,15 @@ class QuantAttentionFused(nn.Module):
self.alibi_slopes = alibi_slopes.float().to(dev)
self.alibi_bias = alibi_bias.float().to(dev)
self.rotary_dim = 0
self.is_neox = False
else:
self.freqs_cis = precompute_freqs_cis(
hidden_size // num_heads,
max_seq_len * 2,
).to(dev)
self.rotary_dim = 0
self.rotary_dim = self.head_dim
self.alibi_slopes = None
self.is_neox = True
def forward(
self,
......@@ -168,7 +170,8 @@ class QuantAttentionFused(nn.Module):
xk = xk.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xv = xv.view(bsz, seqlen, self.n_local_heads, self.head_dim)
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen])
if not self.use_alibi:
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen])
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
......@@ -217,7 +220,7 @@ class QuantAttentionFused(nn.Module):
self.start_pos, # timestep
self.rotary_dim, # rotary embedding dimension
10000, # rotary embedding base
False, # is neox
self.is_neox, # is neox
)
output = output.reshape(bsz, 1, -1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment