Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ef177a5e
Unverified
Commit
ef177a5e
authored
Jul 31, 2024
by
Joao Gante
Committed by
GitHub
Jul 31, 2024
Browse files
Gemma 2: support assisted generation (#32357)
parent
5f1fcc29
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
4 deletions
+20
-4
src/transformers/generation/candidate_generator.py
src/transformers/generation/candidate_generator.py
+3
-0
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+14
-0
src/transformers/models/gemma2/modeling_gemma2.py
src/transformers/models/gemma2/modeling_gemma2.py
+3
-4
No files found.
src/transformers/generation/candidate_generator.py
View file @
ef177a5e
...
...
@@ -171,6 +171,9 @@ class AssistedCandidateGenerator(CandidateGenerator):
"Please pass in `min_length` into `.generate()` instead"
)
# We need to roll back the cache in assisted generation, only DynamicCache is supported
self
.
generation_config
.
cache_implementation
=
None
def
get_candidates
(
self
,
input_ids
:
torch
.
LongTensor
)
->
Tuple
[
torch
.
LongTensor
,
Optional
[
torch
.
FloatTensor
]]:
"""
Fetches the candidates to be tried for the current input.
...
...
src/transformers/generation/utils.py
View file @
ef177a5e
...
...
@@ -1779,6 +1779,20 @@ class GenerationMixin:
cache_name
=
"cache_params"
else
:
cache_name
=
"past_key_values"
# TODO(joao): support static caches in assisted generation. assisted generation needs to roll back caches,
# which is only supported in dynamic caches atm
if
(
assistant_model
is
not
None
and
generation_config
.
cache_implementation
is
not
None
and
self
.
_supports_default_dynamic_cache
()
):
logger
.
warning_once
(
"An assistant model is provided, using a dynamic cache instead of a cache of type="
f
"'
{
generation_config
.
cache_implementation
}
'."
)
generation_config
.
cache_implementation
=
None
if
(
model_kwargs
.
get
(
cache_name
)
is
not
None
)
and
is_torchdynamo_compiling
():
raise
ValueError
(
"Passing `past_key_values` is not supported when compiling `model.generate` with torch.compile -- you "
...
...
src/transformers/models/gemma2/modeling_gemma2.py
View file @
ef177a5e
...
...
@@ -27,7 +27,7 @@ from torch import nn
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
...activations
import
ACT2FN
from
...cache_utils
import
Cache
from
...cache_utils
import
Cache
,
HybridCache
from
...modeling_outputs
import
(
BaseModelOutputWithPast
,
CausalLMOutputWithPast
,
...
...
@@ -591,10 +591,9 @@ class Gemma2PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement
=
[
"past_key_values"
]
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
Fals
e
_supports_cache_class
=
Tru
e
_supports_quantized_cache
=
False
_supports_static_cache
=
True
_is_stateful
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
@@ -841,7 +840,7 @@ class Gemma2Model(Gemma2PreTrainedModel):
dtype
,
device
=
input_tensor
.
dtype
,
input_tensor
.
device
min_dtype
=
torch
.
finfo
(
dtype
).
min
sequence_length
=
input_tensor
.
shape
[
1
]
if
past_key_values
is
not
None
:
if
isinstance
(
past_key_values
,
HybridCache
)
:
target_length
=
past_key_values
.
get_max_length
()
else
:
target_length
=
attention_mask
.
shape
[
-
1
]
if
attention_mask
is
not
None
else
input_tensor
.
shape
[
1
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment