Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
811a9caa
"magic_pdf/vscode:/vscode.git/clone" did not exist on "94a7ba3d230f8ca94ebdccdb809158de7bfb1b90"
Unverified
Commit
811a9caa
authored
Jul 29, 2024
by
Guang Yang
Committed by
GitHub
Jul 29, 2024
Browse files
Make static cache compatible with torch.export (#32168)
parent
7f5d644e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
80 additions
and
10 deletions
+80
-10
src/transformers/cache_utils.py
src/transformers/cache_utils.py
+22
-10
tests/utils/test_cache_utils.py
tests/utils/test_cache_utils.py
+58
-0
No files found.
src/transformers/cache_utils.py
View file @
811a9caa
...
...
@@ -23,12 +23,14 @@ if is_hqq_available():
logger
=
logging
.
get_logger
(
__name__
)
@
dataclass
class
Cache
:
class
Cache
(
torch
.
nn
.
Module
):
"""
Base, abstract class for all caches. The actual data structure is specific to each subclass.
"""
def
__init__
(
self
):
super
().
__init__
()
def
update
(
self
,
key_states
:
torch
.
Tensor
,
...
...
@@ -299,6 +301,7 @@ class DynamicCache(Cache):
"""
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
key_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
value_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
_seen_tokens
=
0
# Used in `generate` to keep tally of how many tokens the cache has seen
...
...
@@ -461,6 +464,7 @@ class QuantizedCache(DynamicCache):
"""
def
__init__
(
self
,
cache_config
:
QuantizedCacheConfig
)
->
None
:
super
().
__init__
()
self
.
_quantized_key_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
_quantized_value_cache
:
List
[
torch
.
Tensor
]
=
[]
...
...
@@ -634,6 +638,7 @@ class SinkCache(Cache):
"""
def
__init__
(
self
,
window_length
:
int
,
num_sink_tokens
:
int
)
->
None
:
super
().
__init__
()
self
.
key_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
value_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
window_length
=
window_length
...
...
@@ -786,7 +791,7 @@ class SinkCache(Cache):
class
StaticCache
(
Cache
):
"""
Static Cache class to be used with `torch.compile(model)`.
Static Cache class to be used with `torch.compile(model)`
and `torch.export()`
.
Parameters:
config (`PretrainedConfig):
...
...
@@ -817,18 +822,22 @@ class StaticCache(Cache):
self
.
key_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
value_cache
:
List
[
torch
.
Tensor
]
=
[]
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
cache_shape
=
(
max_batch_size
,
self
.
num_key_value_heads
,
self
.
max_cache_len
,
self
.
head_dim
)
for
_
in
range
(
config
.
num_hidden_layers
):
for
idx
in
range
(
config
.
num_hidden_layers
):
# Note: `torch.export()`` requires mutations to be registered as buffers.
self
.
register_buffer
(
f
"key_cache_
{
idx
}
"
,
torch
.
zeros
(
cache_shape
,
dtype
=
dtype
,
device
=
device
))
self
.
register_buffer
(
f
"value_cache_
{
idx
}
"
,
torch
.
zeros
(
cache_shape
,
dtype
=
dtype
,
device
=
device
))
key_cache
=
getattr
(
self
,
f
"key_cache_
{
idx
}
"
)
value_cache
=
getattr
(
self
,
f
"value_cache_
{
idx
}
"
)
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway)
new_layer_key_cache
=
torch
.
zeros
(
cache_shape
,
dtype
=
self
.
dtype
,
device
=
device
)
new_layer_value_cache
=
torch
.
zeros
(
cache_shape
,
dtype
=
self
.
dtype
,
device
=
device
)
if
not
is_torchdynamo_compiling
():
torch
.
_dynamo
.
mark_static_address
(
new_layer_
key_cache
)
torch
.
_dynamo
.
mark_static_address
(
new_layer_
value_cache
)
self
.
key_cache
.
append
(
new_layer_
key_cache
)
self
.
value_cache
.
append
(
new_layer_
value_cache
)
torch
.
_dynamo
.
mark_static_address
(
key_cache
)
torch
.
_dynamo
.
mark_static_address
(
value_cache
)
self
.
key_cache
.
append
(
key_cache
)
self
.
value_cache
.
append
(
value_cache
)
def
update
(
self
,
...
...
@@ -928,6 +937,7 @@ class SlidingWindowCache(StaticCache):
"""
def
__init__
(
self
,
config
:
PretrainedConfig
,
max_batch_size
:
int
,
max_cache_len
:
int
,
device
,
dtype
=
None
)
->
None
:
super
().
__init__
()
if
not
hasattr
(
config
,
"sliding_window"
)
or
config
.
sliding_window
is
None
:
raise
ValueError
(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
...
...
@@ -1005,6 +1015,7 @@ class EncoderDecoderCache(Cache):
"""
def
__init__
(
self
,
self_attention_cache
:
Cache
,
cross_attention_cache
:
Cache
):
super
().
__init__
()
self
.
self_attention_cache
=
self_attention_cache
self
.
cross_attention_cache
=
cross_attention_cache
...
...
@@ -1148,6 +1159,7 @@ class EncoderDecoderCache(Cache):
class
HybridCache
(
Cache
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
max_batch_size
,
max_cache_len
,
device
=
"cpu"
,
dtype
=
None
)
->
None
:
super
().
__init__
()
if
not
hasattr
(
config
,
"sliding_window"
)
or
config
.
sliding_window
is
None
:
raise
ValueError
(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
...
...
tests/utils/test_cache_utils.py
View file @
811a9caa
...
...
@@ -15,12 +15,14 @@
import
unittest
from
packaging
import
version
from
parameterized
import
parameterized
from
transformers
import
set_seed
from
transformers.testing_utils
import
(
is_torch_available
,
require_auto_gptq
,
require_read_token
,
require_torch
,
require_torch_gpu
,
slow
,
...
...
@@ -32,6 +34,7 @@ if is_torch_available():
import
torch
from
transformers
import
(
AutoConfig
,
AutoModelForCausalLM
,
AutoTokenizer
,
DynamicCache
,
...
...
@@ -164,6 +167,61 @@ class CacheTest(unittest.TestCase):
self
.
assertTrue
(
cached_keys
.
shape
==
(
1
,
1
,
10
,
128
))
self
.
assertTrue
(
cached_values
.
shape
==
(
1
,
1
,
10
,
128
))
@
slow
@
require_read_token
def
test_static_cache_exportability
(
self
):
"""
Tests that static cache works with `torch.export()`
"""
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"2.3"
):
self
.
skipTest
(
reason
=
"This test requires torch >= 2.3 to run."
)
device
=
"cpu"
dtype
=
torch
.
float32
max_batch_size
=
1
config
=
AutoConfig
.
from_pretrained
(
"google/gemma-2b"
,
torch_dtype
=
dtype
,
use_cache
=
True
,
)
m
=
AutoModelForCausalLM
.
from_pretrained
(
"google/gemma-2b"
,
config
=
config
,
torch_dtype
=
dtype
,
attn_implementation
=
"sdpa"
,
# Export and ExecuTorch only works for SdpaAttention
).
to
(
device
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"google/gemma-2b"
)
inputs
=
tokenizer
([
"The best color is"
],
return_tensors
=
"pt"
).
to
(
device
)[
"input_ids"
]
class
ExportatibleModelWithStaticCache
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
,
model
):
super
().
__init__
()
self
.
config
=
config
self
.
model
=
model
self
.
static_cache
=
StaticCache
(
config
=
config
,
max_batch_size
=
max_batch_size
,
max_cache_len
=
config
.
max_length
,
device
=
device
)
def
forward
(
self
,
tokens
:
torch
.
Tensor
,
input_pos
:
torch
.
Tensor
):
outs
=
self
.
model
(
input_ids
=
tokens
,
attention_mask
=
None
,
position_ids
=
input_pos
.
unsqueeze
(
0
),
cache_position
=
input_pos
,
past_key_values
=
self
.
static_cache
,
use_cache
=
True
,
)
return
outs
.
logits
set_seed
(
0
)
with
torch
.
no_grad
():
from
torch.export
import
ExportedProgram
,
export
model
=
ExportatibleModelWithStaticCache
(
config
,
m
)
exported_program
=
export
(
model
,
args
=
(
inputs
,),
kwargs
=
{
"input_pos"
:
torch
.
arange
(
1
)})
self
.
assertTrue
(
isinstance
(
exported_program
,
ExportedProgram
))
@
require_torch_gpu
@
slow
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment