Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f26e4073
Unverified
Commit
f26e4073
authored
May 08, 2024
by
Joao Gante
Committed by
GitHub
May 08, 2024
Browse files
Cache: models return input cache type (#30716)
parent
71c19850
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
30 additions
and
70 deletions
+30
-70
src/transformers/models/cohere/modeling_cohere.py
src/transformers/models/cohere/modeling_cohere.py
+6
-5
src/transformers/models/dbrx/modeling_dbrx.py
src/transformers/models/dbrx/modeling_dbrx.py
+6
-7
src/transformers/models/gemma/modeling_gemma.py
src/transformers/models/gemma/modeling_gemma.py
+6
-7
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+6
-7
src/transformers/models/olmo/modeling_olmo.py
src/transformers/models/olmo/modeling_olmo.py
+6
-7
tests/models/cohere/test_modeling_cohere.py
tests/models/cohere/test_modeling_cohere.py
+0
-7
tests/models/dbrx/test_modeling_dbrx.py
tests/models/dbrx/test_modeling_dbrx.py
+0
-7
tests/models/gemma/test_modeling_gemma.py
tests/models/gemma/test_modeling_gemma.py
+0
-6
tests/models/llama/test_modeling_llama.py
tests/models/llama/test_modeling_llama.py
+0
-5
tests/models/olmo/test_modeling_olmo.py
tests/models/olmo/test_modeling_olmo.py
+0
-5
tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py
...s/models/recurrent_gemma/test_modeling_recurrent_gemma.py
+0
-7
No files found.
src/transformers/models/cohere/modeling_cohere.py
View file @
f26e4073
...
...
@@ -881,7 +881,9 @@ class CohereModel(CoherePreTrainedModel):
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
past_seen_tokens
=
0
return_legacy_cache
=
False
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache
=
True
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
if
cache_position
is
None
:
...
...
@@ -943,11 +945,10 @@ class CohereModel(CoherePreTrainedModel):
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
next_cache
=
None
if
use_cache
:
next_cache
=
(
next_decoder_cache
.
to_legacy_cache
()
if
isinstance
(
next_decoder_cache
,
Cache
)
else
next_decoder_cache
)
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
return_legacy_cache
:
next_cache
=
next_cache
.
to_legacy_cache
()
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
BaseModelOutputWithPast
(
...
...
src/transformers/models/dbrx/modeling_dbrx.py
View file @
f26e4073
...
...
@@ -1115,7 +1115,9 @@ class DbrxModel(DbrxPreTrainedModel):
inputs_embeds
=
nn
.
functional
.
dropout
(
inputs_embeds
,
p
=
self
.
emb_pdrop
,
training
=
self
.
training
)
return_legacy_cache
=
False
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache
=
True
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
if
cache_position
is
None
:
...
...
@@ -1182,13 +1184,10 @@ class DbrxModel(DbrxPreTrainedModel):
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
next_cache
=
None
if
use_cache
:
next_cache
=
(
next_decoder_cache
.
to_legacy_cache
()
if
isinstance
(
next_decoder_cache
,
DynamicCache
)
else
next_decoder_cache
)
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
return_legacy_cache
:
next_cache
=
next_cache
.
to_legacy_cache
()
if
not
return_dict
:
return
tuple
(
v
...
...
src/transformers/models/gemma/modeling_gemma.py
View file @
f26e4073
...
...
@@ -865,7 +865,9 @@ class GemmaModel(GemmaPreTrainedModel):
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
return_legacy_cache
=
False
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache
=
True
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
if
cache_position
is
None
:
...
...
@@ -933,13 +935,10 @@ class GemmaModel(GemmaPreTrainedModel):
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
next_cache
=
None
if
use_cache
:
next_cache
=
(
next_decoder_cache
.
to_legacy_cache
()
if
isinstance
(
next_decoder_cache
,
DynamicCache
)
else
next_decoder_cache
)
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
return_legacy_cache
:
next_cache
=
next_cache
.
to_legacy_cache
()
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
BaseModelOutputWithPast
(
...
...
src/transformers/models/llama/modeling_llama.py
View file @
f26e4073
...
...
@@ -960,7 +960,9 @@ class LlamaModel(LlamaPreTrainedModel):
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
return_legacy_cache
=
False
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache
=
True
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
if
cache_position
is
None
:
...
...
@@ -1021,13 +1023,10 @@ class LlamaModel(LlamaPreTrainedModel):
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
next_cache
=
None
if
use_cache
:
next_cache
=
(
next_decoder_cache
.
to_legacy_cache
()
if
isinstance
(
next_decoder_cache
,
DynamicCache
)
else
next_decoder_cache
)
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
return_legacy_cache
:
next_cache
=
next_cache
.
to_legacy_cache
()
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
BaseModelOutputWithPast
(
...
...
src/transformers/models/olmo/modeling_olmo.py
View file @
f26e4073
...
...
@@ -938,7 +938,9 @@ class OlmoModel(OlmoPreTrainedModel):
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
return_legacy_cache
=
False
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache
=
True
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
if
cache_position
is
None
:
...
...
@@ -999,13 +1001,10 @@ class OlmoModel(OlmoPreTrainedModel):
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
next_cache
=
None
if
use_cache
:
next_cache
=
(
next_decoder_cache
.
to_legacy_cache
()
if
isinstance
(
next_decoder_cache
,
DynamicCache
)
else
next_decoder_cache
)
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
return_legacy_cache
:
next_cache
=
next_cache
.
to_legacy_cache
()
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
BaseModelOutputWithPast
(
...
...
tests/models/cohere/test_modeling_cohere.py
View file @
f26e4073
...
...
@@ -16,8 +16,6 @@
import
unittest
from
parameterized
import
parameterized
from
transformers
import
CohereConfig
,
is_torch_available
from
transformers.testing_utils
import
(
require_bitsandbytes
,
...
...
@@ -296,11 +294,6 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def
test_config
(
self
):
self
.
config_tester
.
run_common_tests
()
@
unittest
.
skip
(
"TODO @gante fix this for Cohere"
)
@
parameterized
.
expand
([(
1
,
False
),
(
1
,
True
),
(
4
,
False
)])
def
test_new_cache_format
(
self
,
num_beams
,
do_sample
):
pass
def
test_model
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_model
(
*
config_and_inputs
)
...
...
tests/models/dbrx/test_modeling_dbrx.py
View file @
f26e4073
...
...
@@ -17,8 +17,6 @@
import
unittest
from
parameterized
import
parameterized
from
transformers
import
DbrxConfig
,
is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
...
...
@@ -357,11 +355,6 @@ class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def
test_tied_weights_keys
(
self
):
pass
@
unittest
.
skip
(
"TODO @gante fix this for Llama"
)
@
parameterized
.
expand
([(
1
,
False
),
(
1
,
True
),
(
4
,
False
)])
def
test_new_cache_format
(
self
,
num_beams
,
do_sample
):
pass
@
require_torch
class
DbrxModelIntegrationTest
(
unittest
.
TestCase
):
...
...
tests/models/gemma/test_modeling_gemma.py
View file @
f26e4073
...
...
@@ -17,7 +17,6 @@ import tempfile
import
unittest
import
pytest
from
parameterized
import
parameterized
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
GemmaConfig
,
is_torch_available
from
transformers.testing_utils
import
(
...
...
@@ -367,11 +366,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
result
=
model
(
input_ids
,
attention_mask
=
attention_mask
,
labels
=
sequence_labels
)
self
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
model_tester
.
batch_size
,
self
.
model_tester
.
num_labels
))
@
unittest
.
skip
(
"TODO @gante fix this for Llama"
)
@
parameterized
.
expand
([(
1
,
False
),
(
1
,
True
),
(
4
,
False
)])
def
test_new_cache_format
(
self
,
num_beams
,
do_sample
):
pass
@
unittest
.
skip
(
"Gemma buffers include complex numbers, which breaks this test"
)
def
test_save_load_fast_init_from_base
(
self
):
pass
...
...
tests/models/llama/test_modeling_llama.py
View file @
f26e4073
...
...
@@ -591,11 +591,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
msg
=
f
"
\n
{
tokenizer
.
batch_decode
(
res_eager
)
}
\n
vs
\n
{
tokenizer
.
batch_decode
(
res_sdpa
)
}
"
,
)
@
unittest
.
skip
(
"TODO @gante fix this for Llama"
)
@
parameterized
.
expand
([(
1
,
False
),
(
1
,
True
),
(
4
,
False
)])
def
test_new_cache_format
(
self
,
num_beams
,
do_sample
):
pass
@
require_torch_gpu
class
LlamaIntegrationTest
(
unittest
.
TestCase
):
...
...
tests/models/olmo/test_modeling_olmo.py
View file @
f26e4073
...
...
@@ -353,11 +353,6 @@ class OlmoModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
# The output should be different for long inputs
self
.
assertFalse
(
torch
.
allclose
(
original_long_output
,
scaled_long_output
,
atol
=
1e-5
))
@
unittest
.
skip
(
"TODO @gante fix this for OLMo"
)
@
parameterized
.
expand
([(
1
,
False
),
(
1
,
True
),
(
4
,
False
)])
def
test_new_cache_format
(
self
,
num_beams
,
do_sample
):
pass
@
require_torch
class
OlmoIntegrationTest
(
unittest
.
TestCase
):
...
...
tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py
View file @
f26e4073
...
...
@@ -15,8 +15,6 @@
""" Testing suite for the PyTorch RecurrentGemma model. """
import
unittest
from
parameterized
import
parameterized
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
RecurrentGemmaConfig
,
is_torch_available
,
set_seed
from
transformers.testing_utils
import
(
require_bitsandbytes
,
...
...
@@ -330,11 +328,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
config_and_inputs
[
0
].
position_embedding_type
=
type
self
.
model_tester
.
create_and_check_model
(
*
config_and_inputs
)
@
unittest
.
skip
(
"Recurrent gemma does not use legacy cache"
)
@
parameterized
.
expand
([(
1
,
False
),
(
1
,
True
),
(
4
,
False
)])
def
test_new_cache_format
(
self
,
num_beams
,
do_sample
):
pass
def
test_save_load_fast_init_from_base
(
self
):
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment