Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
811a9caa
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e2c935f5615a3c15ee7439fa8a560edd5f13a457"
Unverified
Commit
811a9caa
authored
Jul 29, 2024
by
Guang Yang
Committed by
GitHub
Jul 29, 2024
Browse files
Make static cache compatible with torch.export (#32168)
parent
7f5d644e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
80 additions
and
10 deletions
+80
-10
src/transformers/cache_utils.py
src/transformers/cache_utils.py
+22
-10
tests/utils/test_cache_utils.py
tests/utils/test_cache_utils.py
+58
-0
No files found.
src/transformers/cache_utils.py
View file @
811a9caa
...
@@ -23,12 +23,14 @@ if is_hqq_available():
...
@@ -23,12 +23,14 @@ if is_hqq_available():
logger
=
logging
.
get_logger
(
__name__
)
logger
=
logging
.
get_logger
(
__name__
)
@
dataclass
class
Cache
(
torch
.
nn
.
Module
):
class
Cache
:
"""
"""
Base, abstract class for all caches. The actual data structure is specific to each subclass.
Base, abstract class for all caches. The actual data structure is specific to each subclass.
"""
"""
def
__init__
(
self
):
super
().
__init__
()
def
update
(
def
update
(
self
,
self
,
key_states
:
torch
.
Tensor
,
key_states
:
torch
.
Tensor
,
...
@@ -299,6 +301,7 @@ class DynamicCache(Cache):
...
@@ -299,6 +301,7 @@ class DynamicCache(Cache):
"""
"""
def
__init__
(
self
)
->
None
:
def
__init__
(
self
)
->
None
:
super
().
__init__
()
self
.
key_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
key_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
value_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
value_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
_seen_tokens
=
0
# Used in `generate` to keep tally of how many tokens the cache has seen
self
.
_seen_tokens
=
0
# Used in `generate` to keep tally of how many tokens the cache has seen
...
@@ -461,6 +464,7 @@ class QuantizedCache(DynamicCache):
...
@@ -461,6 +464,7 @@ class QuantizedCache(DynamicCache):
"""
"""
def
__init__
(
self
,
cache_config
:
QuantizedCacheConfig
)
->
None
:
def
__init__
(
self
,
cache_config
:
QuantizedCacheConfig
)
->
None
:
super
().
__init__
()
self
.
_quantized_key_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
_quantized_key_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
_quantized_value_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
_quantized_value_cache
:
List
[
torch
.
Tensor
]
=
[]
...
@@ -634,6 +638,7 @@ class SinkCache(Cache):
...
@@ -634,6 +638,7 @@ class SinkCache(Cache):
"""
"""
def
__init__
(
self
,
window_length
:
int
,
num_sink_tokens
:
int
)
->
None
:
def
__init__
(
self
,
window_length
:
int
,
num_sink_tokens
:
int
)
->
None
:
super
().
__init__
()
self
.
key_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
key_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
value_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
value_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
window_length
=
window_length
self
.
window_length
=
window_length
...
@@ -786,7 +791,7 @@ class SinkCache(Cache):
...
@@ -786,7 +791,7 @@ class SinkCache(Cache):
class
StaticCache
(
Cache
):
class
StaticCache
(
Cache
):
"""
"""
Static Cache class to be used with `torch.compile(model)`.
Static Cache class to be used with `torch.compile(model)`
and `torch.export()`
.
Parameters:
Parameters:
config (`PretrainedConfig):
config (`PretrainedConfig):
...
@@ -817,18 +822,22 @@ class StaticCache(Cache):
...
@@ -817,18 +822,22 @@ class StaticCache(Cache):
self
.
key_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
key_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
value_cache
:
List
[
torch
.
Tensor
]
=
[]
self
.
value_cache
:
List
[
torch
.
Tensor
]
=
[]
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
cache_shape
=
(
max_batch_size
,
self
.
num_key_value_heads
,
self
.
max_cache_len
,
self
.
head_dim
)
cache_shape
=
(
max_batch_size
,
self
.
num_key_value_heads
,
self
.
max_cache_len
,
self
.
head_dim
)
for
_
in
range
(
config
.
num_hidden_layers
):
for
idx
in
range
(
config
.
num_hidden_layers
):
# Note: `torch.export()`` requires mutations to be registered as buffers.
self
.
register_buffer
(
f
"key_cache_
{
idx
}
"
,
torch
.
zeros
(
cache_shape
,
dtype
=
dtype
,
device
=
device
))
self
.
register_buffer
(
f
"value_cache_
{
idx
}
"
,
torch
.
zeros
(
cache_shape
,
dtype
=
dtype
,
device
=
device
))
key_cache
=
getattr
(
self
,
f
"key_cache_
{
idx
}
"
)
value_cache
=
getattr
(
self
,
f
"value_cache_
{
idx
}
"
)
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway)
# it is not needed anyway)
new_layer_key_cache
=
torch
.
zeros
(
cache_shape
,
dtype
=
self
.
dtype
,
device
=
device
)
new_layer_value_cache
=
torch
.
zeros
(
cache_shape
,
dtype
=
self
.
dtype
,
device
=
device
)
if
not
is_torchdynamo_compiling
():
if
not
is_torchdynamo_compiling
():
torch
.
_dynamo
.
mark_static_address
(
new_layer_
key_cache
)
torch
.
_dynamo
.
mark_static_address
(
key_cache
)
torch
.
_dynamo
.
mark_static_address
(
new_layer_
value_cache
)
torch
.
_dynamo
.
mark_static_address
(
value_cache
)
self
.
key_cache
.
append
(
new_layer_
key_cache
)
self
.
key_cache
.
append
(
key_cache
)
self
.
value_cache
.
append
(
new_layer_
value_cache
)
self
.
value_cache
.
append
(
value_cache
)
def
update
(
def
update
(
self
,
self
,
...
@@ -928,6 +937,7 @@ class SlidingWindowCache(StaticCache):
...
@@ -928,6 +937,7 @@ class SlidingWindowCache(StaticCache):
"""
"""
def
__init__
(
self
,
config
:
PretrainedConfig
,
max_batch_size
:
int
,
max_cache_len
:
int
,
device
,
dtype
=
None
)
->
None
:
def
__init__
(
self
,
config
:
PretrainedConfig
,
max_batch_size
:
int
,
max_cache_len
:
int
,
device
,
dtype
=
None
)
->
None
:
super
().
__init__
()
if
not
hasattr
(
config
,
"sliding_window"
)
or
config
.
sliding_window
is
None
:
if
not
hasattr
(
config
,
"sliding_window"
)
or
config
.
sliding_window
is
None
:
raise
ValueError
(
raise
ValueError
(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
...
@@ -1005,6 +1015,7 @@ class EncoderDecoderCache(Cache):
...
@@ -1005,6 +1015,7 @@ class EncoderDecoderCache(Cache):
"""
"""
def
__init__
(
self
,
self_attention_cache
:
Cache
,
cross_attention_cache
:
Cache
):
def
__init__
(
self
,
self_attention_cache
:
Cache
,
cross_attention_cache
:
Cache
):
super
().
__init__
()
self
.
self_attention_cache
=
self_attention_cache
self
.
self_attention_cache
=
self_attention_cache
self
.
cross_attention_cache
=
cross_attention_cache
self
.
cross_attention_cache
=
cross_attention_cache
...
@@ -1148,6 +1159,7 @@ class EncoderDecoderCache(Cache):
...
@@ -1148,6 +1159,7 @@ class EncoderDecoderCache(Cache):
class
HybridCache
(
Cache
):
class
HybridCache
(
Cache
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
max_batch_size
,
max_cache_len
,
device
=
"cpu"
,
dtype
=
None
)
->
None
:
def
__init__
(
self
,
config
:
PretrainedConfig
,
max_batch_size
,
max_cache_len
,
device
=
"cpu"
,
dtype
=
None
)
->
None
:
super
().
__init__
()
if
not
hasattr
(
config
,
"sliding_window"
)
or
config
.
sliding_window
is
None
:
if
not
hasattr
(
config
,
"sliding_window"
)
or
config
.
sliding_window
is
None
:
raise
ValueError
(
raise
ValueError
(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
...
...
tests/utils/test_cache_utils.py
View file @
811a9caa
...
@@ -15,12 +15,14 @@
...
@@ -15,12 +15,14 @@
import
unittest
import
unittest
from
packaging
import
version
from
parameterized
import
parameterized
from
parameterized
import
parameterized
from
transformers
import
set_seed
from
transformers
import
set_seed
from
transformers.testing_utils
import
(
from
transformers.testing_utils
import
(
is_torch_available
,
is_torch_available
,
require_auto_gptq
,
require_auto_gptq
,
require_read_token
,
require_torch
,
require_torch
,
require_torch_gpu
,
require_torch_gpu
,
slow
,
slow
,
...
@@ -32,6 +34,7 @@ if is_torch_available():
...
@@ -32,6 +34,7 @@ if is_torch_available():
import
torch
import
torch
from
transformers
import
(
from
transformers
import
(
AutoConfig
,
AutoModelForCausalLM
,
AutoModelForCausalLM
,
AutoTokenizer
,
AutoTokenizer
,
DynamicCache
,
DynamicCache
,
...
@@ -164,6 +167,61 @@ class CacheTest(unittest.TestCase):
...
@@ -164,6 +167,61 @@ class CacheTest(unittest.TestCase):
self
.
assertTrue
(
cached_keys
.
shape
==
(
1
,
1
,
10
,
128
))
self
.
assertTrue
(
cached_keys
.
shape
==
(
1
,
1
,
10
,
128
))
self
.
assertTrue
(
cached_values
.
shape
==
(
1
,
1
,
10
,
128
))
self
.
assertTrue
(
cached_values
.
shape
==
(
1
,
1
,
10
,
128
))
@
slow
@
require_read_token
def
test_static_cache_exportability
(
self
):
"""
Tests that static cache works with `torch.export()`
"""
if
version
.
parse
(
torch
.
__version__
)
<
version
.
parse
(
"2.3"
):
self
.
skipTest
(
reason
=
"This test requires torch >= 2.3 to run."
)
device
=
"cpu"
dtype
=
torch
.
float32
max_batch_size
=
1
config
=
AutoConfig
.
from_pretrained
(
"google/gemma-2b"
,
torch_dtype
=
dtype
,
use_cache
=
True
,
)
m
=
AutoModelForCausalLM
.
from_pretrained
(
"google/gemma-2b"
,
config
=
config
,
torch_dtype
=
dtype
,
attn_implementation
=
"sdpa"
,
# Export and ExecuTorch only works for SdpaAttention
).
to
(
device
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"google/gemma-2b"
)
inputs
=
tokenizer
([
"The best color is"
],
return_tensors
=
"pt"
).
to
(
device
)[
"input_ids"
]
class
ExportatibleModelWithStaticCache
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
config
,
model
):
super
().
__init__
()
self
.
config
=
config
self
.
model
=
model
self
.
static_cache
=
StaticCache
(
config
=
config
,
max_batch_size
=
max_batch_size
,
max_cache_len
=
config
.
max_length
,
device
=
device
)
def
forward
(
self
,
tokens
:
torch
.
Tensor
,
input_pos
:
torch
.
Tensor
):
outs
=
self
.
model
(
input_ids
=
tokens
,
attention_mask
=
None
,
position_ids
=
input_pos
.
unsqueeze
(
0
),
cache_position
=
input_pos
,
past_key_values
=
self
.
static_cache
,
use_cache
=
True
,
)
return
outs
.
logits
set_seed
(
0
)
with
torch
.
no_grad
():
from
torch.export
import
ExportedProgram
,
export
model
=
ExportatibleModelWithStaticCache
(
config
,
m
)
exported_program
=
export
(
model
,
args
=
(
inputs
,),
kwargs
=
{
"input_pos"
:
torch
.
arange
(
1
)})
self
.
assertTrue
(
isinstance
(
exported_program
,
ExportedProgram
))
@
require_torch_gpu
@
require_torch_gpu
@
slow
@
slow
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment