Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
428504e4
Commit
428504e4
authored
Oct 06, 2023
by
Casper Hansen
Browse files
Create ALiBi module
parent
7a3d06d6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
28 deletions
+35
-28
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+35
-28
No files found.
awq/modules/fused/attn.py
View file @
428504e4
...
@@ -37,29 +37,38 @@ def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor
...
@@ -37,29 +37,38 @@ def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor
xk_out
=
torch
.
view_as_real
(
xk_
*
freqs_cis
).
transpose
(
-
2
,
-
1
).
flatten
(
3
)
xk_out
=
torch
.
view_as_real
(
xk_
*
freqs_cis
).
transpose
(
-
2
,
-
1
).
flatten
(
3
)
return
xq_out
.
type_as
(
xq
),
xk_out
.
type_as
(
xk
)
return
xq_out
.
type_as
(
xq
),
xk_out
.
type_as
(
xk
)
def
gen_slopes
(
n_heads
,
alibi_bias_max
=
8
):
class
ALiBi
(
nn
.
Module
):
_n_heads
=
2
**
math
.
ceil
(
math
.
log2
(
n_heads
))
def
__init__
(
self
,
n_heads
,
max_seq_len
,
device
,
alibi_bias_max
=
8
):
m
=
torch
.
arange
(
1
,
_n_heads
+
1
,
dtype
=
torch
.
float32
)
super
(
ALiBi
,
self
).
__init__
()
m
=
m
.
mul
(
alibi_bias_max
/
_n_heads
)
slopes
=
1.0
/
torch
.
pow
(
2
,
m
)
# Initialize ALiBi slopes and bias
if
_n_heads
!=
n_heads
:
slopes
,
bias
=
self
.
build_alibi_bias
(
n_heads
,
max_seq_len
,
alibi_bias_max
=
alibi_bias_max
)
slopes
=
torch
.
concat
([
slopes
[
1
::
2
],
slopes
[::
2
]])[:
n_heads
]
self
.
slopes
=
nn
.
Parameter
(
slopes
.
float
().
to
(
device
),
requires_grad
=
False
)
return
slopes
.
view
(
1
,
n_heads
,
1
,
1
)
self
.
bias
=
nn
.
Parameter
(
bias
.
float
().
to
(
device
),
requires_grad
=
False
)
@
staticmethod
def
build_alibi_bias
(
def
gen_slopes
(
n_heads
,
alibi_bias_max
=
8
):
n_heads
,
seq_len
,
full
=
False
,
alibi_bias_max
=
8
,
dtype
=
torch
.
float32
_n_heads
=
2
**
math
.
ceil
(
math
.
log2
(
n_heads
))
):
m
=
torch
.
arange
(
1
,
_n_heads
+
1
,
dtype
=
torch
.
float32
)
alibi_bias
=
torch
.
arange
(
1
-
seq_len
,
1
,
dtype
=
torch
.
int32
).
view
(
1
,
1
,
1
,
seq_len
)
m
=
m
.
mul
(
alibi_bias_max
/
_n_heads
)
if
full
:
slopes
=
1.0
/
torch
.
pow
(
2
,
m
)
alibi_bias
=
alibi_bias
-
torch
.
arange
(
1
-
seq_len
,
1
,
dtype
=
torch
.
int32
).
view
(
1
,
1
,
seq_len
,
1
if
_n_heads
!=
n_heads
:
)
slopes
=
torch
.
cat
([
slopes
[
1
::
2
],
slopes
[::
2
]])[:
n_heads
]
alibi_bias
=
alibi_bias
.
abs
().
mul
(
-
1
)
slopes
=
gen_slopes
(
n_heads
,
alibi_bias_max
)
return
slopes
.
view
(
1
,
n_heads
,
1
,
1
)
alibi_bias
=
alibi_bias
*
slopes
slopes
=
slopes
.
squeeze
(
0
).
squeeze
(
-
1
).
squeeze
(
-
1
)
@
staticmethod
return
slopes
.
to
(
dtype
=
dtype
),
alibi_bias
.
to
(
dtype
=
dtype
)
def
build_alibi_bias
(
n_heads
,
seq_len
,
alibi_bias_max
=
8
,
dtype
=
torch
.
float32
):
alibi_bias
=
torch
.
arange
(
1
-
seq_len
,
1
,
dtype
=
torch
.
int32
).
view
(
1
,
1
,
1
,
seq_len
)
slopes
=
ALiBi
.
gen_slopes
(
n_heads
,
alibi_bias_max
)
alibi_bias
=
alibi_bias
*
slopes
slopes
=
slopes
.
squeeze
(
0
).
squeeze
(
-
1
).
squeeze
(
-
1
)
return
slopes
.
to
(
dtype
=
dtype
),
alibi_bias
.
to
(
dtype
=
dtype
)
def
forward
(
self
,
scores
,
seqlen
):
scores
+=
self
.
bias
[...,
:
seqlen
]
return
scores
def
get_attention_shapes
(
attention_shapes
,
max_seq_len
,
cache_batch_size
,
n_heads
,
n_kv_heads
,
head_dim
):
def
get_attention_shapes
(
attention_shapes
,
max_seq_len
,
cache_batch_size
,
n_heads
,
n_kv_heads
,
head_dim
):
if
attention_shapes
is
not
None
:
if
attention_shapes
is
not
None
:
...
@@ -131,9 +140,7 @@ class QuantAttentionFused(nn.Module):
...
@@ -131,9 +140,7 @@ class QuantAttentionFused(nn.Module):
)
)
if
use_alibi
:
if
use_alibi
:
alibi_slopes
,
alibi_bias
=
build_alibi_bias
(
self
.
n_heads
,
max_seq_len
)
self
.
alibi
=
ALiBi
(
n_heads
,
max_seq_len
,
dev
)
self
.
alibi_slopes
=
alibi_slopes
.
float
().
to
(
dev
)
self
.
alibi_bias
=
alibi_bias
.
float
().
to
(
dev
)
self
.
rotary_dim
=
0
self
.
rotary_dim
=
0
self
.
is_neox
=
False
self
.
is_neox
=
False
else
:
else
:
...
@@ -199,7 +206,7 @@ class QuantAttentionFused(nn.Module):
...
@@ -199,7 +206,7 @@ class QuantAttentionFused(nn.Module):
scores
=
torch
.
matmul
(
xq
,
keys
.
transpose
(
2
,
3
))
/
math
.
sqrt
(
self
.
head_dim
)
scores
=
torch
.
matmul
(
xq
,
keys
.
transpose
(
2
,
3
))
/
math
.
sqrt
(
self
.
head_dim
)
if
self
.
use_alibi
:
if
self
.
use_alibi
:
scores
+
=
self
.
alibi
_bias
[...
,
:
seqlen
]
scores
=
self
.
alibi
.
forward
(
scores
,
seqlen
)
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
scores
=
scores
+
attention_mask
# (bs, n_local_heads, slen, cache_len + slen)
scores
=
scores
+
attention_mask
# (bs, n_local_heads, slen, cache_len + slen)
...
@@ -219,7 +226,7 @@ class QuantAttentionFused(nn.Module):
...
@@ -219,7 +226,7 @@ class QuantAttentionFused(nn.Module):
self
.
cache
.
k
,
# key cache
self
.
cache
.
k
,
# key cache
self
.
cache
.
v
,
# value cache
self
.
cache
.
v
,
# value cache
None
,
# length per sample
None
,
# length per sample
self
.
alibi
_
slopes
,
# alibi slopes
self
.
alibi
.
slopes
,
# alibi slopes
self
.
start_pos
,
# timestep
self
.
start_pos
,
# timestep
self
.
rotary_dim
,
# rotary embedding dimension
self
.
rotary_dim
,
# rotary embedding dimension
10000
,
# rotary embedding base
10000
,
# rotary embedding base
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment