Unverified Commit 5e8c8eb5 authored by Aaron Gokaslan's avatar Aaron Gokaslan Committed by GitHub
Browse files

Apply ruff flake8-comprehensions (#21694)

parent df06fb1f
......@@ -166,8 +166,8 @@ def make_task_cmds():
# but need a tiny model for each
#
# should have "{model_type.upper()}_TINY" corresponding vars defined, e.g., T5_TINY, etc.
tasks2models = dict(
trans=[
tasks2models = {
"trans": [
"bart",
"fsmt",
"m2m_100",
......@@ -177,10 +177,10 @@ def make_task_cmds():
"t5_v1",
# "mt5", missing model files
],
sum=[
"sum": [
"pegasus",
],
clm=[
"clm": [
"big_bird",
"bigbird_pegasus",
"blenderbot",
......@@ -192,7 +192,7 @@ def make_task_cmds():
"prophetnet",
# "camembert", missing model files
],
mlm=[
"mlm": [
"albert",
"deberta",
"deberta-v2",
......@@ -203,7 +203,7 @@ def make_task_cmds():
"layoutlm",
# "reformer", # multiple issues with either mlm/qa/clas
],
qa=[
"qa": [
"led",
"longformer",
"mobilebert",
......@@ -213,7 +213,7 @@ def make_task_cmds():
# "convbert", # missing tokenizer files
# "layoutlmv2", missing model files
],
clas=[
"clas": [
"bert",
"xlnet",
# "hubert", # missing tokenizer files
......@@ -223,54 +223,54 @@ def make_task_cmds():
# "openai-gpt", missing model files
# "tapas", multiple issues
],
img_clas=[
"img_clas": [
"vit",
],
)
}
scripts_dir = f"{ROOT_DIRECTORY}/examples/pytorch"
tasks = dict(
trans=f"""
tasks = {
"trans": f"""
{scripts_dir}/translation/run_translation.py
--train_file {data_dir_wmt}/train.json
--source_lang en
--target_lang ro
""",
sum=f"""
"sum": f"""
{scripts_dir}/summarization/run_summarization.py
--train_file {data_dir_xsum}/sample.json
--max_source_length 12
--max_target_length 12
--lang en
""",
clm=f"""
"clm": f"""
{scripts_dir}/language-modeling/run_clm.py
--train_file {FIXTURE_DIRECTORY}/sample_text.txt
--block_size 8
""",
mlm=f"""
"mlm": f"""
{scripts_dir}/language-modeling/run_mlm.py
--train_file {FIXTURE_DIRECTORY}/sample_text.txt
""",
qa=f"""
"qa": f"""
{scripts_dir}/question-answering/run_qa.py
--train_file {data_dir_samples}/SQUAD/sample.json
""",
clas=f"""
"clas": f"""
{scripts_dir}/text-classification/run_glue.py
--train_file {data_dir_samples}/MRPC/train.csv
--max_seq_length 12
--task_name MRPC
""",
img_clas=f"""
"img_clas": f"""
{scripts_dir}/image-classification/run_image_classification.py
--dataset_name hf-internal-testing/cats_vs_dogs_sample
--remove_unused_columns False
--max_steps 10
--image_processor_name {DS_TESTS_DIRECTORY}/vit_feature_extractor.json
""",
)
}
launcher = get_launcher(distributed=True)
......
......@@ -155,21 +155,21 @@ class TestTrainerExt(TestCasePlus):
@require_torch_multi_gpu
def test_trainer_log_level_replica(self, experiment_id):
# as each sub-test is slow-ish split into multiple sub-tests to avoid CI timeout
experiments = dict(
experiments = {
# test with the default log_level - should be info and thus log info once
base=dict(extra_args_str="", n_matches=1),
"base": {"extra_args_str": "", "n_matches": 1},
# test with low log_level and log_level_replica - should be noisy on all processes
# now the info string should appear twice on 2 processes
low=dict(extra_args_str="--log_level debug --log_level_replica debug", n_matches=2),
"low": {"extra_args_str": "--log_level debug --log_level_replica debug", "n_matches": 2},
# test with high log_level and low log_level_replica
# now the info string should appear once only on the replica
high=dict(extra_args_str="--log_level error --log_level_replica debug", n_matches=1),
"high": {"extra_args_str": "--log_level error --log_level_replica debug", "n_matches": 1},
# test with high log_level and log_level_replica - should be quiet on all processes
mixed=dict(extra_args_str="--log_level error --log_level_replica error", n_matches=0),
)
"mixed": {"extra_args_str": "--log_level error --log_level_replica error", "n_matches": 0},
}
data = experiments[experiment_id]
kwargs = dict(distributed=True, predict_with_generate=False, do_eval=False, do_predict=False)
kwargs = {"distributed": True, "predict_with_generate": False, "do_eval": False, "do_predict": False}
log_info_string = "Running training"
with CaptureStderr() as cl:
self.run_seq2seq_quick(**kwargs, extra_args_str=data["extra_args_str"])
......
......@@ -1480,7 +1480,7 @@ class GenerationTesterMixin:
signature = inspect.signature(model.forward)
# We want to test only models where encoder/decoder head masking is implemented
if not set(head_masking.keys()) < set([*signature.parameters.keys()]):
if not set(head_masking.keys()) < {*signature.parameters.keys()}:
continue
for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
......
......@@ -939,7 +939,7 @@ class BartModelIntegrationTests(unittest.TestCase):
def test_xsum_config_generation_params(self):
config = BartConfig.from_pretrained("facebook/bart-large-xsum")
expected_params = dict(num_beams=6, do_sample=False, early_stopping=True, length_penalty=1.0)
expected_params = {"num_beams": 6, "do_sample": False, "early_stopping": True, "length_penalty": 1.0}
config_params = {k: getattr(config, k, "MISSING") for k, v in expected_params.items()}
self.assertDictEqual(expected_params, config_params)
......
......@@ -299,8 +299,8 @@ class Blenderbot3BIntegrationTests(unittest.TestCase):
@slow
def test_generation_from_short_input_same_as_parlai_3B(self):
FASTER_GEN_KWARGS = dict(num_beams=1, early_stopping=True, min_length=15, max_length=25)
TOK_DECODE_KW = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True)
FASTER_GEN_KWARGS = {"num_beams": 1, "early_stopping": True, "min_length": 15, "max_length": 25}
TOK_DECODE_KW = {"skip_special_tokens": True, "clean_up_tokenization_spaces": True}
torch.cuda.empty_cache()
model = BlenderbotForConditionalGeneration.from_pretrained(self.ckpt).half().to(torch_device)
......
......@@ -402,8 +402,8 @@ class FlaxBlenderbotModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGener
@unittest.skipUnless(jax_device != "cpu", "3B test too slow on CPU.")
@slow
def test_generation_from_short_input_same_as_parlai_3B(self):
FASTER_GEN_KWARGS = dict(num_beams=1, early_stopping=True, min_length=15, max_length=25)
TOK_DECODE_KW = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True)
FASTER_GEN_KWARGS = {"num_beams": 1, "early_stopping": True, "min_length": 15, "max_length": 25}
TOK_DECODE_KW = {"skip_special_tokens": True, "clean_up_tokenization_spaces": True}
model = FlaxBlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-3B", from_pt=True)
tokenizer = BlenderbotTokenizer.from_pretrained("facebook/blenderbot-3B")
......
......@@ -124,7 +124,7 @@ class BloomTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
input_text = list(sample_data.values())
output_tokens = list(map(tokenizer.encode, input_text))
predicted_text = list(map(lambda x: tokenizer.decode(x, clean_up_tokenization_spaces=False), output_tokens))
predicted_text = [tokenizer.decode(x, clean_up_tokenization_spaces=False) for x in output_tokens]
self.assertListEqual(predicted_text, input_text)
def test_pretrained_model_lists(self):
......
......@@ -551,7 +551,7 @@ class TFCLIPModelTest(TFModelTesterMixin, unittest.TestCase):
if self.__class__.__name__ == "TFCLIPModelTest":
inputs_dict.pop("return_loss", None)
tf_main_layer_classes = set(
tf_main_layer_classes = {
module_member
for model_class in self.all_model_classes
for module in (import_module(model_class.__module__),)
......@@ -563,7 +563,7 @@ class TFCLIPModelTest(TFModelTesterMixin, unittest.TestCase):
if isinstance(module_member, type)
and tf.keras.layers.Layer in module_member.__bases__
and getattr(module_member, "_keras_serializable", False)
)
}
for main_layer_class in tf_main_layer_classes:
# T5MainLayer needs an embed_tokens parameter when called without the inputs_embeds parameter
if "T5" in main_layer_class.__name__:
......
......@@ -398,7 +398,7 @@ class TFData2VecVisionModelTest(TFModelTesterMixin, unittest.TestCase):
# The number of elements in the loss should be the same as the number of elements in the label
_, prepared_for_class = self.model_tester.prepare_config_and_inputs_for_keras_fit()
added_label = prepared_for_class[
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
sorted(prepared_for_class.keys() - inputs_dict.keys(), reverse=True)[0]
]
loss_size = tf.size(added_label)
......
......@@ -628,7 +628,7 @@ class TFGroupViTModelTest(TFModelTesterMixin, unittest.TestCase):
if self.__class__.__name__ == "TFGroupViTModelTest":
inputs_dict.pop("return_loss", None)
tf_main_layer_classes = set(
tf_main_layer_classes = {
module_member
for model_class in self.all_model_classes
for module in (import_module(model_class.__module__),)
......@@ -640,7 +640,7 @@ class TFGroupViTModelTest(TFModelTesterMixin, unittest.TestCase):
if isinstance(module_member, type)
and tf.keras.layers.Layer in module_member.__bases__
and getattr(module_member, "_keras_serializable", False)
)
}
for main_layer_class in tf_main_layer_classes:
# T5MainLayer needs an embed_tokens parameter when called without the inputs_embeds parameter
if "T5" in main_layer_class.__name__:
......
......@@ -30,10 +30,10 @@ if is_torch_available():
class Jukebox1bModelTester(unittest.TestCase):
all_model_classes = (JukeboxModel,) if is_torch_available() else ()
model_id = "openai/jukebox-1b-lyrics"
metas = dict(
artist="Zac Brown Band",
genres="Country",
lyrics="""I met a traveller from an antique land,
metas = {
"artist": "Zac Brown Band",
"genres": "Country",
"lyrics": """I met a traveller from an antique land,
Who said "Two vast and trunkless legs of stone
Stand in the desert. . . . Near them, on the sand,
Half sunk a shattered visage lies, whose frown,
......@@ -48,7 +48,7 @@ class Jukebox1bModelTester(unittest.TestCase):
Of that colossal Wreck, boundless and bare
The lone and level sands stretch far away
""",
)
}
# fmt: off
EXPECTED_OUTPUT_2 = [
1864, 1536, 1213, 1870, 1357, 1536, 519, 880, 1323, 789, 1082, 534,
......@@ -180,7 +180,7 @@ class Jukebox1bModelTester(unittest.TestCase):
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval()
set_seed(0)
waveform = torch.rand((1, 5120, 1))
tokens = [i for i in self.prepare_inputs()]
tokens = list(self.prepare_inputs())
zs = [model.vqvae.encode(waveform, start_level=2, bs_chunks=waveform.shape[0])[0], None, None]
zs = model._sample(
......@@ -220,10 +220,10 @@ class Jukebox1bModelTester(unittest.TestCase):
class Jukebox5bModelTester(unittest.TestCase):
all_model_classes = (JukeboxModel,) if is_torch_available() else ()
model_id = "openai/jukebox-5b-lyrics"
metas = dict(
artist="Zac Brown Band",
genres="Country",
lyrics="""I met a traveller from an antique land,
metas = {
"artist": "Zac Brown Band",
"genres": "Country",
"lyrics": """I met a traveller from an antique land,
Who said "Two vast and trunkless legs of stone
Stand in the desert. . . . Near them, on the sand,
Half sunk a shattered visage lies, whose frown,
......@@ -238,7 +238,7 @@ class Jukebox5bModelTester(unittest.TestCase):
Of that colossal Wreck, boundless and bare
The lone and level sands stretch far away
""",
)
}
# fmt: off
EXPECTED_OUTPUT_2 = [
......
......@@ -21,10 +21,10 @@ from transformers.testing_utils import require_torch
class JukeboxTokenizationTest(unittest.TestCase):
tokenizer_class = JukeboxTokenizer
metas = dict(
artist="Zac Brown Band",
genres="Country",
lyrics="""I met a traveller from an antique land,
metas = {
"artist": "Zac Brown Band",
"genres": "Country",
"lyrics": """I met a traveller from an antique land,
Who said "Two vast and trunkless legs of stone
Stand in the desert. . . . Near them, on the sand,
Half sunk a shattered visage lies, whose frown,
......@@ -39,7 +39,7 @@ class JukeboxTokenizationTest(unittest.TestCase):
Of that colossal Wreck, boundless and bare
The lone and level sands stretch far away
""",
)
}
@require_torch
def test_1b_lyrics_tokenizer(self):
......
......@@ -233,7 +233,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify image
......@@ -253,7 +253,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify images
......@@ -301,7 +301,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -340,7 +340,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "labels", "token_type_ids"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -362,7 +362,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "labels", "token_type_ids"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -403,7 +403,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -422,7 +422,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -456,7 +456,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -472,7 +472,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......
......@@ -320,7 +320,7 @@ class TFLayoutLMv3ModelTest(TFModelTesterMixin, unittest.TestCase):
# The number of elements in the loss should be the same as the number of elements in the label
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
added_label = prepared_for_class[
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
sorted(prepared_for_class.keys() - inputs_dict.keys(), reverse=True)[0]
]
expected_loss_size = added_label.shape.as_list()[:1]
......
......@@ -213,7 +213,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify image
......@@ -235,7 +235,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify images
......@@ -285,7 +285,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -324,7 +324,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "input_ids", "labels", "pixel_values"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -346,7 +346,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "input_ids", "labels", "pixel_values"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -387,7 +387,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -406,7 +406,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -440,7 +440,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -456,7 +456,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......
......@@ -228,7 +228,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify image
......@@ -250,7 +250,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify images
......@@ -300,7 +300,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -339,7 +339,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "labels"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -361,7 +361,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "labels"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -402,7 +402,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -421,7 +421,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -455,7 +455,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -471,7 +471,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids"]
actual_keys = sorted(list(input_processor.keys()))
actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......
......@@ -204,7 +204,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "input_ids", "token_type_ids", "xpath_subs_seq", "xpath_tags_seq"]
actual_keys = sorted(list(inputs.keys()))
actual_keys = sorted(inputs.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -216,7 +216,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "input_ids", "token_type_ids", "xpath_subs_seq", "xpath_tags_seq"]
actual_keys = sorted(list(inputs.keys()))
actual_keys = sorted(inputs.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -260,7 +260,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "input_ids", "token_type_ids", "xpath_subs_seq", "xpath_tags_seq"]
actual_keys = sorted(list(inputs.keys()))
actual_keys = sorted(inputs.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -294,7 +294,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
"xpath_subs_seq",
"xpath_tags_seq",
]
actual_keys = sorted(list(inputs.keys()))
actual_keys = sorted(inputs.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -331,7 +331,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
"xpath_subs_seq",
"xpath_tags_seq",
]
actual_keys = sorted(list(inputs.keys()))
actual_keys = sorted(inputs.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -367,7 +367,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "input_ids", "token_type_ids", "xpath_subs_seq", "xpath_tags_seq"]
actual_keys = sorted(list(inputs.keys()))
actual_keys = sorted(inputs.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -390,7 +390,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "input_ids", "token_type_ids", "xpath_subs_seq", "xpath_tags_seq"]
actual_keys = sorted(list(inputs.keys()))
actual_keys = sorted(inputs.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -425,7 +425,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "input_ids", "token_type_ids", "xpath_subs_seq", "xpath_tags_seq"]
actual_keys = sorted(list(inputs.keys()))
actual_keys = sorted(inputs.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......@@ -444,7 +444,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys
expected_keys = ["attention_mask", "input_ids", "token_type_ids", "xpath_subs_seq", "xpath_tags_seq"]
actual_keys = sorted(list(inputs.keys()))
actual_keys = sorted(inputs.keys())
self.assertListEqual(actual_keys, expected_keys)
# verify input_ids
......
......@@ -295,7 +295,7 @@ class MobileViTModelTest(TFModelTesterMixin, unittest.TestCase):
# The number of elements in the loss should be the same as the number of elements in the label
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
added_label = prepared_for_class[
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
sorted(prepared_for_class.keys() - inputs_dict.keys(), reverse=True)[0]
]
expected_loss_size = added_label.shape.as_list()[:1]
......
......@@ -166,9 +166,11 @@ class PerceiverModelTester:
audio = torch.randn(
(self.batch_size, self.num_frames * self.audio_samples_per_frame, 1), device=torch_device
)
inputs = dict(
image=images, audio=audio, label=torch.zeros((self.batch_size, self.num_labels), device=torch_device)
)
inputs = {
"image": images,
"audio": audio,
"label": torch.zeros((self.batch_size, self.num_labels), device=torch_device),
}
else:
raise ValueError(f"Model class {model_class} not supported")
......@@ -734,7 +736,7 @@ class PerceiverModelTest(ModelTesterMixin, unittest.TestCase):
continue
config, inputs, input_mask, _, _ = self.model_tester.prepare_config_and_inputs(model_class=model_class)
inputs_dict = dict(inputs=inputs, attention_mask=input_mask)
inputs_dict = {"inputs": inputs, "attention_mask": input_mask}
for problem_type in problem_types:
with self.subTest(msg=f"Testing {model_class} with {problem_type['title']}"):
......
......@@ -44,8 +44,8 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
super().setUp()
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "你", "好", "是", "谁", "a", "b", "c", "d"]
word_shape = dict()
word_pronunciation = dict()
word_shape = {}
word_pronunciation = {}
for i, value in enumerate(vocab_tokens):
word_shape[value] = i
word_pronunciation[value] = i
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment