Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ef177a5e
Unverified
Commit
ef177a5e
authored
Jul 31, 2024
by
Joao Gante
Committed by
GitHub
Jul 31, 2024
Browse files
Gemma 2: support assisted generation (#32357)
parent
5f1fcc29
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
4 deletions
+20
-4
src/transformers/generation/candidate_generator.py
src/transformers/generation/candidate_generator.py
+3
-0
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+14
-0
src/transformers/models/gemma2/modeling_gemma2.py
src/transformers/models/gemma2/modeling_gemma2.py
+3
-4
No files found.
src/transformers/generation/candidate_generator.py
View file @
ef177a5e
...
...
@@ -171,6 +171,9 @@ class AssistedCandidateGenerator(CandidateGenerator):
"Please pass in `min_length` into `.generate()` instead"
)
# We need to roll back the cache in assisted generation, only DynamicCache is supported
self
.
generation_config
.
cache_implementation
=
None
def
get_candidates
(
self
,
input_ids
:
torch
.
LongTensor
)
->
Tuple
[
torch
.
LongTensor
,
Optional
[
torch
.
FloatTensor
]]:
"""
Fetches the candidates to be tried for the current input.
...
...
src/transformers/generation/utils.py
View file @
ef177a5e
...
...
@@ -1779,6 +1779,20 @@ class GenerationMixin:
cache_name
=
"cache_params"
else
:
cache_name
=
"past_key_values"
# TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches,
# which is only supported in dynamic caches atm
if
(
assistant_model
is
not
None
and
generation_config
.
cache_implementation
is
not
None
and
self
.
_supports_default_dynamic_cache
()
):
logger
.
warning_once
(
"An assistant model is provided, using a dynamic cache instead of a cache of type="
f
"'
{
generation_config
.
cache_implementation
}
'."
)
generation_config
.
cache_implementation
=
None
if
(
model_kwargs
.
get
(
cache_name
)
is
not
None
)
and
is_torchdynamo_compiling
():
raise
ValueError
(
"Passing `past_key_values` is not supported when compiling `model.generate` with torch.compile -- you "
...
...
src/transformers/models/gemma2/modeling_gemma2.py
View file @
ef177a5e
...
...
@@ -27,7 +27,7 @@ from torch import nn
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
...activations
import
ACT2FN
from
...cache_utils
import
Cache
from
...cache_utils
import
Cache
,
HybridCache
from
...modeling_outputs
import
(
BaseModelOutputWithPast
,
CausalLMOutputWithPast
,
...
...
@@ -591,10 +591,9 @@ class Gemma2PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement
=
[
"past_key_values"
]
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
Fals
e
_supports_cache_class
=
Tru
e
_supports_quantized_cache
=
False
_supports_static_cache
=
True
_is_stateful
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
@@ -841,7 +840,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
dtype
,
device
=
input_tensor
.
dtype
,
input_tensor
.
device
min_dtype
=
torch
.
finfo
(
dtype
).
min
sequence_length
=
input_tensor
.
shape
[
1
]
if
past_key_values
is
not
None
:
if
isinstance
(
past_key_values
,
HybridCache
)
:
target_length
=
past_key_values
.
get_max_length
()
else
:
target_length
=
attention_mask
.
shape
[
-
1
]
if
attention_mask
is
not
None
else
input_tensor
.
shape
[
1
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment