Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
83259e40
Unverified
Commit
83259e40
authored
Jun 19, 2024
by
Joao Gante
Committed by
GitHub
Jun 19, 2024
Browse files
Mamba: add generative tests (#31478)
parent
7d683f7b
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
83 additions
and
56 deletions
+83
-56
src/transformers/generation/utils.py
src/transformers/generation/utils.py
+11
-0
src/transformers/modeling_utils.py
src/transformers/modeling_utils.py
+1
-0
src/transformers/models/jamba/modeling_jamba.py
src/transformers/models/jamba/modeling_jamba.py
+1
-0
src/transformers/models/mamba/modeling_mamba.py
src/transformers/models/mamba/modeling_mamba.py
+1
-0
src/transformers/models/rwkv/modeling_rwkv.py
src/transformers/models/rwkv/modeling_rwkv.py
+1
-0
tests/generation/test_utils.py
tests/generation/test_utils.py
+66
-40
tests/models/jamba/test_modeling_jamba.py
tests/models/jamba/test_modeling_jamba.py
+0
-4
tests/models/mamba/test_modeling_mamba.py
tests/models/mamba/test_modeling_mamba.py
+2
-12
No files found.
src/transformers/generation/utils.py
View file @
83259e40
...
...
@@ -1830,6 +1830,12 @@ class GenerationMixin:
raise
ValueError
(
"assisted generate requires `use_cache=True`"
)
if
generation_config
.
cache_implementation
==
"static"
:
raise
ValueError
(
"assisted generate is not supported with `static_cache`"
)
if
self
.
_is_stateful
:
# In assisted generation we need the ability to confirm whether the model would pick certain tokens,
# which is not possible with stateful models (they can't reset to a previous subset of generated text)
raise
ValueError
(
f
"assisted generation is not supported with stateful models, such as
{
self
.
__class__
.
__name__
}
"
)
# 11. Get the candidate generator, given the parameterization
candidate_generator
=
self
.
_get_candidate_generator
(
...
...
@@ -1867,6 +1873,11 @@ class GenerationMixin:
elif
generation_mode
==
GenerationMode
.
CONTRASTIVE_SEARCH
:
if
not
model_kwargs
[
"use_cache"
]:
raise
ValueError
(
"Contrastive search requires `use_cache=True`"
)
if
self
.
_is_stateful
:
# Just like assisted generation, we need to be able to rollback to a previous state (see comment above)
raise
ValueError
(
f
"contrastive search is not supported with stateful models, such as
{
self
.
__class__
.
__name__
}
"
)
result
=
self
.
_contrastive_search
(
input_ids
,
...
...
src/transformers/modeling_utils.py
View file @
83259e40
...
...
@@ -1281,6 +1281,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
is_parallelizable
=
False
supports_gradient_checkpointing
=
False
_is_stateful
=
False
# Flash Attention 2 support
_supports_flash_attn_2
=
False
...
...
src/transformers/models/jamba/modeling_jamba.py
View file @
83259e40
...
...
@@ -1266,6 +1266,7 @@ class JambaPreTrainedModel(PreTrainedModel):
_supports_flash_attn_2
=
True
_supports_sdpa
=
True
_supports_cache_class
=
True
# Note: only supports HybridMambaAttentionDynamicCache
_is_stateful
=
True
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
...
...
src/transformers/models/mamba/modeling_mamba.py
View file @
83259e40
...
...
@@ -354,6 +354,7 @@ class MambaPreTrainedModel(PreTrainedModel):
base_model_prefix
=
"backbone"
_no_split_modules
=
[
"MambaBlock"
]
supports_gradient_checkpointing
=
True
_is_stateful
=
True
def
_init_weights
(
self
,
module
):
"""Initialize the weights."""
...
...
src/transformers/models/rwkv/modeling_rwkv.py
View file @
83259e40
...
...
@@ -394,6 +394,7 @@ class RwkvPreTrainedModel(PreTrainedModel):
_no_split_modules
=
[
"RwkvBlock"
]
_keep_in_fp32_modules
=
[
"time_decay"
,
"time_first"
]
supports_gradient_checkpointing
=
True
_is_stateful
=
True
def
_init_weights
(
self
,
module
):
"""Initialize the weights."""
...
...
tests/generation/test_utils.py
View file @
83259e40
...
...
@@ -102,7 +102,11 @@ class GenerationTesterMixin:
if
isinstance
(
config
.
eos_token_id
,
int
):
config
.
eos_token_id
=
[
config
.
eos_token_id
]
config
.
pad_token_id
=
config
.
eos_token_id
[
0
]
attention_mask
=
torch
.
ones_like
(
input_ids
,
dtype
=
torch
.
long
)
if
self
.
has_attentions
:
attention_mask
=
torch
.
ones_like
(
input_ids
,
dtype
=
torch
.
long
)
else
:
attention_mask
=
None
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated
...
...
@@ -437,7 +441,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
)
...
...
@@ -471,7 +475,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
)
...
...
@@ -529,7 +533,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
)
...
...
@@ -595,7 +599,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
)
if
model
.
config
.
is_encoder_decoder
:
...
...
@@ -642,7 +646,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
)
...
...
@@ -733,7 +737,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
)
...
...
@@ -834,7 +838,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
)
if
model
.
config
.
is_encoder_decoder
:
...
...
@@ -952,7 +956,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
)
...
...
@@ -973,6 +977,9 @@ class GenerationTesterMixin:
def
test_contrastive_generate
(
self
):
for
model_class
in
self
.
all_generative_model_classes
:
if
model_class
.
_is_stateful
:
self
.
skipTest
(
"Stateful models don't support contrastive search generation"
)
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
...
...
@@ -997,6 +1004,9 @@ class GenerationTesterMixin:
def
test_contrastive_generate_dict_outputs_use_cache
(
self
):
for
model_class
in
self
.
all_generative_model_classes
:
if
model_class
.
_is_stateful
:
self
.
skipTest
(
"Stateful models don't support contrastive search generation"
)
# won't fix: FSMT and Reformer have a different cache variable type (and format).
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
...
...
@@ -1017,7 +1027,7 @@ class GenerationTesterMixin:
output_scores
=
True
,
output_logits
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
)
...
...
@@ -1030,9 +1040,12 @@ class GenerationTesterMixin:
def
test_contrastive_generate_low_memory
(
self
):
# Check that choosing 'low_memory' does not change the model output
for
model_class
in
self
.
all_generative_model_classes
:
if
model_class
.
_is_stateful
:
self
.
skipTest
(
"Stateful models don't support contrastive search generation"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
,
"speech2text"
]):
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"gptbigcode"
,
"jamba"
]):
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"gptbigcode"
]):
self
.
skipTest
(
"TODO: fix me"
)
config
,
input_ids
,
attention_mask
=
self
.
_get_input_ids_and_config
(
batch_size
=
1
)
...
...
@@ -1069,6 +1082,8 @@ class GenerationTesterMixin:
def
test_beam_search_low_memory
(
self
):
# Check that choosing 'low_memory' does not change the model output
for
model_class
in
self
.
all_generative_model_classes
:
if
model_class
.
_is_stateful
:
self
.
skipTest
(
"May fix in the future: need custom cache handling"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
if
any
(
...
...
@@ -1115,6 +1130,8 @@ class GenerationTesterMixin:
# - assisted_decoding does not support `batch_size > 1`
for
model_class
in
self
.
all_generative_model_classes
:
if
model_class
.
_is_stateful
:
self
.
skipTest
(
"Stateful models don't support assisted generation"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
if
any
(
...
...
@@ -1156,7 +1173,7 @@ class GenerationTesterMixin:
"output_scores"
:
True
,
"output_logits"
:
True
,
"output_hidden_states"
:
True
,
"output_attentions"
:
True
,
"output_attentions"
:
self
.
has_attentions
,
"return_dict_in_generate"
:
True
,
}
output_greedy
=
model
.
generate
(
input_ids
,
attention_mask
=
attention_mask
,
**
generation_kwargs
)
...
...
@@ -1184,6 +1201,8 @@ class GenerationTesterMixin:
# This test is mostly a copy of test_assisted_decoding_matches_greedy_search
for
model_class
in
self
.
all_generative_model_classes
:
if
model_class
.
_is_stateful
:
self
.
skipTest
(
"Stateful models don't support assisted generation"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
if
any
(
...
...
@@ -1225,7 +1244,7 @@ class GenerationTesterMixin:
"output_scores"
:
True
,
"output_logits"
:
True
,
"output_hidden_states"
:
True
,
"output_attentions"
:
True
,
"output_attentions"
:
self
.
has_attentions
,
"return_dict_in_generate"
:
True
,
}
...
...
@@ -1244,6 +1263,8 @@ class GenerationTesterMixin:
# match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with
# different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535).
for
model_class
in
self
.
all_generative_model_classes
:
if
model_class
.
_is_stateful
:
self
.
skipTest
(
"Stateful models don't support assisted generation"
)
if
any
(
model_name
in
model_class
.
__name__
.
lower
()
for
model_name
in
[
"fsmt"
,
"reformer"
]):
self
.
skipTest
(
"Won't fix: old model with different cache format"
)
if
any
(
...
...
@@ -1289,7 +1310,7 @@ class GenerationTesterMixin:
"output_scores"
:
True
,
"output_logits"
:
True
,
"output_hidden_states"
:
True
,
"output_attentions"
:
True
,
"output_attentions"
:
self
.
has_attentions
,
"return_dict_in_generate"
:
True
,
}
output_assisted
=
model
.
generate
(
input_ids
,
attention_mask
=
attention_mask
,
**
generation_kwargs
)
...
...
@@ -1326,7 +1347,7 @@ class GenerationTesterMixin:
input_ids
,
attention_mask
=
attention_mask
,
num_beams
=
1
,
output_attentions
=
True
,
output_attentions
=
self
.
has_attentions
,
return_dict_in_generate
=
True
,
remove_invalid_values
=
True
,
**
{
name
:
mask
},
...
...
@@ -1344,6 +1365,10 @@ class GenerationTesterMixin:
if
len
(
self
.
all_generative_model_classes
)
==
0
:
self
.
skipTest
(
reason
=
"No generative architecture available for this model."
)
# - The model must support padding
if
not
self
.
has_attentions
:
self
.
skipTest
(
reason
=
"This model doesn't support padding."
)
# - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
decoder_only_classes
=
[]
for
model_class
in
self
.
all_generative_model_classes
:
...
...
@@ -1704,30 +1729,31 @@ class GenerationTesterMixin:
self
.
_check_logits
(
num_sequences_in_output
,
output
.
logits
,
config
=
config
)
# Attentions
if
config
.
is_encoder_decoder
:
# encoder
self
.
_check_encoder_attention_for_generate
(
output
.
encoder_attentions
,
batch_size
,
config
,
seq_length
)
# decoder
self
.
_check_attentions_for_generate
(
num_sequences_in_output
,
output
.
decoder_attentions
,
min_length
=
1
,
max_length
=
output
.
sequences
.
shape
[
-
1
],
config
=
config
,
use_cache
=
use_cache
,
)
else
:
# if use_cache first input is equal to no use_cache, so skip here
attentions
=
output
.
attentions
if
not
use_cache
else
output
.
attentions
[
1
:]
min_length
=
seq_length
if
not
use_cache
else
seq_length
+
1
self
.
_check_attentions_for_generate
(
num_sequences_in_output
,
attentions
=
attentions
,
min_length
=
min_length
,
max_length
=
output
.
sequences
.
shape
[
-
1
],
config
=
config
,
use_cache
=
use_cache
,
)
if
self
.
has_attentions
:
if
config
.
is_encoder_decoder
:
# encoder
self
.
_check_encoder_attention_for_generate
(
output
.
encoder_attentions
,
batch_size
,
config
,
seq_length
)
# decoder
self
.
_check_attentions_for_generate
(
num_sequences_in_output
,
output
.
decoder_attentions
,
min_length
=
1
,
max_length
=
output
.
sequences
.
shape
[
-
1
],
config
=
config
,
use_cache
=
use_cache
,
)
else
:
# if use_cache first input is equal to no use_cache, so skip here
attentions
=
output
.
attentions
if
not
use_cache
else
output
.
attentions
[
1
:]
min_length
=
seq_length
if
not
use_cache
else
seq_length
+
1
self
.
_check_attentions_for_generate
(
num_sequences_in_output
,
attentions
=
attentions
,
min_length
=
min_length
,
max_length
=
output
.
sequences
.
shape
[
-
1
],
config
=
config
,
use_cache
=
use_cache
,
)
# Hidden States
if
config
.
is_encoder_decoder
:
...
...
@@ -1763,7 +1789,7 @@ class GenerationTesterMixin:
# 2. Some old models still return `output.past_key_values` even without `use_cache=True`
# 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is
# complete
models_without_standard_cache
=
(
"bloom"
,
"ctrl"
,
"fsmt"
,
"gptbigcode"
,
"mega"
,
"reformer"
,
"jamba"
)
models_without_standard_cache
=
(
"bloom"
,
"ctrl"
,
"fsmt"
,
"gptbigcode"
,
"mega"
,
"reformer"
,
"jamba"
,
"mamba"
)
has_standard_cache
=
not
any
(
model_name
in
config
.
__class__
.
__name__
.
lower
()
for
model_name
in
models_without_standard_cache
)
...
...
tests/models/jamba/test_modeling_jamba.py
View file @
83259e40
...
...
@@ -503,10 +503,6 @@ class JambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# They should result in very similar logits
self
.
assertTrue
(
torch
.
allclose
(
next_logits_wo_padding
,
next_logits_with_padding
,
atol
=
3e-3
))
@
unittest
.
skip
(
"Jamba has its own special cache type"
)
# FIXME: @gante
def
test_assisted_decoding_matches_greedy_search_0_random
(
self
):
pass
@
require_flash_attn
@
require_torch_gpu
@
require_bitsandbytes
...
...
tests/models/mamba/test_modeling_mamba.py
View file @
83259e40
...
...
@@ -250,6 +250,8 @@ class MambaModelTester:
@
require_torch
class
MambaModelTest
(
ModelTesterMixin
,
GenerationTesterMixin
,
PipelineTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
MambaModel
,
MambaForCausalLM
)
if
is_torch_available
()
else
()
all_generative_model_classes
=
(
MambaForCausalLM
,)
if
is_torch_available
()
else
()
has_attentions
=
False
# Mamba does not support attentions
fx_compatible
=
False
# FIXME let's try to support this @ArthurZucker
test_torchscript
=
False
# FIXME let's try to support this @ArthurZucker
test_missing_keys
=
False
...
...
@@ -292,10 +294,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
def
test_config
(
self
):
self
.
config_tester
.
run_common_tests
()
@
unittest
.
skip
(
"No attention in mamba"
)
def
test_retain_grad_hidden_states_attentions
(
self
):
pass
@
require_torch_multi_gpu
def
test_multi_gpu_data_parallel_forward
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
@@ -364,14 +362,6 @@ class MambaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
# check if it's a ones like
self
.
assertTrue
(
torch
.
allclose
(
param
.
data
,
torch
.
ones_like
(
param
.
data
),
atol
=
1e-5
,
rtol
=
1e-5
))
@
unittest
.
skip
(
"Mamba does not use attention"
)
def
test_attention_outputs
(
self
):
r
"""
Overriding the test_attention_outputs test as the attention outputs of Mamba are different from other models
it has a shape `batch_size, seq_len, hidden_size`.
"""
pass
@
slow
def
test_model_from_pretrained
(
self
):
model
=
MambaModel
.
from_pretrained
(
"hf-internal-testing/mamba-130m"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment