Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
f26e4073
Unverified
Commit
f26e4073
authored
May 08, 2024
by
Joao Gante
Committed by
GitHub
May 08, 2024
Browse files
Cache: models return input cache type (#30716)
parent
71c19850
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
30 additions
and
70 deletions
+30
-70
src/transformers/models/cohere/modeling_cohere.py
src/transformers/models/cohere/modeling_cohere.py
+6
-5
src/transformers/models/dbrx/modeling_dbrx.py
src/transformers/models/dbrx/modeling_dbrx.py
+6
-7
src/transformers/models/gemma/modeling_gemma.py
src/transformers/models/gemma/modeling_gemma.py
+6
-7
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+6
-7
src/transformers/models/olmo/modeling_olmo.py
src/transformers/models/olmo/modeling_olmo.py
+6
-7
tests/models/cohere/test_modeling_cohere.py
tests/models/cohere/test_modeling_cohere.py
+0
-7
tests/models/dbrx/test_modeling_dbrx.py
tests/models/dbrx/test_modeling_dbrx.py
+0
-7
tests/models/gemma/test_modeling_gemma.py
tests/models/gemma/test_modeling_gemma.py
+0
-6
tests/models/llama/test_modeling_llama.py
tests/models/llama/test_modeling_llama.py
+0
-5
tests/models/olmo/test_modeling_olmo.py
tests/models/olmo/test_modeling_olmo.py
+0
-5
tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py
...s/models/recurrent_gemma/test_modeling_recurrent_gemma.py
+0
-7
No files found.
src/transformers/models/cohere/modeling_cohere.py
View file @
f26e4073
...
@@ -881,7 +881,9 @@ class CohereModel(CoherePreTrainedModel):
...
@@ -881,7 +881,9 @@ class CohereModel(CoherePreTrainedModel):
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
past_seen_tokens
=
0
past_seen_tokens
=
0
return_legacy_cache
=
False
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache
=
True
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
if
cache_position
is
None
:
if
cache_position
is
None
:
...
@@ -943,11 +945,10 @@ class CohereModel(CoherePreTrainedModel):
...
@@ -943,11 +945,10 @@ class CohereModel(CoherePreTrainedModel):
if
output_hidden_states
:
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
all_hidden_states
+=
(
hidden_states
,)
next_cache
=
None
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
use_cache
:
if
return_legacy_cache
:
next_cache
=
(
next_cache
=
next_cache
.
to_legacy_cache
()
next_decoder_cache
.
to_legacy_cache
()
if
isinstance
(
next_decoder_cache
,
Cache
)
else
next_decoder_cache
)
if
not
return_dict
:
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
BaseModelOutputWithPast
(
return
BaseModelOutputWithPast
(
...
...
src/transformers/models/dbrx/modeling_dbrx.py
View file @
f26e4073
...
@@ -1115,7 +1115,9 @@ class DbrxModel(DbrxPreTrainedModel):
...
@@ -1115,7 +1115,9 @@ class DbrxModel(DbrxPreTrainedModel):
inputs_embeds
=
nn
.
functional
.
dropout
(
inputs_embeds
,
p
=
self
.
emb_pdrop
,
training
=
self
.
training
)
inputs_embeds
=
nn
.
functional
.
dropout
(
inputs_embeds
,
p
=
self
.
emb_pdrop
,
training
=
self
.
training
)
return_legacy_cache
=
False
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache
=
True
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
if
cache_position
is
None
:
if
cache_position
is
None
:
...
@@ -1182,13 +1184,10 @@ class DbrxModel(DbrxPreTrainedModel):
...
@@ -1182,13 +1184,10 @@ class DbrxModel(DbrxPreTrainedModel):
if
output_hidden_states
:
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
all_hidden_states
+=
(
hidden_states
,)
next_cache
=
None
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
use_cache
:
if
return_legacy_cache
:
next_cache
=
(
next_cache
=
next_cache
.
to_legacy_cache
()
next_decoder_cache
.
to_legacy_cache
()
if
isinstance
(
next_decoder_cache
,
DynamicCache
)
else
next_decoder_cache
)
if
not
return_dict
:
if
not
return_dict
:
return
tuple
(
return
tuple
(
v
v
...
...
src/transformers/models/gemma/modeling_gemma.py
View file @
f26e4073
...
@@ -865,7 +865,9 @@ class GemmaModel(GemmaPreTrainedModel):
...
@@ -865,7 +865,9 @@ class GemmaModel(GemmaPreTrainedModel):
if
inputs_embeds
is
None
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
return_legacy_cache
=
False
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache
=
True
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
if
cache_position
is
None
:
if
cache_position
is
None
:
...
@@ -933,13 +935,10 @@ class GemmaModel(GemmaPreTrainedModel):
...
@@ -933,13 +935,10 @@ class GemmaModel(GemmaPreTrainedModel):
if
output_hidden_states
:
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
all_hidden_states
+=
(
hidden_states
,)
next_cache
=
None
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
use_cache
:
if
return_legacy_cache
:
next_cache
=
(
next_cache
=
next_cache
.
to_legacy_cache
()
next_decoder_cache
.
to_legacy_cache
()
if
isinstance
(
next_decoder_cache
,
DynamicCache
)
else
next_decoder_cache
)
if
not
return_dict
:
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
BaseModelOutputWithPast
(
return
BaseModelOutputWithPast
(
...
...
src/transformers/models/llama/modeling_llama.py
View file @
f26e4073
...
@@ -960,7 +960,9 @@ class LlamaModel(LlamaPreTrainedModel):
...
@@ -960,7 +960,9 @@ class LlamaModel(LlamaPreTrainedModel):
if
inputs_embeds
is
None
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
return_legacy_cache
=
False
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache
=
True
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
if
cache_position
is
None
:
if
cache_position
is
None
:
...
@@ -1021,13 +1023,10 @@ class LlamaModel(LlamaPreTrainedModel):
...
@@ -1021,13 +1023,10 @@ class LlamaModel(LlamaPreTrainedModel):
if
output_hidden_states
:
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
all_hidden_states
+=
(
hidden_states
,)
next_cache
=
None
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
use_cache
:
if
return_legacy_cache
:
next_cache
=
(
next_cache
=
next_cache
.
to_legacy_cache
()
next_decoder_cache
.
to_legacy_cache
()
if
isinstance
(
next_decoder_cache
,
DynamicCache
)
else
next_decoder_cache
)
if
not
return_dict
:
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
BaseModelOutputWithPast
(
return
BaseModelOutputWithPast
(
...
...
src/transformers/models/olmo/modeling_olmo.py
View file @
f26e4073
...
@@ -938,7 +938,9 @@ class OlmoModel(OlmoPreTrainedModel):
...
@@ -938,7 +938,9 @@ class OlmoModel(OlmoPreTrainedModel):
if
inputs_embeds
is
None
:
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
return_legacy_cache
=
False
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
if
use_cache
and
not
isinstance
(
past_key_values
,
Cache
):
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache
=
True
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
past_key_values
=
DynamicCache
.
from_legacy_cache
(
past_key_values
)
if
cache_position
is
None
:
if
cache_position
is
None
:
...
@@ -999,13 +1001,10 @@ class OlmoModel(OlmoPreTrainedModel):
...
@@ -999,13 +1001,10 @@ class OlmoModel(OlmoPreTrainedModel):
if
output_hidden_states
:
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
all_hidden_states
+=
(
hidden_states
,)
next_cache
=
None
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
use_cache
:
if
return_legacy_cache
:
next_cache
=
(
next_cache
=
next_cache
.
to_legacy_cache
()
next_decoder_cache
.
to_legacy_cache
()
if
isinstance
(
next_decoder_cache
,
DynamicCache
)
else
next_decoder_cache
)
if
not
return_dict
:
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
BaseModelOutputWithPast
(
return
BaseModelOutputWithPast
(
...
...
tests/models/cohere/test_modeling_cohere.py
View file @
f26e4073
...
@@ -16,8 +16,6 @@
...
@@ -16,8 +16,6 @@
import
unittest
import
unittest
from
parameterized
import
parameterized
from
transformers
import
CohereConfig
,
is_torch_available
from
transformers
import
CohereConfig
,
is_torch_available
from
transformers.testing_utils
import
(
from
transformers.testing_utils
import
(
require_bitsandbytes
,
require_bitsandbytes
,
...
@@ -296,11 +294,6 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
...
@@ -296,11 +294,6 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def
test_config
(
self
):
def
test_config
(
self
):
self
.
config_tester
.
run_common_tests
()
self
.
config_tester
.
run_common_tests
()
@
unittest
.
skip
(
"TODO @gante fix this for Cohere"
)
@
parameterized
.
expand
([(
1
,
False
),
(
1
,
True
),
(
4
,
False
)])
def
test_new_cache_format
(
self
,
num_beams
,
do_sample
):
pass
def
test_model
(
self
):
def
test_model
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_model
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_model
(
*
config_and_inputs
)
...
...
tests/models/dbrx/test_modeling_dbrx.py
View file @
f26e4073
...
@@ -17,8 +17,6 @@
...
@@ -17,8 +17,6 @@
import
unittest
import
unittest
from
parameterized
import
parameterized
from
transformers
import
DbrxConfig
,
is_torch_available
from
transformers
import
DbrxConfig
,
is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
...
@@ -357,11 +355,6 @@ class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
...
@@ -357,11 +355,6 @@ class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def
test_tied_weights_keys
(
self
):
def
test_tied_weights_keys
(
self
):
pass
pass
@
unittest
.
skip
(
"TODO @gante fix this for Llama"
)
@
parameterized
.
expand
([(
1
,
False
),
(
1
,
True
),
(
4
,
False
)])
def
test_new_cache_format
(
self
,
num_beams
,
do_sample
):
pass
@
require_torch
@
require_torch
class
DbrxModelIntegrationTest
(
unittest
.
TestCase
):
class
DbrxModelIntegrationTest
(
unittest
.
TestCase
):
...
...
tests/models/gemma/test_modeling_gemma.py
View file @
f26e4073
...
@@ -17,7 +17,6 @@ import tempfile
...
@@ -17,7 +17,6 @@ import tempfile
import
unittest
import
unittest
import
pytest
import
pytest
from
parameterized
import
parameterized
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
GemmaConfig
,
is_torch_available
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
GemmaConfig
,
is_torch_available
from
transformers.testing_utils
import
(
from
transformers.testing_utils
import
(
...
@@ -367,11 +366,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
...
@@ -367,11 +366,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
result
=
model
(
input_ids
,
attention_mask
=
attention_mask
,
labels
=
sequence_labels
)
result
=
model
(
input_ids
,
attention_mask
=
attention_mask
,
labels
=
sequence_labels
)
self
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
model_tester
.
batch_size
,
self
.
model_tester
.
num_labels
))
self
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
model_tester
.
batch_size
,
self
.
model_tester
.
num_labels
))
@
unittest
.
skip
(
"TODO @gante fix this for Llama"
)
@
parameterized
.
expand
([(
1
,
False
),
(
1
,
True
),
(
4
,
False
)])
def
test_new_cache_format
(
self
,
num_beams
,
do_sample
):
pass
@
unittest
.
skip
(
"Gemma buffers include complex numbers, which breaks this test"
)
@
unittest
.
skip
(
"Gemma buffers include complex numbers, which breaks this test"
)
def
test_save_load_fast_init_from_base
(
self
):
def
test_save_load_fast_init_from_base
(
self
):
pass
pass
...
...
tests/models/llama/test_modeling_llama.py
View file @
f26e4073
...
@@ -591,11 +591,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
...
@@ -591,11 +591,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
msg
=
f
"
\n
{
tokenizer
.
batch_decode
(
res_eager
)
}
\n
vs
\n
{
tokenizer
.
batch_decode
(
res_sdpa
)
}
"
,
msg
=
f
"
\n
{
tokenizer
.
batch_decode
(
res_eager
)
}
\n
vs
\n
{
tokenizer
.
batch_decode
(
res_sdpa
)
}
"
,
)
)
@
unittest
.
skip
(
"TODO @gante fix this for Llama"
)
@
parameterized
.
expand
([(
1
,
False
),
(
1
,
True
),
(
4
,
False
)])
def
test_new_cache_format
(
self
,
num_beams
,
do_sample
):
pass
@
require_torch_gpu
@
require_torch_gpu
class
LlamaIntegrationTest
(
unittest
.
TestCase
):
class
LlamaIntegrationTest
(
unittest
.
TestCase
):
...
...
tests/models/olmo/test_modeling_olmo.py
View file @
f26e4073
...
@@ -353,11 +353,6 @@ class OlmoModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
...
@@ -353,11 +353,6 @@ class OlmoModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
# The output should be different for long inputs
# The output should be different for long inputs
self
.
assertFalse
(
torch
.
allclose
(
original_long_output
,
scaled_long_output
,
atol
=
1e-5
))
self
.
assertFalse
(
torch
.
allclose
(
original_long_output
,
scaled_long_output
,
atol
=
1e-5
))
@
unittest
.
skip
(
"TODO @gante fix this for OLMo"
)
@
parameterized
.
expand
([(
1
,
False
),
(
1
,
True
),
(
4
,
False
)])
def
test_new_cache_format
(
self
,
num_beams
,
do_sample
):
pass
@
require_torch
@
require_torch
class
OlmoIntegrationTest
(
unittest
.
TestCase
):
class
OlmoIntegrationTest
(
unittest
.
TestCase
):
...
...
tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py
View file @
f26e4073
...
@@ -15,8 +15,6 @@
...
@@ -15,8 +15,6 @@
""" Testing suite for the PyTorch RecurrentGemma model. """
""" Testing suite for the PyTorch RecurrentGemma model. """
import
unittest
import
unittest
from
parameterized
import
parameterized
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
RecurrentGemmaConfig
,
is_torch_available
,
set_seed
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
RecurrentGemmaConfig
,
is_torch_available
,
set_seed
from
transformers.testing_utils
import
(
from
transformers.testing_utils
import
(
require_bitsandbytes
,
require_bitsandbytes
,
...
@@ -330,11 +328,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
...
@@ -330,11 +328,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
config_and_inputs
[
0
].
position_embedding_type
=
type
config_and_inputs
[
0
].
position_embedding_type
=
type
self
.
model_tester
.
create_and_check_model
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_model
(
*
config_and_inputs
)
@
unittest
.
skip
(
"Recurrent gemma does not use legacy cache"
)
@
parameterized
.
expand
([(
1
,
False
),
(
1
,
True
),
(
4
,
False
)])
def
test_new_cache_format
(
self
,
num_beams
,
do_sample
):
pass
def
test_save_load_fast_init_from_base
(
self
):
def
test_save_load_fast_init_from_base
(
self
):
pass
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment