Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3d8bd119
Unverified
Commit
3d8bd119
authored
Aug 06, 2024
by
Joao Gante
Committed by
GitHub
Aug 06, 2024
Browse files
Generate: fix end to end compilation (#32465)
parent
6a03942d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
20 deletions
+24
-20
src/transformers/cache_utils.py
src/transformers/cache_utils.py
+15
-12
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+9
-8
No files found.
src/transformers/cache_utils.py
View file @
3d8bd119
...
@@ -1024,19 +1024,22 @@ class StaticCache(Cache):
...
@@ -1024,19 +1024,22 @@ class StaticCache(Cache):
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
cache_shape
=
(
max_batch_size
,
self
.
num_key_value_heads
,
self
.
max_cache_len
,
self
.
head_dim
)
cache_shape
=
(
max_batch_size
,
self
.
num_key_value_heads
,
self
.
max_cache_len
,
self
.
head_dim
)
for
idx
in
range
(
config
.
num_hidden_layers
):
for
idx
in
range
(
config
.
num_hidden_layers
):
# Note: `torch.export()`` requires mutations to be registered as buffers.
new_layer_key_cache
=
torch
.
zeros
(
cache_shape
,
dtype
=
self
.
dtype
,
device
=
device
)
self
.
register_buffer
(
f
"key_cache_
{
idx
}
"
,
torch
.
zeros
(
cache_shape
,
dtype
=
dtype
,
device
=
device
))
new_layer_value_cache
=
torch
.
zeros
(
cache_shape
,
dtype
=
self
.
dtype
,
device
=
device
)
self
.
register_buffer
(
f
"value_cache_
{
idx
}
"
,
torch
.
zeros
(
cache_shape
,
dtype
=
dtype
,
device
=
device
))
# Notes:
key_cache
=
getattr
(
self
,
f
"key_cache_
{
idx
}
"
)
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
value_cache
=
getattr
(
self
,
f
"value_cache_
{
idx
}
"
)
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway)
# it is not needed anyway)
# 2. `torch.export()` requires mutations to be registered as buffers.
if
not
is_torchdynamo_compiling
():
if
not
is_torchdynamo_compiling
():
torch
.
_dynamo
.
mark_static_address
(
key_cache
)
self
.
register_buffer
(
f
"key_cache_
{
idx
}
"
,
torch
.
zeros
(
cache_shape
,
dtype
=
dtype
,
device
=
device
))
torch
.
_dynamo
.
mark_static_address
(
value_cache
)
self
.
register_buffer
(
f
"value_cache_
{
idx
}
"
,
torch
.
zeros
(
cache_shape
,
dtype
=
dtype
,
device
=
device
))
self
.
key_cache
.
append
(
key_cache
)
new_layer_key_cache
=
getattr
(
self
,
f
"key_cache_
{
idx
}
"
)
self
.
value_cache
.
append
(
value_cache
)
new_layer_value_cache
=
getattr
(
self
,
f
"value_cache_
{
idx
}
"
)
torch
.
_dynamo
.
mark_static_address
(
new_layer_key_cache
)
torch
.
_dynamo
.
mark_static_address
(
new_layer_value_cache
)
self
.
key_cache
.
append
(
new_layer_key_cache
)
self
.
value_cache
.
append
(
new_layer_value_cache
)
def
update
(
def
update
(
self
,
self
,
...
...
src/transformers/generation/utils.py
View file @
3d8bd119
...
@@ -1429,7 +1429,9 @@ class GenerationMixin:
...
@@ -1429,7 +1429,9 @@ class GenerationMixin:
model_kwargs
[
"cache_position"
]
=
cache_position
model_kwargs
[
"cache_position"
]
=
cache_position
return
model_kwargs
return
model_kwargs
def
_get_cache
(
self
,
cache_implementation
:
str
,
max_batch_size
:
int
,
max_cache_len
:
int
,
model_kwargs
)
->
Cache
:
def
_get_cache
(
self
,
cache_implementation
:
str
,
max_batch_size
:
int
,
max_cache_len
:
int
,
device
:
torch
.
device
,
model_kwargs
)
->
Cache
:
"""
"""
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
new `generate` call requires a larger cache or uses a different batch size.
new `generate` call requires a larger cache or uses a different batch size.
...
@@ -1477,7 +1479,7 @@ class GenerationMixin:
...
@@ -1477,7 +1479,7 @@ class GenerationMixin:
"config"
:
self
.
config
,
"config"
:
self
.
config
,
"max_batch_size"
:
max_batch_size
,
"max_batch_size"
:
max_batch_size
,
"max_cache_len"
:
max_cache_len
,
"max_cache_len"
:
max_cache_len
,
"device"
:
self
.
device
,
"device"
:
device
,
"dtype"
:
cache_dtype
,
"dtype"
:
cache_dtype
,
}
}
self
.
_cache
=
cache_cls
(
**
cache_kwargs
)
self
.
_cache
=
cache_cls
(
**
cache_kwargs
)
...
@@ -1813,12 +1815,11 @@ class GenerationMixin:
...
@@ -1813,12 +1815,11 @@ class GenerationMixin:
"issue: https://github.com/huggingface/transformers/issues/28981"
"issue: https://github.com/huggingface/transformers/issues/28981"
)
)
model_kwargs
[
cache_name
]
=
self
.
_get_cache
(
model_kwargs
[
cache_name
]
=
self
.
_get_cache
(
generation_config
.
cache_implementation
,
cache_implementation
=
generation_config
.
cache_implementation
,
getattr
(
generation_config
,
"num_beams"
,
1
)
max_batch_size
=
generation_config
.
num_beams
*
generation_config
.
num_return_sequences
*
batch_size
,
*
getattr
(
generation_config
,
"num_return_sequences"
,
1
)
max_cache_len
=
generation_config
.
max_length
,
*
batch_size
,
device
=
device
,
generation_config
.
max_length
,
model_kwargs
=
model_kwargs
,
model_kwargs
,
)
)
elif
generation_config
.
cache_implementation
==
"quantized"
:
elif
generation_config
.
cache_implementation
==
"quantized"
:
if
not
self
.
_supports_quantized_cache
:
if
not
self
.
_supports_quantized_cache
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment