Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9d889f87
Unverified
Commit
9d889f87
authored
May 16, 2024
by
Joao Gante
Committed by
GitHub
May 16, 2024
Browse files
Cache: add new flag to distinguish models that `Cache` but not static cache (#30800)
* jamba cache * new flag * generate exception
parent
17cc71e1
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
23 additions
and
3 deletions
+23
-3
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+5
-0
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+2
-1
src/transformers/models/cohere/modeling_cohere.py
src/transformers/models/cohere/modeling_cohere.py
+1
-0
src/transformers/models/dbrx/modeling_dbrx.py
src/transformers/models/dbrx/modeling_dbrx.py
+1
-0
src/transformers/models/gemma/modeling_gemma.py
src/transformers/models/gemma/modeling_gemma.py
+1
-0
src/transformers/models/idefics2/modeling_idefics2.py
src/transformers/models/idefics2/modeling_idefics2.py
+1
-0
src/transformers/models/jamba/modeling_jamba.py
src/transformers/models/jamba/modeling_jamba.py
+1
-0
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+1
-0
src/transformers/models/mistral/modeling_mistral.py
src/transformers/models/mistral/modeling_mistral.py
+1
-0
src/transformers/models/mixtral/modeling_mixtral.py
src/transformers/models/mixtral/modeling_mixtral.py
+1
-0
src/transformers/models/olmo/modeling_olmo.py
src/transformers/models/olmo/modeling_olmo.py
+1
-0
src/transformers/models/persimmon/modeling_persimmon.py
src/transformers/models/persimmon/modeling_persimmon.py
+1
-0
src/transformers/models/phi/modeling_phi.py
src/transformers/models/phi/modeling_phi.py
+1
-0
src/transformers/models/phi3/modeling_phi3.py
src/transformers/models/phi3/modeling_phi3.py
+1
-0
src/transformers/models/qwen2/modeling_qwen2.py
src/transformers/models/qwen2/modeling_qwen2.py
+1
-0
src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
+1
-0
src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
...ormers/models/recurrent_gemma/modeling_recurrent_gemma.py
+0
-1
src/transformers/models/starcoder2/modeling_starcoder2.py
src/transformers/models/starcoder2/modeling_starcoder2.py
+1
-0
tests/test_modeling_common.py
tests/test_modeling_common.py
+1
-1
No files found.
src/transformers/generation/utils.py
View file @
9d889f87
...
...
@@ -1616,6 +1616,11 @@ class GenerationMixin:
"issue: https://github.com/huggingface/transformers/issues/28981."
)
if
generation_config
.
cache_implementation
==
"static"
:
if
not
self
.
_supports_static_cache
:
raise
ValueError
(
"This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs
[
"past_key_values"
]
=
self
.
_get_static_cache
(
batch_size
,
generation_config
.
max_length
)
self
.
_validate_generated_length
(
generation_config
,
input_ids_length
,
has_default_max_length
)
...
...
src/transformers/modeling_utils.py
View file @
9d889f87
...
...
@@ -1280,8 +1280,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# SDPA support
_supports_sdpa
=
False
# Has support for a `Cache` instance as `past_key_values`
# Has support for a `Cache` instance as `past_key_values`
? Does it support a `StaticCache`?
_supports_cache_class
=
False
_supports_static_cache
=
False
@
property
def
dummy_inputs
(
self
)
->
Dict
[
str
,
torch
.
Tensor
]:
...
...
src/transformers/models/cohere/modeling_cohere.py
View file @
9d889f87
...
...
@@ -720,6 +720,7 @@ class CoherePreTrainedModel(PreTrainedModel):
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
_supports_static_cache
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/dbrx/modeling_dbrx.py
View file @
9d889f87
...
...
@@ -938,6 +938,7 @@ class DbrxPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
_supports_static_cache
=
True
def
_init_weights
(
self
,
module
:
nn
.
Module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/gemma/modeling_gemma.py
View file @
9d889f87
...
...
@@ -703,6 +703,7 @@ class GemmaPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
_supports_static_cache
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/idefics2/modeling_idefics2.py
View file @
9d889f87
...
...
@@ -1341,6 +1341,7 @@ class Idefics2PreTrainedModel(PreTrainedModel):
_no_split_modules
=
[
"Idefics2VisionAttention"
,
"Idefics2MLP"
,
"Idefics2PerceiverLayer"
,
"Idefics2DecoderLayer"
]
_skip_keys_device_placement
=
"past_key_values"
_supports_flash_attn_2
=
True
_supports_cache_class
=
True
def
_init_weights
(
self
,
module
):
# important: this ported version of Idefics2 isn't meant for training from scratch - only
...
...
src/transformers/models/jamba/modeling_jamba.py
View file @
9d889f87
...
...
@@ -1261,6 +1261,7 @@ class JambaPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement
=
"past_key_values"
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
# Note: only supports HybridMambaAttentionDynamicCache
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/llama/modeling_llama.py
View file @
9d889f87
...
...
@@ -799,6 +799,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
_supports_static_cache
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/mistral/modeling_mistral.py
View file @
9d889f87
...
...
@@ -810,6 +810,7 @@ class MistralPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement
=
"past_key_values"
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/mixtral/modeling_mixtral.py
View file @
9d889f87
...
...
@@ -989,6 +989,7 @@ class MixtralPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement
=
"past_key_values"
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/olmo/modeling_olmo.py
View file @
9d889f87
...
...
@@ -776,6 +776,7 @@ class OlmoPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
_supports_static_cache
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/persimmon/modeling_persimmon.py
View file @
9d889f87
...
...
@@ -457,6 +457,7 @@ class PersimmonPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
"PersimmonDecoderLayer"
]
_skip_keys_device_placement
=
"past_key_values"
_supports_cache_class
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/phi/modeling_phi.py
View file @
9d889f87
...
...
@@ -825,6 +825,7 @@ class PhiPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement
=
"past_key_values"
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/phi3/modeling_phi3.py
View file @
9d889f87
...
...
@@ -921,6 +921,7 @@ class Phi3PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement
=
"past_key_values"
_supports_flash_attn_2
=
True
_supports_sdpa
=
False
_supports_cache_class
=
True
_version
=
"0.0.5"
...
...
src/transformers/models/qwen2/modeling_qwen2.py
View file @
9d889f87
...
...
@@ -821,6 +821,7 @@ class Qwen2PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement
=
"past_key_values"
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
View file @
9d889f87
...
...
@@ -975,6 +975,7 @@ class Qwen2MoePreTrainedModel(PreTrainedModel):
_skip_keys_device_placement
=
"past_key_values"
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
View file @
9d889f87
...
...
@@ -541,7 +541,6 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement
=
[
"cache"
]
_supports_flash_attn_2
=
False
_supports_sdpa
=
False
# we can't compare with eager for now
_supports_cache_class
=
True
def
_init_weights
(
self
,
module
):
std
=
math
.
sqrt
(
self
.
config
.
w_init_variance_scale
/
self
.
config
.
conv1d_width
)
...
...
src/transformers/models/starcoder2/modeling_starcoder2.py
View file @
9d889f87
...
...
@@ -799,6 +799,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement
=
"past_key_values"
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
tests/test_modeling_common.py
View file @
9d889f87
...
...
@@ -4365,7 +4365,7 @@ class ModelTesterMixin:
self
.
skipTest
(
"Model architecture has no generative classes, and thus not necessarily supporting 4D masks"
)
for
model_class
in
self
.
all_generative_model_classes
:
if
not
model_class
.
_supports_
cache_class
:
if
not
model_class
.
_supports_
static_cache
:
self
.
skipTest
(
f
"
{
model_class
.
__name__
}
is not guaranteed to work with custom 4D attention masks"
)
config
,
_
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
model
=
model_class
(
config
).
to
(
device
=
torch_device
,
dtype
=
torch
.
float32
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment