Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5e8c8eb5
Unverified
Commit
5e8c8eb5
authored
Feb 22, 2023
by
Aaron Gokaslan
Committed by
GitHub
Feb 22, 2023
Browse files
Apply ruff flake8-comprehensions (#21694)
parent
df06fb1f
Changes
230
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
99 additions
and
97 deletions
+99
-97
tests/deepspeed/test_model_zoo.py
tests/deepspeed/test_model_zoo.py
+18
-18
tests/extended/test_trainer_ext.py
tests/extended/test_trainer_ext.py
+7
-7
tests/generation/test_utils.py
tests/generation/test_utils.py
+1
-1
tests/models/bart/test_modeling_bart.py
tests/models/bart/test_modeling_bart.py
+1
-1
tests/models/blenderbot/test_modeling_blenderbot.py
tests/models/blenderbot/test_modeling_blenderbot.py
+2
-2
tests/models/blenderbot/test_modeling_flax_blenderbot.py
tests/models/blenderbot/test_modeling_flax_blenderbot.py
+2
-2
tests/models/bloom/test_tokenization_bloom.py
tests/models/bloom/test_tokenization_bloom.py
+1
-1
tests/models/clip/test_modeling_tf_clip.py
tests/models/clip/test_modeling_tf_clip.py
+2
-2
tests/models/data2vec/test_modeling_tf_data2vec_vision.py
tests/models/data2vec/test_modeling_tf_data2vec_vision.py
+1
-1
tests/models/groupvit/test_modeling_tf_groupvit.py
tests/models/groupvit/test_modeling_tf_groupvit.py
+2
-2
tests/models/jukebox/test_modeling_jukebox.py
tests/models/jukebox/test_modeling_jukebox.py
+11
-11
tests/models/jukebox/test_tokenization_jukebox.py
tests/models/jukebox/test_tokenization_jukebox.py
+5
-5
tests/models/layoutlmv2/test_processor_layoutlmv2.py
tests/models/layoutlmv2/test_processor_layoutlmv2.py
+9
-9
tests/models/layoutlmv3/test_modeling_tf_layoutlmv3.py
tests/models/layoutlmv3/test_modeling_tf_layoutlmv3.py
+1
-1
tests/models/layoutlmv3/test_processor_layoutlmv3.py
tests/models/layoutlmv3/test_processor_layoutlmv3.py
+9
-9
tests/models/layoutxlm/test_processor_layoutxlm.py
tests/models/layoutxlm/test_processor_layoutxlm.py
+9
-9
tests/models/markuplm/test_processor_markuplm.py
tests/models/markuplm/test_processor_markuplm.py
+9
-9
tests/models/mobilevit/test_modeling_tf_mobilevit.py
tests/models/mobilevit/test_modeling_tf_mobilevit.py
+1
-1
tests/models/perceiver/test_modeling_perceiver.py
tests/models/perceiver/test_modeling_perceiver.py
+6
-4
tests/models/roc_bert/test_tokenization_roc_bert.py
tests/models/roc_bert/test_tokenization_roc_bert.py
+2
-2
No files found.
tests/deepspeed/test_model_zoo.py
View file @
5e8c8eb5
...
@@ -166,8 +166,8 @@ def make_task_cmds():
...
@@ -166,8 +166,8 @@ def make_task_cmds():
# but need a tiny model for each
# but need a tiny model for each
#
#
# should have "{model_type.upper()}_TINY" corresponding vars defined, e.g., T5_TINY, etc.
# should have "{model_type.upper()}_TINY" corresponding vars defined, e.g., T5_TINY, etc.
tasks2models
=
dict
(
tasks2models
=
{
trans
=
[
"
trans
"
:
[
"bart"
,
"bart"
,
"fsmt"
,
"fsmt"
,
"m2m_100"
,
"m2m_100"
,
...
@@ -177,10 +177,10 @@ def make_task_cmds():
...
@@ -177,10 +177,10 @@ def make_task_cmds():
"t5_v1"
,
"t5_v1"
,
# "mt5", missing model files
# "mt5", missing model files
],
],
sum
=
[
"
sum
"
:
[
"pegasus"
,
"pegasus"
,
],
],
clm
=
[
"
clm
"
:
[
"big_bird"
,
"big_bird"
,
"bigbird_pegasus"
,
"bigbird_pegasus"
,
"blenderbot"
,
"blenderbot"
,
...
@@ -192,7 +192,7 @@ def make_task_cmds():
...
@@ -192,7 +192,7 @@ def make_task_cmds():
"prophetnet"
,
"prophetnet"
,
# "camembert", missing model files
# "camembert", missing model files
],
],
mlm
=
[
"
mlm
"
:
[
"albert"
,
"albert"
,
"deberta"
,
"deberta"
,
"deberta-v2"
,
"deberta-v2"
,
...
@@ -203,7 +203,7 @@ def make_task_cmds():
...
@@ -203,7 +203,7 @@ def make_task_cmds():
"layoutlm"
,
"layoutlm"
,
# "reformer", # multiple issues with either mlm/qa/clas
# "reformer", # multiple issues with either mlm/qa/clas
],
],
qa
=
[
"
qa
"
:
[
"led"
,
"led"
,
"longformer"
,
"longformer"
,
"mobilebert"
,
"mobilebert"
,
...
@@ -213,7 +213,7 @@ def make_task_cmds():
...
@@ -213,7 +213,7 @@ def make_task_cmds():
# "convbert", # missing tokenizer files
# "convbert", # missing tokenizer files
# "layoutlmv2", missing model files
# "layoutlmv2", missing model files
],
],
clas
=
[
"
clas
"
:
[
"bert"
,
"bert"
,
"xlnet"
,
"xlnet"
,
# "hubert", # missing tokenizer files
# "hubert", # missing tokenizer files
...
@@ -223,54 +223,54 @@ def make_task_cmds():
...
@@ -223,54 +223,54 @@ def make_task_cmds():
# "openai-gpt", missing model files
# "openai-gpt", missing model files
# "tapas", multiple issues
# "tapas", multiple issues
],
],
img_clas
=
[
"
img_clas
"
:
[
"vit"
,
"vit"
,
],
],
)
}
scripts_dir
=
f
"
{
ROOT_DIRECTORY
}
/examples/pytorch"
scripts_dir
=
f
"
{
ROOT_DIRECTORY
}
/examples/pytorch"
tasks
=
dict
(
tasks
=
{
trans
=
f
"""
"
trans
"
:
f
"""
{
scripts_dir
}
/translation/run_translation.py
{
scripts_dir
}
/translation/run_translation.py
--train_file
{
data_dir_wmt
}
/train.json
--train_file
{
data_dir_wmt
}
/train.json
--source_lang en
--source_lang en
--target_lang ro
--target_lang ro
"""
,
"""
,
sum
=
f
"""
"
sum
"
:
f
"""
{
scripts_dir
}
/summarization/run_summarization.py
{
scripts_dir
}
/summarization/run_summarization.py
--train_file
{
data_dir_xsum
}
/sample.json
--train_file
{
data_dir_xsum
}
/sample.json
--max_source_length 12
--max_source_length 12
--max_target_length 12
--max_target_length 12
--lang en
--lang en
"""
,
"""
,
clm
=
f
"""
"
clm
"
:
f
"""
{
scripts_dir
}
/language-modeling/run_clm.py
{
scripts_dir
}
/language-modeling/run_clm.py
--train_file
{
FIXTURE_DIRECTORY
}
/sample_text.txt
--train_file
{
FIXTURE_DIRECTORY
}
/sample_text.txt
--block_size 8
--block_size 8
"""
,
"""
,
mlm
=
f
"""
"
mlm
"
:
f
"""
{
scripts_dir
}
/language-modeling/run_mlm.py
{
scripts_dir
}
/language-modeling/run_mlm.py
--train_file
{
FIXTURE_DIRECTORY
}
/sample_text.txt
--train_file
{
FIXTURE_DIRECTORY
}
/sample_text.txt
"""
,
"""
,
qa
=
f
"""
"
qa
"
:
f
"""
{
scripts_dir
}
/question-answering/run_qa.py
{
scripts_dir
}
/question-answering/run_qa.py
--train_file
{
data_dir_samples
}
/SQUAD/sample.json
--train_file
{
data_dir_samples
}
/SQUAD/sample.json
"""
,
"""
,
clas
=
f
"""
"
clas
"
:
f
"""
{
scripts_dir
}
/text-classification/run_glue.py
{
scripts_dir
}
/text-classification/run_glue.py
--train_file
{
data_dir_samples
}
/MRPC/train.csv
--train_file
{
data_dir_samples
}
/MRPC/train.csv
--max_seq_length 12
--max_seq_length 12
--task_name MRPC
--task_name MRPC
"""
,
"""
,
img_clas
=
f
"""
"
img_clas
"
:
f
"""
{
scripts_dir
}
/image-classification/run_image_classification.py
{
scripts_dir
}
/image-classification/run_image_classification.py
--dataset_name hf-internal-testing/cats_vs_dogs_sample
--dataset_name hf-internal-testing/cats_vs_dogs_sample
--remove_unused_columns False
--remove_unused_columns False
--max_steps 10
--max_steps 10
--image_processor_name
{
DS_TESTS_DIRECTORY
}
/vit_feature_extractor.json
--image_processor_name
{
DS_TESTS_DIRECTORY
}
/vit_feature_extractor.json
"""
,
"""
,
)
}
launcher
=
get_launcher
(
distributed
=
True
)
launcher
=
get_launcher
(
distributed
=
True
)
...
...
tests/extended/test_trainer_ext.py
View file @
5e8c8eb5
...
@@ -155,21 +155,21 @@ class TestTrainerExt(TestCasePlus):
...
@@ -155,21 +155,21 @@ class TestTrainerExt(TestCasePlus):
@
require_torch_multi_gpu
@
require_torch_multi_gpu
def
test_trainer_log_level_replica
(
self
,
experiment_id
):
def
test_trainer_log_level_replica
(
self
,
experiment_id
):
# as each sub-test is slow-ish split into multiple sub-tests to avoid CI timeout
# as each sub-test is slow-ish split into multiple sub-tests to avoid CI timeout
experiments
=
dict
(
experiments
=
{
# test with the default log_level - should be info and thus log info once
# test with the default log_level - should be info and thus log info once
base
=
dict
(
extra_args_str
=
""
,
n_matches
=
1
)
,
"
base
"
:
{
"
extra_args_str
"
:
""
,
"
n_matches
"
:
1
}
,
# test with low log_level and log_level_replica - should be noisy on all processes
# test with low log_level and log_level_replica - should be noisy on all processes
# now the info string should appear twice on 2 processes
# now the info string should appear twice on 2 processes
low
=
dict
(
extra_args_str
=
"--log_level debug --log_level_replica debug"
,
n_matches
=
2
)
,
"
low
"
:
{
"
extra_args_str
"
:
"--log_level debug --log_level_replica debug"
,
"
n_matches
"
:
2
}
,
# test with high log_level and low log_level_replica
# test with high log_level and low log_level_replica
# now the info string should appear once only on the replica
# now the info string should appear once only on the replica
high
=
dict
(
extra_args_str
=
"--log_level error --log_level_replica debug"
,
n_matches
=
1
)
,
"
high
"
:
{
"
extra_args_str
"
:
"--log_level error --log_level_replica debug"
,
"
n_matches
"
:
1
}
,
# test with high log_level and log_level_replica - should be quiet on all processes
# test with high log_level and log_level_replica - should be quiet on all processes
mixed
=
dict
(
extra_args_str
=
"--log_level error --log_level_replica error"
,
n_matches
=
0
)
,
"
mixed
"
:
{
"
extra_args_str
"
:
"--log_level error --log_level_replica error"
,
"
n_matches
"
:
0
}
,
)
}
data
=
experiments
[
experiment_id
]
data
=
experiments
[
experiment_id
]
kwargs
=
dict
(
distributed
=
True
,
predict_with_generate
=
False
,
do_eval
=
False
,
do_predict
=
False
)
kwargs
=
{
"
distributed
"
:
True
,
"
predict_with_generate
"
:
False
,
"
do_eval
"
:
False
,
"
do_predict
"
:
False
}
log_info_string
=
"Running training"
log_info_string
=
"Running training"
with
CaptureStderr
()
as
cl
:
with
CaptureStderr
()
as
cl
:
self
.
run_seq2seq_quick
(
**
kwargs
,
extra_args_str
=
data
[
"extra_args_str"
])
self
.
run_seq2seq_quick
(
**
kwargs
,
extra_args_str
=
data
[
"extra_args_str"
])
...
...
tests/generation/test_utils.py
View file @
5e8c8eb5
...
@@ -1480,7 +1480,7 @@ class GenerationTesterMixin:
...
@@ -1480,7 +1480,7 @@ class GenerationTesterMixin:
signature
=
inspect
.
signature
(
model
.
forward
)
signature
=
inspect
.
signature
(
model
.
forward
)
# We want to test only models where encoder/decoder head masking is implemented
# We want to test only models where encoder/decoder head masking is implemented
if
not
set
(
head_masking
.
keys
())
<
set
([
*
signature
.
parameters
.
keys
()
])
:
if
not
set
(
head_masking
.
keys
())
<
{
*
signature
.
parameters
.
keys
()
}
:
continue
continue
for
attn_name
,
(
name
,
mask
)
in
zip
(
attention_names
,
head_masking
.
items
()):
for
attn_name
,
(
name
,
mask
)
in
zip
(
attention_names
,
head_masking
.
items
()):
...
...
tests/models/bart/test_modeling_bart.py
View file @
5e8c8eb5
...
@@ -939,7 +939,7 @@ class BartModelIntegrationTests(unittest.TestCase):
...
@@ -939,7 +939,7 @@ class BartModelIntegrationTests(unittest.TestCase):
def
test_xsum_config_generation_params
(
self
):
def
test_xsum_config_generation_params
(
self
):
config
=
BartConfig
.
from_pretrained
(
"facebook/bart-large-xsum"
)
config
=
BartConfig
.
from_pretrained
(
"facebook/bart-large-xsum"
)
expected_params
=
dict
(
num_beams
=
6
,
do_sample
=
False
,
early_stopping
=
True
,
length_penalty
=
1.0
)
expected_params
=
{
"
num_beams
"
:
6
,
"
do_sample
"
:
False
,
"
early_stopping
"
:
True
,
"
length_penalty
"
:
1.0
}
config_params
=
{
k
:
getattr
(
config
,
k
,
"MISSING"
)
for
k
,
v
in
expected_params
.
items
()}
config_params
=
{
k
:
getattr
(
config
,
k
,
"MISSING"
)
for
k
,
v
in
expected_params
.
items
()}
self
.
assertDictEqual
(
expected_params
,
config_params
)
self
.
assertDictEqual
(
expected_params
,
config_params
)
...
...
tests/models/blenderbot/test_modeling_blenderbot.py
View file @
5e8c8eb5
...
@@ -299,8 +299,8 @@ class Blenderbot3BIntegrationTests(unittest.TestCase):
...
@@ -299,8 +299,8 @@ class Blenderbot3BIntegrationTests(unittest.TestCase):
@
slow
@
slow
def
test_generation_from_short_input_same_as_parlai_3B
(
self
):
def
test_generation_from_short_input_same_as_parlai_3B
(
self
):
FASTER_GEN_KWARGS
=
dict
(
num_beams
=
1
,
early_stopping
=
True
,
min_length
=
15
,
max_length
=
25
)
FASTER_GEN_KWARGS
=
{
"
num_beams
"
:
1
,
"
early_stopping
"
:
True
,
"
min_length
"
:
15
,
"
max_length
"
:
25
}
TOK_DECODE_KW
=
dict
(
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
True
)
TOK_DECODE_KW
=
{
"
skip_special_tokens
"
:
True
,
"
clean_up_tokenization_spaces
"
:
True
}
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
model
=
BlenderbotForConditionalGeneration
.
from_pretrained
(
self
.
ckpt
).
half
().
to
(
torch_device
)
model
=
BlenderbotForConditionalGeneration
.
from_pretrained
(
self
.
ckpt
).
half
().
to
(
torch_device
)
...
...
tests/models/blenderbot/test_modeling_flax_blenderbot.py
View file @
5e8c8eb5
...
@@ -402,8 +402,8 @@ class FlaxBlenderbotModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGener
...
@@ -402,8 +402,8 @@ class FlaxBlenderbotModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGener
@
unittest
.
skipUnless
(
jax_device
!=
"cpu"
,
"3B test too slow on CPU."
)
@
unittest
.
skipUnless
(
jax_device
!=
"cpu"
,
"3B test too slow on CPU."
)
@
slow
@
slow
def
test_generation_from_short_input_same_as_parlai_3B
(
self
):
def
test_generation_from_short_input_same_as_parlai_3B
(
self
):
FASTER_GEN_KWARGS
=
dict
(
num_beams
=
1
,
early_stopping
=
True
,
min_length
=
15
,
max_length
=
25
)
FASTER_GEN_KWARGS
=
{
"
num_beams
"
:
1
,
"
early_stopping
"
:
True
,
"
min_length
"
:
15
,
"
max_length
"
:
25
}
TOK_DECODE_KW
=
dict
(
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
True
)
TOK_DECODE_KW
=
{
"
skip_special_tokens
"
:
True
,
"
clean_up_tokenization_spaces
"
:
True
}
model
=
FlaxBlenderbotForConditionalGeneration
.
from_pretrained
(
"facebook/blenderbot-3B"
,
from_pt
=
True
)
model
=
FlaxBlenderbotForConditionalGeneration
.
from_pretrained
(
"facebook/blenderbot-3B"
,
from_pt
=
True
)
tokenizer
=
BlenderbotTokenizer
.
from_pretrained
(
"facebook/blenderbot-3B"
)
tokenizer
=
BlenderbotTokenizer
.
from_pretrained
(
"facebook/blenderbot-3B"
)
...
...
tests/models/bloom/test_tokenization_bloom.py
View file @
5e8c8eb5
...
@@ -124,7 +124,7 @@ class BloomTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
...
@@ -124,7 +124,7 @@ class BloomTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
input_text
=
list
(
sample_data
.
values
())
input_text
=
list
(
sample_data
.
values
())
output_tokens
=
list
(
map
(
tokenizer
.
encode
,
input_text
))
output_tokens
=
list
(
map
(
tokenizer
.
encode
,
input_text
))
predicted_text
=
list
(
map
(
lambda
x
:
tokenizer
.
decode
(
x
,
clean_up_tokenization_spaces
=
False
)
,
output_tokens
))
predicted_text
=
[
tokenizer
.
decode
(
x
,
clean_up_tokenization_spaces
=
False
)
for
x
in
output_tokens
]
self
.
assertListEqual
(
predicted_text
,
input_text
)
self
.
assertListEqual
(
predicted_text
,
input_text
)
def
test_pretrained_model_lists
(
self
):
def
test_pretrained_model_lists
(
self
):
...
...
tests/models/clip/test_modeling_tf_clip.py
View file @
5e8c8eb5
...
@@ -551,7 +551,7 @@ class TFCLIPModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -551,7 +551,7 @@ class TFCLIPModelTest(TFModelTesterMixin, unittest.TestCase):
if
self
.
__class__
.
__name__
==
"TFCLIPModelTest"
:
if
self
.
__class__
.
__name__
==
"TFCLIPModelTest"
:
inputs_dict
.
pop
(
"return_loss"
,
None
)
inputs_dict
.
pop
(
"return_loss"
,
None
)
tf_main_layer_classes
=
set
(
tf_main_layer_classes
=
{
module_member
module_member
for
model_class
in
self
.
all_model_classes
for
model_class
in
self
.
all_model_classes
for
module
in
(
import_module
(
model_class
.
__module__
),)
for
module
in
(
import_module
(
model_class
.
__module__
),)
...
@@ -563,7 +563,7 @@ class TFCLIPModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -563,7 +563,7 @@ class TFCLIPModelTest(TFModelTesterMixin, unittest.TestCase):
if
isinstance
(
module_member
,
type
)
if
isinstance
(
module_member
,
type
)
and
tf
.
keras
.
layers
.
Layer
in
module_member
.
__bases__
and
tf
.
keras
.
layers
.
Layer
in
module_member
.
__bases__
and
getattr
(
module_member
,
"_keras_serializable"
,
False
)
and
getattr
(
module_member
,
"_keras_serializable"
,
False
)
)
}
for
main_layer_class
in
tf_main_layer_classes
:
for
main_layer_class
in
tf_main_layer_classes
:
# T5MainLayer needs an embed_tokens parameter when called without the inputs_embeds parameter
# T5MainLayer needs an embed_tokens parameter when called without the inputs_embeds parameter
if
"T5"
in
main_layer_class
.
__name__
:
if
"T5"
in
main_layer_class
.
__name__
:
...
...
tests/models/data2vec/test_modeling_tf_data2vec_vision.py
View file @
5e8c8eb5
...
@@ -398,7 +398,7 @@ class TFData2VecVisionModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -398,7 +398,7 @@ class TFData2VecVisionModelTest(TFModelTesterMixin, unittest.TestCase):
# The number of elements in the loss should be the same as the number of elements in the label
# The number of elements in the loss should be the same as the number of elements in the label
_
,
prepared_for_class
=
self
.
model_tester
.
prepare_config_and_inputs_for_keras_fit
()
_
,
prepared_for_class
=
self
.
model_tester
.
prepare_config_and_inputs_for_keras_fit
()
added_label
=
prepared_for_class
[
added_label
=
prepared_for_class
[
sorted
(
list
(
prepared_for_class
.
keys
()
-
inputs_dict
.
keys
()
)
,
reverse
=
True
)[
0
]
sorted
(
prepared_for_class
.
keys
()
-
inputs_dict
.
keys
(),
reverse
=
True
)[
0
]
]
]
loss_size
=
tf
.
size
(
added_label
)
loss_size
=
tf
.
size
(
added_label
)
...
...
tests/models/groupvit/test_modeling_tf_groupvit.py
View file @
5e8c8eb5
...
@@ -628,7 +628,7 @@ class TFGroupViTModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -628,7 +628,7 @@ class TFGroupViTModelTest(TFModelTesterMixin, unittest.TestCase):
if
self
.
__class__
.
__name__
==
"TFGroupViTModelTest"
:
if
self
.
__class__
.
__name__
==
"TFGroupViTModelTest"
:
inputs_dict
.
pop
(
"return_loss"
,
None
)
inputs_dict
.
pop
(
"return_loss"
,
None
)
tf_main_layer_classes
=
set
(
tf_main_layer_classes
=
{
module_member
module_member
for
model_class
in
self
.
all_model_classes
for
model_class
in
self
.
all_model_classes
for
module
in
(
import_module
(
model_class
.
__module__
),)
for
module
in
(
import_module
(
model_class
.
__module__
),)
...
@@ -640,7 +640,7 @@ class TFGroupViTModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -640,7 +640,7 @@ class TFGroupViTModelTest(TFModelTesterMixin, unittest.TestCase):
if
isinstance
(
module_member
,
type
)
if
isinstance
(
module_member
,
type
)
and
tf
.
keras
.
layers
.
Layer
in
module_member
.
__bases__
and
tf
.
keras
.
layers
.
Layer
in
module_member
.
__bases__
and
getattr
(
module_member
,
"_keras_serializable"
,
False
)
and
getattr
(
module_member
,
"_keras_serializable"
,
False
)
)
}
for
main_layer_class
in
tf_main_layer_classes
:
for
main_layer_class
in
tf_main_layer_classes
:
# T5MainLayer needs an embed_tokens parameter when called without the inputs_embeds parameter
# T5MainLayer needs an embed_tokens parameter when called without the inputs_embeds parameter
if
"T5"
in
main_layer_class
.
__name__
:
if
"T5"
in
main_layer_class
.
__name__
:
...
...
tests/models/jukebox/test_modeling_jukebox.py
View file @
5e8c8eb5
...
@@ -30,10 +30,10 @@ if is_torch_available():
...
@@ -30,10 +30,10 @@ if is_torch_available():
class
Jukebox1bModelTester
(
unittest
.
TestCase
):
class
Jukebox1bModelTester
(
unittest
.
TestCase
):
all_model_classes
=
(
JukeboxModel
,)
if
is_torch_available
()
else
()
all_model_classes
=
(
JukeboxModel
,)
if
is_torch_available
()
else
()
model_id
=
"openai/jukebox-1b-lyrics"
model_id
=
"openai/jukebox-1b-lyrics"
metas
=
dict
(
metas
=
{
artist
=
"Zac Brown Band"
,
"
artist
"
:
"Zac Brown Band"
,
genres
=
"Country"
,
"
genres
"
:
"Country"
,
lyrics
=
"""I met a traveller from an antique land,
"
lyrics
"
:
"""I met a traveller from an antique land,
Who said "Two vast and trunkless legs of stone
Who said "Two vast and trunkless legs of stone
Stand in the desert. . . . Near them, on the sand,
Stand in the desert. . . . Near them, on the sand,
Half sunk a shattered visage lies, whose frown,
Half sunk a shattered visage lies, whose frown,
...
@@ -48,7 +48,7 @@ class Jukebox1bModelTester(unittest.TestCase):
...
@@ -48,7 +48,7 @@ class Jukebox1bModelTester(unittest.TestCase):
Of that colossal Wreck, boundless and bare
Of that colossal Wreck, boundless and bare
The lone and level sands stretch far away
The lone and level sands stretch far away
"""
,
"""
,
)
}
# fmt: off
# fmt: off
EXPECTED_OUTPUT_2
=
[
EXPECTED_OUTPUT_2
=
[
1864
,
1536
,
1213
,
1870
,
1357
,
1536
,
519
,
880
,
1323
,
789
,
1082
,
534
,
1864
,
1536
,
1213
,
1870
,
1357
,
1536
,
519
,
880
,
1323
,
789
,
1082
,
534
,
...
@@ -180,7 +180,7 @@ class Jukebox1bModelTester(unittest.TestCase):
...
@@ -180,7 +180,7 @@ class Jukebox1bModelTester(unittest.TestCase):
model
=
JukeboxModel
.
from_pretrained
(
self
.
model_id
,
min_duration
=
0
).
eval
()
model
=
JukeboxModel
.
from_pretrained
(
self
.
model_id
,
min_duration
=
0
).
eval
()
set_seed
(
0
)
set_seed
(
0
)
waveform
=
torch
.
rand
((
1
,
5120
,
1
))
waveform
=
torch
.
rand
((
1
,
5120
,
1
))
tokens
=
[
i
for
i
in
self
.
prepare_inputs
()
]
tokens
=
list
(
self
.
prepare_inputs
()
)
zs
=
[
model
.
vqvae
.
encode
(
waveform
,
start_level
=
2
,
bs_chunks
=
waveform
.
shape
[
0
])[
0
],
None
,
None
]
zs
=
[
model
.
vqvae
.
encode
(
waveform
,
start_level
=
2
,
bs_chunks
=
waveform
.
shape
[
0
])[
0
],
None
,
None
]
zs
=
model
.
_sample
(
zs
=
model
.
_sample
(
...
@@ -220,10 +220,10 @@ class Jukebox1bModelTester(unittest.TestCase):
...
@@ -220,10 +220,10 @@ class Jukebox1bModelTester(unittest.TestCase):
class
Jukebox5bModelTester
(
unittest
.
TestCase
):
class
Jukebox5bModelTester
(
unittest
.
TestCase
):
all_model_classes
=
(
JukeboxModel
,)
if
is_torch_available
()
else
()
all_model_classes
=
(
JukeboxModel
,)
if
is_torch_available
()
else
()
model_id
=
"openai/jukebox-5b-lyrics"
model_id
=
"openai/jukebox-5b-lyrics"
metas
=
dict
(
metas
=
{
artist
=
"Zac Brown Band"
,
"
artist
"
:
"Zac Brown Band"
,
genres
=
"Country"
,
"
genres
"
:
"Country"
,
lyrics
=
"""I met a traveller from an antique land,
"
lyrics
"
:
"""I met a traveller from an antique land,
Who said "Two vast and trunkless legs of stone
Who said "Two vast and trunkless legs of stone
Stand in the desert. . . . Near them, on the sand,
Stand in the desert. . . . Near them, on the sand,
Half sunk a shattered visage lies, whose frown,
Half sunk a shattered visage lies, whose frown,
...
@@ -238,7 +238,7 @@ class Jukebox5bModelTester(unittest.TestCase):
...
@@ -238,7 +238,7 @@ class Jukebox5bModelTester(unittest.TestCase):
Of that colossal Wreck, boundless and bare
Of that colossal Wreck, boundless and bare
The lone and level sands stretch far away
The lone and level sands stretch far away
"""
,
"""
,
)
}
# fmt: off
# fmt: off
EXPECTED_OUTPUT_2
=
[
EXPECTED_OUTPUT_2
=
[
...
...
tests/models/jukebox/test_tokenization_jukebox.py
View file @
5e8c8eb5
...
@@ -21,10 +21,10 @@ from transformers.testing_utils import require_torch
...
@@ -21,10 +21,10 @@ from transformers.testing_utils import require_torch
class
JukeboxTokenizationTest
(
unittest
.
TestCase
):
class
JukeboxTokenizationTest
(
unittest
.
TestCase
):
tokenizer_class
=
JukeboxTokenizer
tokenizer_class
=
JukeboxTokenizer
metas
=
dict
(
metas
=
{
artist
=
"Zac Brown Band"
,
"
artist
"
:
"Zac Brown Band"
,
genres
=
"Country"
,
"
genres
"
:
"Country"
,
lyrics
=
"""I met a traveller from an antique land,
"
lyrics
"
:
"""I met a traveller from an antique land,
Who said "Two vast and trunkless legs of stone
Who said "Two vast and trunkless legs of stone
Stand in the desert. . . . Near them, on the sand,
Stand in the desert. . . . Near them, on the sand,
Half sunk a shattered visage lies, whose frown,
Half sunk a shattered visage lies, whose frown,
...
@@ -39,7 +39,7 @@ class JukeboxTokenizationTest(unittest.TestCase):
...
@@ -39,7 +39,7 @@ class JukeboxTokenizationTest(unittest.TestCase):
Of that colossal Wreck, boundless and bare
Of that colossal Wreck, boundless and bare
The lone and level sands stretch far away
The lone and level sands stretch far away
"""
,
"""
,
)
}
@
require_torch
@
require_torch
def
test_1b_lyrics_tokenizer
(
self
):
def
test_1b_lyrics_tokenizer
(
self
):
...
...
tests/models/layoutlmv2/test_processor_layoutlmv2.py
View file @
5e8c8eb5
...
@@ -233,7 +233,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
...
@@ -233,7 +233,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"token_type_ids"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"token_type_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify image
# verify image
...
@@ -253,7 +253,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
...
@@ -253,7 +253,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"token_type_ids"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"token_type_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify images
# verify images
...
@@ -301,7 +301,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
...
@@ -301,7 +301,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"token_type_ids"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"token_type_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -340,7 +340,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
...
@@ -340,7 +340,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"labels"
,
"token_type_ids"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"labels"
,
"token_type_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -362,7 +362,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
...
@@ -362,7 +362,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"labels"
,
"token_type_ids"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"labels"
,
"token_type_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -403,7 +403,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
...
@@ -403,7 +403,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"token_type_ids"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"token_type_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -422,7 +422,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
...
@@ -422,7 +422,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"token_type_ids"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"token_type_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -456,7 +456,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
...
@@ -456,7 +456,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"token_type_ids"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"token_type_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -472,7 +472,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
...
@@ -472,7 +472,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"token_type_ids"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"token_type_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
...
tests/models/layoutlmv3/test_modeling_tf_layoutlmv3.py
View file @
5e8c8eb5
...
@@ -320,7 +320,7 @@ class TFLayoutLMv3ModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -320,7 +320,7 @@ class TFLayoutLMv3ModelTest(TFModelTesterMixin, unittest.TestCase):
# The number of elements in the loss should be the same as the number of elements in the label
# The number of elements in the loss should be the same as the number of elements in the label
prepared_for_class
=
self
.
_prepare_for_class
(
inputs_dict
.
copy
(),
model_class
,
return_labels
=
True
)
prepared_for_class
=
self
.
_prepare_for_class
(
inputs_dict
.
copy
(),
model_class
,
return_labels
=
True
)
added_label
=
prepared_for_class
[
added_label
=
prepared_for_class
[
sorted
(
list
(
prepared_for_class
.
keys
()
-
inputs_dict
.
keys
()
)
,
reverse
=
True
)[
0
]
sorted
(
prepared_for_class
.
keys
()
-
inputs_dict
.
keys
(),
reverse
=
True
)[
0
]
]
]
expected_loss_size
=
added_label
.
shape
.
as_list
()[:
1
]
expected_loss_size
=
added_label
.
shape
.
as_list
()[:
1
]
...
...
tests/models/layoutlmv3/test_processor_layoutlmv3.py
View file @
5e8c8eb5
...
@@ -213,7 +213,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
...
@@ -213,7 +213,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"pixel_values"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"pixel_values"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify image
# verify image
...
@@ -235,7 +235,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
...
@@ -235,7 +235,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"pixel_values"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"pixel_values"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify images
# verify images
...
@@ -285,7 +285,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
...
@@ -285,7 +285,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"pixel_values"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"pixel_values"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -324,7 +324,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
...
@@ -324,7 +324,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"labels"
,
"pixel_values"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"labels"
,
"pixel_values"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -346,7 +346,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
...
@@ -346,7 +346,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"labels"
,
"pixel_values"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"labels"
,
"pixel_values"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -387,7 +387,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
...
@@ -387,7 +387,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"pixel_values"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"pixel_values"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -406,7 +406,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
...
@@ -406,7 +406,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"pixel_values"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"pixel_values"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -440,7 +440,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
...
@@ -440,7 +440,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"pixel_values"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"pixel_values"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -456,7 +456,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
...
@@ -456,7 +456,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"pixel_values"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"pixel_values"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
...
tests/models/layoutxlm/test_processor_layoutxlm.py
View file @
5e8c8eb5
...
@@ -228,7 +228,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
...
@@ -228,7 +228,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify image
# verify image
...
@@ -250,7 +250,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
...
@@ -250,7 +250,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify images
# verify images
...
@@ -300,7 +300,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
...
@@ -300,7 +300,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -339,7 +339,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
...
@@ -339,7 +339,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"labels"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"labels"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -361,7 +361,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
...
@@ -361,7 +361,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"labels"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"labels"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -402,7 +402,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
...
@@ -402,7 +402,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -421,7 +421,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
...
@@ -421,7 +421,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -455,7 +455,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
...
@@ -455,7 +455,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -471,7 +471,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
...
@@ -471,7 +471,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
]
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
...
tests/models/markuplm/test_processor_markuplm.py
View file @
5e8c8eb5
...
@@ -204,7 +204,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
...
@@ -204,7 +204,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
,
"xpath_subs_seq"
,
"xpath_tags_seq"
]
expected_keys
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
,
"xpath_subs_seq"
,
"xpath_tags_seq"
]
actual_keys
=
sorted
(
list
(
inputs
.
keys
())
)
actual_keys
=
sorted
(
inputs
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -216,7 +216,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
...
@@ -216,7 +216,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
,
"xpath_subs_seq"
,
"xpath_tags_seq"
]
expected_keys
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
,
"xpath_subs_seq"
,
"xpath_tags_seq"
]
actual_keys
=
sorted
(
list
(
inputs
.
keys
())
)
actual_keys
=
sorted
(
inputs
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -260,7 +260,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
...
@@ -260,7 +260,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
,
"xpath_subs_seq"
,
"xpath_tags_seq"
]
expected_keys
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
,
"xpath_subs_seq"
,
"xpath_tags_seq"
]
actual_keys
=
sorted
(
list
(
inputs
.
keys
())
)
actual_keys
=
sorted
(
inputs
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -294,7 +294,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
...
@@ -294,7 +294,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
"xpath_subs_seq"
,
"xpath_subs_seq"
,
"xpath_tags_seq"
,
"xpath_tags_seq"
,
]
]
actual_keys
=
sorted
(
list
(
inputs
.
keys
())
)
actual_keys
=
sorted
(
inputs
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -331,7 +331,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
...
@@ -331,7 +331,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
"xpath_subs_seq"
,
"xpath_subs_seq"
,
"xpath_tags_seq"
,
"xpath_tags_seq"
,
]
]
actual_keys
=
sorted
(
list
(
inputs
.
keys
())
)
actual_keys
=
sorted
(
inputs
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -367,7 +367,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
...
@@ -367,7 +367,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
,
"xpath_subs_seq"
,
"xpath_tags_seq"
]
expected_keys
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
,
"xpath_subs_seq"
,
"xpath_tags_seq"
]
actual_keys
=
sorted
(
list
(
inputs
.
keys
())
)
actual_keys
=
sorted
(
inputs
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -390,7 +390,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
...
@@ -390,7 +390,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
,
"xpath_subs_seq"
,
"xpath_tags_seq"
]
expected_keys
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
,
"xpath_subs_seq"
,
"xpath_tags_seq"
]
actual_keys
=
sorted
(
list
(
inputs
.
keys
())
)
actual_keys
=
sorted
(
inputs
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -425,7 +425,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
...
@@ -425,7 +425,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
,
"xpath_subs_seq"
,
"xpath_tags_seq"
]
expected_keys
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
,
"xpath_subs_seq"
,
"xpath_tags_seq"
]
actual_keys
=
sorted
(
list
(
inputs
.
keys
())
)
actual_keys
=
sorted
(
inputs
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
@@ -444,7 +444,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
...
@@ -444,7 +444,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
# verify keys
expected_keys
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
,
"xpath_subs_seq"
,
"xpath_tags_seq"
]
expected_keys
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
,
"xpath_subs_seq"
,
"xpath_tags_seq"
]
actual_keys
=
sorted
(
list
(
inputs
.
keys
())
)
actual_keys
=
sorted
(
inputs
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
# verify input_ids
...
...
tests/models/mobilevit/test_modeling_tf_mobilevit.py
View file @
5e8c8eb5
...
@@ -295,7 +295,7 @@ class MobileViTModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -295,7 +295,7 @@ class MobileViTModelTest(TFModelTesterMixin, unittest.TestCase):
# The number of elements in the loss should be the same as the number of elements in the label
# The number of elements in the loss should be the same as the number of elements in the label
prepared_for_class
=
self
.
_prepare_for_class
(
inputs_dict
.
copy
(),
model_class
,
return_labels
=
True
)
prepared_for_class
=
self
.
_prepare_for_class
(
inputs_dict
.
copy
(),
model_class
,
return_labels
=
True
)
added_label
=
prepared_for_class
[
added_label
=
prepared_for_class
[
sorted
(
list
(
prepared_for_class
.
keys
()
-
inputs_dict
.
keys
()
)
,
reverse
=
True
)[
0
]
sorted
(
prepared_for_class
.
keys
()
-
inputs_dict
.
keys
(),
reverse
=
True
)[
0
]
]
]
expected_loss_size
=
added_label
.
shape
.
as_list
()[:
1
]
expected_loss_size
=
added_label
.
shape
.
as_list
()[:
1
]
...
...
tests/models/perceiver/test_modeling_perceiver.py
View file @
5e8c8eb5
...
@@ -166,9 +166,11 @@ class PerceiverModelTester:
...
@@ -166,9 +166,11 @@ class PerceiverModelTester:
audio
=
torch
.
randn
(
audio
=
torch
.
randn
(
(
self
.
batch_size
,
self
.
num_frames
*
self
.
audio_samples_per_frame
,
1
),
device
=
torch_device
(
self
.
batch_size
,
self
.
num_frames
*
self
.
audio_samples_per_frame
,
1
),
device
=
torch_device
)
)
inputs
=
dict
(
inputs
=
{
image
=
images
,
audio
=
audio
,
label
=
torch
.
zeros
((
self
.
batch_size
,
self
.
num_labels
),
device
=
torch_device
)
"image"
:
images
,
)
"audio"
:
audio
,
"label"
:
torch
.
zeros
((
self
.
batch_size
,
self
.
num_labels
),
device
=
torch_device
),
}
else
:
else
:
raise
ValueError
(
f
"Model class
{
model_class
}
not supported"
)
raise
ValueError
(
f
"Model class
{
model_class
}
not supported"
)
...
@@ -734,7 +736,7 @@ class PerceiverModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -734,7 +736,7 @@ class PerceiverModelTest(ModelTesterMixin, unittest.TestCase):
continue
continue
config
,
inputs
,
input_mask
,
_
,
_
=
self
.
model_tester
.
prepare_config_and_inputs
(
model_class
=
model_class
)
config
,
inputs
,
input_mask
,
_
,
_
=
self
.
model_tester
.
prepare_config_and_inputs
(
model_class
=
model_class
)
inputs_dict
=
dict
(
inputs
=
inputs
,
attention_mask
=
input_mask
)
inputs_dict
=
{
"
inputs
"
:
inputs
,
"
attention_mask
"
:
input_mask
}
for
problem_type
in
problem_types
:
for
problem_type
in
problem_types
:
with
self
.
subTest
(
msg
=
f
"Testing
{
model_class
}
with
{
problem_type
[
'title'
]
}
"
):
with
self
.
subTest
(
msg
=
f
"Testing
{
model_class
}
with
{
problem_type
[
'title'
]
}
"
):
...
...
tests/models/roc_bert/test_tokenization_roc_bert.py
View file @
5e8c8eb5
...
@@ -44,8 +44,8 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
...
@@ -44,8 +44,8 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
super
().
setUp
()
super
().
setUp
()
vocab_tokens
=
[
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"[PAD]"
,
"[MASK]"
,
"你"
,
"好"
,
"是"
,
"谁"
,
"a"
,
"b"
,
"c"
,
"d"
]
vocab_tokens
=
[
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"[PAD]"
,
"[MASK]"
,
"你"
,
"好"
,
"是"
,
"谁"
,
"a"
,
"b"
,
"c"
,
"d"
]
word_shape
=
dict
()
word_shape
=
{}
word_pronunciation
=
dict
()
word_pronunciation
=
{}
for
i
,
value
in
enumerate
(
vocab_tokens
):
for
i
,
value
in
enumerate
(
vocab_tokens
):
word_shape
[
value
]
=
i
word_shape
[
value
]
=
i
word_pronunciation
[
value
]
=
i
word_pronunciation
[
value
]
=
i
...
...
Prev
1
…
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment