Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
83259e40
Unverified
Commit
83259e40
authored
Jun 19, 2024
by
Joao Gante
Committed by
GitHub
Jun 19, 2024
Browse files
Mamba: add generative tests (#31478)
parent
7d683f7b
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
83 additions
and
56 deletions
+83
-56
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+11
-0
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+1
-0
src/transformers/models/jamba/modeling_jamba.py
src/transformers/models/jamba/modeling_jamba.py
+1
-0
src/transformers/models/mamba/modeling_mamba.py
src/transformers/models/mamba/modeling_mamba.py
+1
-0
src/transformers/models/rwkv/modeling_rwkv.py
src/transformers/models/rwkv/modeling_rwkv.py
+1
-0
tests/generation/test_utils.py
tests/generation/test_utils.py
+66
-40
tests/models/jamba/test_modeling_jamba.py
tests/models/jamba/test_modeling_jamba.py
+0
-4
tests/models/mamba/test_modeling_mamba.py
tests/models/mamba/test_modeling_mamba.py
+2
-12
No files found.
src/transformers/generation/utils.py
View file @
83259e40
...
...
@@ -1830,6 +1830,12 @@ class GenerationMixin:
raise
ValueError
(
"assisted generate requires `use_cache=True`"
)
if
generation_config
.
cache_implementation
==
"static"
:
raise
ValueError
(
"assisted generate is not supported with `static_cache`"
)
if
self
.
_is_stateful
:
# In assisted generation we need the ability to confirm whether the model would pick certain tokens,
# which is not possible with stateful models (they can't reset to a previous subset of generated text)
raise
ValueError
(
f
"assisted generation is not supported with stateful models, such as
{
self
.
__class__
.
__name__
}
"
)
# 11. Get the candidate generator, given the parameterization
candidate_generator
=
self
.
_get_candidate_generator
(
...
...
@@ -1867,6 +1873,11 @@ class GenerationMixin:
elif
generation_mode
==
GenerationMode
.
CONTRASTIVE_SEARCH
:
if
not
model_kwargs
[
"use_cache"
]:
raise
ValueError
(
"Contrastive search requires `use_cache=True`"
)
if
self
.
_is_stateful
:
# Just like assisted generation, we need to be able to rollback to a previous state (see comment above)
raise
ValueError
(
f
"contrastive search is not supported with stateful models, such as
{
self
.
__class__
.
__name__
}
"
)
result
=
self
.
_contrastive_search
(
input_ids
,
...
...
src/transformers/modeling_utils.py
View file @
83259e40
...
...
@@ -1281,6 +1281,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
is_parallelizable
=
False
supports_gradient_checkpointing
=
False
_is_stateful
=
False
# Flash Attention 2 support
_supports_flash_attn_2
=
False
...
...
src/transformers/models/jamba/modeling_jamba.py
View file @
83259e40
...
...
@@ -1266,6 +1266,7 @@ class JambaPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
# Note: only supports HybridMambaAttentionDynamicCache
_is_stateful
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/mamba/modeling_mamba.py
View file @
83259e40
...
...
@@ -354,6 +354,7 @@ class MambaPreTrainedModel(PreTrainedModel):
base_model_prefix
=
"backbone"
_no_split_modules
=
[
"MambaBlock"
]
supports_gradient_checkpointing
=
True
_is_stateful
=
True
def
_init_weights
(
self
,
module
):
"""Initialize the weights."""
...
...
src/transformers/models/rwkv/modeling_rwkv.py
View file @
83259e40
...
...
@@ -394,6 +394,7 @@ class RwkvPreTrainedModel(PreTrainedModel):
_no_split_modules
=
[
"RwkvBlock"
]
_keep_in_fp32_modules
=
[
"time_decay"
,
"time_first"
]
supports_gradient_checkpointing
=
True
_is_stateful
=
True
def
_init_weights
(
self
,
module
):
"""Initialize the weights."""
...
...
tests/generation/test_utils.py
View file @
83259e40
...
...
@@ -102,7 +102,11 @@ class GenerationTesterMixin:
if
isinstance
(
config
.
eos_token_id
,
int
):
config
.
eos_token_id
=
[
config
.
eos_token_id
]
config
.
pad_token_id
=
config
.
eos_token_id
[
0
]
if
self
.
has_attentions
:
attention_mask
=
torch
.
ones_like
(
input_ids
,
dtype
=
torch
.
long
)
else
:
attention_mask
=
None
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated
...
...
@@ -437,7 +441,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
)
...
...
@@ -471,7 +475,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
)
...
...
@@ -529,7 +533,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
)
...
...
@@ -595,7 +599,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
)
if
model
.
config
.
is_encoder_decoder
:
...
...
@@ -642,7 +646,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
)
...
...
@@ -733,7 +737,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
)
...
...
@@ -834,7 +838,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
)
if
model
.
config
.
is_encoder_decoder
:
...
...
@@ -952,7 +956,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
)
...
...
@@ -973,6 +977,9 @@ class GenerationTesterMixin:
def
test_contrastive_generate
(
self
):
for
model_class
in
self
.
all_generative_model_classes
:
if
model_class
.
_is_stateful
:
self
.
skipTest
(
"Stateful models don't support contrastive search generation"
)
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
...
...
@@ -997,6 +1004,9 @@ class GenerationTesterMixin:
def
test_contrastive_generate_dict_outputs_use_cache
(
self
):
for
model_class
in
self
.
all_generative_model_classes
:
if
model_class
.
_is_stateful
:
self
.
skipTest
(
"Stateful models don't support contrastive search generation"
)
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
...
...
@@ -1017,7 +1027,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
)
...
...
@@ -1030,9 +1040,12 @@ class GenerationTesterMixin:
def
test_contrastive_generate_low_memory
(
self
):
# Check that choosing 'low_memory' does not change the model output
for
model_class
in
self
.
all_generative_model_classes
:
if
model_class
.
_is_stateful
:
self
.
skipTest
(
"Stateful models don't support contrastive search generation"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
,
"speech2text"
]):
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"gptbigcode"
,
"jamba"
]):
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"gptbigcode"
]):
self
.
skipTest
(
"TODO: fix me"
)
config
,
input_ids
,
attention_mask
=
self
.
_get_input_ids_and_config
(
batch_size
=
1
)
...
...
@@ -1069,6 +1082,8 @@ class GenerationTesterMixin:
def
test_beam_search_low_memory
(
self
):
# Check that choosing 'low_memory' does not change the model output
for
model_class
in
self
.
all_generative_model_classes
:
if
model_class
.
_is_stateful
:
self
.
skipTest
(
"May fix in the future: need custom cache handling"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
if
any
(
...
...
@@ -1115,6 +1130,8 @@ class GenerationTesterMixin:
# - assisted_decoding does not support `batch_size > 1`
for
model_class
in
self
.
all_generative_model_classes
:
if
model_class
.
_is_stateful
:
self
.
skipTest
(
"Stateful models don't support assisted generation"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
if
any
(
...
...
@@ -1156,7 +1173,7 @@ class GenerationTesterMixin:
"output_scores"
:
True
,
"output_logits"
:
True
,
"output_hidden_states"
:
True
,
"output_attentions"
:
True
,
"output_attentions"
:
self
.
has_attentions
,
"return_dict_in_generate"
:
True
,
}
output_greedy
=
model
.
generate
(
input_ids
,
attention_mask
=
attention_mask
,
**
generation_kwargs
)
...
...
@@ -1184,6 +1201,8 @@ class GenerationTesterMixin:
# This test is mostly a copy of test_assisted_decoding_matches_greedy_search
for
model_class
in
self
.
all_generative_model_classes
:
if
model_class
.
_is_stateful
:
self
.
skipTest
(
"Stateful models don't support assisted generation"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
if
any
(
...
...
@@ -1225,7 +1244,7 @@ class GenerationTesterMixin:
"output_scores"
:
True
,
"output_logits"
:
True
,
"output_hidden_states"
:
True
,
"output_attentions"
:
True
,
"output_attentions"
:
self
.
has_attentions
,
"return_dict_in_generate"
:
True
,
}
...
...
@@ -1244,6 +1263,8 @@ class GenerationTesterMixin:
# match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with
# different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535).
for
model_class
in
self
.
all_generative_model_classes
:
if
model_class
.
_is_stateful
:
self
.
skipTest
(
"Stateful models don't support assisted generation"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
if
any
(
...
...
@@ -1289,7 +1310,7 @@ class GenerationTesterMixin:
"output_scores"
:
True
,
"output_logits"
:
True
,
"output_hidden_states"
:
True
,
"output_attentions"
:
True
,
"output_attentions"
:
self
.
has_attentions
,
"return_dict_in_generate"
:
True
,
}
output_assisted
=
model
.
generate
(
input_ids
,
attention_mask
=
attention_mask
,
**
generation_kwargs
)
...
...
@@ -1326,7 +1347,7 @@ class GenerationTesterMixin:
input_ids
,
attention_mask
=
attention_mask
,
num_beams
=
1
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
remove_invalid_values
=
True
,
**
{
name
:
mask
},
...
...
@@ -1344,6 +1365,10 @@ class GenerationTesterMixin:
if
len
(
self
.
all_generative_model_classes
)
==
0
:
self
.
skipTest
(
reason
=
"No generative architecture available for this model."
)
# - The model must support padding
if
not
self
.
has_attentions
:
self
.
skipTest
(
reason
=
"This model doesn't support padding."
)
# - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
decoder_only_classes
=
[]
for
model_class
in
self
.
all_generative_model_classes
:
...
...
@@ -1704,6 +1729,7 @@ class GenerationTesterMixin:
self
.
_check_logits
(
num_sequences_in_output
,
output
.
logits
,
config
=
config
)
# Attentions
if
self
.
has_attentions
:
if
config
.
is_encoder_decoder
:
# encoder
self
.
_check_encoder_attention_for_generate
(
output
.
encoder_attentions
,
batch_size
,
config
,
seq_length
)
...
...
@@ -1763,7 +1789,7 @@ class GenerationTesterMixin:
# 2. Some old models still return `output.past_key_values` even without `use_cache=True`
# 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is
# complete
models_without_standard_cache
=
(
"bloom"
,
"ctrl"
,
"fsmt"
,
"gptbigcode"
,
"mega"
,
"reformer"
,
"jamba"
)
models_without_standard_cache
=
(
"bloom"
,
"ctrl"
,
"fsmt"
,
"gptbigcode"
,
"mega"
,
"reformer"
,
"jamba"
,
"mamba"
)
has_standard_cache
=
not
any
(
model_name
in
config
.
__class__
.
__name__
.
lower
()
for
model_name
in
models_without_standard_cache
)
...
...
tests/models/jamba/test_modeling_jamba.py
View file @
83259e40
...
...
@@ -503,10 +503,6 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# They should result in very similar logits
self
.
assertTrue
(
torch
.
allclose
(
next_logits_wo_padding
,
next_logits_with_padding
,
atol
=
3e-3
))
@
unittest
.
skip
(
"Jamba has its own special cache type"
)
# FIXME: @gante
def
test_assisted_decoding_matches_greedy_search_0_random
(
self
):
pass
@
require_flash_attn
@
require_torch_gpu
@
require_bitsandbytes
...
...
tests/models/mamba/test_modeling_mamba.py
View file @
83259e40
...
...
@@ -250,6 +250,8 @@ class MambaModelTester:
@
require_torch
class
MambaModelTest
(
ModelTesterMixin
,
GenerationTesterMixin
,
PipelineTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
MambaModel
,
MambaForCausalLM
)
if
is_torch_available
()
else
()
all_generative_model_classes
=
(
MambaForCausalLM
,)
if
is_torch_available
()
else
()
has_attentions
=
False
# Mamba does not support attentions
fx_compatible
=
False
# FIXME let's try to support this @ArthurZucker
test_torchscript
=
False
# FIXME let's try to support this @ArthurZucker
test_missing_keys
=
False
...
...
@@ -292,10 +294,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def
test_config
(
self
):
self
.
config_tester
.
run_common_tests
()
@
unittest
.
skip
(
"No attention in mamba"
)
def
test_retain_grad_hidden_states_attentions
(
self
):
pass
@
require_torch_multi_gpu
def
test_multi_gpu_data_parallel_forward
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
@@ -364,14 +362,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# check if it's a ones like
self
.
assertTrue
(
torch
.
allclose
(
param
.
data
,
torch
.
ones_like
(
param
.
data
),
atol
=
1e-5
,
rtol
=
1e-5
))
@
unittest
.
skip
(
"Mamba does not use attention"
)
def
test_attention_outputs
(
self
):
r
"""
Overriding the test_attention_outputs test as the attention outputs of Mamba are different from other models
it has a shape `batch_size, seq_len, hidden_size`.
"""
pass
@
slow
def
test_model_from_pretrained
(
self
):
model
=
MambaModel
.
from_pretrained
(
"hf-internal-testing/mamba-130m"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment