Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
bf76e108
Commit
bf76e108
authored
Sep 20, 2023
by
Casper Hansen
Browse files
Removed unused module
parent
63a12504
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
54 deletions
+0
-54
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+0
-54
No files found.
awq/modules/fused/attn.py
View file @
bf76e108
...
...
@@ -60,60 +60,6 @@ def build_alibi_bias(
return
slopes
.
to
(
dtype
=
dtype
),
alibi_bias
.
to
(
dtype
=
dtype
)
class
QuantLlamaRotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
):
super
().
__init__
()
self
.
dim
=
dim
self
.
max_position_embeddings
=
max_position_embeddings
self
.
base
=
base
inv_freq
=
1.0
/
(
self
.
base
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
).
float
().
to
(
device
)
/
self
.
dim
)
)
self
.
register_buffer
(
"inv_freq"
,
inv_freq
)
# Build here to make `torch.jit.trace` work.
self
.
_set_cos_sin_cache
(
seq_len
=
max_position_embeddings
,
device
=
self
.
inv_freq
.
device
,
dtype
=
torch
.
get_default_dtype
(),
)
def
_set_cos_sin_cache
(
self
,
seq_len
,
device
,
dtype
):
self
.
max_seq_len_cached
=
seq_len
t
=
torch
.
arange
(
self
.
max_seq_len_cached
,
device
=
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
einsum
(
"i,j->ij"
,
t
,
self
.
inv_freq
)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
.
half
(),
persistent
=
False
)
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
):
# Apply rotary embedding to the query and key before passing them
# to the attention op.
# print(positions.shape, query.shape, key.shape, self.cos_sin_cache.shape)
query
=
query
.
contiguous
()
key
=
key
.
contiguous
()
awq_inference_engine
.
rotary_embedding_neox
(
positions
,
query
,
key
,
self
.
dim
,
self
.
cos_sin_cache
)
return
query
,
key
class
QuantAttentionFused
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
num_heads
,
qkv_layer
,
o_proj
,
dev
,
max_seq_len
,
use_alibi
=
False
,
attention_shapes
=
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment