Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
c57da6b8
Unverified
Commit
c57da6b8
authored
Sep 27, 2023
by
Casper
Committed by
GitHub
Sep 27, 2023
Browse files
Merge pull request #75 from casper-hansen/fix_runtime
Fix KV cache shapes error
parents
8eb26eb2
cba9a28c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
43 additions
and
28 deletions
+43
-28
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+37
-27
awq/utils/utils.py
awq/utils/utils.py
+6
-1
No files found.
awq/modules/fused/attn.py
View file @
c57da6b8
...
@@ -80,12 +80,32 @@ class QuantAttentionFused(nn.Module):
...
@@ -80,12 +80,32 @@ class QuantAttentionFused(nn.Module):
self
.
start_pos
=
0
self
.
start_pos
=
0
self
.
use_alibi
=
use_alibi
self
.
use_alibi
=
use_alibi
self
.
cache_batch_size
=
int
(
os
.
getenv
(
"AWQ_BATCH_SIZE"
,
"1"
))
self
.
cache_batch_size
=
int
(
os
.
getenv
(
"AWQ_BATCH_SIZE"
,
"1"
))
self
.
max_seq_len
=
max_seq_len
self
.
attention_shapes
=
self
.
_get_attention_shapes
(
attention_shapes
,
max_seq_len
)
self
.
cache_v
=
(
torch
.
zeros
(
self
.
attention_shapes
[
"cache_v"
]).
to
(
dev
).
half
()
)
self
.
cache_k
=
(
torch
.
zeros
(
self
.
attention_shapes
[
"cache_k"
]).
to
(
dev
).
half
()
)
if
use_alibi
:
alibi_slopes
,
alibi_bias
=
build_alibi_bias
(
self
.
n_heads
,
max_seq_len
)
self
.
alibi_slopes
=
alibi_slopes
.
float
().
to
(
dev
)
self
.
alibi_bias
=
alibi_bias
.
float
().
to
(
dev
)
self
.
rotary_dim
=
0
self
.
is_neox
=
False
else
:
self
.
freqs_cis
=
precompute_freqs_cis
(
hidden_size
//
n_heads
,
max_seq_len
*
2
,
).
to
(
dev
)
self
.
rotary_dim
=
self
.
head_dim
self
.
alibi_slopes
=
None
self
.
is_neox
=
True
def
_get_attention_shapes
(
self
,
attention_shapes
,
max_seq_len
):
if
attention_shapes
is
not
None
:
if
attention_shapes
is
not
None
:
self
.
attention_shapes
=
attention_shapes
attention_shapes
=
attention_shapes
elif
self
.
n_kv_heads
==
0
:
elif
self
.
n_kv_heads
==
0
:
self
.
attention_shapes
=
{
attention_shapes
=
{
# following fastertransformer definition
# following fastertransformer definition
"cache_v"
:
(
self
.
cache_batch_size
,
self
.
n_heads
,
max_seq_len
,
self
.
head_dim
,),
"cache_v"
:
(
self
.
cache_batch_size
,
self
.
n_heads
,
max_seq_len
,
self
.
head_dim
,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
# 8: pack 8 fp16 in FT, if fp32 then use 4
...
@@ -104,7 +124,7 @@ class QuantAttentionFused(nn.Module):
...
@@ -104,7 +124,7 @@ class QuantAttentionFused(nn.Module):
}
}
else
:
else
:
self
.
attention_shapes
=
{
attention_shapes
=
{
# following fastertransformer definition
# following fastertransformer definition
"cache_v"
:
(
self
.
cache_batch_size
,
self
.
n_kv_heads
,
max_seq_len
,
self
.
head_dim
,),
"cache_v"
:
(
self
.
cache_batch_size
,
self
.
n_kv_heads
,
max_seq_len
,
self
.
head_dim
,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
# 8: pack 8 fp16 in FT, if fp32 then use 4
...
@@ -122,32 +142,11 @@ class QuantAttentionFused(nn.Module):
...
@@ -122,32 +142,11 @@ class QuantAttentionFused(nn.Module):
"single_xv_view"
:
(
self
.
n_kv_heads
,
self
.
head_dim
)
"single_xv_view"
:
(
self
.
n_kv_heads
,
self
.
head_dim
)
}
}
self
.
cache_v
=
(
return
attention_shapes
torch
.
zeros
(
self
.
attention_shapes
[
"cache_v"
]).
to
(
dev
).
half
()
)
self
.
cache_k
=
(
torch
.
zeros
(
self
.
attention_shapes
[
"cache_k"
]).
to
(
dev
).
half
()
)
if
use_alibi
:
alibi_slopes
,
alibi_bias
=
build_alibi_bias
(
self
.
n_heads
,
max_seq_len
)
self
.
alibi_slopes
=
alibi_slopes
.
float
().
to
(
dev
)
self
.
alibi_bias
=
alibi_bias
.
float
().
to
(
dev
)
self
.
rotary_dim
=
0
self
.
is_neox
=
False
else
:
self
.
freqs_cis
=
precompute_freqs_cis
(
hidden_size
//
n_heads
,
max_seq_len
*
2
,
).
to
(
dev
)
self
.
rotary_dim
=
self
.
head_dim
self
.
alibi_slopes
=
None
self
.
is_neox
=
True
def
forward
(
def
forward
(
self
,
self
,
hidden_states
,
past_key_value
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
output_attentions
=
False
,
use_cache
=
False
hidden_states
:
torch
.
Tensor
,
past_key_value
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
output_attentions
=
False
,
use_cache
=
False
):
):
bsz
,
seqlen
,
_
=
hidden_states
.
shape
bsz
,
seqlen
,
_
=
hidden_states
.
shape
if
bsz
!=
self
.
cache_batch_size
:
if
bsz
!=
self
.
cache_batch_size
:
...
@@ -155,6 +154,17 @@ class QuantAttentionFused(nn.Module):
...
@@ -155,6 +154,17 @@ class QuantAttentionFused(nn.Module):
f
"Batch size is incorrectly set - input batch size
{
bsz
}
, kv-cache batch size
{
self
.
cache_batch_size
}
. "
f
"Batch size is incorrectly set - input batch size
{
bsz
}
, kv-cache batch size
{
self
.
cache_batch_size
}
. "
f
"Use: AutoAWQForCausalLM.from_quantized(batch_size=
{
bsz
}
)"
f
"Use: AutoAWQForCausalLM.from_quantized(batch_size=
{
bsz
}
)"
)
)
if
self
.
start_pos
>
self
.
max_seq_len
or
self
.
start_pos
+
seqlen
>
self
.
max_seq_len
:
# Roll cache to the left
roll_len
=
self
.
start_pos
self
.
cache_v
=
torch
.
roll
(
self
.
cache_v
,
shifts
=-
roll_len
,
dims
=
2
)
self
.
cache_k
=
torch
.
roll
(
self
.
cache_k
,
shifts
=-
roll_len
,
dims
=
3
)
# Zero out the new part
self
.
cache_v
[:,
:,
-
roll_len
:,
:]
=
0
self
.
cache_k
[:,
:,
:,
-
roll_len
:,
:]
=
0
self
.
start_pos
=
0
xqkv
=
self
.
qkv_proj
(
hidden_states
)
xqkv
=
self
.
qkv_proj
(
hidden_states
)
xqkv
=
xqkv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xqkv_view"
])
xqkv
=
xqkv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xqkv_view"
])
...
...
awq/utils/utils.py
View file @
c57da6b8
...
@@ -60,3 +60,8 @@ def clear_memory(weight=None):
...
@@ -60,3 +60,8 @@ def clear_memory(weight=None):
del
weight
del
weight
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
def
compute_memory_used_pct
(
device
):
memory_used
=
torch
.
cuda
.
max_memory_allocated
(
device
)
/
(
1024
**
3
)
memory_pct
=
memory_used
/
(
torch
.
cuda
.
get_device_properties
(
device
).
total_memory
/
(
1024
**
3
))
*
100
return
memory_pct
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment