Commit adc5304b authored by Casper Hansen's avatar Casper Hansen
Browse files

Implement ALiBi.

parent 48be2ee2
...@@ -34,6 +34,30 @@ def apply_rotary_emb( ...@@ -34,6 +34,30 @@ def apply_rotary_emb(
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk) return xq_out.type_as(xq), xk_out.type_as(xk)
def gen_slopes(n_heads, alibi_bias_max=8):
_n_heads = 2 ** math.ceil(math.log2(n_heads))
m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
m = m.mul(alibi_bias_max / _n_heads)
slopes = 1.0 / torch.pow(2, m)
if _n_heads != n_heads:
slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
return slopes.view(1, n_heads, 1, 1)
def build_alibi_bias(
n_heads, seq_len, full=False, alibi_bias_max=8, dtype=torch.float32
):
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len)
if full:
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32).view(
1, 1, seq_len, 1
)
alibi_bias = alibi_bias.abs().mul(-1)
slopes = gen_slopes(n_heads, alibi_bias_max)
alibi_bias = alibi_bias * slopes
slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)
class QuantLlamaRotaryEmbedding(nn.Module): class QuantLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
...@@ -89,8 +113,8 @@ class QuantLlamaRotaryEmbedding(nn.Module): ...@@ -89,8 +113,8 @@ class QuantLlamaRotaryEmbedding(nn.Module):
) )
return query, key return query, key
class QuantLlamaAttentionFused(nn.Module): class QuantAttentionFused(nn.Module):
def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, max_position_embeddings): def __init__(self, hidden_size, num_heads, qkv_layer, o_proj, dev, max_seq_len, use_alibi=False):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.n_local_heads = num_heads self.n_local_heads = num_heads
...@@ -98,44 +122,35 @@ class QuantLlamaAttentionFused(nn.Module): ...@@ -98,44 +122,35 @@ class QuantLlamaAttentionFused(nn.Module):
self.qkv_proj = qkv_layer self.qkv_proj = qkv_layer
self.o_proj = o_proj self.o_proj = o_proj
self.start_pos = 0 self.start_pos = 0
self.use_sdpa_torch = False self.use_alibi = use_alibi
self.cache_batch_size = 1
# following fastertransformer definition # following fastertransformer definition
self.cache_v = ( self.cache_v = (
torch.zeros( torch.zeros(
( 1, self.n_local_heads, max_position_embeddings, self.head_dim, ) ( self.cache_batch_size, self.n_local_heads, max_seq_len, self.head_dim, )
).to(dev).half() ).to(dev).half()
) )
# 8: pack 8 fp16 in FT, if fp32 then use 4 # 8: pack 8 fp16 in FT, if fp32 then use 4
self.cache_k = ( self.cache_k = (
torch.zeros( torch.zeros(
( 1, self.n_local_heads, self.head_dim // 8, max_position_embeddings, 8, ) ( self.cache_batch_size, self.n_local_heads, self.head_dim // 8, max_seq_len, 8, )
).to(dev).half() ).to(dev).half()
) )
if use_alibi:
alibi_slopes, alibi_bias = build_alibi_bias(self.n_local_heads, max_seq_len)
self.alibi_slopes = alibi_slopes.float().to(dev)
self.alibi_bias = alibi_bias.float().to(dev)
self.rotary_dim = 0
else:
self.freqs_cis = precompute_freqs_cis( self.freqs_cis = precompute_freqs_cis(
hidden_size // num_heads, hidden_size // num_heads,
max_position_embeddings * 2, max_seq_len * 2,
).to(dev) ).to(dev)
self.rotary_dim = 0
def _multi_query_attention_torch(self, query, key, value, batch_size, seqlen, use_cache, past_key_value): self.alibi_slopes = None
# faster prompt processing
query = query.view(batch_size, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
key = key.view(batch_size, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
value = value.view(batch_size, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
if use_cache:
key = key.contiguous()
value = value.contiguous()
query = query.contiguous()
output = F.scaled_dot_product_attention(query, key, value, is_causal=past_key_value is None)
del query, key, value
output = output.transpose(1, 2).reshape(batch_size, seqlen, self.hidden_size)
return output
def forward( def forward(
self, self,
...@@ -172,15 +187,17 @@ class QuantLlamaAttentionFused(nn.Module): ...@@ -172,15 +187,17 @@ class QuantLlamaAttentionFused(nn.Module):
values = xv values = xv
past_key_value = (xk, xv) if use_cache else None past_key_value = (xk, xv) if use_cache else None
if self.use_sdpa_torch:
output = self._multi_query_attention_torch(xq, xk, xv, bsz, seqlen, True, past_key_value)
else:
xq = xq.transpose(1, 2) xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2) keys = keys.transpose(1, 2)
values = values.transpose(1, 2) values = values.transpose(1, 2)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if self.use_alibi:
scores += self.alibi_bias[..., :seqlen]
if attention_mask is not None: if attention_mask is not None:
scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen) scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq) scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
...@@ -190,17 +207,17 @@ class QuantLlamaAttentionFused(nn.Module): ...@@ -190,17 +207,17 @@ class QuantLlamaAttentionFused(nn.Module):
xv = xv[:, 0, :, :] xv = xv[:, 0, :, :]
past_key_value = (xk, xv) if use_cache else None past_key_value = (xk, xv) if use_cache else None
output = awq_inference_engine.single_query_attention( output = awq_inference_engine.single_query_attention(
xq, xq, # query
xk, xk, # key
xv, xv, # value
self.cache_k, self.cache_k, # key cache
self.cache_v, self.cache_v, # value cache
None, None, # length per sample
None, self.alibi_slopes, # alibi slopes
self.start_pos, self.start_pos, # timestep
self.head_dim, self.rotary_dim, # rotary embedding dimension
10000, 10000, # rotary embedding base
True, False, # is neox
) )
output = output.reshape(bsz, 1, -1) output = output.reshape(bsz, 1, -1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment