Commit 428504e4 authored by Casper Hansen's avatar Casper Hansen
Browse files

Create ALiBi module

parent 7a3d06d6
...@@ -37,29 +37,38 @@ def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor ...@@ -37,29 +37,38 @@ def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk) return xq_out.type_as(xq), xk_out.type_as(xk)
def gen_slopes(n_heads, alibi_bias_max=8): class ALiBi(nn.Module):
_n_heads = 2 ** math.ceil(math.log2(n_heads)) def __init__(self, n_heads, max_seq_len, device, alibi_bias_max=8):
m = torch.arange(1, _n_heads + 1, dtype=torch.float32) super(ALiBi, self).__init__()
m = m.mul(alibi_bias_max / _n_heads)
slopes = 1.0 / torch.pow(2, m) # Initialize ALiBi slopes and bias
if _n_heads != n_heads: slopes, bias = self.build_alibi_bias(n_heads, max_seq_len, alibi_bias_max=alibi_bias_max)
slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads] self.slopes = nn.Parameter(slopes.float().to(device), requires_grad=False)
return slopes.view(1, n_heads, 1, 1) self.bias = nn.Parameter(bias.float().to(device), requires_grad=False)
@staticmethod
def build_alibi_bias( def gen_slopes(n_heads, alibi_bias_max=8):
n_heads, seq_len, full=False, alibi_bias_max=8, dtype=torch.float32 _n_heads = 2 ** math.ceil(math.log2(n_heads))
): m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len) m = m.mul(alibi_bias_max / _n_heads)
if full: slopes = 1.0 / torch.pow(2, m)
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32).view(
1, 1, seq_len, 1 if _n_heads != n_heads:
) slopes = torch.cat([slopes[1::2], slopes[::2]])[:n_heads]
alibi_bias = alibi_bias.abs().mul(-1)
slopes = gen_slopes(n_heads, alibi_bias_max) return slopes.view(1, n_heads, 1, 1)
alibi_bias = alibi_bias * slopes
slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1) @staticmethod
return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype) def build_alibi_bias(n_heads, seq_len, alibi_bias_max=8, dtype=torch.float32):
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len)
slopes = ALiBi.gen_slopes(n_heads, alibi_bias_max)
alibi_bias = alibi_bias * slopes
slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)
def forward(self, scores, seqlen):
scores += self.bias[..., :seqlen]
return scores
def get_attention_shapes(attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim): def get_attention_shapes(attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim):
if attention_shapes is not None: if attention_shapes is not None:
...@@ -131,9 +140,7 @@ class QuantAttentionFused(nn.Module): ...@@ -131,9 +140,7 @@ class QuantAttentionFused(nn.Module):
) )
if use_alibi: if use_alibi:
alibi_slopes, alibi_bias = build_alibi_bias(self.n_heads, max_seq_len) self.alibi = ALiBi(n_heads, max_seq_len, dev)
self.alibi_slopes = alibi_slopes.float().to(dev)
self.alibi_bias = alibi_bias.float().to(dev)
self.rotary_dim = 0 self.rotary_dim = 0
self.is_neox = False self.is_neox = False
else: else:
...@@ -199,7 +206,7 @@ class QuantAttentionFused(nn.Module): ...@@ -199,7 +206,7 @@ class QuantAttentionFused(nn.Module):
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if self.use_alibi: if self.use_alibi:
scores += self.alibi_bias[..., :seqlen] scores = self.alibi.forward(scores, seqlen)
if attention_mask is not None: if attention_mask is not None:
scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen) scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen)
...@@ -219,7 +226,7 @@ class QuantAttentionFused(nn.Module): ...@@ -219,7 +226,7 @@ class QuantAttentionFused(nn.Module):
self.cache.k, # key cache self.cache.k, # key cache
self.cache.v, # value cache self.cache.v, # value cache
None, # length per sample None, # length per sample
self.alibi_slopes, # alibi slopes self.alibi.slopes, # alibi slopes
self.start_pos, # timestep self.start_pos, # timestep
self.rotary_dim, # rotary embedding dimension self.rotary_dim, # rotary embedding dimension
10000, # rotary embedding base 10000, # rotary embedding base
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment