Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
c5581b27
Unverified
Commit
c5581b27
authored
Nov 11, 2023
by
Casper
Committed by
GitHub
Nov 11, 2023
Browse files
Adaptive batch sizing (#181)
parent
df909e83
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
6 deletions
+18
-6
README.md
README.md
+1
-1
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+7
-4
awq/modules/fused/cache.py
awq/modules/fused/cache.py
+10
-1
No files found.
README.md
View file @
c5581b27
...
@@ -120,7 +120,7 @@ Fused modules are a large part of the speedup you get from AutoAWQ. The idea is
...
@@ -120,7 +120,7 @@ Fused modules are a large part of the speedup you get from AutoAWQ. The idea is
-
Fused modules are activated when you use
`fuse_layers=True`
.
-
Fused modules are activated when you use
`fuse_layers=True`
.
-
A custom cache is implemented. It preallocates based on batch size and sequence length.
-
A custom cache is implemented. It preallocates based on batch size and sequence length.
-
You cannot change the sequence length
or batch size
after you have created your model.
-
You cannot change the sequence length after you have created your model.
-
Reference:
`AutoAWQForCausalLM.from_quantized(max_new_tokens=seq_len, batch_size=batch_size)`
-
Reference:
`AutoAWQForCausalLM.from_quantized(max_new_tokens=seq_len, batch_size=batch_size)`
-
The main accelerator in the fused modules comes from FasterTransformer, which is only compatible with Linux.
-
The main accelerator in the fused modules comes from FasterTransformer, which is only compatible with Linux.
-
The
`past_key_values`
from
`model.generate()`
are only dummy values, so they cannot be used after generation.
-
The
`past_key_values`
from
`model.generate()`
are only dummy values, so they cannot be used after generation.
...
...
awq/modules/fused/attn.py
View file @
c5581b27
...
@@ -123,11 +123,14 @@ class QuantAttentionFused(nn.Module):
...
@@ -123,11 +123,14 @@ class QuantAttentionFused(nn.Module):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
=
None
,
*
args
,
**
kwargs
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
=
None
,
*
args
,
**
kwargs
):
bsz
,
seqlen
,
_
=
hidden_states
.
shape
bsz
,
seqlen
,
_
=
hidden_states
.
shape
# Reallocate cache if batch size changes
if
bsz
!=
self
.
cache_batch_size
:
if
bsz
!=
self
.
cache_batch_size
:
raise
RuntimeError
(
if
bsz
>
self
.
cache_batch_size
:
f
"Batch size is incorrectly set - input batch size
{
bsz
}
, kv-cache batch size
{
self
.
cache_batch_size
}
. "
self
.
cache
.
increase_batch_size
(
bsz
)
f
"Use: AutoAWQForCausalLM.from_quantized(batch_size=
{
bsz
}
)"
self
.
cache_batch_size
=
bsz
)
elif
bsz
<
self
.
cache_batch_size
:
self
.
cache
.
decrease_batch_size
(
bsz
)
self
.
cache_batch_size
=
bsz
xqkv
=
self
.
qkv_proj
(
hidden_states
)
xqkv
=
self
.
qkv_proj
(
hidden_states
)
xqkv
=
xqkv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xqkv_view"
])
xqkv
=
xqkv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xqkv_view"
])
...
...
awq/modules/fused/cache.py
View file @
c5581b27
...
@@ -47,4 +47,13 @@ class WindowedCache:
...
@@ -47,4 +47,13 @@ class WindowedCache:
def
to
(
self
,
device
):
def
to
(
self
,
device
):
self
.
k
=
self
.
k
.
to
(
device
)
self
.
k
=
self
.
k
.
to
(
device
)
self
.
v
=
self
.
v
.
to
(
device
)
self
.
v
=
self
.
v
.
to
(
device
)
\ No newline at end of file
def
increase_batch_size
(
self
,
to_bsz
):
"""Dynamically allocate new kv when batch size changes."""
self
.
v
=
torch
.
zeros
(
to_bsz
,
*
self
.
v
.
shape
[
1
:],
dtype
=
self
.
v
.
dtype
,
device
=
self
.
v
.
device
)
self
.
k
=
torch
.
zeros
(
to_bsz
,
*
self
.
k
.
shape
[
1
:],
dtype
=
self
.
k
.
dtype
,
device
=
self
.
k
.
device
)
def
decrease_batch_size
(
self
,
to_bsz
):
"""Dynamically remove part of cache if batch size changes."""
self
.
v
=
self
.
v
[:
to_bsz
,
:,
:,
:]
self
.
k
=
self
.
k
[:
to_bsz
,
:,
:,
:,
:]
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment