Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
3eda6562
Unverified
Commit
3eda6562
authored
Nov 03, 2023
by
Casper
Committed by
GitHub
Nov 03, 2023
Browse files
Fix performance regression (#148)
parent
bf64abd8
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
33 additions
and
16 deletions
+33
-16
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+10
-4
awq/modules/fused/cache.py
awq/modules/fused/cache.py
+21
-10
examples/benchmark.py
examples/benchmark.py
+1
-1
setup.py
setup.py
+1
-1
No files found.
awq/modules/fused/attn.py
View file @
3eda6562
...
@@ -107,7 +107,7 @@ class QuantAttentionFused(nn.Module):
...
@@ -107,7 +107,7 @@ class QuantAttentionFused(nn.Module):
)
)
# cache store that rolls cache
# cache store that rolls cache
self
.
cache
=
WindowedCache
(
self
.
cache
=
WindowedCache
(
self
.
attention_shapes
[
"cache_v"
],
self
.
attention_shapes
[
"cache_k"
],
dev
self
.
attention_shapes
[
"cache_v"
],
self
.
attention_shapes
[
"cache_k"
],
self
.
max_seq_len
,
dev
)
)
if
use_alibi
:
if
use_alibi
:
...
@@ -128,9 +128,14 @@ class QuantAttentionFused(nn.Module):
...
@@ -128,9 +128,14 @@ class QuantAttentionFused(nn.Module):
f
"Use: AutoAWQForCausalLM.from_quantized(batch_size=
{
bsz
}
)"
f
"Use: AutoAWQForCausalLM.from_quantized(batch_size=
{
bsz
}
)"
)
)
if
self
.
start_pos
>
self
.
max_seq_len
or
self
.
start_pos
+
seqlen
>
self
.
max_seq_len
:
will_cache_be_exceeded
=
self
.
start_pos
+
seqlen
>
self
.
max_seq_len
excess_length
=
self
.
start_pos
+
seqlen
-
self
.
max_seq_len
self
.
start_pos
=
self
.
cache
.
roll_kv
(
excess_length
,
self
.
start_pos
)
# Reset and avoid retaining state when processing context
if
will_cache_be_exceeded
:
self
.
start_pos
=
self
.
cache
.
roll_kv_n_steps
(
self
.
start_pos
,
n
=
self
.
start_pos
)
# Slowly roll out old tokens without performance hit if exceeded during decoding
elif
will_cache_be_exceeded
and
seqlen
==
1
:
self
.
start_pos
=
self
.
cache
.
roll_kv_n_steps
(
self
.
start_pos
,
n
=
100
)
xqkv
=
self
.
qkv_proj
(
hidden_states
)
xqkv
=
self
.
qkv_proj
(
hidden_states
)
xqkv
=
xqkv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xqkv_view"
])
xqkv
=
xqkv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xqkv_view"
])
...
@@ -158,6 +163,7 @@ class QuantAttentionFused(nn.Module):
...
@@ -158,6 +163,7 @@ class QuantAttentionFused(nn.Module):
self
.
cache
.
update_kv
(
values_store
,
keys_store
,
bsz
,
self
.
start_pos
,
seqlen
)
self
.
cache
.
update_kv
(
values_store
,
keys_store
,
bsz
,
self
.
start_pos
,
seqlen
)
# Only necessary to retrieve from cache when we are not processing context
if
seqlen
==
1
:
if
seqlen
==
1
:
xv
,
xk
=
self
.
cache
.
get_kv
(
bsz
,
self
.
start_pos
,
seqlen
,
self
.
head_dim
)
xv
,
xk
=
self
.
cache
.
get_kv
(
bsz
,
self
.
start_pos
,
seqlen
,
self
.
head_dim
)
...
...
awq/modules/fused/cache.py
View file @
3eda6562
import
torch
import
torch
class
WindowedCache
:
class
WindowedCache
:
def
__init__
(
self
,
cache_v_shape
,
cache_k_shape
,
device
):
def
__init__
(
self
,
cache_v_shape
,
cache_k_shape
,
max_seq_len
,
device
):
"""
"""
The window size is the same as the max_new_tokens. The window will
The window size is the same as the max_new_tokens. The window will
automatically roll once max_new_tokens is exceeded.
automatically roll once max_new_tokens is exceeded.
...
@@ -10,8 +10,12 @@ class WindowedCache:
...
@@ -10,8 +10,12 @@ class WindowedCache:
self
.
v
=
torch
.
zeros
(
cache_v_shape
).
to
(
device
).
half
()
self
.
v
=
torch
.
zeros
(
cache_v_shape
).
to
(
device
).
half
()
# [batch_size, n_kv_heads, head_dim // pack_factor, max_seq_len, pack_factor]
# [batch_size, n_kv_heads, head_dim // pack_factor, max_seq_len, pack_factor]
self
.
k
=
torch
.
zeros
(
cache_k_shape
).
to
(
device
).
half
()
self
.
k
=
torch
.
zeros
(
cache_k_shape
).
to
(
device
).
half
()
self
.
max_seq_len
=
max_seq_len
def
get_kv
(
self
,
batch_size
,
start_pos
,
seqlen
,
head_dim
):
def
get_kv
(
self
,
batch_size
,
start_pos
,
seqlen
,
head_dim
):
"""
Gets the key-value store in correct shapes.
"""
xv
=
self
.
v
[:
batch_size
,
:,
:
start_pos
+
seqlen
,
:].
transpose
(
1
,
2
).
contiguous
()
xv
=
self
.
v
[:
batch_size
,
:,
:
start_pos
+
seqlen
,
:].
transpose
(
1
,
2
).
contiguous
()
xk
=
self
.
k
[:
batch_size
,
:,
:,
:
start_pos
+
seqlen
,
:].
transpose
(
2
,
3
).
contiguous
()
xk
=
self
.
k
[:
batch_size
,
:,
:,
:
start_pos
+
seqlen
,
:].
transpose
(
2
,
3
).
contiguous
()
xk
=
xk
.
reshape
(
xk
.
shape
[:
-
2
]
+
(
head_dim
,)).
transpose
(
1
,
2
).
contiguous
()
xk
=
xk
.
reshape
(
xk
.
shape
[:
-
2
]
+
(
head_dim
,)).
transpose
(
1
,
2
).
contiguous
()
...
@@ -19,19 +23,26 @@ class WindowedCache:
...
@@ -19,19 +23,26 @@ class WindowedCache:
return
xv
,
xk
return
xv
,
xk
def
update_kv
(
self
,
values_store
,
keys_store
,
batch_size
,
start_pos
,
seqlen
):
def
update_kv
(
self
,
values_store
,
keys_store
,
batch_size
,
start_pos
,
seqlen
):
"""
Updates the values in the key-value store.
"""
self
.
v
[:
batch_size
,
:,
start_pos
:
start_pos
+
seqlen
,
:]
=
values_store
self
.
v
[:
batch_size
,
:,
start_pos
:
start_pos
+
seqlen
,
:]
=
values_store
self
.
k
[:
batch_size
,
:,
:,
start_pos
:
start_pos
+
seqlen
,
:]
=
keys_store
self
.
k
[:
batch_size
,
:,
:,
start_pos
:
start_pos
+
seqlen
,
:]
=
keys_store
def
roll_kv
(
self
,
roll_len
,
start_pos
):
# Roll only the necessary part of the cache to the left
self
.
v
[:,
:,
:
-
roll_len
,
:]
=
self
.
v
[:,
:,
roll_len
:,
:]
self
.
k
[:,
:,
:,
:
-
roll_len
,
:]
=
self
.
k
[:,
:,
:,
roll_len
:,
:]
# Zero out the new part
def
roll_kv_n_steps
(
self
,
start_pos
,
n
=
100
):
self
.
v
[:,
:,
-
roll_len
:,
:]
=
0
"""
self
.
k
[:,
:,
:,
-
roll_len
:,
:]
=
0
Roll cache n to the left.
"""
n
=
min
(
n
,
self
.
max_seq_len
)
# Roll cache to the left
self
.
v
=
torch
.
roll
(
self
.
v
,
shifts
=-
n
,
dims
=
2
)
self
.
k
=
torch
.
roll
(
self
.
k
,
shifts
=-
n
,
dims
=
3
)
return
start_pos
-
roll_len
# Zero out the new part
self
.
v
[:,
:,
-
n
:,
:]
=
0
self
.
k
[:,
:,
:,
-
n
:,
:]
=
0
return
start_pos
-
n
def
to
(
self
,
device
):
def
to
(
self
,
device
):
self
.
k
=
self
.
k
.
to
(
device
)
self
.
k
=
self
.
k
.
to
(
device
)
...
...
examples/benchmark.py
View file @
3eda6562
...
@@ -85,7 +85,7 @@ def run_round(model_path, quant_file, n_generate, input_ids, batch_size, safeten
...
@@ -85,7 +85,7 @@ def run_round(model_path, quant_file, n_generate, input_ids, batch_size, safeten
"Prefill tokens/s"
:
prefill_tokens_per_second
,
"Prefill tokens/s"
:
prefill_tokens_per_second
,
"Decode tokens/s"
:
decode_tokens_per_second
,
"Decode tokens/s"
:
decode_tokens_per_second
,
"Memory (VRAM)"
:
f
"
{
memory_used
:.
2
f
}
GB (
{
memory_pct
:.
2
f
}
%)"
"Memory (VRAM)"
:
f
"
{
memory_used
:.
2
f
}
GB (
{
memory_pct
:.
2
f
}
%)"
},
model
.
quant_config
[
"
version
"
]
},
model
.
quant_config
.
version
def
main
(
args
):
def
main
(
args
):
rounds
=
[
rounds
=
[
...
...
setup.py
View file @
3eda6562
...
@@ -13,7 +13,7 @@ PYPI_BUILD = os.getenv("PYPI_BUILD", "0") == "1"
...
@@ -13,7 +13,7 @@ PYPI_BUILD = os.getenv("PYPI_BUILD", "0") == "1"
if
not
PYPI_BUILD
:
if
not
PYPI_BUILD
:
try
:
try
:
CUDA_VERSION
=
""
.
join
(
os
.
environ
.
get
(
"CUDA_VERSION"
,
torch
.
version
.
cuda
).
split
(
"."
))[:
3
]
CUDA_VERSION
=
""
.
join
(
os
.
environ
.
get
(
"CUDA_VERSION"
,
torch
.
version
.
cuda
).
split
(
"."
))[:
3
]
AUTOAWQ_VERSION
+=
f
"cu
+
{
CUDA_VERSION
}
"
AUTOAWQ_VERSION
+=
f
"
+
cu
{
CUDA_VERSION
}
"
except
Exception
as
ex
:
except
Exception
as
ex
:
raise
RuntimeError
(
"Your system must have an Nvidia GPU for installing AutoAWQ"
)
raise
RuntimeError
(
"Your system must have an Nvidia GPU for installing AutoAWQ"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment