Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
d3625d1c
Commit
d3625d1c
authored
Sep 13, 2023
by
Casper Hansen
Browse files
Set batch size for attention shapes
parent
fdff74d6
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
4 deletions
+14
-4
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+9
-2
awq/modules/fused/block.py
awq/modules/fused/block.py
+5
-2
No files found.
awq/modules/fused/attn.py
View file @
d3625d1c
import
os
import
math
import
torch
import
torch.nn
as
nn
...
...
@@ -114,7 +115,8 @@ class QuantLlamaRotaryEmbedding(nn.Module):
return
query
,
key
class
QuantAttentionFused
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
num_heads
,
qkv_layer
,
o_proj
,
dev
,
max_seq_len
,
use_alibi
=
False
,
attention_shapes
=
None
):
def
__init__
(
self
,
hidden_size
,
num_heads
,
qkv_layer
,
o_proj
,
dev
,
max_seq_len
,
use_alibi
=
False
,
attention_shapes
=
None
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
n_local_heads
=
num_heads
...
...
@@ -123,7 +125,7 @@ class QuantAttentionFused(nn.Module):
self
.
o_proj
=
o_proj
self
.
start_pos
=
0
self
.
use_alibi
=
use_alibi
self
.
cache_batch_size
=
1
self
.
cache_batch_size
=
int
(
os
.
getenv
(
"AWQ_BATCH_SIZE"
,
"1"
))
self
.
attention_shapes
=
attention_shapes
if
attention_shapes
is
not
None
else
{
# following fastertransformer definition
"cache_v"
:
(
self
.
cache_batch_size
,
self
.
n_local_heads
,
max_seq_len
,
self
.
head_dim
,),
...
...
@@ -170,6 +172,11 @@ class QuantAttentionFused(nn.Module):
hidden_states
,
past_key_value
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
output_attentions
=
False
,
use_cache
=
False
):
bsz
,
seqlen
,
_
=
hidden_states
.
shape
if
bsz
!=
self
.
cache_batch_size
:
raise
RuntimeError
(
f
"Batch size is incorrectly set - input batch size
{
bsz
}
, kv-cache batch size
{
self
.
cache_batch_size
}
. "
f
"Use: AutoAWQForCausalLM.from_quantized(batch_size=
{
bsz
}
)"
)
xqkv
=
self
.
qkv_proj
(
hidden_states
)
xqkv
=
xqkv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xqkv_view"
])
...
...
awq/modules/fused/block.py
View file @
d3625d1c
import
os
import
torch.nn
as
nn
from
awq.modules.fused.attn
import
QuantAttentionFused
...
...
@@ -34,7 +35,7 @@ class FalconDecoderLayer(nn.Module):
self
.
n_heads
=
n_heads
self
.
hidden_size
=
hidden_size
self
.
new_decoder_arch
=
new_decoder_arch
attention_shapes
=
self
.
_get_attention_shapes
(
1
,
n_heads
,
max_seq_len
,
self
.
hidden_size
//
n_heads
,
new_decoder_arch
)
attention_shapes
=
self
.
_get_attention_shapes
(
n_heads
,
max_seq_len
,
self
.
hidden_size
//
n_heads
,
new_decoder_arch
)
# TODO: Falcon has ALiBi implemented but which model uses it?
self
.
attn
=
QuantAttentionFused
(
...
...
@@ -51,7 +52,9 @@ class FalconDecoderLayer(nn.Module):
self
.
mlp
=
mlp
def
_get_attention_shapes
(
self
,
batch_size
,
n_heads
,
max_seq_len
,
head_dim
,
new_decoder_arch
):
def
_get_attention_shapes
(
self
,
n_heads
,
max_seq_len
,
head_dim
,
new_decoder_arch
):
batch_size
=
int
(
os
.
getenv
(
"AWQ_BATCH_SIZE"
,
"1"
))
if
new_decoder_arch
:
kv_heads
=
8
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment