Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
83259e40
Unverified
Commit
83259e40
authored
Jun 19, 2024
by
Joao Gante
Committed by
GitHub
Jun 19, 2024
Browse files
Mamba: add generative tests (#31478)
parent
7d683f7b
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
83 additions
and
56 deletions
+83
-56
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+11
-0
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+1
-0
src/transformers/models/jamba/modeling_jamba.py
src/transformers/models/jamba/modeling_jamba.py
+1
-0
src/transformers/models/mamba/modeling_mamba.py
src/transformers/models/mamba/modeling_mamba.py
+1
-0
src/transformers/models/rwkv/modeling_rwkv.py
src/transformers/models/rwkv/modeling_rwkv.py
+1
-0
tests/generation/test_utils.py
tests/generation/test_utils.py
+66
-40
tests/models/jamba/test_modeling_jamba.py
tests/models/jamba/test_modeling_jamba.py
+0
-4
tests/models/mamba/test_modeling_mamba.py
tests/models/mamba/test_modeling_mamba.py
+2
-12
No files found.
src/transformers/generation/utils.py
View file @
83259e40
...
@@ -1830,6 +1830,12 @@ class GenerationMixin:
...
@@ -1830,6 +1830,12 @@ class GenerationMixin:
raise
ValueError
(
"assisted generate requires `use_cache=True`"
)
raise
ValueError
(
"assisted generate requires `use_cache=True`"
)
if
generation_config
.
cache_implementation
==
"static"
:
if
generation_config
.
cache_implementation
==
"static"
:
raise
ValueError
(
"assisted generate is not supported with `static_cache`"
)
raise
ValueError
(
"assisted generate is not supported with `static_cache`"
)
if
self
.
_is_stateful
:
# In assisted generation we need the ability to confirm whether the model would pick certain tokens,
# which is not possible with stateful models (they can't reset to a previous subset of generated text)
raise
ValueError
(
f
"assisted generation is not supported with stateful models, such as
{
self
.
__class__
.
__name__
}
"
)
# 11. Get the candidate generator, given the parameterization
# 11. Get the candidate generator, given the parameterization
candidate_generator
=
self
.
_get_candidate_generator
(
candidate_generator
=
self
.
_get_candidate_generator
(
...
@@ -1867,6 +1873,11 @@ class GenerationMixin:
...
@@ -1867,6 +1873,11 @@ class GenerationMixin:
elif
generation_mode
==
GenerationMode
.
CONTRASTIVE_SEARCH
:
elif
generation_mode
==
GenerationMode
.
CONTRASTIVE_SEARCH
:
if
not
model_kwargs
[
"use_cache"
]:
if
not
model_kwargs
[
"use_cache"
]:
raise
ValueError
(
"Contrastive search requires `use_cache=True`"
)
raise
ValueError
(
"Contrastive search requires `use_cache=True`"
)
if
self
.
_is_stateful
:
# Just like assisted generation, we need to be able to rollback to a previous state (see comment above)
raise
ValueError
(
f
"contrastive search is not supported with stateful models, such as
{
self
.
__class__
.
__name__
}
"
)
result
=
self
.
_contrastive_search
(
result
=
self
.
_contrastive_search
(
input_ids
,
input_ids
,
...
...
src/transformers/modeling_utils.py
View file @
83259e40
...
@@ -1281,6 +1281,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
...
@@ -1281,6 +1281,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
is_parallelizable
=
False
is_parallelizable
=
False
supports_gradient_checkpointing
=
False
supports_gradient_checkpointing
=
False
_is_stateful
=
False
# Flash Attention 2 support
# Flash Attention 2 support
_supports_flash_attn_2
=
False
_supports_flash_attn_2
=
False
...
...
src/transformers/models/jamba/modeling_jamba.py
View file @
83259e40
...
@@ -1266,6 +1266,7 @@ class JambaPreTrainedModel(PreTrainedModel):
...
@@ -1266,6 +1266,7 @@ class JambaPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2
=
True
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
# Note: only supports HybridMambaAttentionDynamicCache
_supports_cache_class
=
True
# Note: only supports HybridMambaAttentionDynamicCache
_is_stateful
=
True
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/mamba/modeling_mamba.py
View file @
83259e40
...
@@ -354,6 +354,7 @@ class MambaPreTrainedModel(PreTrainedModel):
...
@@ -354,6 +354,7 @@ class MambaPreTrainedModel(PreTrainedModel):
base_model_prefix
=
"backbone"
base_model_prefix
=
"backbone"
_no_split_modules
=
[
"MambaBlock"
]
_no_split_modules
=
[
"MambaBlock"
]
supports_gradient_checkpointing
=
True
supports_gradient_checkpointing
=
True
_is_stateful
=
True
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
"""Initialize the weights."""
"""Initialize the weights."""
...
...
src/transformers/models/rwkv/modeling_rwkv.py
View file @
83259e40
...
@@ -394,6 +394,7 @@ class RwkvPreTrainedModel(PreTrainedModel):
...
@@ -394,6 +394,7 @@ class RwkvPreTrainedModel(PreTrainedModel):
_no_split_modules
=
[
"RwkvBlock"
]
_no_split_modules
=
[
"RwkvBlock"
]
_keep_in_fp32_modules
=
[
"time_decay"
,
"time_first"
]
_keep_in_fp32_modules
=
[
"time_decay"
,
"time_first"
]
supports_gradient_checkpointing
=
True
supports_gradient_checkpointing
=
True
_is_stateful
=
True
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
"""Initialize the weights."""
"""Initialize the weights."""
...
...
tests/generation/test_utils.py
View file @
83259e40
...
@@ -102,7 +102,11 @@ class GenerationTesterMixin:
...
@@ -102,7 +102,11 @@ class GenerationTesterMixin:
if
isinstance
(
config
.
eos_token_id
,
int
):
if
isinstance
(
config
.
eos_token_id
,
int
):
config
.
eos_token_id
=
[
config
.
eos_token_id
]
config
.
eos_token_id
=
[
config
.
eos_token_id
]
config
.
pad_token_id
=
config
.
eos_token_id
[
0
]
config
.
pad_token_id
=
config
.
eos_token_id
[
0
]
attention_mask
=
torch
.
ones_like
(
input_ids
,
dtype
=
torch
.
long
)
if
self
.
has_attentions
:
attention_mask
=
torch
.
ones_like
(
input_ids
,
dtype
=
torch
.
long
)
else
:
attention_mask
=
None
# It is important set set the eos_token_id to None to ensure that no sequences
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated
# shorter than `max_length` can be generated
...
@@ -437,7 +441,7 @@ class GenerationTesterMixin:
...
@@ -437,7 +441,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_scores
=
True
,
output_logits
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
)
)
...
@@ -471,7 +475,7 @@ class GenerationTesterMixin:
...
@@ -471,7 +475,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_scores
=
True
,
output_logits
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
)
)
...
@@ -529,7 +533,7 @@ class GenerationTesterMixin:
...
@@ -529,7 +533,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_scores
=
True
,
output_logits
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
)
)
...
@@ -595,7 +599,7 @@ class GenerationTesterMixin:
...
@@ -595,7 +599,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_scores
=
True
,
output_logits
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
)
)
if
model
.
config
.
is_encoder_decoder
:
if
model
.
config
.
is_encoder_decoder
:
...
@@ -642,7 +646,7 @@ class GenerationTesterMixin:
...
@@ -642,7 +646,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_scores
=
True
,
output_logits
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
)
)
...
@@ -733,7 +737,7 @@ class GenerationTesterMixin:
...
@@ -733,7 +737,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_scores
=
True
,
output_logits
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
)
)
...
@@ -834,7 +838,7 @@ class GenerationTesterMixin:
...
@@ -834,7 +838,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_scores
=
True
,
output_logits
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
)
)
if
model
.
config
.
is_encoder_decoder
:
if
model
.
config
.
is_encoder_decoder
:
...
@@ -952,7 +956,7 @@ class GenerationTesterMixin:
...
@@ -952,7 +956,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_scores
=
True
,
output_logits
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
)
)
...
@@ -973,6 +977,9 @@ class GenerationTesterMixin:
...
@@ -973,6 +977,9 @@ class GenerationTesterMixin:
def
test_contrastive_generate
(
self
):
def
test_contrastive_generate
(
self
):
for
model_class
in
self
.
all_generative_model_classes
:
for
model_class
in
self
.
all_generative_model_classes
:
if
model_class
.
_is_stateful
:
self
.
skipTest
(
"Stateful models don't support contrastive search generation"
)
# won't fix: FSMT and Reformer have a different cache variable type (and format).
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
...
@@ -997,6 +1004,9 @@ class GenerationTesterMixin:
...
@@ -997,6 +1004,9 @@ class GenerationTesterMixin:
def
test_contrastive_generate_dict_outputs_use_cache
(
self
):
def
test_contrastive_generate_dict_outputs_use_cache
(
self
):
for
model_class
in
self
.
all_generative_model_classes
:
for
model_class
in
self
.
all_generative_model_classes
:
if
model_class
.
_is_stateful
:
self
.
skipTest
(
"Stateful models don't support contrastive search generation"
)
# won't fix: FSMT and Reformer have a different cache variable type (and format).
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
...
@@ -1017,7 +1027,7 @@ class GenerationTesterMixin:
...
@@ -1017,7 +1027,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_scores
=
True
,
output_logits
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
)
)
...
@@ -1030,9 +1040,12 @@ class GenerationTesterMixin:
...
@@ -1030,9 +1040,12 @@ class GenerationTesterMixin:
def
test_contrastive_generate_low_memory
(
self
):
def
test_contrastive_generate_low_memory
(
self
):
# Check that choosing 'low_memory' does not change the model output
# Check that choosing 'low_memory' does not change the model output
for
model_class
in
self
.
all_generative_model_classes
:
for
model_class
in
self
.
all_generative_model_classes
:
if
model_class
.
_is_stateful
:
self
.
skipTest
(
"Stateful models don't support contrastive search generation"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
,
"speech2text"
]):
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
,
"speech2text"
]):
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"gptbigcode"
,
"jamba"
]):
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"gptbigcode"
]):
self
.
skipTest
(
"TODO: fix me"
)
self
.
skipTest
(
"TODO: fix me"
)
config
,
input_ids
,
attention_mask
=
self
.
_get_input_ids_and_config
(
batch_size
=
1
)
config
,
input_ids
,
attention_mask
=
self
.
_get_input_ids_and_config
(
batch_size
=
1
)
...
@@ -1069,6 +1082,8 @@ class GenerationTesterMixin:
...
@@ -1069,6 +1082,8 @@ class GenerationTesterMixin:
def
test_beam_search_low_memory
(
self
):
def
test_beam_search_low_memory
(
self
):
# Check that choosing 'low_memory' does not change the model output
# Check that choosing 'low_memory' does not change the model output
for
model_class
in
self
.
all_generative_model_classes
:
for
model_class
in
self
.
all_generative_model_classes
:
if
model_class
.
_is_stateful
:
self
.
skipTest
(
"May fix in the future: need custom cache handling"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
if
any
(
if
any
(
...
@@ -1115,6 +1130,8 @@ class GenerationTesterMixin:
...
@@ -1115,6 +1130,8 @@ class GenerationTesterMixin:
# - assisted_decoding does not support `batch_size > 1`
# - assisted_decoding does not support `batch_size > 1`
for
model_class
in
self
.
all_generative_model_classes
:
for
model_class
in
self
.
all_generative_model_classes
:
if
model_class
.
_is_stateful
:
self
.
skipTest
(
"Stateful models don't support assisted generation"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
if
any
(
if
any
(
...
@@ -1156,7 +1173,7 @@ class GenerationTesterMixin:
...
@@ -1156,7 +1173,7 @@ class GenerationTesterMixin:
"output_scores"
:
True
,
"output_scores"
:
True
,
"output_logits"
:
True
,
"output_logits"
:
True
,
"output_hidden_states"
:
True
,
"output_hidden_states"
:
True
,
"output_attentions"
:
True
,
"output_attentions"
:
self
.
has_attentions
,
"return_dict_in_generate"
:
True
,
"return_dict_in_generate"
:
True
,
}
}
output_greedy
=
model
.
generate
(
input_ids
,
attention_mask
=
attention_mask
,
**
generation_kwargs
)
output_greedy
=
model
.
generate
(
input_ids
,
attention_mask
=
attention_mask
,
**
generation_kwargs
)
...
@@ -1184,6 +1201,8 @@ class GenerationTesterMixin:
...
@@ -1184,6 +1201,8 @@ class GenerationTesterMixin:
# This test is mostly a copy of test_assisted_decoding_matches_greedy_search
# This test is mostly a copy of test_assisted_decoding_matches_greedy_search
for
model_class
in
self
.
all_generative_model_classes
:
for
model_class
in
self
.
all_generative_model_classes
:
if
model_class
.
_is_stateful
:
self
.
skipTest
(
"Stateful models don't support assisted generation"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
if
any
(
if
any
(
...
@@ -1225,7 +1244,7 @@ class GenerationTesterMixin:
...
@@ -1225,7 +1244,7 @@ class GenerationTesterMixin:
"output_scores"
:
True
,
"output_scores"
:
True
,
"output_logits"
:
True
,
"output_logits"
:
True
,
"output_hidden_states"
:
True
,
"output_hidden_states"
:
True
,
"output_attentions"
:
True
,
"output_attentions"
:
self
.
has_attentions
,
"return_dict_in_generate"
:
True
,
"return_dict_in_generate"
:
True
,
}
}
...
@@ -1244,6 +1263,8 @@ class GenerationTesterMixin:
...
@@ -1244,6 +1263,8 @@ class GenerationTesterMixin:
# match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with
# match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with
# different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535).
# different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535).
for
model_class
in
self
.
all_generative_model_classes
:
for
model_class
in
self
.
all_generative_model_classes
:
if
model_class
.
_is_stateful
:
self
.
skipTest
(
"Stateful models don't support assisted generation"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
if
any
(
if
any
(
...
@@ -1289,7 +1310,7 @@ class GenerationTesterMixin:
...
@@ -1289,7 +1310,7 @@ class GenerationTesterMixin:
"output_scores"
:
True
,
"output_scores"
:
True
,
"output_logits"
:
True
,
"output_logits"
:
True
,
"output_hidden_states"
:
True
,
"output_hidden_states"
:
True
,
"output_attentions"
:
True
,
"output_attentions"
:
self
.
has_attentions
,
"return_dict_in_generate"
:
True
,
"return_dict_in_generate"
:
True
,
}
}
output_assisted
=
model
.
generate
(
input_ids
,
attention_mask
=
attention_mask
,
**
generation_kwargs
)
output_assisted
=
model
.
generate
(
input_ids
,
attention_mask
=
attention_mask
,
**
generation_kwargs
)
...
@@ -1326,7 +1347,7 @@ class GenerationTesterMixin:
...
@@ -1326,7 +1347,7 @@ class GenerationTesterMixin:
input_ids
,
input_ids
,
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
num_beams
=
1
,
num_beams
=
1
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
return_dict_in_generate
=
True
,
remove_invalid_values
=
True
,
remove_invalid_values
=
True
,
**
{
name
:
mask
},
**
{
name
:
mask
},
...
@@ -1344,6 +1365,10 @@ class GenerationTesterMixin:
...
@@ -1344,6 +1365,10 @@ class GenerationTesterMixin:
if
len
(
self
.
all_generative_model_classes
)
==
0
:
if
len
(
self
.
all_generative_model_classes
)
==
0
:
self
.
skipTest
(
reason
=
"No generative architecture available for this model."
)
self
.
skipTest
(
reason
=
"No generative architecture available for this model."
)
# - The model must support padding
if
not
self
.
has_attentions
:
self
.
skipTest
(
reason
=
"This model doesn't support padding."
)
# - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
# - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
decoder_only_classes
=
[]
decoder_only_classes
=
[]
for
model_class
in
self
.
all_generative_model_classes
:
for
model_class
in
self
.
all_generative_model_classes
:
...
@@ -1704,30 +1729,31 @@ class GenerationTesterMixin:
...
@@ -1704,30 +1729,31 @@ class GenerationTesterMixin:
self
.
_check_logits
(
num_sequences_in_output
,
output
.
logits
,
config
=
config
)
self
.
_check_logits
(
num_sequences_in_output
,
output
.
logits
,
config
=
config
)
# Attentions
# Attentions
if
config
.
is_encoder_decoder
:
if
self
.
has_attentions
:
# encoder
if
config
.
is_encoder_decoder
:
self
.
_check_encoder_attention_for_generate
(
output
.
encoder_attentions
,
batch_size
,
config
,
seq_length
)
# encoder
# decoder
self
.
_check_encoder_attention_for_generate
(
output
.
encoder_attentions
,
batch_size
,
config
,
seq_length
)
self
.
_check_attentions_for_generate
(
# decoder
num_sequences_in_output
,
self
.
_check_attentions_for_generate
(
output
.
decoder_attentions
,
num_sequences_in_output
,
min_length
=
1
,
output
.
decoder_attentions
,
max_length
=
output
.
sequences
.
shape
[
-
1
],
min_length
=
1
,
config
=
config
,
max_length
=
output
.
sequences
.
shape
[
-
1
],
use_cache
=
use_cache
,
config
=
config
,
)
use_cache
=
use_cache
,
else
:
)
# if use_cache first input is equal to no use_cache, so skip here
else
:
attentions
=
output
.
attentions
if
not
use_cache
else
output
.
attentions
[
1
:]
# if use_cache first input is equal to no use_cache, so skip here
min_length
=
seq_length
if
not
use_cache
else
seq_length
+
1
attentions
=
output
.
attentions
if
not
use_cache
else
output
.
attentions
[
1
:]
self
.
_check_attentions_for_generate
(
min_length
=
seq_length
if
not
use_cache
else
seq_length
+
1
num_sequences_in_output
,
self
.
_check_attentions_for_generate
(
attentions
=
attentions
,
num_sequences_in_output
,
min_length
=
min_length
,
attentions
=
attentions
,
max_length
=
output
.
sequences
.
shape
[
-
1
],
min_length
=
min_length
,
config
=
config
,
max_length
=
output
.
sequences
.
shape
[
-
1
],
use_cache
=
use_cache
,
config
=
config
,
)
use_cache
=
use_cache
,
)
# Hidden States
# Hidden States
if
config
.
is_encoder_decoder
:
if
config
.
is_encoder_decoder
:
...
@@ -1763,7 +1789,7 @@ class GenerationTesterMixin:
...
@@ -1763,7 +1789,7 @@ class GenerationTesterMixin:
# 2. Some old models still return `output.past_key_values` even without `use_cache=True`
# 2. Some old models still return `output.past_key_values` even without `use_cache=True`
# 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is
# 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is
# complete
# complete
models_without_standard_cache
=
(
"bloom"
,
"ctrl"
,
"fsmt"
,
"gptbigcode"
,
"mega"
,
"reformer"
,
"jamba"
)
models_without_standard_cache
=
(
"bloom"
,
"ctrl"
,
"fsmt"
,
"gptbigcode"
,
"mega"
,
"reformer"
,
"jamba"
,
"mamba"
)
has_standard_cache
=
not
any
(
has_standard_cache
=
not
any
(
model_name
in
config
.
__class__
.
__name__
.
lower
()
for
model_name
in
models_without_standard_cache
model_name
in
config
.
__class__
.
__name__
.
lower
()
for
model_name
in
models_without_standard_cache
)
)
...
...
tests/models/jamba/test_modeling_jamba.py
View file @
83259e40
...
@@ -503,10 +503,6 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
...
@@ -503,10 +503,6 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# They should result in very similar logits
# They should result in very similar logits
self
.
assertTrue
(
torch
.
allclose
(
next_logits_wo_padding
,
next_logits_with_padding
,
atol
=
3e-3
))
self
.
assertTrue
(
torch
.
allclose
(
next_logits_wo_padding
,
next_logits_with_padding
,
atol
=
3e-3
))
@
unittest
.
skip
(
"Jamba has its own special cache type"
)
# FIXME: @gante
def
test_assisted_decoding_matches_greedy_search_0_random
(
self
):
pass
@
require_flash_attn
@
require_flash_attn
@
require_torch_gpu
@
require_torch_gpu
@
require_bitsandbytes
@
require_bitsandbytes
...
...
tests/models/mamba/test_modeling_mamba.py
View file @
83259e40
...
@@ -250,6 +250,8 @@ class MambaModelTester:
...
@@ -250,6 +250,8 @@ class MambaModelTester:
@
require_torch
@
require_torch
class
MambaModelTest
(
ModelTesterMixin
,
GenerationTesterMixin
,
PipelineTesterMixin
,
unittest
.
TestCase
):
class
MambaModelTest
(
ModelTesterMixin
,
GenerationTesterMixin
,
PipelineTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
MambaModel
,
MambaForCausalLM
)
if
is_torch_available
()
else
()
all_model_classes
=
(
MambaModel
,
MambaForCausalLM
)
if
is_torch_available
()
else
()
all_generative_model_classes
=
(
MambaForCausalLM
,)
if
is_torch_available
()
else
()
has_attentions
=
False
# Mamba does not support attentions
fx_compatible
=
False
# FIXME let's try to support this @ArthurZucker
fx_compatible
=
False
# FIXME let's try to support this @ArthurZucker
test_torchscript
=
False
# FIXME let's try to support this @ArthurZucker
test_torchscript
=
False
# FIXME let's try to support this @ArthurZucker
test_missing_keys
=
False
test_missing_keys
=
False
...
@@ -292,10 +294,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
...
@@ -292,10 +294,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def
test_config
(
self
):
def
test_config
(
self
):
self
.
config_tester
.
run_common_tests
()
self
.
config_tester
.
run_common_tests
()
@
unittest
.
skip
(
"No attention in mamba"
)
def
test_retain_grad_hidden_states_attentions
(
self
):
pass
@
require_torch_multi_gpu
@
require_torch_multi_gpu
def
test_multi_gpu_data_parallel_forward
(
self
):
def
test_multi_gpu_data_parallel_forward
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
@@ -364,14 +362,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
...
@@ -364,14 +362,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# check if it's a ones like
# check if it's a ones like
self
.
assertTrue
(
torch
.
allclose
(
param
.
data
,
torch
.
ones_like
(
param
.
data
),
atol
=
1e-5
,
rtol
=
1e-5
))
self
.
assertTrue
(
torch
.
allclose
(
param
.
data
,
torch
.
ones_like
(
param
.
data
),
atol
=
1e-5
,
rtol
=
1e-5
))
@
unittest
.
skip
(
"Mamba does not use attention"
)
def
test_attention_outputs
(
self
):
r
"""
Overriding the test_attention_outputs test as the attention outputs of Mamba are different from other models
it has a shape `batch_size, seq_len, hidden_size`.
"""
pass
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
model
=
MambaModel
.
from_pretrained
(
"hf-internal-testing/mamba-130m"
)
model
=
MambaModel
.
from_pretrained
(
"hf-internal-testing/mamba-130m"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment