Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4692d261
Unverified
Commit
4692d261
authored
Aug 11, 2023
by
Joao Gante
Committed by
GitHub
Aug 11, 2023
Browse files
Switch Transformers: remove overwritten beam sample test (#25458)
parent
41d56ea6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
96 deletions
+0
-96
tests/models/switch_transformers/test_modeling_switch_transformers.py
.../switch_transformers/test_modeling_switch_transformers.py
+0
-96
No files found.
tests/models/switch_transformers/test_modeling_switch_transformers.py
View file @
4692d261
...
@@ -37,7 +37,6 @@ if is_torch_available():
...
@@ -37,7 +37,6 @@ if is_torch_available():
SwitchTransformersModel
,
SwitchTransformersModel
,
SwitchTransformersTop1Router
,
SwitchTransformersTop1Router
,
)
)
from
transformers.generation
import
BeamSampleDecoderOnlyOutput
,
BeamSampleEncoderDecoderOutput
from
transformers.models.switch_transformers.modeling_switch_transformers
import
(
from
transformers.models.switch_transformers.modeling_switch_transformers
import
(
SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST
,
SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST
,
load_balancing_loss_func
,
load_balancing_loss_func
,
...
@@ -613,101 +612,6 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
...
@@ -613,101 +612,6 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_decoder_model_attention_mask_past
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_decoder_model_attention_mask_past
(
*
config_and_inputs
)
@
slow
def
test_beam_sample_generate_dict_output
(
self
):
r
"""
This test needs to be overriden with a larger model since it fails for very small models due to precision issues.
"""
for
model_class
in
self
.
all_generative_model_classes
:
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
# disable cache
config
.
use_cache
=
False
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config
.
eos_token_id
=
None
config
.
forced_eos_token_id
=
None
model
=
model_class
.
from_pretrained
(
"google/switch-base-8"
).
to
(
torch_device
).
eval
()
logits_warper_kwargs
,
logits_warper
=
self
.
_get_warper_and_kwargs
(
num_beams
=
2
)
num_return_sequences
=
2
if
model
.
config
.
is_encoder_decoder
:
max_length
=
4
beam_kwargs
,
beam_scorer
=
self
.
_get_beam_scorer_and_kwargs
(
input_ids
.
shape
[
0
]
*
num_return_sequences
,
max_length
)
beam_kwargs
[
"num_return_sequences"
]
=
num_return_sequences
output_beam_sample
,
output_generate
=
self
.
_beam_sample_generate
(
model
=
model
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
max_length
=
max_length
,
num_return_sequences
=
num_return_sequences
,
beam_scorer
=
beam_scorer
,
beam_kwargs
=
beam_kwargs
,
logits_warper
=
logits_warper
,
logits_warper_kwargs
=
logits_warper_kwargs
,
output_scores
=
True
,
output_hidden_states
=
True
,
output_attentions
=
True
,
return_dict_in_generate
=
True
,
)
if
model
.
config
.
is_encoder_decoder
:
self
.
assertIsInstance
(
output_beam_sample
,
BeamSampleEncoderDecoderOutput
)
self
.
assertIsInstance
(
output_generate
,
BeamSampleEncoderDecoderOutput
)
else
:
self
.
assertIsInstance
(
output_beam_sample
,
BeamSampleDecoderOnlyOutput
)
self
.
assertIsInstance
(
output_generate
,
BeamSampleDecoderOnlyOutput
)
self
.
assertListEqual
(
output_generate
.
sequences
.
tolist
(),
output_beam_sample
.
sequences
.
tolist
())
@
slow
def
test_beam_sample_generate
(
self
):
r
"""
This test needs to be overriden with a larger model since it fails for very small models due to precision issues.
"""
for
model_class
in
self
.
all_generative_model_classes
:
config
,
input_ids
,
attention_mask
,
max_length
=
self
.
_get_input_ids_and_config
()
# It is important set set the eos_token_id to None to ensure that no sequences
# shorter than `max_length` can be generated which could lead to flaky circle ci
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
config
.
eos_token_id
=
None
config
.
forced_eos_token_id
=
None
logits_warper_kwargs
,
logits_warper
=
self
.
_get_warper_and_kwargs
(
num_beams
=
2
)
model
=
model_class
.
from_pretrained
(
"google/switch-base-8"
).
to
(
torch_device
).
eval
()
# check `generate()` and `beam_search()` are equal
# change `num_return_sequences = 2` but not for `beam_scorer`
num_return_sequences
=
2
if
model
.
config
.
is_encoder_decoder
:
max_length
=
4
beam_kwargs
,
beam_scorer
=
self
.
_get_beam_scorer_and_kwargs
(
input_ids
.
shape
[
0
]
*
num_return_sequences
,
max_length
)
beam_kwargs
[
"num_return_sequences"
]
=
num_return_sequences
output_generate
,
output_beam_sample
=
self
.
_beam_sample_generate
(
model
=
model
,
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
max_length
=
max_length
,
num_return_sequences
=
num_return_sequences
,
beam_scorer
=
beam_scorer
,
beam_kwargs
=
beam_kwargs
,
logits_warper
=
logits_warper
,
logits_warper_kwargs
=
logits_warper_kwargs
,
)
self
.
assertListEqual
(
output_generate
.
tolist
(),
output_beam_sample
.
tolist
())
def
test_decoder_model_past_with_3d_attn_mask
(
self
):
def
test_decoder_model_past_with_3d_attn_mask
(
self
):
(
(
config
,
config
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment