Unverified Commit 5e8c8eb5 authored by Aaron Gokaslan's avatar Aaron Gokaslan Committed by GitHub
Browse files

Apply ruff flake8-comprehensions (#21694)

parent df06fb1f
...@@ -166,8 +166,8 @@ def make_task_cmds(): ...@@ -166,8 +166,8 @@ def make_task_cmds():
# but need a tiny model for each # but need a tiny model for each
# #
# should have "{model_type.upper()}_TINY" corresponding vars defined, e.g., T5_TINY, etc. # should have "{model_type.upper()}_TINY" corresponding vars defined, e.g., T5_TINY, etc.
tasks2models = dict( tasks2models = {
trans=[ "trans": [
"bart", "bart",
"fsmt", "fsmt",
"m2m_100", "m2m_100",
...@@ -177,10 +177,10 @@ def make_task_cmds(): ...@@ -177,10 +177,10 @@ def make_task_cmds():
"t5_v1", "t5_v1",
# "mt5", missing model files # "mt5", missing model files
], ],
sum=[ "sum": [
"pegasus", "pegasus",
], ],
clm=[ "clm": [
"big_bird", "big_bird",
"bigbird_pegasus", "bigbird_pegasus",
"blenderbot", "blenderbot",
...@@ -192,7 +192,7 @@ def make_task_cmds(): ...@@ -192,7 +192,7 @@ def make_task_cmds():
"prophetnet", "prophetnet",
# "camembert", missing model files # "camembert", missing model files
], ],
mlm=[ "mlm": [
"albert", "albert",
"deberta", "deberta",
"deberta-v2", "deberta-v2",
...@@ -203,7 +203,7 @@ def make_task_cmds(): ...@@ -203,7 +203,7 @@ def make_task_cmds():
"layoutlm", "layoutlm",
# "reformer", # multiple issues with either mlm/qa/clas # "reformer", # multiple issues with either mlm/qa/clas
], ],
qa=[ "qa": [
"led", "led",
"longformer", "longformer",
"mobilebert", "mobilebert",
...@@ -213,7 +213,7 @@ def make_task_cmds(): ...@@ -213,7 +213,7 @@ def make_task_cmds():
# "convbert", # missing tokenizer files # "convbert", # missing tokenizer files
# "layoutlmv2", missing model files # "layoutlmv2", missing model files
], ],
clas=[ "clas": [
"bert", "bert",
"xlnet", "xlnet",
# "hubert", # missing tokenizer files # "hubert", # missing tokenizer files
...@@ -223,54 +223,54 @@ def make_task_cmds(): ...@@ -223,54 +223,54 @@ def make_task_cmds():
# "openai-gpt", missing model files # "openai-gpt", missing model files
# "tapas", multiple issues # "tapas", multiple issues
], ],
img_clas=[ "img_clas": [
"vit", "vit",
], ],
) }
scripts_dir = f"{ROOT_DIRECTORY}/examples/pytorch" scripts_dir = f"{ROOT_DIRECTORY}/examples/pytorch"
tasks = dict( tasks = {
trans=f""" "trans": f"""
{scripts_dir}/translation/run_translation.py {scripts_dir}/translation/run_translation.py
--train_file {data_dir_wmt}/train.json --train_file {data_dir_wmt}/train.json
--source_lang en --source_lang en
--target_lang ro --target_lang ro
""", """,
sum=f""" "sum": f"""
{scripts_dir}/summarization/run_summarization.py {scripts_dir}/summarization/run_summarization.py
--train_file {data_dir_xsum}/sample.json --train_file {data_dir_xsum}/sample.json
--max_source_length 12 --max_source_length 12
--max_target_length 12 --max_target_length 12
--lang en --lang en
""", """,
clm=f""" "clm": f"""
{scripts_dir}/language-modeling/run_clm.py {scripts_dir}/language-modeling/run_clm.py
--train_file {FIXTURE_DIRECTORY}/sample_text.txt --train_file {FIXTURE_DIRECTORY}/sample_text.txt
--block_size 8 --block_size 8
""", """,
mlm=f""" "mlm": f"""
{scripts_dir}/language-modeling/run_mlm.py {scripts_dir}/language-modeling/run_mlm.py
--train_file {FIXTURE_DIRECTORY}/sample_text.txt --train_file {FIXTURE_DIRECTORY}/sample_text.txt
""", """,
qa=f""" "qa": f"""
{scripts_dir}/question-answering/run_qa.py {scripts_dir}/question-answering/run_qa.py
--train_file {data_dir_samples}/SQUAD/sample.json --train_file {data_dir_samples}/SQUAD/sample.json
""", """,
clas=f""" "clas": f"""
{scripts_dir}/text-classification/run_glue.py {scripts_dir}/text-classification/run_glue.py
--train_file {data_dir_samples}/MRPC/train.csv --train_file {data_dir_samples}/MRPC/train.csv
--max_seq_length 12 --max_seq_length 12
--task_name MRPC --task_name MRPC
""", """,
img_clas=f""" "img_clas": f"""
{scripts_dir}/image-classification/run_image_classification.py {scripts_dir}/image-classification/run_image_classification.py
--dataset_name hf-internal-testing/cats_vs_dogs_sample --dataset_name hf-internal-testing/cats_vs_dogs_sample
--remove_unused_columns False --remove_unused_columns False
--max_steps 10 --max_steps 10
--image_processor_name {DS_TESTS_DIRECTORY}/vit_feature_extractor.json --image_processor_name {DS_TESTS_DIRECTORY}/vit_feature_extractor.json
""", """,
) }
launcher = get_launcher(distributed=True) launcher = get_launcher(distributed=True)
......
...@@ -155,21 +155,21 @@ class TestTrainerExt(TestCasePlus): ...@@ -155,21 +155,21 @@ class TestTrainerExt(TestCasePlus):
@require_torch_multi_gpu @require_torch_multi_gpu
def test_trainer_log_level_replica(self, experiment_id): def test_trainer_log_level_replica(self, experiment_id):
# as each sub-test is slow-ish split into multiple sub-tests to avoid CI timeout # as each sub-test is slow-ish split into multiple sub-tests to avoid CI timeout
experiments = dict( experiments = {
# test with the default log_level - should be info and thus log info once # test with the default log_level - should be info and thus log info once
base=dict(extra_args_str="", n_matches=1), "base": {"extra_args_str": "", "n_matches": 1},
# test with low log_level and log_level_replica - should be noisy on all processes # test with low log_level and log_level_replica - should be noisy on all processes
# now the info string should appear twice on 2 processes # now the info string should appear twice on 2 processes
low=dict(extra_args_str="--log_level debug --log_level_replica debug", n_matches=2), "low": {"extra_args_str": "--log_level debug --log_level_replica debug", "n_matches": 2},
# test with high log_level and low log_level_replica # test with high log_level and low log_level_replica
# now the info string should appear once only on the replica # now the info string should appear once only on the replica
high=dict(extra_args_str="--log_level error --log_level_replica debug", n_matches=1), "high": {"extra_args_str": "--log_level error --log_level_replica debug", "n_matches": 1},
# test with high log_level and log_level_replica - should be quiet on all processes # test with high log_level and log_level_replica - should be quiet on all processes
mixed=dict(extra_args_str="--log_level error --log_level_replica error", n_matches=0), "mixed": {"extra_args_str": "--log_level error --log_level_replica error", "n_matches": 0},
) }
data = experiments[experiment_id] data = experiments[experiment_id]
kwargs = dict(distributed=True, predict_with_generate=False, do_eval=False, do_predict=False) kwargs = {"distributed": True, "predict_with_generate": False, "do_eval": False, "do_predict": False}
log_info_string = "Running training" log_info_string = "Running training"
with CaptureStderr() as cl: with CaptureStderr() as cl:
self.run_seq2seq_quick(**kwargs, extra_args_str=data["extra_args_str"]) self.run_seq2seq_quick(**kwargs, extra_args_str=data["extra_args_str"])
......
...@@ -1480,7 +1480,7 @@ class GenerationTesterMixin: ...@@ -1480,7 +1480,7 @@ class GenerationTesterMixin:
signature = inspect.signature(model.forward) signature = inspect.signature(model.forward)
# We want to test only models where encoder/decoder head masking is implemented # We want to test only models where encoder/decoder head masking is implemented
if not set(head_masking.keys()) < set([*signature.parameters.keys()]): if not set(head_masking.keys()) < {*signature.parameters.keys()}:
continue continue
for attn_name, (name, mask) in zip(attention_names, head_masking.items()): for attn_name, (name, mask) in zip(attention_names, head_masking.items()):
......
...@@ -939,7 +939,7 @@ class BartModelIntegrationTests(unittest.TestCase): ...@@ -939,7 +939,7 @@ class BartModelIntegrationTests(unittest.TestCase):
def test_xsum_config_generation_params(self): def test_xsum_config_generation_params(self):
config = BartConfig.from_pretrained("facebook/bart-large-xsum") config = BartConfig.from_pretrained("facebook/bart-large-xsum")
expected_params = dict(num_beams=6, do_sample=False, early_stopping=True, length_penalty=1.0) expected_params = {"num_beams": 6, "do_sample": False, "early_stopping": True, "length_penalty": 1.0}
config_params = {k: getattr(config, k, "MISSING") for k, v in expected_params.items()} config_params = {k: getattr(config, k, "MISSING") for k, v in expected_params.items()}
self.assertDictEqual(expected_params, config_params) self.assertDictEqual(expected_params, config_params)
......
...@@ -299,8 +299,8 @@ class Blenderbot3BIntegrationTests(unittest.TestCase): ...@@ -299,8 +299,8 @@ class Blenderbot3BIntegrationTests(unittest.TestCase):
@slow @slow
def test_generation_from_short_input_same_as_parlai_3B(self): def test_generation_from_short_input_same_as_parlai_3B(self):
FASTER_GEN_KWARGS = dict(num_beams=1, early_stopping=True, min_length=15, max_length=25) FASTER_GEN_KWARGS = {"num_beams": 1, "early_stopping": True, "min_length": 15, "max_length": 25}
TOK_DECODE_KW = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) TOK_DECODE_KW = {"skip_special_tokens": True, "clean_up_tokenization_spaces": True}
torch.cuda.empty_cache() torch.cuda.empty_cache()
model = BlenderbotForConditionalGeneration.from_pretrained(self.ckpt).half().to(torch_device) model = BlenderbotForConditionalGeneration.from_pretrained(self.ckpt).half().to(torch_device)
......
...@@ -402,8 +402,8 @@ class FlaxBlenderbotModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGener ...@@ -402,8 +402,8 @@ class FlaxBlenderbotModelTest(FlaxModelTesterMixin, unittest.TestCase, FlaxGener
@unittest.skipUnless(jax_device != "cpu", "3B test too slow on CPU.") @unittest.skipUnless(jax_device != "cpu", "3B test too slow on CPU.")
@slow @slow
def test_generation_from_short_input_same_as_parlai_3B(self): def test_generation_from_short_input_same_as_parlai_3B(self):
FASTER_GEN_KWARGS = dict(num_beams=1, early_stopping=True, min_length=15, max_length=25) FASTER_GEN_KWARGS = {"num_beams": 1, "early_stopping": True, "min_length": 15, "max_length": 25}
TOK_DECODE_KW = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) TOK_DECODE_KW = {"skip_special_tokens": True, "clean_up_tokenization_spaces": True}
model = FlaxBlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-3B", from_pt=True) model = FlaxBlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-3B", from_pt=True)
tokenizer = BlenderbotTokenizer.from_pretrained("facebook/blenderbot-3B") tokenizer = BlenderbotTokenizer.from_pretrained("facebook/blenderbot-3B")
......
...@@ -124,7 +124,7 @@ class BloomTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -124,7 +124,7 @@ class BloomTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
input_text = list(sample_data.values()) input_text = list(sample_data.values())
output_tokens = list(map(tokenizer.encode, input_text)) output_tokens = list(map(tokenizer.encode, input_text))
predicted_text = list(map(lambda x: tokenizer.decode(x, clean_up_tokenization_spaces=False), output_tokens)) predicted_text = [tokenizer.decode(x, clean_up_tokenization_spaces=False) for x in output_tokens]
self.assertListEqual(predicted_text, input_text) self.assertListEqual(predicted_text, input_text)
def test_pretrained_model_lists(self): def test_pretrained_model_lists(self):
......
...@@ -551,7 +551,7 @@ class TFCLIPModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -551,7 +551,7 @@ class TFCLIPModelTest(TFModelTesterMixin, unittest.TestCase):
if self.__class__.__name__ == "TFCLIPModelTest": if self.__class__.__name__ == "TFCLIPModelTest":
inputs_dict.pop("return_loss", None) inputs_dict.pop("return_loss", None)
tf_main_layer_classes = set( tf_main_layer_classes = {
module_member module_member
for model_class in self.all_model_classes for model_class in self.all_model_classes
for module in (import_module(model_class.__module__),) for module in (import_module(model_class.__module__),)
...@@ -563,7 +563,7 @@ class TFCLIPModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -563,7 +563,7 @@ class TFCLIPModelTest(TFModelTesterMixin, unittest.TestCase):
if isinstance(module_member, type) if isinstance(module_member, type)
and tf.keras.layers.Layer in module_member.__bases__ and tf.keras.layers.Layer in module_member.__bases__
and getattr(module_member, "_keras_serializable", False) and getattr(module_member, "_keras_serializable", False)
) }
for main_layer_class in tf_main_layer_classes: for main_layer_class in tf_main_layer_classes:
# T5MainLayer needs an embed_tokens parameter when called without the inputs_embeds parameter # T5MainLayer needs an embed_tokens parameter when called without the inputs_embeds parameter
if "T5" in main_layer_class.__name__: if "T5" in main_layer_class.__name__:
......
...@@ -398,7 +398,7 @@ class TFData2VecVisionModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -398,7 +398,7 @@ class TFData2VecVisionModelTest(TFModelTesterMixin, unittest.TestCase):
# The number of elements in the loss should be the same as the number of elements in the label # The number of elements in the loss should be the same as the number of elements in the label
_, prepared_for_class = self.model_tester.prepare_config_and_inputs_for_keras_fit() _, prepared_for_class = self.model_tester.prepare_config_and_inputs_for_keras_fit()
added_label = prepared_for_class[ added_label = prepared_for_class[
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0] sorted(prepared_for_class.keys() - inputs_dict.keys(), reverse=True)[0]
] ]
loss_size = tf.size(added_label) loss_size = tf.size(added_label)
......
...@@ -628,7 +628,7 @@ class TFGroupViTModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -628,7 +628,7 @@ class TFGroupViTModelTest(TFModelTesterMixin, unittest.TestCase):
if self.__class__.__name__ == "TFGroupViTModelTest": if self.__class__.__name__ == "TFGroupViTModelTest":
inputs_dict.pop("return_loss", None) inputs_dict.pop("return_loss", None)
tf_main_layer_classes = set( tf_main_layer_classes = {
module_member module_member
for model_class in self.all_model_classes for model_class in self.all_model_classes
for module in (import_module(model_class.__module__),) for module in (import_module(model_class.__module__),)
...@@ -640,7 +640,7 @@ class TFGroupViTModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -640,7 +640,7 @@ class TFGroupViTModelTest(TFModelTesterMixin, unittest.TestCase):
if isinstance(module_member, type) if isinstance(module_member, type)
and tf.keras.layers.Layer in module_member.__bases__ and tf.keras.layers.Layer in module_member.__bases__
and getattr(module_member, "_keras_serializable", False) and getattr(module_member, "_keras_serializable", False)
) }
for main_layer_class in tf_main_layer_classes: for main_layer_class in tf_main_layer_classes:
# T5MainLayer needs an embed_tokens parameter when called without the inputs_embeds parameter # T5MainLayer needs an embed_tokens parameter when called without the inputs_embeds parameter
if "T5" in main_layer_class.__name__: if "T5" in main_layer_class.__name__:
......
...@@ -30,10 +30,10 @@ if is_torch_available(): ...@@ -30,10 +30,10 @@ if is_torch_available():
class Jukebox1bModelTester(unittest.TestCase): class Jukebox1bModelTester(unittest.TestCase):
all_model_classes = (JukeboxModel,) if is_torch_available() else () all_model_classes = (JukeboxModel,) if is_torch_available() else ()
model_id = "openai/jukebox-1b-lyrics" model_id = "openai/jukebox-1b-lyrics"
metas = dict( metas = {
artist="Zac Brown Band", "artist": "Zac Brown Band",
genres="Country", "genres": "Country",
lyrics="""I met a traveller from an antique land, "lyrics": """I met a traveller from an antique land,
Who said "Two vast and trunkless legs of stone Who said "Two vast and trunkless legs of stone
Stand in the desert. . . . Near them, on the sand, Stand in the desert. . . . Near them, on the sand,
Half sunk a shattered visage lies, whose frown, Half sunk a shattered visage lies, whose frown,
...@@ -48,7 +48,7 @@ class Jukebox1bModelTester(unittest.TestCase): ...@@ -48,7 +48,7 @@ class Jukebox1bModelTester(unittest.TestCase):
Of that colossal Wreck, boundless and bare Of that colossal Wreck, boundless and bare
The lone and level sands stretch far away The lone and level sands stretch far away
""", """,
) }
# fmt: off # fmt: off
EXPECTED_OUTPUT_2 = [ EXPECTED_OUTPUT_2 = [
1864, 1536, 1213, 1870, 1357, 1536, 519, 880, 1323, 789, 1082, 534, 1864, 1536, 1213, 1870, 1357, 1536, 519, 880, 1323, 789, 1082, 534,
...@@ -180,7 +180,7 @@ class Jukebox1bModelTester(unittest.TestCase): ...@@ -180,7 +180,7 @@ class Jukebox1bModelTester(unittest.TestCase):
model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval() model = JukeboxModel.from_pretrained(self.model_id, min_duration=0).eval()
set_seed(0) set_seed(0)
waveform = torch.rand((1, 5120, 1)) waveform = torch.rand((1, 5120, 1))
tokens = [i for i in self.prepare_inputs()] tokens = list(self.prepare_inputs())
zs = [model.vqvae.encode(waveform, start_level=2, bs_chunks=waveform.shape[0])[0], None, None] zs = [model.vqvae.encode(waveform, start_level=2, bs_chunks=waveform.shape[0])[0], None, None]
zs = model._sample( zs = model._sample(
...@@ -220,10 +220,10 @@ class Jukebox1bModelTester(unittest.TestCase): ...@@ -220,10 +220,10 @@ class Jukebox1bModelTester(unittest.TestCase):
class Jukebox5bModelTester(unittest.TestCase): class Jukebox5bModelTester(unittest.TestCase):
all_model_classes = (JukeboxModel,) if is_torch_available() else () all_model_classes = (JukeboxModel,) if is_torch_available() else ()
model_id = "openai/jukebox-5b-lyrics" model_id = "openai/jukebox-5b-lyrics"
metas = dict( metas = {
artist="Zac Brown Band", "artist": "Zac Brown Band",
genres="Country", "genres": "Country",
lyrics="""I met a traveller from an antique land, "lyrics": """I met a traveller from an antique land,
Who said "Two vast and trunkless legs of stone Who said "Two vast and trunkless legs of stone
Stand in the desert. . . . Near them, on the sand, Stand in the desert. . . . Near them, on the sand,
Half sunk a shattered visage lies, whose frown, Half sunk a shattered visage lies, whose frown,
...@@ -238,7 +238,7 @@ class Jukebox5bModelTester(unittest.TestCase): ...@@ -238,7 +238,7 @@ class Jukebox5bModelTester(unittest.TestCase):
Of that colossal Wreck, boundless and bare Of that colossal Wreck, boundless and bare
The lone and level sands stretch far away The lone and level sands stretch far away
""", """,
) }
# fmt: off # fmt: off
EXPECTED_OUTPUT_2 = [ EXPECTED_OUTPUT_2 = [
......
...@@ -21,10 +21,10 @@ from transformers.testing_utils import require_torch ...@@ -21,10 +21,10 @@ from transformers.testing_utils import require_torch
class JukeboxTokenizationTest(unittest.TestCase): class JukeboxTokenizationTest(unittest.TestCase):
tokenizer_class = JukeboxTokenizer tokenizer_class = JukeboxTokenizer
metas = dict( metas = {
artist="Zac Brown Band", "artist": "Zac Brown Band",
genres="Country", "genres": "Country",
lyrics="""I met a traveller from an antique land, "lyrics": """I met a traveller from an antique land,
Who said "Two vast and trunkless legs of stone Who said "Two vast and trunkless legs of stone
Stand in the desert. . . . Near them, on the sand, Stand in the desert. . . . Near them, on the sand,
Half sunk a shattered visage lies, whose frown, Half sunk a shattered visage lies, whose frown,
...@@ -39,7 +39,7 @@ class JukeboxTokenizationTest(unittest.TestCase): ...@@ -39,7 +39,7 @@ class JukeboxTokenizationTest(unittest.TestCase):
Of that colossal Wreck, boundless and bare Of that colossal Wreck, boundless and bare
The lone and level sands stretch far away The lone and level sands stretch far away
""", """,
) }
@require_torch @require_torch
def test_1b_lyrics_tokenizer(self): def test_1b_lyrics_tokenizer(self):
......
...@@ -233,7 +233,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase): ...@@ -233,7 +233,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"] expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify image # verify image
...@@ -253,7 +253,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase): ...@@ -253,7 +253,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"] expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify images # verify images
...@@ -301,7 +301,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase): ...@@ -301,7 +301,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"] expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -340,7 +340,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase): ...@@ -340,7 +340,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "labels", "token_type_ids"] expected_keys = ["attention_mask", "bbox", "image", "input_ids", "labels", "token_type_ids"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -362,7 +362,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase): ...@@ -362,7 +362,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "labels", "token_type_ids"] expected_keys = ["attention_mask", "bbox", "image", "input_ids", "labels", "token_type_ids"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -403,7 +403,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase): ...@@ -403,7 +403,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"] expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -422,7 +422,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase): ...@@ -422,7 +422,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"] expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -456,7 +456,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase): ...@@ -456,7 +456,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"] expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -472,7 +472,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase): ...@@ -472,7 +472,7 @@ class LayoutLMv2ProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"] expected_keys = ["attention_mask", "bbox", "image", "input_ids", "token_type_ids"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
......
...@@ -320,7 +320,7 @@ class TFLayoutLMv3ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -320,7 +320,7 @@ class TFLayoutLMv3ModelTest(TFModelTesterMixin, unittest.TestCase):
# The number of elements in the loss should be the same as the number of elements in the label # The number of elements in the loss should be the same as the number of elements in the label
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True) prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
added_label = prepared_for_class[ added_label = prepared_for_class[
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0] sorted(prepared_for_class.keys() - inputs_dict.keys(), reverse=True)[0]
] ]
expected_loss_size = added_label.shape.as_list()[:1] expected_loss_size = added_label.shape.as_list()[:1]
......
...@@ -213,7 +213,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase): ...@@ -213,7 +213,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify image # verify image
...@@ -235,7 +235,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase): ...@@ -235,7 +235,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify images # verify images
...@@ -285,7 +285,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase): ...@@ -285,7 +285,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -324,7 +324,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase): ...@@ -324,7 +324,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "input_ids", "labels", "pixel_values"] expected_keys = ["attention_mask", "bbox", "input_ids", "labels", "pixel_values"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -346,7 +346,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase): ...@@ -346,7 +346,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "input_ids", "labels", "pixel_values"] expected_keys = ["attention_mask", "bbox", "input_ids", "labels", "pixel_values"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -387,7 +387,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase): ...@@ -387,7 +387,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -406,7 +406,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase): ...@@ -406,7 +406,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -440,7 +440,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase): ...@@ -440,7 +440,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -456,7 +456,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase): ...@@ -456,7 +456,7 @@ class LayoutLMv3ProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"] expected_keys = ["attention_mask", "bbox", "input_ids", "pixel_values"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
......
...@@ -228,7 +228,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase): ...@@ -228,7 +228,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids"] expected_keys = ["attention_mask", "bbox", "image", "input_ids"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify image # verify image
...@@ -250,7 +250,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase): ...@@ -250,7 +250,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids"] expected_keys = ["attention_mask", "bbox", "image", "input_ids"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify images # verify images
...@@ -300,7 +300,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase): ...@@ -300,7 +300,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids"] expected_keys = ["attention_mask", "bbox", "image", "input_ids"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -339,7 +339,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase): ...@@ -339,7 +339,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "labels"] expected_keys = ["attention_mask", "bbox", "image", "input_ids", "labels"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -361,7 +361,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase): ...@@ -361,7 +361,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids", "labels"] expected_keys = ["attention_mask", "bbox", "image", "input_ids", "labels"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -402,7 +402,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase): ...@@ -402,7 +402,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids"] expected_keys = ["attention_mask", "bbox", "image", "input_ids"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -421,7 +421,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase): ...@@ -421,7 +421,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids"] expected_keys = ["attention_mask", "bbox", "image", "input_ids"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -455,7 +455,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase): ...@@ -455,7 +455,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids"] expected_keys = ["attention_mask", "bbox", "image", "input_ids"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -471,7 +471,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase): ...@@ -471,7 +471,7 @@ class LayoutXLMProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "bbox", "image", "input_ids"] expected_keys = ["attention_mask", "bbox", "image", "input_ids"]
actual_keys = sorted(list(input_processor.keys())) actual_keys = sorted(input_processor.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
......
...@@ -204,7 +204,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase): ...@@ -204,7 +204,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "input_ids", "token_type_ids", "xpath_subs_seq", "xpath_tags_seq"] expected_keys = ["attention_mask", "input_ids", "token_type_ids", "xpath_subs_seq", "xpath_tags_seq"]
actual_keys = sorted(list(inputs.keys())) actual_keys = sorted(inputs.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -216,7 +216,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase): ...@@ -216,7 +216,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "input_ids", "token_type_ids", "xpath_subs_seq", "xpath_tags_seq"] expected_keys = ["attention_mask", "input_ids", "token_type_ids", "xpath_subs_seq", "xpath_tags_seq"]
actual_keys = sorted(list(inputs.keys())) actual_keys = sorted(inputs.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -260,7 +260,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase): ...@@ -260,7 +260,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "input_ids", "token_type_ids", "xpath_subs_seq", "xpath_tags_seq"] expected_keys = ["attention_mask", "input_ids", "token_type_ids", "xpath_subs_seq", "xpath_tags_seq"]
actual_keys = sorted(list(inputs.keys())) actual_keys = sorted(inputs.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -294,7 +294,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase): ...@@ -294,7 +294,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
"xpath_subs_seq", "xpath_subs_seq",
"xpath_tags_seq", "xpath_tags_seq",
] ]
actual_keys = sorted(list(inputs.keys())) actual_keys = sorted(inputs.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -331,7 +331,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase): ...@@ -331,7 +331,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
"xpath_subs_seq", "xpath_subs_seq",
"xpath_tags_seq", "xpath_tags_seq",
] ]
actual_keys = sorted(list(inputs.keys())) actual_keys = sorted(inputs.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -367,7 +367,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase): ...@@ -367,7 +367,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "input_ids", "token_type_ids", "xpath_subs_seq", "xpath_tags_seq"] expected_keys = ["attention_mask", "input_ids", "token_type_ids", "xpath_subs_seq", "xpath_tags_seq"]
actual_keys = sorted(list(inputs.keys())) actual_keys = sorted(inputs.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -390,7 +390,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase): ...@@ -390,7 +390,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "input_ids", "token_type_ids", "xpath_subs_seq", "xpath_tags_seq"] expected_keys = ["attention_mask", "input_ids", "token_type_ids", "xpath_subs_seq", "xpath_tags_seq"]
actual_keys = sorted(list(inputs.keys())) actual_keys = sorted(inputs.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -425,7 +425,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase): ...@@ -425,7 +425,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "input_ids", "token_type_ids", "xpath_subs_seq", "xpath_tags_seq"] expected_keys = ["attention_mask", "input_ids", "token_type_ids", "xpath_subs_seq", "xpath_tags_seq"]
actual_keys = sorted(list(inputs.keys())) actual_keys = sorted(inputs.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
...@@ -444,7 +444,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase): ...@@ -444,7 +444,7 @@ class MarkupLMProcessorIntegrationTests(unittest.TestCase):
# verify keys # verify keys
expected_keys = ["attention_mask", "input_ids", "token_type_ids", "xpath_subs_seq", "xpath_tags_seq"] expected_keys = ["attention_mask", "input_ids", "token_type_ids", "xpath_subs_seq", "xpath_tags_seq"]
actual_keys = sorted(list(inputs.keys())) actual_keys = sorted(inputs.keys())
self.assertListEqual(actual_keys, expected_keys) self.assertListEqual(actual_keys, expected_keys)
# verify input_ids # verify input_ids
......
...@@ -295,7 +295,7 @@ class MobileViTModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -295,7 +295,7 @@ class MobileViTModelTest(TFModelTesterMixin, unittest.TestCase):
# The number of elements in the loss should be the same as the number of elements in the label # The number of elements in the loss should be the same as the number of elements in the label
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True) prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
added_label = prepared_for_class[ added_label = prepared_for_class[
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0] sorted(prepared_for_class.keys() - inputs_dict.keys(), reverse=True)[0]
] ]
expected_loss_size = added_label.shape.as_list()[:1] expected_loss_size = added_label.shape.as_list()[:1]
......
...@@ -166,9 +166,11 @@ class PerceiverModelTester: ...@@ -166,9 +166,11 @@ class PerceiverModelTester:
audio = torch.randn( audio = torch.randn(
(self.batch_size, self.num_frames * self.audio_samples_per_frame, 1), device=torch_device (self.batch_size, self.num_frames * self.audio_samples_per_frame, 1), device=torch_device
) )
inputs = dict( inputs = {
image=images, audio=audio, label=torch.zeros((self.batch_size, self.num_labels), device=torch_device) "image": images,
) "audio": audio,
"label": torch.zeros((self.batch_size, self.num_labels), device=torch_device),
}
else: else:
raise ValueError(f"Model class {model_class} not supported") raise ValueError(f"Model class {model_class} not supported")
...@@ -734,7 +736,7 @@ class PerceiverModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -734,7 +736,7 @@ class PerceiverModelTest(ModelTesterMixin, unittest.TestCase):
continue continue
config, inputs, input_mask, _, _ = self.model_tester.prepare_config_and_inputs(model_class=model_class) config, inputs, input_mask, _, _ = self.model_tester.prepare_config_and_inputs(model_class=model_class)
inputs_dict = dict(inputs=inputs, attention_mask=input_mask) inputs_dict = {"inputs": inputs, "attention_mask": input_mask}
for problem_type in problem_types: for problem_type in problem_types:
with self.subTest(msg=f"Testing {model_class} with {problem_type['title']}"): with self.subTest(msg=f"Testing {model_class} with {problem_type['title']}"):
......
...@@ -44,8 +44,8 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -44,8 +44,8 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
super().setUp() super().setUp()
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "你", "好", "是", "谁", "a", "b", "c", "d"] vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "你", "好", "是", "谁", "a", "b", "c", "d"]
word_shape = dict() word_shape = {}
word_pronunciation = dict() word_pronunciation = {}
for i, value in enumerate(vocab_tokens): for i, value in enumerate(vocab_tokens):
word_shape[value] = i word_shape[value] = i
word_pronunciation[value] = i word_pronunciation[value] = i
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment