Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3d8bd119
Unverified
Commit
3d8bd119
authored
Aug 06, 2024
by
Joao Gante
Committed by
GitHub
Aug 06, 2024
Browse files
Generate: fix end to end compilation (#32465)
parent
6a03942d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
20 deletions
+24
-20
src/transformers/cache_utils.py
src/transformers/cache_utils.py
+15
-12
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+9
-8
No files found.
src/transformers/cache_utils.py
View file @
3d8bd119
...
@@ -1024,19 +1024,22 @@ class StaticCache(Cache):
...
@@ -1024,19 +1024,22 @@ class StaticCache(Cache):
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
cache_shape
=
(
max_batch_size
,
self
.
num_key_value_heads
,
self
.
max_cache_len
,
self
.
head_dim
)
cache_shape
=
(
max_batch_size
,
self
.
num_key_value_heads
,
self
.
max_cache_len
,
self
.
head_dim
)
for
idx
in
range
(
config
.
num_hidden_layers
):
for
idx
in
range
(
config
.
num_hidden_layers
):
# Note: `torch.export()`` requires mutations to be registered as buffers.
new_layer_key_cache
=
torch
.
zeros
(
cache_shape
,
dtype
=
self
.
dtype
,
device
=
device
)
self
.
register_buffer
(
f
"key_cache_
{
idx
}
"
,
torch
.
zeros
(
cache_shape
,
dtype
=
dtype
,
device
=
device
))
new_layer_value_cache
=
torch
.
zeros
(
cache_shape
,
dtype
=
self
.
dtype
,
device
=
device
)
self
.
register_buffer
(
f
"value_cache_
{
idx
}
"
,
torch
.
zeros
(
cache_shape
,
dtype
=
dtype
,
device
=
device
))
# Notes:
key_cache
=
getattr
(
self
,
f
"key_cache_
{
idx
}
"
)
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
value_cache
=
getattr
(
self
,
f
"value_cache_
{
idx
}
"
)
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# it is not needed anyway)
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# 2. `torch.export()` requires mutations to be registered as buffers.
# it is not needed anyway)
if
not
is_torchdynamo_compiling
():
if
not
is_torchdynamo_compiling
():
torch
.
_dynamo
.
mark_static_address
(
key_cache
)
self
.
register_buffer
(
f
"key_cache_
{
idx
}
"
,
torch
.
zeros
(
cache_shape
,
dtype
=
dtype
,
device
=
device
))
torch
.
_dynamo
.
mark_static_address
(
value_cache
)
self
.
register_buffer
(
f
"value_cache_
{
idx
}
"
,
torch
.
zeros
(
cache_shape
,
dtype
=
dtype
,
device
=
device
))
self
.
key_cache
.
append
(
key_cache
)
new_layer_key_cache
=
getattr
(
self
,
f
"key_cache_
{
idx
}
"
)
self
.
value_cache
.
append
(
value_cache
)
new_layer_value_cache
=
getattr
(
self
,
f
"value_cache_
{
idx
}
"
)
torch
.
_dynamo
.
mark_static_address
(
new_layer_key_cache
)
torch
.
_dynamo
.
mark_static_address
(
new_layer_value_cache
)
self
.
key_cache
.
append
(
new_layer_key_cache
)
self
.
value_cache
.
append
(
new_layer_value_cache
)
def
update
(
def
update
(
self
,
self
,
...
...
src/transformers/generation/utils.py
View file @
3d8bd119
...
@@ -1429,7 +1429,9 @@ class GenerationMixin:
...
@@ -1429,7 +1429,9 @@ class GenerationMixin:
model_kwargs
[
"cache_position"
]
=
cache_position
model_kwargs
[
"cache_position"
]
=
cache_position
return
model_kwargs
return
model_kwargs
def
_get_cache
(
self
,
cache_implementation
:
str
,
max_batch_size
:
int
,
max_cache_len
:
int
,
model_kwargs
)
->
Cache
:
def
_get_cache
(
self
,
cache_implementation
:
str
,
max_batch_size
:
int
,
max_cache_len
:
int
,
device
:
torch
.
device
,
model_kwargs
)
->
Cache
:
"""
"""
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache or uses a different batch size.
new `generate` call requires a larger cache or uses a different batch size.
...
@@ -1477,7 +1479,7 @@ class GenerationMixin:
...
@@ -1477,7 +1479,7 @@ class GenerationMixin:
"config"
:
self
.
config
,
"config"
:
self
.
config
,
"max_batch_size"
:
max_batch_size
,
"max_batch_size"
:
max_batch_size
,
"max_cache_len"
:
max_cache_len
,
"max_cache_len"
:
max_cache_len
,
"device"
:
self
.
device
,
"device"
:
device
,
"dtype"
:
cache_dtype
,
"dtype"
:
cache_dtype
,
}
}
self
.
_cache
=
cache_cls
(
**
cache_kwargs
)
self
.
_cache
=
cache_cls
(
**
cache_kwargs
)
...
@@ -1813,12 +1815,11 @@ class GenerationMixin:
...
@@ -1813,12 +1815,11 @@ class GenerationMixin:
"issue: https://github.com/huggingface/transformers/issues/28981"
"issue: https://github.com/huggingface/transformers/issues/28981"
)
)
model_kwargs
[
cache_name
]
=
self
.
_get_cache
(
model_kwargs
[
cache_name
]
=
self
.
_get_cache
(
generation_config
.
cache_implementation
,
cache_implementation
=
generation_config
.
cache_implementation
,
getattr
(
generation_config
,
"num_beams"
,
1
)
max_batch_size
=
generation_config
.
num_beams
*
generation_config
.
num_return_sequences
*
batch_size
,
*
getattr
(
generation_config
,
"num_return_sequences"
,
1
)
max_cache_len
=
generation_config
.
max_length
,
*
batch_size
,
device
=
device
,
generation_config
.
max_length
,
model_kwargs
=
model_kwargs
,
model_kwargs
,
)
)
elif
generation_config
.
cache_implementation
==
"quantized"
:
elif
generation_config
.
cache_implementation
==
"quantized"
:
if
not
self
.
_supports_quantized_cache
:
if
not
self
.
_supports_quantized_cache
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment