Commit a75c64d8 authored by Lysandre's avatar Lysandre
Browse files

Black 20 release

parent e78c1103
......@@ -32,7 +32,8 @@ if is_torch_available():
class TransfoXLModelTester:
def __init__(
self, parent,
self,
parent,
):
self.parent = parent
self.batch_size = 14
......
......@@ -41,7 +41,8 @@ if is_torch_available():
class XLMModelTester:
def __init__(
self, parent,
self,
parent,
):
self.parent = parent
self.batch_size = 13
......
......@@ -104,10 +104,20 @@ class XLNetModelTester:
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
perm_mask = torch.zeros(
self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float, device=torch_device,
self.batch_size,
self.seq_length + 1,
self.seq_length + 1,
dtype=torch.float,
device=torch_device,
)
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
target_mapping = torch.zeros(self.batch_size, 1, self.seq_length + 1, dtype=torch.float, device=torch_device,)
target_mapping = torch.zeros(
self.batch_size,
1,
self.seq_length + 1,
dtype=torch.float,
device=torch_device,
)
target_mapping[:, 0, -1] = 1.0 # predict last token
sequence_labels = None
......@@ -217,7 +227,11 @@ class XLNetModelTester:
# first forward pass
causal_mask = torch.ones(
input_ids_1.shape[0], input_ids_1.shape[1], input_ids_1.shape[1], dtype=torch.float, device=torch_device,
input_ids_1.shape[0],
input_ids_1.shape[1],
input_ids_1.shape[1],
dtype=torch.float,
device=torch_device,
)
causal_mask = torch.triu(causal_mask, diagonal=0)
outputs_cache = model(input_ids_1, use_cache=True, perm_mask=causal_mask)
......@@ -363,7 +377,11 @@ class XLNetModelTester:
total_loss, mems = result_with_labels.to_tuple()
result_with_labels = model(input_ids_1, start_positions=sequence_labels, end_positions=sequence_labels,)
result_with_labels = model(
input_ids_1,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
total_loss, mems = result_with_labels.to_tuple()
......
......@@ -164,7 +164,8 @@ class MonoColumnInputTestCase(unittest.TestCase):
for result, expect in zip(multi_result, expected_multi_result):
for key in expected_check_keys or []:
self.assertEqual(
set([o[key] for o in result]), set([o[key] for o in expect]),
set([o[key] for o in result]),
set([o[key] for o in expect]),
)
if isinstance(multi_result[0], list):
......@@ -214,7 +215,13 @@ class MonoColumnInputTestCase(unittest.TestCase):
"This is" # No mask_token is not supported
]
for model_name in FILL_MASK_FINETUNED_MODELS:
nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt", topk=2,)
nlp = pipeline(
task="fill-mask",
model=model_name,
tokenizer=model_name,
framework="pt",
topk=2,
)
self._test_mono_column_pipeline(
nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"]
)
......@@ -231,7 +238,13 @@ class MonoColumnInputTestCase(unittest.TestCase):
"This is" # No mask_token is not supported
]
for model_name in FILL_MASK_FINETUNED_MODELS:
nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="tf", topk=2,)
nlp = pipeline(
task="fill-mask",
model=model_name,
tokenizer=model_name,
framework="tf",
topk=2,
)
self._test_mono_column_pipeline(
nlp, valid_inputs, mandatory_keys, invalid_inputs, expected_check_keys=["sequence"]
)
......@@ -274,7 +287,13 @@ class MonoColumnInputTestCase(unittest.TestCase):
]
valid_targets = [" Patrick", " Clara"]
for model_name in LARGE_FILL_MASK_FINETUNED_MODELS:
nlp = pipeline(task="fill-mask", model=model_name, tokenizer=model_name, framework="pt", topk=2,)
nlp = pipeline(
task="fill-mask",
model=model_name,
tokenizer=model_name,
framework="pt",
topk=2,
)
self._test_mono_column_pipeline(
nlp,
valid_inputs,
......@@ -343,7 +362,12 @@ class MonoColumnInputTestCase(unittest.TestCase):
invalid_inputs = [4, "<mask>"]
mandatory_keys = ["summary_text"]
for model_name in TF_SUMMARIZATION_FINETUNED_MODELS:
nlp = pipeline(task="summarization", model=model_name, tokenizer=model_name, framework="tf",)
nlp = pipeline(
task="summarization",
model=model_name,
tokenizer=model_name,
framework="tf",
)
self._test_mono_column_pipeline(
nlp, VALID_INPUTS, mandatory_keys, invalid_inputs=invalid_inputs, **SUMMARIZATION_KWARGS
)
......@@ -355,7 +379,10 @@ class MonoColumnInputTestCase(unittest.TestCase):
for model_name, task in TRANSLATION_FINETUNED_MODELS:
nlp = pipeline(task=task, model=model_name, tokenizer=model_name)
self._test_mono_column_pipeline(
nlp, VALID_INPUTS, mandatory_keys, invalid_inputs,
nlp,
VALID_INPUTS,
mandatory_keys,
invalid_inputs,
)
@require_tf
......@@ -655,7 +682,9 @@ class QAPipelineTests(unittest.TestCase):
class NerPipelineTests(unittest.TestCase):
def _test_ner_pipeline(
self, nlp: Pipeline, output_keys: Iterable[str],
self,
nlp: Pipeline,
output_keys: Iterable[str],
):
ungrouped_ner_inputs = [
......
......@@ -882,8 +882,7 @@ class TokenizerTesterMixin:
assert encoded_sequence == padded_sequence_left
def test_padding_to_max_length(self):
""" We keep this test for backward compatibility but it should be remove when `pad_to_max_length` will e deprecated
"""
"""We keep this test for backward compatibility but it should be remove when `pad_to_max_length` will e deprecated"""
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
with self.subTest(f"{tokenizer.__class__.__name__}"):
......@@ -972,7 +971,11 @@ class TokenizerTesterMixin:
# Test 'longest' and 'no_padding' don't do anything
tokenizer.padding_side = "right"
not_padded_sequence = tokenizer.encode_plus(sequence, padding=True, return_special_tokens_mask=True,)
not_padded_sequence = tokenizer.encode_plus(
sequence,
padding=True,
return_special_tokens_mask=True,
)
not_padded_input_ids = not_padded_sequence["input_ids"]
not_padded_special_tokens_mask = not_padded_sequence["special_tokens_mask"]
......@@ -982,7 +985,11 @@ class TokenizerTesterMixin:
assert input_ids == not_padded_input_ids
assert special_tokens_mask == not_padded_special_tokens_mask
not_padded_sequence = tokenizer.encode_plus(sequence, padding=False, return_special_tokens_mask=True,)
not_padded_sequence = tokenizer.encode_plus(
sequence,
padding=False,
return_special_tokens_mask=True,
)
not_padded_input_ids = not_padded_sequence["input_ids"]
not_padded_special_tokens_mask = not_padded_sequence["special_tokens_mask"]
......@@ -1148,7 +1155,8 @@ class TokenizerTesterMixin:
)
for key in encoded_sequences_batch_padded_1.keys():
self.assertListEqual(
encoded_sequences_batch_padded_1[key], encoded_sequences_batch_padded_2[key],
encoded_sequences_batch_padded_1[key],
encoded_sequences_batch_padded_2[key],
)
# check 'no_padding' is unsensitive to a max length
......@@ -1158,7 +1166,8 @@ class TokenizerTesterMixin:
)
for key in encoded_sequences_batch_padded_1.keys():
self.assertListEqual(
encoded_sequences_batch_padded_1[key], encoded_sequences_batch_padded_2[key],
encoded_sequences_batch_padded_1[key],
encoded_sequences_batch_padded_2[key],
)
def test_added_token_serializable(self):
......@@ -1361,10 +1370,18 @@ class TokenizerTesterMixin:
if tokenizer.pad_token_id is None:
self.assertRaises(
ValueError, tokenizer.batch_encode_plus, sequences, padding=True, return_tensors="pt",
ValueError,
tokenizer.batch_encode_plus,
sequences,
padding=True,
return_tensors="pt",
)
self.assertRaises(
ValueError, tokenizer.batch_encode_plus, sequences, padding="longest", return_tensors="tf",
ValueError,
tokenizer.batch_encode_plus,
sequences,
padding="longest",
return_tensors="tf",
)
else:
pytorch_tensor = tokenizer.batch_encode_plus(sequences, padding=True, return_tensors="pt")
......
......@@ -228,7 +228,8 @@ class CommonFastTokenizerTest(unittest.TestCase):
def assert_special_tokens_map_equal(self, tokenizer_r, tokenizer_p):
# Assert the set of special tokens match.
self.assertSequenceEqual(
tokenizer_p.special_tokens_map.items(), tokenizer_r.special_tokens_map.items(),
tokenizer_p.special_tokens_map.items(),
tokenizer_r.special_tokens_map.items(),
)
def assert_add_tokens(self, tokenizer_r):
......@@ -544,18 +545,26 @@ class CommonFastTokenizerTest(unittest.TestCase):
assert_batch_padded_input_match(input_r, input_p, max_length)
input_r = tokenizer_r.batch_encode_plus(
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, padding="max_length",
["This is a simple input 1", "This is a simple input 2"],
max_length=max_length,
padding="max_length",
)
input_p = tokenizer_p.batch_encode_plus(
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, padding="max_length",
["This is a simple input 1", "This is a simple input 2"],
max_length=max_length,
padding="max_length",
)
assert_batch_padded_input_match(input_r, input_p, max_length)
input_r = tokenizer_r.batch_encode_plus(
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, padding="longest",
["This is a simple input 1", "This is a simple input 2"],
max_length=max_length,
padding="longest",
)
input_p = tokenizer_p.batch_encode_plus(
["This is a simple input 1", "This is a simple input 2"], max_length=max_length, padding=True,
["This is a simple input 1", "This is a simple input 2"],
max_length=max_length,
padding=True,
)
assert_batch_padded_input_match(input_r, input_p, len(input_r["input_ids"][0]))
......@@ -865,7 +874,11 @@ class NoPaddingTokenFastTokenizerMatchingTest(CommonFastTokenizerTest):
# Simple input
self.assertRaises(
ValueError, tokenizer_r.batch_encode_plus, s2, max_length=max_length, padding="max_length",
ValueError,
tokenizer_r.batch_encode_plus,
s2,
max_length=max_length,
padding="max_length",
)
# Pair input
......@@ -876,7 +889,11 @@ class NoPaddingTokenFastTokenizerMatchingTest(CommonFastTokenizerTest):
# Pair input
self.assertRaises(
ValueError, tokenizer_r.batch_encode_plus, p2, max_length=max_length, padding="max_length",
ValueError,
tokenizer_r.batch_encode_plus,
p2,
max_length=max_length,
padding="max_length",
)
......
......@@ -125,7 +125,9 @@ class MBartEnroIntegrationTest(unittest.TestCase):
def test_enro_tokenizer_prepare_seq2seq_batch(self):
batch = self.tokenizer.prepare_seq2seq_batch(
self.src_text, tgt_texts=self.tgt_text, max_length=len(self.expected_src_tokens),
self.src_text,
tgt_texts=self.tgt_text,
max_length=len(self.expected_src_tokens),
)
self.assertIsInstance(batch, BatchEncoding)
......
......@@ -44,7 +44,8 @@ class ReformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382],
tokenizer.convert_tokens_to_ids(tokens),
[285, 46, 10, 170, 382],
)
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
......@@ -76,7 +77,8 @@ class ReformerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
)
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(
ids, [8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4],
ids,
[8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72, 80, 6, 0, 4],
)
back_tokens = tokenizer.convert_ids_to_tokens(ids)
......
......@@ -126,7 +126,11 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
"Another summary.",
]
expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, tokenizer.eos_token_id]
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK,)
batch = tokenizer.prepare_seq2seq_batch(
src_text,
tgt_texts=tgt_text,
return_tensors=FRAMEWORK,
)
self.assertIsInstance(batch, BatchEncoding)
result = list(batch.input_ids.numpy()[0])
self.assertListEqual(expected_src_tokens, result)
......
......@@ -275,7 +275,9 @@ class TrainerIntegrationTest(unittest.TestCase):
MODEL_ID = "distilroberta-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
dataset = LineByLineTextDataset(
tokenizer=tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=tokenizer.max_len_single_sentence,
tokenizer=tokenizer,
file_path=PATH_SAMPLE_TEXT,
block_size=tokenizer.max_len_single_sentence,
)
self.assertEqual(len(dataset), 31)
......
......@@ -18,7 +18,7 @@ S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
def list_python_files_in_repository():
""" List all python files in the repository.
"""List all python files in the repository.
This function assumes that the script is executed in the root folder.
"""
......@@ -43,7 +43,7 @@ def find_all_links(file_paths):
def scan_code_for_links(source):
""" Scans the file to find links using a regular expression.
"""Scans the file to find links using a regular expression.
Returns a list of links.
"""
with open(source, "r") as content:
......@@ -55,7 +55,7 @@ def scan_code_for_links(source):
def check_all_links(links):
""" Check that the provided links are valid.
"""Check that the provided links are valid.
Links are considered valid if a HEAD request to the server
returns a 200 status code.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment