Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9d889f87
Unverified
Commit
9d889f87
authored
May 16, 2024
by
Joao Gante
Committed by
GitHub
May 16, 2024
Browse files
Cache: add new flag to distinguish models that `Cache` but not static cache (#30800)
* jamba cache * new flag * generate exception
parent
17cc71e1
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
23 additions
and
3 deletions
+23
-3
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+5
-0
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+2
-1
src/transformers/models/cohere/modeling_cohere.py
src/transformers/models/cohere/modeling_cohere.py
+1
-0
src/transformers/models/dbrx/modeling_dbrx.py
src/transformers/models/dbrx/modeling_dbrx.py
+1
-0
src/transformers/models/gemma/modeling_gemma.py
src/transformers/models/gemma/modeling_gemma.py
+1
-0
src/transformers/models/idefics2/modeling_idefics2.py
src/transformers/models/idefics2/modeling_idefics2.py
+1
-0
src/transformers/models/jamba/modeling_jamba.py
src/transformers/models/jamba/modeling_jamba.py
+1
-0
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+1
-0
src/transformers/models/mistral/modeling_mistral.py
src/transformers/models/mistral/modeling_mistral.py
+1
-0
src/transformers/models/mixtral/modeling_mixtral.py
src/transformers/models/mixtral/modeling_mixtral.py
+1
-0
src/transformers/models/olmo/modeling_olmo.py
src/transformers/models/olmo/modeling_olmo.py
+1
-0
src/transformers/models/persimmon/modeling_persimmon.py
src/transformers/models/persimmon/modeling_persimmon.py
+1
-0
src/transformers/models/phi/modeling_phi.py
src/transformers/models/phi/modeling_phi.py
+1
-0
src/transformers/models/phi3/modeling_phi3.py
src/transformers/models/phi3/modeling_phi3.py
+1
-0
src/transformers/models/qwen2/modeling_qwen2.py
src/transformers/models/qwen2/modeling_qwen2.py
+1
-0
src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
+1
-0
src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
...ormers/models/recurrent_gemma/modeling_recurrent_gemma.py
+0
-1
src/transformers/models/starcoder2/modeling_starcoder2.py
src/transformers/models/starcoder2/modeling_starcoder2.py
+1
-0
tests/test_modeling_common.py
tests/test_modeling_common.py
+1
-1
No files found.
src/transformers/generation/utils.py
View file @
9d889f87
...
...
@@ -1616,6 +1616,11 @@ class GenerationMixin:
"issue: https://github.com/huggingface/transformers/issues/28981."
)
if
generation_config
.
cache_implementation
==
"static"
:
if
not
self
.
_supports_static_cache
:
raise
ValueError
(
"This model does not support `cache_implementation='static'`. Please check the following "
"issue: https://github.com/huggingface/transformers/issues/28981"
)
model_kwargs
[
"past_key_values"
]
=
self
.
_get_static_cache
(
batch_size
,
generation_config
.
max_length
)
self
.
_validate_generated_length
(
generation_config
,
input_ids_length
,
has_default_max_length
)
...
...
src/transformers/modeling_utils.py
View file @
9d889f87
...
...
@@ -1280,8 +1280,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# SDPA support
_supports_sdpa
=
False
# Has support for a `Cache` instance as `past_key_values`
# Has support for a `Cache` instance as `past_key_values`
? Does it support a `StaticCache`?
_supports_cache_class
=
False
_supports_static_cache
=
False
@
property
def
dummy_inputs
(
self
)
->
Dict
[
str
,
torch
.
Tensor
]:
...
...
src/transformers/models/cohere/modeling_cohere.py
View file @
9d889f87
...
...
@@ -720,6 +720,7 @@ class CoherePreTrainedModel(PreTrainedModel):
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
_supports_static_cache
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/dbrx/modeling_dbrx.py
View file @
9d889f87
...
...
@@ -938,6 +938,7 @@ class DbrxPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
_supports_static_cache
=
True
def
_init_weights
(
self
,
module
:
nn
.
Module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/gemma/modeling_gemma.py
View file @
9d889f87
...
...
@@ -703,6 +703,7 @@ class GemmaPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
_supports_static_cache
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/idefics2/modeling_idefics2.py
View file @
9d889f87
...
...
@@ -1341,6 +1341,7 @@ class Idefics2PreTrainedModel(PreTrainedModel):
_no_split_modules
=
[
"Idefics2VisionAttention"
,
"Idefics2MLP"
,
"Idefics2PerceiverLayer"
,
"Idefics2DecoderLayer"
]
_skip_keys_device_placement
=
"past_key_values"
_supports_flash_attn_2
=
True
_supports_cache_class
=
True
def
_init_weights
(
self
,
module
):
# important: this ported version of Idefics2 isn't meant for training from scratch - only
...
...
src/transformers/models/jamba/modeling_jamba.py
View file @
9d889f87
...
...
@@ -1261,6 +1261,7 @@ class JambaPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement
=
"past_key_values"
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
# Note: only supports HybridMambaAttentionDynamicCache
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/llama/modeling_llama.py
View file @
9d889f87
...
...
@@ -799,6 +799,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
_supports_static_cache
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/mistral/modeling_mistral.py
View file @
9d889f87
...
...
@@ -810,6 +810,7 @@ class MistralPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement
=
"past_key_values"
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/mixtral/modeling_mixtral.py
View file @
9d889f87
...
...
@@ -989,6 +989,7 @@ class MixtralPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement
=
"past_key_values"
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/olmo/modeling_olmo.py
View file @
9d889f87
...
...
@@ -776,6 +776,7 @@ class OlmoPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
_supports_static_cache
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/persimmon/modeling_persimmon.py
View file @
9d889f87
...
...
@@ -457,6 +457,7 @@ class PersimmonPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
"PersimmonDecoderLayer"
]
_skip_keys_device_placement
=
"past_key_values"
_supports_cache_class
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/phi/modeling_phi.py
View file @
9d889f87
...
...
@@ -825,6 +825,7 @@ class PhiPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement
=
"past_key_values"
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/phi3/modeling_phi3.py
View file @
9d889f87
...
...
@@ -921,6 +921,7 @@ class Phi3PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement
=
"past_key_values"
_supports_flash_attn_2
=
True
_supports_sdpa
=
False
_supports_cache_class
=
True
_version
=
"0.0.5"
...
...
src/transformers/models/qwen2/modeling_qwen2.py
View file @
9d889f87
...
...
@@ -821,6 +821,7 @@ class Qwen2PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement
=
"past_key_values"
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
View file @
9d889f87
...
...
@@ -975,6 +975,7 @@ class Qwen2MoePreTrainedModel(PreTrainedModel):
_skip_keys_device_placement
=
"past_key_values"
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
View file @
9d889f87
...
...
@@ -541,7 +541,6 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement
=
[
"cache"
]
_supports_flash_attn_2
=
False
_supports_sdpa
=
False
# we can't compare with eager for now
_supports_cache_class
=
True
def
_init_weights
(
self
,
module
):
std
=
math
.
sqrt
(
self
.
config
.
w_init_variance_scale
/
self
.
config
.
conv1d_width
)
...
...
src/transformers/models/starcoder2/modeling_starcoder2.py
View file @
9d889f87
...
...
@@ -799,6 +799,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement
=
"past_key_values"
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
tests/test_modeling_common.py
View file @
9d889f87
...
...
@@ -4365,7 +4365,7 @@ class ModelTesterMixin:
self
.
skipTest
(
"Model architecture has no generative classes, and thus not necessarily supporting 4D masks"
)
for
model_class
in
self
.
all_generative_model_classes
:
if
not
model_class
.
_supports_
cache_class
:
if
not
model_class
.
_supports_
static_cache
:
self
.
skipTest
(
f
"
{
model_class
.
__name__
}
is not guaranteed to work with custom 4D attention masks"
)
config
,
_
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
model
=
model_class
(
config
).
to
(
device
=
torch_device
,
dtype
=
torch
.
float32
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment