Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5e8c8eb5
"docs/zh_cn/advanced_guides/customize_runtime.md" did not exist on "2f6baaee5db2641711a85f745e9e0a57a4049a1f"
Unverified
Commit
5e8c8eb5
authored
Feb 22, 2023
by
Aaron Gokaslan
Committed by
GitHub
Feb 22, 2023
Browse files
Apply ruff flake8-comprehensions (#21694)
parent
df06fb1f
Changes
230
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
99 additions
and
97 deletions
+99
-97
tests/deepspeed/test_model_zoo.py
tests/deepspeed/test_model_zoo.py
+18
-18
tests/extended/test_trainer_ext.py
tests/extended/test_trainer_ext.py
+7
-7
tests/generation/test_utils.py
tests/generation/test_utils.py
+1
-1
tests/models/bart/test_modeling_bart.py
tests/models/bart/test_modeling_bart.py
+1
-1
tests/models/blenderbot/test_modeling_blenderbot.py
tests/models/blenderbot/test_modeling_blenderbot.py
+2
-2
tests/models/blenderbot/test_modeling_flax_blenderbot.py
tests/models/blenderbot/test_modeling_flax_blenderbot.py
+2
-2
tests/models/bloom/test_tokenization_bloom.py
tests/models/bloom/test_tokenization_bloom.py
+1
-1
tests/models/clip/test_modeling_tf_clip.py
tests/models/clip/test_modeling_tf_clip.py
+2
-2
tests/models/data2vec/test_modeling_tf_data2vec_vision.py
tests/models/data2vec/test_modeling_tf_data2vec_vision.py
+1
-1
tests/models/groupvit/test_modeling_tf_groupvit.py
tests/models/groupvit/test_modeling_tf_groupvit.py
+2
-2
tests/models/jukebox/test_modeling_jukebox.py
tests/models/jukebox/test_modeling_jukebox.py
+11
-11
tests/models/jukebox/test_tokenization_jukebox.py
tests/models/jukebox/test_tokenization_jukebox.py
+5
-5
tests/models/layoutlmv2/test_processor_layoutlmv2.py
tests/models/layoutlmv2/test_processor_layoutlmv2.py
+9
-9
tests/models/layoutlmv3/test_modeling_tf_layoutlmv3.py
tests/models/layoutlmv3/test_modeling_tf_layoutlmv3.py
+1
-1
tests/models/layoutlmv3/test_processor_layoutlmv3.py
tests/models/layoutlmv3/test_processor_layoutlmv3.py
+9
-9
tests/models/layoutxlm/test_processor_layoutxlm.py
tests/models/layoutxlm/test_processor_layoutxlm.py
+9
-9
tests/models/markuplm/test_processor_markuplm.py
tests/models/markuplm/test_processor_markuplm.py
+9
-9
tests/models/mobilevit/test_modeling_tf_mobilevit.py
tests/models/mobilevit/test_modeling_tf_mobilevit.py
+1
-1
tests/models/perceiver/test_modeling_perceiver.py
tests/models/perceiver/test_modeling_perceiver.py
+6
-4
tests/models/roc_bert/test_tokenization_roc_bert.py
tests/models/roc_bert/test_tokenization_roc_bert.py
+2
-2
No files found.
tests/deepspeed/test_model_zoo.py
View file @
5e8c8eb5
...
...
@@ -166,8 +166,8 @@ def make_task_cmds():
# but need a tiny model for each
#
# should have "{model_type.upper()}_TINY" corresponding vars defined, e.g., T5_TINY, etc.
tasks2models
=
dict
(
trans
=
[
tasks2models
=
{
"
trans
"
:
[
"bart"
,
"fsmt"
,
"m2m_100"
,
...
...
@@ -177,10 +177,10 @@ def make_task_cmds():
"t5_v1"
,
# "mt5", missing model files
],
sum
=
[
"
sum
"
:
[
"pegasus"
,
],
clm
=
[
"
clm
"
:
[
"big_bird"
,
"bigbird_pegasus"
,
"blenderbot"
,
...
...
@@ -192,7 +192,7 @@ def make_task_cmds():
"prophetnet"
,
# "camembert", missing model files
],
mlm
=
[
"
mlm
"
:
[
"albert"
,
"deberta"
,
"deberta-v2"
,
...
...
@@ -203,7 +203,7 @@ def make_task_cmds():
"layoutlm"
,
# "reformer", # multiple issues with either mlm/qa/clas
],
qa
=
[
"
qa
"
:
[
"led"
,
"longformer"
,
"mobilebert"
,
...
...
@@ -213,7 +213,7 @@ def make_task_cmds():
# "convbert", # missing tokenizer files
# "layoutlmv2", missing model files
],
clas
=
[
"
clas
"
:
[
"bert"
,
"xlnet"
,
# "hubert", # missing tokenizer files
...
...
@@ -223,54 +223,54 @@ def make_task_cmds():
# "openai-gpt", missing model files
# "tapas", multiple issues
],
img_clas
=
[
"
img_clas
"
:
[
"vit"
,
],
)
}
scripts_dir
=
f
"
{
ROOT_DIRECTORY
}
/examples/pytorch"
tasks
=
dict
(
trans
=
f
"""
tasks
=
{
"
trans
"
:
f
"""
{
scripts_dir
}
/translation/run_translation.py
--train_file
{
data_dir_wmt
}
/train.json
--source_lang en
--target_lang ro
"""
,
sum
=
f
"""
"
sum
"
:
f
"""
{
scripts_dir
}
/summarization/run_summarization.py
--train_file
{
data_dir_xsum
}
/sample.json
--max_source_length 12
--max_target_length 12
--lang en
"""
,
clm
=
f
"""
"
clm
"
:
f
"""
{
scripts_dir
}
/language-modeling/run_clm.py
--train_file
{
FIXTURE_DIRECTORY
}
/sample_text.txt
--block_size 8
"""
,
mlm
=
f
"""
"
mlm
"
:
f
"""
{
scripts_dir
}
/language-modeling/run_mlm.py
--train_file
{
FIXTURE_DIRECTORY
}
/sample_text.txt
"""
,
qa
=
f
"""
"
qa
"
:
f
"""
{
scripts_dir
}
/question-answering/run_qa.py
--train_file
{
data_dir_samples
}
/SQUAD/sample.json
"""
,
clas
=
f
"""
"
clas
"
:
f
"""
{
scripts_dir
}
/text-classification/run_glue.py
--train_file
{
data_dir_samples
}
/MRPC/train.csv
--max_seq_length 12
--task_name MRPC
"""
,
img_clas
=
f
"""
"
img_clas
"
:
f
"""
{
scripts_dir
}
/image-classification/run_image_classification.py
--dataset_name hf-internal-testing/cats_vs_dogs_sample
--remove_unused_columns False
--max_steps 10
--image_processor_name
{
DS_TESTS_DIRECTORY
}
/vit_feature_extractor.json
"""
,
)
}
launcher
=
get_launcher
(
distributed
=
True
)
...
...
tests/extended/test_trainer_ext.py
View file @
5e8c8eb5
...
...
@@ -155,21 +155,21 @@ class TestTrainerExt(TestCasePlus):
@
require_torch_multi_gpu
def
test_trainer_log_level_replica
(
self
,
experiment_id
):
# as each sub-test is slow-ish split into multiple sub-tests to avoid CI timeout
experiments
=
dict
(
experiments
=
{
# test with the default log_level - should be info and thus log info once
base
=
dict
(
extra_args_str
=
""
,
n_matches
=
1
)
,
"
base
"
:
{
"
extra_args_str
"
:
""
,
"
n_matches
"
:
1
}
,
# test with low log_level and log_level_replica - should be noisy on all processes
# now the info string should appear twice on 2 processes
low
=
dict
(
extra_args_str
=
"--log_level debug --log_level_replica debug"
,
n_matches
=
2
)
,
"
low
"
:
{
"
extra_args_str
"
:
"--log_level debug --log_level_replica debug"
,
"
n_matches
"
:
2
}
,
# test with high log_level and low log_level_replica
# now the info string should appear once only on the replica
high
=
dict
(
extra_args_str
=
"--log_level error --log_level_replica debug"
,
n_matches
=
1
)
,
"
high
"
:
{
"
extra_args_str
"
:
"--log_level error --log_level_replica debug"
,
"
n_matches
"
:
1
}
,
# test with high log_level and log_level_replica - should be quiet on all processes
mixed
=
dict
(
extra_args_str
=
"--log_level error --log_level_replica error"
,
n_matches
=
0
)
,
)
"
mixed
"
:
{
"
extra_args_str
"
:
"--log_level error --log_level_replica error"
,
"
n_matches
"
:
0
}
,
}
data
=
experiments
[
experiment_id
]
kwargs
=
dict
(
distributed
=
True
,
predict_with_generate
=
False
,
do_eval
=
False
,
do_predict
=
False
)
kwargs
=
{
"
distributed
"
:
True
,
"
predict_with_generate
"
:
False
,
"
do_eval
"
:
False
,
"
do_predict
"
:
False
}
log_info_string
=
"Running training"
with
CaptureStderr
()
as
cl
:
self
.
run_seq2seq_quick
(
**
kwargs
,
extra_args_str
=
data
[
"extra_args_str"
])
...
...
tests/generation/test_utils.py
View file @
5e8c8eb5
...
...
@@ -1480,7 +1480,7 @@ class GenerationTesterMixin:
signature
=
inspect
.
signature
(
model
.
forward
)
# We want to test only models where encoder/decoder head masking is implemented
if
not
set
(
head_masking
.
keys
())
<
set
([
*
signature
.
parameters
.
keys
()
])
:
if
not
set
(
head_masking
.
keys
())
<
{
*
signature
.
parameters
.
keys
()
}
:
continue
for
attn_name
,
(
name
,
mask
)
in
zip
(
attention_names
,
head_masking
.
items
()):
...
...
tests/models/bart/test_modeling_bart.py
View file @
5e8c8eb5
...
...
@@ -939,7 +939,7 @@ class BartModelIntegrationTests(unittest.TestCase):
def
test_xsum_config_generation_params
(
self
):
config
=
BartConfig
.
from_pretrained
(
"facebook/bart-large-xsum"
)
expected_params
=
dict
(
num_beams
=
6
,
do_sample
=
False
,
early_stopping
=
True
,
length_penalty
=
1.0
)
expected_params
=
{
"
num_beams
"
:
6
,
"
do_sample
"
:
False
,
"
early_stopping
"
:
True
,
"
length_penalty
"
:
1.0
}
config_params
=
{
k
:
getattr
(
config
,
k
,
"MISSING"
)
for
k
,
v
in
expected_params
.
items
()}
self
.
assertDictEqual
(
expected_params
,
config_params
)
...
...
tests/models/blenderbot/test_modeling_blenderbot.py
View file @
5e8c8eb5
...
...
@@ -299,8 +299,8 @@ class Blenderbot3BIntegrationTests(unittest.TestCase):
@
slow
def
test_generation_from_short_input_same_as_parlai_3B
(
self
):
FASTER_GEN_KWARGS
=
dict
(
num_beams
=
1
,
early_stopping
=
True
,
min_length
=
15
,
max_length
=
25
)
TOK_DECODE_KW
=
dict
(
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
True
)
FASTER_GEN_KWARGS
=
{
"
num_beams
"
:
1
,
"
early_stopping
"
:
True
,
"
min_length
"
:
15
,
"
max_length
"
:
25
}
TOK_DECODE_KW
=
{
"
skip_special_tokens
"
:
True
,
"
clean_up_tokenization_spaces
"
:
True
}
torch
.
cuda
.
empty_cache
()
model
=
BlenderbotForConditionalGeneration
.
from_pretrained
(
self
.
ckpt
).
half
().
to
(
torch_device
)
...
...
tests/models/blenderbot/test_modeling_flax_blenderbot.py
View file @
5e8c8eb5
...
...
@@ -402,8 +402,8 @@ class FlaxBlenderbotModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGener
@
unittest
.
skipUnless
(
jax_device
!=
"cpu"
,
"3B test too slow on CPU."
)
@
slow
def
test_generation_from_short_input_same_as_parlai_3B
(
self
):
FASTER_GEN_KWARGS
=
dict
(
num_beams
=
1
,
early_stopping
=
True
,
min_length
=
15
,
max_length
=
25
)
TOK_DECODE_KW
=
dict
(
skip_special_tokens
=
True
,
clean_up_tokenization_spaces
=
True
)
FASTER_GEN_KWARGS
=
{
"
num_beams
"
:
1
,
"
early_stopping
"
:
True
,
"
min_length
"
:
15
,
"
max_length
"
:
25
}
TOK_DECODE_KW
=
{
"
skip_special_tokens
"
:
True
,
"
clean_up_tokenization_spaces
"
:
True
}
model
=
FlaxBlenderbotForConditionalGeneration
.
from_pretrained
(
"facebook/blenderbot-3B"
,
from_pt
=
True
)
tokenizer
=
BlenderbotTokenizer
.
from_pretrained
(
"facebook/blenderbot-3B"
)
...
...
tests/models/bloom/test_tokenization_bloom.py
View file @
5e8c8eb5
...
...
@@ -124,7 +124,7 @@ class BloomTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
input_text
=
list
(
sample_data
.
values
())
output_tokens
=
list
(
map
(
tokenizer
.
encode
,
input_text
))
predicted_text
=
list
(
map
(
lambda
x
:
tokenizer
.
decode
(
x
,
clean_up_tokenization_spaces
=
False
)
,
output_tokens
))
predicted_text
=
[
tokenizer
.
decode
(
x
,
clean_up_tokenization_spaces
=
False
)
for
x
in
output_tokens
]
self
.
assertListEqual
(
predicted_text
,
input_text
)
def
test_pretrained_model_lists
(
self
):
...
...
tests/models/clip/test_modeling_tf_clip.py
View file @
5e8c8eb5
...
...
@@ -551,7 +551,7 @@ class TFCLIPModelTest(TFModelTesterMixin, unittest.TestCase):
if
self
.
__class__
.
__name__
==
"TFCLIPModelTest"
:
inputs_dict
.
pop
(
"return_loss"
,
None
)
tf_main_layer_classes
=
set
(
tf_main_layer_classes
=
{
module_member
for
model_class
in
self
.
all_model_classes
for
module
in
(
import_module
(
model_class
.
__module__
),)
...
...
@@ -563,7 +563,7 @@ class TFCLIPModelTest(TFModelTesterMixin, unittest.TestCase):
if
isinstance
(
module_member
,
type
)
and
tf
.
keras
.
layers
.
Layer
in
module_member
.
__bases__
and
getattr
(
module_member
,
"_keras_serializable"
,
False
)
)
}
for
main_layer_class
in
tf_main_layer_classes
:
# T5MainLayer needs an embed_tokens parameter when called without the inputs_embeds parameter
if
"T5"
in
main_layer_class
.
__name__
:
...
...
tests/models/data2vec/test_modeling_tf_data2vec_vision.py
View file @
5e8c8eb5
...
...
@@ -398,7 +398,7 @@ class TFData2VecVisionModelTest(TFModelTesterMixin, unittest.TestCase):
# The number of elements in the loss should be the same as the number of elements in the label
_
,
prepared_for_class
=
self
.
model_tester
.
prepare_config_and_inputs_for_keras_fit
()
added_label
=
prepared_for_class
[
sorted
(
list
(
prepared_for_class
.
keys
()
-
inputs_dict
.
keys
()
)
,
reverse
=
True
)[
0
]
sorted
(
prepared_for_class
.
keys
()
-
inputs_dict
.
keys
(),
reverse
=
True
)[
0
]
]
loss_size
=
tf
.
size
(
added_label
)
...
...
tests/models/groupvit/test_modeling_tf_groupvit.py
View file @
5e8c8eb5
...
...
@@ -628,7 +628,7 @@ class TFGroupViTModelTest(TFModelTesterMixin, unittest.TestCase):
if
self
.
__class__
.
__name__
==
"TFGroupViTModelTest"
:
inputs_dict
.
pop
(
"return_loss"
,
None
)
tf_main_layer_classes
=
set
(
tf_main_layer_classes
=
{
module_member
for
model_class
in
self
.
all_model_classes
for
module
in
(
import_module
(
model_class
.
__module__
),)
...
...
@@ -640,7 +640,7 @@ class TFGroupViTModelTest(TFModelTesterMixin, unittest.TestCase):
if
isinstance
(
module_member
,
type
)
and
tf
.
keras
.
layers
.
Layer
in
module_member
.
__bases__
and
getattr
(
module_member
,
"_keras_serializable"
,
False
)
)
}
for
main_layer_class
in
tf_main_layer_classes
:
# T5MainLayer needs an embed_tokens parameter when called without the inputs_embeds parameter
if
"T5"
in
main_layer_class
.
__name__
:
...
...
tests/models/jukebox/test_modeling_jukebox.py
View file @
5e8c8eb5
...
...
@@ -30,10 +30,10 @@ if is_torch_available():
class
Jukebox1bModelTester
(
unittest
.
TestCase
):
all_model_classes
=
(
JukeboxModel
,)
if
is_torch_available
()
else
()
model_id
=
"openai/jukebox-1b-lyrics"
metas
=
dict
(
artist
=
"Zac Brown Band"
,
genres
=
"Country"
,
lyrics
=
"""I met a traveller from an antique land,
metas
=
{
"
artist
"
:
"Zac Brown Band"
,
"
genres
"
:
"Country"
,
"
lyrics
"
:
"""I met a traveller from an antique land,
Who said "Two vast and trunkless legs of stone
Stand in the desert. . . . Near them, on the sand,
Half sunk a shattered visage lies, whose frown,
...
...
@@ -48,7 +48,7 @@ class Jukebox1bModelTester(unittest.TestCase):
Of that colossal Wreck, boundless and bare
The lone and level sands stretch far away
"""
,
)
}
# fmt: off
EXPECTED_OUTPUT_2
=
[
1864
,
1536
,
1213
,
1870
,
1357
,
1536
,
519
,
880
,
1323
,
789
,
1082
,
534
,
...
...
@@ -180,7 +180,7 @@ class Jukebox1bModelTester(unittest.TestCase):
model
=
JukeboxModel
.
from_pretrained
(
self
.
model_id
,
min_duration
=
0
).
eval
()
set_seed
(
0
)
waveform
=
torch
.
rand
((
1
,
5120
,
1
))
tokens
=
[
i
for
i
in
self
.
prepare_inputs
()
]
tokens
=
list
(
self
.
prepare_inputs
()
)
zs
=
[
model
.
vqvae
.
encode
(
waveform
,
start_level
=
2
,
bs_chunks
=
waveform
.
shape
[
0
])[
0
],
None
,
None
]
zs
=
model
.
_sample
(
...
...
@@ -220,10 +220,10 @@ class Jukebox1bModelTester(unittest.TestCase):
class
Jukebox5bModelTester
(
unittest
.
TestCase
):
all_model_classes
=
(
JukeboxModel
,)
if
is_torch_available
()
else
()
model_id
=
"openai/jukebox-5b-lyrics"
metas
=
dict
(
artist
=
"Zac Brown Band"
,
genres
=
"Country"
,
lyrics
=
"""I met a traveller from an antique land,
metas
=
{
"
artist
"
:
"Zac Brown Band"
,
"
genres
"
:
"Country"
,
"
lyrics
"
:
"""I met a traveller from an antique land,
Who said "Two vast and trunkless legs of stone
Stand in the desert. . . . Near them, on the sand,
Half sunk a shattered visage lies, whose frown,
...
...
@@ -238,7 +238,7 @@ class Jukebox5bModelTester(unittest.TestCase):
Of that colossal Wreck, boundless and bare
The lone and level sands stretch far away
"""
,
)
}
# fmt: off
EXPECTED_OUTPUT_2
=
[
...
...
tests/models/jukebox/test_tokenization_jukebox.py
View file @
5e8c8eb5
...
...
@@ -21,10 +21,10 @@ from transformers.testing_utils import require_torch
class
JukeboxTokenizationTest
(
unittest
.
TestCase
):
tokenizer_class
=
JukeboxTokenizer
metas
=
dict
(
artist
=
"Zac Brown Band"
,
genres
=
"Country"
,
lyrics
=
"""I met a traveller from an antique land,
metas
=
{
"
artist
"
:
"Zac Brown Band"
,
"
genres
"
:
"Country"
,
"
lyrics
"
:
"""I met a traveller from an antique land,
Who said "Two vast and trunkless legs of stone
Stand in the desert. . . . Near them, on the sand,
Half sunk a shattered visage lies, whose frown,
...
...
@@ -39,7 +39,7 @@ class JukeboxTokenizationTest(unittest.TestCase):
Of that colossal Wreck, boundless and bare
The lone and level sands stretch far away
"""
,
)
}
@
require_torch
def
test_1b_lyrics_tokenizer
(
self
):
...
...
tests/models/layoutlmv2/test_processor_layoutlmv2.py
View file @
5e8c8eb5
...
...
@@ -233,7 +233,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"token_type_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify image
...
...
@@ -253,7 +253,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"token_type_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify images
...
...
@@ -301,7 +301,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"token_type_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -340,7 +340,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"labels"
,
"token_type_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -362,7 +362,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"labels"
,
"token_type_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -403,7 +403,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"token_type_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -422,7 +422,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"token_type_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -456,7 +456,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"token_type_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -472,7 +472,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"token_type_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
tests/models/layoutlmv3/test_modeling_tf_layoutlmv3.py
View file @
5e8c8eb5
...
...
@@ -320,7 +320,7 @@ class TFLayoutLMv3ModelTest(TFModelTesterMixin, unittest.TestCase):
# The number of elements in the loss should be the same as the number of elements in the label
prepared_for_class
=
self
.
_prepare_for_class
(
inputs_dict
.
copy
(),
model_class
,
return_labels
=
True
)
added_label
=
prepared_for_class
[
sorted
(
list
(
prepared_for_class
.
keys
()
-
inputs_dict
.
keys
()
)
,
reverse
=
True
)[
0
]
sorted
(
prepared_for_class
.
keys
()
-
inputs_dict
.
keys
(),
reverse
=
True
)[
0
]
]
expected_loss_size
=
added_label
.
shape
.
as_list
()[:
1
]
...
...
tests/models/layoutlmv3/test_processor_layoutlmv3.py
View file @
5e8c8eb5
...
...
@@ -213,7 +213,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"pixel_values"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify image
...
...
@@ -235,7 +235,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"pixel_values"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify images
...
...
@@ -285,7 +285,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"pixel_values"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -324,7 +324,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"labels"
,
"pixel_values"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -346,7 +346,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"labels"
,
"pixel_values"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -387,7 +387,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"pixel_values"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -406,7 +406,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"pixel_values"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -440,7 +440,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"pixel_values"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -456,7 +456,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"input_ids"
,
"pixel_values"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
tests/models/layoutxlm/test_processor_layoutxlm.py
View file @
5e8c8eb5
...
...
@@ -228,7 +228,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify image
...
...
@@ -250,7 +250,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify images
...
...
@@ -300,7 +300,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -339,7 +339,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"labels"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -361,7 +361,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
,
"labels"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -402,7 +402,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -421,7 +421,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -455,7 +455,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -471,7 +471,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"bbox"
,
"image"
,
"input_ids"
]
actual_keys
=
sorted
(
list
(
input_processor
.
keys
())
)
actual_keys
=
sorted
(
input_processor
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
tests/models/markuplm/test_processor_markuplm.py
View file @
5e8c8eb5
...
...
@@ -204,7 +204,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
,
"xpath_subs_seq"
,
"xpath_tags_seq"
]
actual_keys
=
sorted
(
list
(
inputs
.
keys
())
)
actual_keys
=
sorted
(
inputs
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -216,7 +216,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
,
"xpath_subs_seq"
,
"xpath_tags_seq"
]
actual_keys
=
sorted
(
list
(
inputs
.
keys
())
)
actual_keys
=
sorted
(
inputs
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -260,7 +260,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
,
"xpath_subs_seq"
,
"xpath_tags_seq"
]
actual_keys
=
sorted
(
list
(
inputs
.
keys
())
)
actual_keys
=
sorted
(
inputs
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -294,7 +294,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
"xpath_subs_seq"
,
"xpath_tags_seq"
,
]
actual_keys
=
sorted
(
list
(
inputs
.
keys
())
)
actual_keys
=
sorted
(
inputs
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -331,7 +331,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
"xpath_subs_seq"
,
"xpath_tags_seq"
,
]
actual_keys
=
sorted
(
list
(
inputs
.
keys
())
)
actual_keys
=
sorted
(
inputs
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -367,7 +367,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
,
"xpath_subs_seq"
,
"xpath_tags_seq"
]
actual_keys
=
sorted
(
list
(
inputs
.
keys
())
)
actual_keys
=
sorted
(
inputs
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -390,7 +390,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
,
"xpath_subs_seq"
,
"xpath_tags_seq"
]
actual_keys
=
sorted
(
list
(
inputs
.
keys
())
)
actual_keys
=
sorted
(
inputs
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -425,7 +425,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
,
"xpath_subs_seq"
,
"xpath_tags_seq"
]
actual_keys
=
sorted
(
list
(
inputs
.
keys
())
)
actual_keys
=
sorted
(
inputs
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
@@ -444,7 +444,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys
=
[
"attention_mask"
,
"input_ids"
,
"token_type_ids"
,
"xpath_subs_seq"
,
"xpath_tags_seq"
]
actual_keys
=
sorted
(
list
(
inputs
.
keys
())
)
actual_keys
=
sorted
(
inputs
.
keys
())
self
.
assertListEqual
(
actual_keys
,
expected_keys
)
# verify input_ids
...
...
tests/models/mobilevit/test_modeling_tf_mobilevit.py
View file @
5e8c8eb5
...
...
@@ -295,7 +295,7 @@ class MobileViTModelTest(TFModelTesterMixin, unittest.TestCase):
# The number of elements in the loss should be the same as the number of elements in the label
prepared_for_class
=
self
.
_prepare_for_class
(
inputs_dict
.
copy
(),
model_class
,
return_labels
=
True
)
added_label
=
prepared_for_class
[
sorted
(
list
(
prepared_for_class
.
keys
()
-
inputs_dict
.
keys
()
)
,
reverse
=
True
)[
0
]
sorted
(
prepared_for_class
.
keys
()
-
inputs_dict
.
keys
(),
reverse
=
True
)[
0
]
]
expected_loss_size
=
added_label
.
shape
.
as_list
()[:
1
]
...
...
tests/models/perceiver/test_modeling_perceiver.py
View file @
5e8c8eb5
...
...
@@ -166,9 +166,11 @@ class PerceiverModelTester:
audio
=
torch
.
randn
(
(
self
.
batch_size
,
self
.
num_frames
*
self
.
audio_samples_per_frame
,
1
),
device
=
torch_device
)
inputs
=
dict
(
image
=
images
,
audio
=
audio
,
label
=
torch
.
zeros
((
self
.
batch_size
,
self
.
num_labels
),
device
=
torch_device
)
)
inputs
=
{
"image"
:
images
,
"audio"
:
audio
,
"label"
:
torch
.
zeros
((
self
.
batch_size
,
self
.
num_labels
),
device
=
torch_device
),
}
else
:
raise
ValueError
(
f
"Model class
{
model_class
}
not supported"
)
...
...
@@ -734,7 +736,7 @@ class PerceiverModelTest(ModelTesterMixin, unittest.TestCase):
continue
config
,
inputs
,
input_mask
,
_
,
_
=
self
.
model_tester
.
prepare_config_and_inputs
(
model_class
=
model_class
)
inputs_dict
=
dict
(
inputs
=
inputs
,
attention_mask
=
input_mask
)
inputs_dict
=
{
"
inputs
"
:
inputs
,
"
attention_mask
"
:
input_mask
}
for
problem_type
in
problem_types
:
with
self
.
subTest
(
msg
=
f
"Testing
{
model_class
}
with
{
problem_type
[
'title'
]
}
"
):
...
...
tests/models/roc_bert/test_tokenization_roc_bert.py
View file @
5e8c8eb5
...
...
@@ -44,8 +44,8 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
super
().
setUp
()
vocab_tokens
=
[
"[UNK]"
,
"[CLS]"
,
"[SEP]"
,
"[PAD]"
,
"[MASK]"
,
"你"
,
"好"
,
"是"
,
"谁"
,
"a"
,
"b"
,
"c"
,
"d"
]
word_shape
=
dict
()
word_pronunciation
=
dict
()
word_shape
=
{}
word_pronunciation
=
{}
for
i
,
value
in
enumerate
(
vocab_tokens
):
word_shape
[
value
]
=
i
word_pronunciation
[
value
]
=
i
...
...
Prev
1
…
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment