Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
adc5304b
Commit
adc5304b
authored
Sep 09, 2023
by
Casper Hansen
Browse files
Implement ALiBi.
parent
48be2ee2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
67 additions
and
50 deletions
+67
-50
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+67
-50
No files found.
awq/modules/fused/attn.py
View file @
adc5304b
...
@@ -34,6 +34,30 @@ def apply_rotary_emb(
...
@@ -34,6 +34,30 @@ def apply_rotary_emb(
xk_out
=
torch
.
view_as_real
(
xk_
*
freqs_cis
).
transpose
(
-
2
,
-
1
).
flatten
(
3
)
xk_out
=
torch
.
view_as_real
(
xk_
*
freqs_cis
).
transpose
(
-
2
,
-
1
).
flatten
(
3
)
return
xq_out
.
type_as
(
xq
),
xk_out
.
type_as
(
xk
)
return
xq_out
.
type_as
(
xq
),
xk_out
.
type_as
(
xk
)
def
gen_slopes
(
n_heads
,
alibi_bias_max
=
8
):
_n_heads
=
2
**
math
.
ceil
(
math
.
log2
(
n_heads
))
m
=
torch
.
arange
(
1
,
_n_heads
+
1
,
dtype
=
torch
.
float32
)
m
=
m
.
mul
(
alibi_bias_max
/
_n_heads
)
slopes
=
1.0
/
torch
.
pow
(
2
,
m
)
if
_n_heads
!=
n_heads
:
slopes
=
torch
.
concat
([
slopes
[
1
::
2
],
slopes
[::
2
]])[:
n_heads
]
return
slopes
.
view
(
1
,
n_heads
,
1
,
1
)
def
build_alibi_bias
(
n_heads
,
seq_len
,
full
=
False
,
alibi_bias_max
=
8
,
dtype
=
torch
.
float32
):
alibi_bias
=
torch
.
arange
(
1
-
seq_len
,
1
,
dtype
=
torch
.
int32
).
view
(
1
,
1
,
1
,
seq_len
)
if
full
:
alibi_bias
=
alibi_bias
-
torch
.
arange
(
1
-
seq_len
,
1
,
dtype
=
torch
.
int32
).
view
(
1
,
1
,
seq_len
,
1
)
alibi_bias
=
alibi_bias
.
abs
().
mul
(
-
1
)
slopes
=
gen_slopes
(
n_heads
,
alibi_bias_max
)
alibi_bias
=
alibi_bias
*
slopes
slopes
=
slopes
.
squeeze
(
0
).
squeeze
(
-
1
).
squeeze
(
-
1
)
return
slopes
.
to
(
dtype
=
dtype
),
alibi_bias
.
to
(
dtype
=
dtype
)
class
QuantLlamaRotaryEmbedding
(
nn
.
Module
):
class
QuantLlamaRotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
):
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
):
...
@@ -89,8 +113,8 @@ class QuantLlamaRotaryEmbedding(nn.Module):
...
@@ -89,8 +113,8 @@ class QuantLlamaRotaryEmbedding(nn.Module):
)
)
return
query
,
key
return
query
,
key
class
Quant
Llama
AttentionFused
(
nn
.
Module
):
class
QuantAttentionFused
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
num_heads
,
qkv_layer
,
o_proj
,
dev
,
max_
position_embeddings
):
def
__init__
(
self
,
hidden_size
,
num_heads
,
qkv_layer
,
o_proj
,
dev
,
max_
seq_len
,
use_alibi
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
n_local_heads
=
num_heads
self
.
n_local_heads
=
num_heads
...
@@ -98,44 +122,35 @@ class QuantLlamaAttentionFused(nn.Module):
...
@@ -98,44 +122,35 @@ class QuantLlamaAttentionFused(nn.Module):
self
.
qkv_proj
=
qkv_layer
self
.
qkv_proj
=
qkv_layer
self
.
o_proj
=
o_proj
self
.
o_proj
=
o_proj
self
.
start_pos
=
0
self
.
start_pos
=
0
self
.
use_sdpa_torch
=
False
self
.
use_alibi
=
use_alibi
self
.
cache_batch_size
=
1
# following fastertransformer definition
# following fastertransformer definition
self
.
cache_v
=
(
self
.
cache_v
=
(
torch
.
zeros
(
torch
.
zeros
(
(
1
,
self
.
n_local_heads
,
max_
position_embeddings
,
self
.
head_dim
,
)
(
self
.
cache_batch_size
,
self
.
n_local_heads
,
max_
seq_len
,
self
.
head_dim
,
)
).
to
(
dev
).
half
()
).
to
(
dev
).
half
()
)
)
# 8: pack 8 fp16 in FT, if fp32 then use 4
# 8: pack 8 fp16 in FT, if fp32 then use 4
self
.
cache_k
=
(
self
.
cache_k
=
(
torch
.
zeros
(
torch
.
zeros
(
(
1
,
self
.
n_local_heads
,
self
.
head_dim
//
8
,
max_
position_embeddings
,
8
,
)
(
self
.
cache_batch_size
,
self
.
n_local_heads
,
self
.
head_dim
//
8
,
max_
seq_len
,
8
,
)
).
to
(
dev
).
half
()
).
to
(
dev
).
half
()
)
)
if
use_alibi
:
alibi_slopes
,
alibi_bias
=
build_alibi_bias
(
self
.
n_local_heads
,
max_seq_len
)
self
.
alibi_slopes
=
alibi_slopes
.
float
().
to
(
dev
)
self
.
alibi_bias
=
alibi_bias
.
float
().
to
(
dev
)
self
.
rotary_dim
=
0
else
:
self
.
freqs_cis
=
precompute_freqs_cis
(
self
.
freqs_cis
=
precompute_freqs_cis
(
hidden_size
//
num_heads
,
hidden_size
//
num_heads
,
max_position_embeddings
*
2
,
max_seq_len
*
2
,
).
to
(
dev
)
).
to
(
dev
)
self
.
rotary_dim
=
0
def
_multi_query_attention_torch
(
self
,
query
,
key
,
value
,
batch_size
,
seqlen
,
use_cache
,
past_key_value
):
self
.
alibi_slopes
=
None
# faster prompt processing
query
=
query
.
view
(
batch_size
,
seqlen
,
self
.
n_local_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key
=
key
.
view
(
batch_size
,
seqlen
,
self
.
n_local_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value
=
value
.
view
(
batch_size
,
seqlen
,
self
.
n_local_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
if
use_cache
:
key
=
key
.
contiguous
()
value
=
value
.
contiguous
()
query
=
query
.
contiguous
()
output
=
F
.
scaled_dot_product_attention
(
query
,
key
,
value
,
is_causal
=
past_key_value
is
None
)
del
query
,
key
,
value
output
=
output
.
transpose
(
1
,
2
).
reshape
(
batch_size
,
seqlen
,
self
.
hidden_size
)
return
output
def
forward
(
def
forward
(
self
,
self
,
...
@@ -172,15 +187,17 @@ class QuantLlamaAttentionFused(nn.Module):
...
@@ -172,15 +187,17 @@ class QuantLlamaAttentionFused(nn.Module):
values
=
xv
values
=
xv
past_key_value
=
(
xk
,
xv
)
if
use_cache
else
None
past_key_value
=
(
xk
,
xv
)
if
use_cache
else
None
if
self
.
use_sdpa_torch
:
output
=
self
.
_multi_query_attention_torch
(
xq
,
xk
,
xv
,
bsz
,
seqlen
,
True
,
past_key_value
)
else
:
xq
=
xq
.
transpose
(
1
,
2
)
xq
=
xq
.
transpose
(
1
,
2
)
keys
=
keys
.
transpose
(
1
,
2
)
keys
=
keys
.
transpose
(
1
,
2
)
values
=
values
.
transpose
(
1
,
2
)
values
=
values
.
transpose
(
1
,
2
)
scores
=
torch
.
matmul
(
xq
,
keys
.
transpose
(
2
,
3
))
/
math
.
sqrt
(
self
.
head_dim
)
scores
=
torch
.
matmul
(
xq
,
keys
.
transpose
(
2
,
3
))
/
math
.
sqrt
(
self
.
head_dim
)
if
self
.
use_alibi
:
scores
+=
self
.
alibi_bias
[...,
:
seqlen
]
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
scores
=
scores
+
attention_mask
# (bs, n_local_heads, slen, cache_len + slen)
scores
=
scores
+
attention_mask
# (bs, n_local_heads, slen, cache_len + slen)
scores
=
F
.
softmax
(
scores
.
float
(),
dim
=-
1
).
type_as
(
xq
)
scores
=
F
.
softmax
(
scores
.
float
(),
dim
=-
1
).
type_as
(
xq
)
output
=
torch
.
matmul
(
scores
,
values
)
# (bs, n_local_heads, slen, head_dim)
output
=
torch
.
matmul
(
scores
,
values
)
# (bs, n_local_heads, slen, head_dim)
output
=
output
.
transpose
(
1
,
2
).
contiguous
().
view
(
bsz
,
seqlen
,
-
1
)
output
=
output
.
transpose
(
1
,
2
).
contiguous
().
view
(
bsz
,
seqlen
,
-
1
)
...
@@ -190,17 +207,17 @@ class QuantLlamaAttentionFused(nn.Module):
...
@@ -190,17 +207,17 @@ class QuantLlamaAttentionFused(nn.Module):
xv
=
xv
[:,
0
,
:,
:]
xv
=
xv
[:,
0
,
:,
:]
past_key_value
=
(
xk
,
xv
)
if
use_cache
else
None
past_key_value
=
(
xk
,
xv
)
if
use_cache
else
None
output
=
awq_inference_engine
.
single_query_attention
(
output
=
awq_inference_engine
.
single_query_attention
(
xq
,
xq
,
# query
xk
,
xk
,
# key
xv
,
xv
,
# value
self
.
cache_k
,
self
.
cache_k
,
# key cache
self
.
cache_v
,
self
.
cache_v
,
# value cache
None
,
None
,
# length per sample
None
,
self
.
alibi_slopes
,
# alibi slopes
self
.
start_pos
,
self
.
start_pos
,
# timestep
self
.
head_dim
,
self
.
rotary_dim
,
# rotary embedding dimension
10000
,
10000
,
# rotary embedding base
True
,
False
,
# is neox
)
)
output
=
output
.
reshape
(
bsz
,
1
,
-
1
)
output
=
output
.
reshape
(
bsz
,
1
,
-
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment