Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f26e4073
Unverified
Commit
f26e4073
authored
May 08, 2024
by
Joao Gante
Committed by
GitHub
May 08, 2024
Browse files
Cache: models return input cache type (#30716)
parent
71c19850
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
30 additions
and
70 deletions
+30
-70
src/transformers/models/cohere/modeling_cohere.py
src/transformers/models/cohere/modeling_cohere.py
+6
-5
src/transformers/models/dbrx/modeling_dbrx.py
src/transformers/models/dbrx/modeling_dbrx.py
+6
-7
src/transformers/models/gemma/modeling_gemma.py
src/transformers/models/gemma/modeling_gemma.py
+6
-7
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+6
-7
src/transformers/models/olmo/modeling_olmo.py
src/transformers/models/olmo/modeling_olmo.py
+6
-7
tests/models/cohere/test_modeling_cohere.py
tests/models/cohere/test_modeling_cohere.py
+0
-7
tests/models/dbrx/test_modeling_dbrx.py
tests/models/dbrx/test_modeling_dbrx.py
+0
-7
tests/models/gemma/test_modeling_gemma.py
tests/models/gemma/test_modeling_gemma.py
+0
-6
tests/models/llama/test_modeling_llama.py
tests/models/llama/test_modeling_llama.py
+0
-5
tests/models/olmo/test_modeling_olmo.py
tests/models/olmo/test_modeling_olmo.py
+0
-5
tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py
...s/models/recurrent_gemma/test_modeling_recurrent_gemma.py
+0
-7
No files found.
src/transformers/models/cohere/modeling_cohere.py
View file @
f26e4073
...
...
@@ -881,7 +881,9 @@ class CohereModel(CoherePreTrainedModel):
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
past_seen_tokens
=
0
return_legacy_cache
=
False
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache
=
True
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
if
cache_position
is
None
:
...
...
@@ -943,11 +945,10 @@ class CohereModel(CoherePreTrainedModel):
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
next_cache
=
None
if
use_cache
:
next_cache
=
(
next_decoder_cache
.
to_legacy_cache
()
if
isinstance
(
next_decoder_cache
,
Cache
)
else
next_decoder_cache
)
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
return_legacy_cache
:
next_cache
=
next_cache
.
to_legacy_cache
()
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
BaseModelOutputWithPast
(
...
...
src/transformers/models/dbrx/modeling_dbrx.py
View file @
f26e4073
...
...
@@ -1115,7 +1115,9 @@ class DbrxModel(DbrxPreTrainedModel):
inputs_embeds
=
nn
.
functional
.
dropout
(
inputs_embeds
,
p
=
self
.
emb_pdrop
,
training
=
self
.
training
)
return_legacy_cache
=
False
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache
=
True
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
if
cache_position
is
None
:
...
...
@@ -1182,13 +1184,10 @@ class DbrxModel(DbrxPreTrainedModel):
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
next_cache
=
None
if
use_cache
:
next_cache
=
(
next_decoder_cache
.
to_legacy_cache
()
if
isinstance
(
next_decoder_cache
,
DynamicCache
)
else
next_decoder_cache
)
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
return_legacy_cache
:
next_cache
=
next_cache
.
to_legacy_cache
()
if
not
return_dict
:
return
tuple
(
v
...
...
src/transformers/models/gemma/modeling_gemma.py
View file @
f26e4073
...
...
@@ -865,7 +865,9 @@ class GemmaModel(GemmaPreTrainedModel):
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
return_legacy_cache
=
False
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache
=
True
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
if
cache_position
is
None
:
...
...
@@ -933,13 +935,10 @@ class GemmaModel(GemmaPreTrainedModel):
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
next_cache
=
None
if
use_cache
:
next_cache
=
(
next_decoder_cache
.
to_legacy_cache
()
if
isinstance
(
next_decoder_cache
,
DynamicCache
)
else
next_decoder_cache
)
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
return_legacy_cache
:
next_cache
=
next_cache
.
to_legacy_cache
()
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
BaseModelOutputWithPast
(
...
...
src/transformers/models/llama/modeling_llama.py
View file @
f26e4073
...
...
@@ -960,7 +960,9 @@ class LlamaModel(LlamaPreTrainedModel):
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
return_legacy_cache
=
False
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache
=
True
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
if
cache_position
is
None
:
...
...
@@ -1021,13 +1023,10 @@ class LlamaModel(LlamaPreTrainedModel):
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
next_cache
=
None
if
use_cache
:
next_cache
=
(
next_decoder_cache
.
to_legacy_cache
()
if
isinstance
(
next_decoder_cache
,
DynamicCache
)
else
next_decoder_cache
)
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
return_legacy_cache
:
next_cache
=
next_cache
.
to_legacy_cache
()
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
BaseModelOutputWithPast
(
...
...
src/transformers/models/olmo/modeling_olmo.py
View file @
f26e4073
...
...
@@ -938,7 +938,9 @@ class OlmoModel(OlmoPreTrainedModel):
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
return_legacy_cache
=
False
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache
=
True
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
if
cache_position
is
None
:
...
...
@@ -999,13 +1001,10 @@ class OlmoModel(OlmoPreTrainedModel):
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
next_cache
=
None
if
use_cache
:
next_cache
=
(
next_decoder_cache
.
to_legacy_cache
()
if
isinstance
(
next_decoder_cache
,
DynamicCache
)
else
next_decoder_cache
)
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
return_legacy_cache
:
next_cache
=
next_cache
.
to_legacy_cache
()
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
BaseModelOutputWithPast
(
...
...
tests/models/cohere/test_modeling_cohere.py
View file @
f26e4073
...
...
@@ -16,8 +16,6 @@
import
unittest
from
parameterized
import
parameterized
from
transformers
import
CohereConfig
,
is_torch_available
from
transformers.testing_utils
import
(
require_bitsandbytes
,
...
...
@@ -296,11 +294,6 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def
test_config
(
self
):
self
.
config_tester
.
run_common_tests
()
@
unittest
.
skip
(
"TODO @gante fix this for Cohere"
)
@
parameterized
.
expand
([(
1
,
False
),
(
1
,
True
),
(
4
,
False
)])
def
test_new_cache_format
(
self
,
num_beams
,
do_sample
):
pass
def
test_model
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_model
(
*
config_and_inputs
)
...
...
tests/models/dbrx/test_modeling_dbrx.py
View file @
f26e4073
...
...
@@ -17,8 +17,6 @@
import
unittest
from
parameterized
import
parameterized
from
transformers
import
DbrxConfig
,
is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
...
...
@@ -357,11 +355,6 @@ class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def
test_tied_weights_keys
(
self
):
pass
@
unittest
.
skip
(
"TODO @gante fix this for Llama"
)
@
parameterized
.
expand
([(
1
,
False
),
(
1
,
True
),
(
4
,
False
)])
def
test_new_cache_format
(
self
,
num_beams
,
do_sample
):
pass
@
require_torch
class
DbrxModelIntegrationTest
(
unittest
.
TestCase
):
...
...
tests/models/gemma/test_modeling_gemma.py
View file @
f26e4073
...
...
@@ -17,7 +17,6 @@ import tempfile
import
unittest
import
pytest
from
parameterized
import
parameterized
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
GemmaConfig
,
is_torch_available
from
transformers.testing_utils
import
(
...
...
@@ -367,11 +366,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
result
=
model
(
input_ids
,
attention_mask
=
attention_mask
,
labels
=
sequence_labels
)
self
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
model_tester
.
batch_size
,
self
.
model_tester
.
num_labels
))
@
unittest
.
skip
(
"TODO @gante fix this for Llama"
)
@
parameterized
.
expand
([(
1
,
False
),
(
1
,
True
),
(
4
,
False
)])
def
test_new_cache_format
(
self
,
num_beams
,
do_sample
):
pass
@
unittest
.
skip
(
"Gemma buffers include complex numbers, which breaks this test"
)
def
test_save_load_fast_init_from_base
(
self
):
pass
...
...
tests/models/llama/test_modeling_llama.py
View file @
f26e4073
...
...
@@ -591,11 +591,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
msg
=
f
"
\n
{
tokenizer
.
batch_decode
(
res_eager
)
}
\n
vs
\n
{
tokenizer
.
batch_decode
(
res_sdpa
)
}
"
,
)
@
unittest
.
skip
(
"TODO @gante fix this for Llama"
)
@
parameterized
.
expand
([(
1
,
False
),
(
1
,
True
),
(
4
,
False
)])
def
test_new_cache_format
(
self
,
num_beams
,
do_sample
):
pass
@
require_torch_gpu
class
LlamaIntegrationTest
(
unittest
.
TestCase
):
...
...
tests/models/olmo/test_modeling_olmo.py
View file @
f26e4073
...
...
@@ -353,11 +353,6 @@ class OlmoModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
# The output should be different for long inputs
self
.
assertFalse
(
torch
.
allclose
(
original_long_output
,
scaled_long_output
,
atol
=
1e-5
))
@
unittest
.
skip
(
"TODO @gante fix this for OLMo"
)
@
parameterized
.
expand
([(
1
,
False
),
(
1
,
True
),
(
4
,
False
)])
def
test_new_cache_format
(
self
,
num_beams
,
do_sample
):
pass
@
require_torch
class
OlmoIntegrationTest
(
unittest
.
TestCase
):
...
...
tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py
View file @
f26e4073
...
...
@@ -15,8 +15,6 @@
""" Testing suite for the PyTorch RecurrentGemma model. """
import
unittest
from
parameterized
import
parameterized
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
RecurrentGemmaConfig
,
is_torch_available
,
set_seed
from
transformers.testing_utils
import
(
require_bitsandbytes
,
...
...
@@ -330,11 +328,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
config_and_inputs
[
0
].
position_embedding_type
=
type
self
.
model_tester
.
create_and_check_model
(
*
config_and_inputs
)
@
unittest
.
skip
(
"Recurrent gemma does not use legacy cache"
)
@
parameterized
.
expand
([(
1
,
False
),
(
1
,
True
),
(
4
,
False
)])
def
test_new_cache_format
(
self
,
num_beams
,
do_sample
):
pass
def
test_save_load_fast_init_from_base
(
self
):
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment