Commit bf503158 authored by LysandreJik's avatar LysandreJik
Browse files

Sentence -> Sequence. Removed output_mask from the special token addition methods.

parent 8cba0572
...@@ -75,7 +75,7 @@ class TextDataset(Dataset): ...@@ -75,7 +75,7 @@ class TextDataset(Dataset):
tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text)) tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
while len(tokenized_text) >= block_size: # Truncate in block of block_size while len(tokenized_text) >= block_size: # Truncate in block of block_size
self.examples.append(tokenizer.add_special_tokens_single_sentence(tokenized_text[:block_size])) self.examples.append(tokenizer.add_special_tokens_single_sequence(tokenized_text[:block_size]))
tokenized_text = tokenized_text[block_size:] tokenized_text = tokenized_text[block_size:]
# Note that we are loosing the last truncated example here for the sake of simplicity (no padding) # Note that we are loosing the last truncated example here for the sake of simplicity (no padding)
# If your dataset is small, first you should loook for a bigger one :-) and second you # If your dataset is small, first you should loook for a bigger one :-) and second you
......
...@@ -131,8 +131,8 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -131,8 +131,8 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
text = tokenizer.encode("sequence builders") text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build") text_2 = tokenizer.encode("multi-sequence build")
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
assert encoded_sentence == [101] + text + [102] assert encoded_sentence == [101] + text + [102]
assert encoded_pair == [101] + text + [102] + text_2 + [102] assert encoded_pair == [101] + text + [102] + text_2 + [102]
......
...@@ -36,8 +36,8 @@ class DistilBertTokenizationTest(BertTokenizationTest): ...@@ -36,8 +36,8 @@ class DistilBertTokenizationTest(BertTokenizationTest):
text = tokenizer.encode("sequence builders") text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build") text_2 = tokenizer.encode("multi-sequence build")
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
assert encoded_sentence == text assert encoded_sentence == text
assert encoded_pair == text + [102] + text_2 assert encoded_pair == text + [102] + text_2
......
...@@ -87,8 +87,8 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -87,8 +87,8 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True) encoded_text_from_decode = tokenizer.encode("sequence builders", add_special_tokens=True)
encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True) encoded_pair_from_decode = tokenizer.encode("sequence builders", "multi-sequence build", add_special_tokens=True)
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
assert encoded_sentence == encoded_text_from_decode assert encoded_sentence == encoded_text_from_decode
assert encoded_pair == encoded_pair_from_decode assert encoded_pair == encoded_pair_from_decode
......
...@@ -187,18 +187,18 @@ class CommonTestCases: ...@@ -187,18 +187,18 @@ class CommonTestCases:
for weights_list_2 in weights_lists_2: for weights_list_2 in weights_lists_2:
self.assertListEqual(weights_list, weights_list_2) self.assertListEqual(weights_list, weights_list_2)
def test_mask_output(self): # def test_mask_output(self):
if sys.version_info <= (3, 0): # if sys.version_info <= (3, 0):
return # return
#
tokenizer = self.get_tokenizer() # tokenizer = self.get_tokenizer()
#
if tokenizer.add_special_tokens_sentences_pair.__qualname__.split('.')[0] != "PreTrainedTokenizer": # if tokenizer.add_special_tokens_sequence_pair.__qualname__.split('.')[0] != "PreTrainedTokenizer":
seq_0 = "Test this method." # seq_0 = "Test this method."
seq_1 = "With these inputs." # seq_1 = "With these inputs."
information = tokenizer.encode_plus(seq_0, seq_1, add_special_tokens=True, output_mask=True) # information = tokenizer.encode_plus(seq_0, seq_1, add_special_tokens=True, output_mask=True)
sequences, mask = information["sequence"], information["mask"] # sequences, mask = information["sequence"], information["mask"]
assert len(sequences) == len(mask) # assert len(sequences) == len(mask)
def test_number_of_added_tokens(self): def test_number_of_added_tokens(self):
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
...@@ -228,7 +228,7 @@ class CommonTestCases: ...@@ -228,7 +228,7 @@ class CommonTestCases:
assert len(overflowing_tokens) == 2 assert len(overflowing_tokens) == 2
assert len(truncated_sequence) == total_length - 2 assert len(truncated_sequence) == total_length - 2
assert truncated_sequence == tokenizer.add_special_tokens_single_sentence(sequence[:-2]) assert truncated_sequence == tokenizer.add_special_tokens_single_sequence(sequence[:-2])
def test_maximum_encoding_length_pair_input(self): def test_maximum_encoding_length_pair_input(self):
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
...@@ -237,7 +237,7 @@ class CommonTestCases: ...@@ -237,7 +237,7 @@ class CommonTestCases:
seq_1 = "This is another sentence to be encoded." seq_1 = "This is another sentence to be encoded."
sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True) sequence = tokenizer.encode(seq_0, seq_1, add_special_tokens=True)
truncated_second_sequence = tokenizer.add_special_tokens_sentences_pair( truncated_second_sequence = tokenizer.add_special_tokens_sequence_pair(
tokenizer.encode(seq_0), tokenizer.encode(seq_0),
tokenizer.encode(seq_1)[:-2] tokenizer.encode(seq_1)[:-2]
) )
......
...@@ -72,8 +72,8 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -72,8 +72,8 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
text = tokenizer.encode("sequence builders") text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build") text_2 = tokenizer.encode("multi-sequence build")
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
assert encoded_sentence == [1] + text + [1] assert encoded_sentence == [1] + text + [1]
assert encoded_pair == [1] + text + [1] + text_2 + [1] assert encoded_pair == [1] + text + [1] + text_2 + [1]
......
...@@ -95,8 +95,8 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): ...@@ -95,8 +95,8 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
text = tokenizer.encode("sequence builders") text = tokenizer.encode("sequence builders")
text_2 = tokenizer.encode("multi-sequence build") text_2 = tokenizer.encode("multi-sequence build")
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) encoded_sentence = tokenizer.add_special_tokens_single_sequence(text)
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) encoded_pair = tokenizer.add_special_tokens_sequence_pair(text, text_2)
assert encoded_sentence == text + [4, 3] assert encoded_sentence == text + [4, 3]
assert encoded_pair == text + [4] + text_2 + [4, 3] assert encoded_pair == text + [4] + text_2 + [4, 3]
......
...@@ -187,26 +187,21 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -187,26 +187,21 @@ class BertTokenizer(PreTrainedTokenizer):
out_string = ' '.join(tokens).replace(' ##', '').strip() out_string = ' '.join(tokens).replace(' ##', '').strip()
return out_string return out_string
def add_special_tokens_single_sentence(self, token_ids): def add_special_tokens_single_sequence(self, token_ids):
""" """
Adds special tokens to the a sequence for sequence classification tasks. Adds special tokens to the a sequence for sequence classification tasks.
A BERT sequence has the following format: [CLS] X [SEP] A BERT sequence has the following format: [CLS] X [SEP]
""" """
return [self.cls_token_id] + token_ids + [self.sep_token_id] return [self.cls_token_id] + token_ids + [self.sep_token_id]
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False): def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
""" """
Adds special tokens to a sequence pair for sequence classification tasks. Adds special tokens to a sequence pair for sequence classification tasks.
A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP] A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP]
""" """
sep = [self.sep_token_id] sep = [self.sep_token_id]
cls = [self.cls_token_id] cls = [self.cls_token_id]
if output_mask:
return (
cls + token_ids_0 + sep + token_ids_1 + sep,
[0] * len(cls + token_ids_0 + sep) + [1] * len(token_ids_1 + sep)
)
else:
return cls + token_ids_0 + sep + token_ids_1 + sep return cls + token_ids_0 + sep + token_ids_1 + sep
def save_vocabulary(self, vocab_path): def save_vocabulary(self, vocab_path):
......
...@@ -61,10 +61,10 @@ class DistilBertTokenizer(BertTokenizer): ...@@ -61,10 +61,10 @@ class DistilBertTokenizer(BertTokenizer):
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def add_special_tokens_single_sentence(self, token_ids): def add_special_tokens_single_sequence(self, token_ids):
return token_ids return token_ids
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False): def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1, output_mask=False):
sep = [self.sep_token_id] sep = [self.sep_token_id]
if output_mask: if output_mask:
return ( return (
......
...@@ -81,24 +81,18 @@ class RobertaTokenizer(GPT2Tokenizer): ...@@ -81,24 +81,18 @@ class RobertaTokenizer(GPT2Tokenizer):
sep_token=sep_token, cls_token=cls_token, pad_token=pad_token, sep_token=sep_token, cls_token=cls_token, pad_token=pad_token,
mask_token=mask_token, **kwargs) mask_token=mask_token, **kwargs)
def add_special_tokens_single_sentence(self, token_ids): def add_special_tokens_single_sequence(self, token_ids):
""" """
Adds special tokens to a sequence for sequence classification tasks. Adds special tokens to a sequence for sequence classification tasks.
A RoBERTa sequence has the following format: <s> X </s> A RoBERTa sequence has the following format: <s> X </s>
""" """
return [self.cls_token_id] + token_ids + [self.sep_token_id] return [self.cls_token_id] + token_ids + [self.sep_token_id]
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False): def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
""" """
Adds special tokens to a sequence pair for sequence classification tasks. Adds special tokens to a sequence pair for sequence classification tasks.
A RoBERTa sequence pair has the following format: <s> A </s></s> B </s> A RoBERTa sequence pair has the following format: <s> A </s></s> B </s>
""" """
sep = [self.sep_token_id] sep = [self.sep_token_id]
cls = [self.cls_token_id] cls = [self.cls_token_id]
if output_mask:
return (
cls + token_ids_0 + sep + sep + token_ids_1 + sep,
[0] * len(cls + token_ids_0 + sep + sep) + [1] * len(token_ids_1 + sep)
)
else:
return cls + token_ids_0 + sep + sep + token_ids_1 + sep return cls + token_ids_0 + sep + sep + token_ids_1 + sep
...@@ -708,7 +708,7 @@ class PreTrainedTokenizer(object): ...@@ -708,7 +708,7 @@ class PreTrainedTokenizer(object):
if text_pair is None: if text_pair is None:
if add_special_tokens: if add_special_tokens:
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
return self.add_special_tokens_single_sentence(sequence_tokens) return self.add_special_tokens_single_sequence(sequence_tokens)
else: else:
ids = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) ids = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
return ids return ids
...@@ -717,7 +717,7 @@ class PreTrainedTokenizer(object): ...@@ -717,7 +717,7 @@ class PreTrainedTokenizer(object):
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
if add_special_tokens: if add_special_tokens:
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens) return self.add_special_tokens_sequence_pair(first_sentence_tokens, second_sentence_tokens)
else: else:
logger.warning("No special tokens were added. The two sequences have been concatenated.") logger.warning("No special tokens were added. The two sequences have been concatenated.")
return first_sentence_tokens + second_sentence_tokens return first_sentence_tokens + second_sentence_tokens
...@@ -747,7 +747,7 @@ class PreTrainedTokenizer(object): ...@@ -747,7 +747,7 @@ class PreTrainedTokenizer(object):
if max_length: if max_length:
information["overflowing_tokens"] = sequence_tokens[max_length - n_added_tokens:] information["overflowing_tokens"] = sequence_tokens[max_length - n_added_tokens:]
sequence_tokens = sequence_tokens[:max_length - n_added_tokens] sequence_tokens = sequence_tokens[:max_length - n_added_tokens]
sequence = self.add_special_tokens_single_sentence(sequence_tokens) sequence = self.add_special_tokens_single_sequence(sequence_tokens)
else: else:
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
if max_length: if max_length:
...@@ -774,16 +774,13 @@ class PreTrainedTokenizer(object): ...@@ -774,16 +774,13 @@ class PreTrainedTokenizer(object):
information["overflowing_tokens"] = second_sentence_tokens[max_length - f_len - n_added_tokens:] information["overflowing_tokens"] = second_sentence_tokens[max_length - f_len - n_added_tokens:]
second_sentence_tokens = second_sentence_tokens[:max_length - f_len - n_added_tokens] second_sentence_tokens = second_sentence_tokens[:max_length - f_len - n_added_tokens]
encoded_sequence = self.add_special_tokens_sentences_pair( sequence = self.add_special_tokens_sequence_pair(
first_sentence_tokens, first_sentence_tokens,
second_sentence_tokens, second_sentence_tokens
output_mask
) )
if output_mask: # if output_mask:
sequence, information["mask"] = encoded_sequence # sequence, information["mask"] = encoded_sequence
else:
sequence = encoded_sequence
information["sequence"] = sequence information["sequence"] = sequence
else: else:
...@@ -800,11 +797,11 @@ class PreTrainedTokenizer(object): ...@@ -800,11 +797,11 @@ class PreTrainedTokenizer(object):
return information return information
def add_special_tokens_single_sentence(self, token_ids): def add_special_tokens_single_sequence(self, token_ids):
logger.warning("This tokenizer does not make use of special tokens. The sequence has been returned with no modification.") logger.warning("This tokenizer does not make use of special tokens. The sequence has been returned with no modification.")
return token_ids return token_ids
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False): def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
logger.warning("This tokenizer does not make use of special tokens. The two sequences have been concatenated.") logger.warning("This tokenizer does not make use of special tokens. The two sequences have been concatenated.")
return token_ids_0 + token_ids_1 return token_ids_0 + token_ids_1
......
...@@ -754,27 +754,20 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -754,27 +754,20 @@ class XLMTokenizer(PreTrainedTokenizer):
out_string = ''.join(tokens).replace('</w>', ' ').strip() out_string = ''.join(tokens).replace('</w>', ' ').strip()
return out_string return out_string
def add_special_tokens_single_sentence(self, token_ids): def add_special_tokens_single_sequence(self, token_ids):
""" """
Adds special tokens to a sequence for sequence classification tasks. Adds special tokens to a sequence for sequence classification tasks.
An XLM sequence has the following format: [CLS] X [SEP] An XLM sequence has the following format: [CLS] X [SEP]
""" """
return [self.cls_token_id] + token_ids + [self.sep_token_id] return [self.cls_token_id] + token_ids + [self.sep_token_id]
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False): def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
""" """
Adds special tokens to a sequence pair for sequence classification tasks. Adds special tokens to a sequence pair for sequence classification tasks.
An XLM sequence pair has the following format: [CLS] A [SEP] B [SEP] An XLM sequence pair has the following format: [CLS] A [SEP] B [SEP]
""" """
sep = [self.sep_token_id] sep = [self.sep_token_id]
cls = [self.cls_token_id] cls = [self.cls_token_id]
if output_mask:
return (
cls + token_ids_0 + sep + token_ids_1 + sep,
[0] * len(cls + token_ids_0 + sep) + [1] * len(token_ids_1 + sep)
)
else:
return cls + token_ids_0 + sep + token_ids_1 + sep return cls + token_ids_0 + sep + token_ids_1 + sep
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
......
...@@ -181,7 +181,7 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -181,7 +181,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip() out_string = ''.join(tokens).replace(SPIECE_UNDERLINE, ' ').strip()
return out_string return out_string
def add_special_tokens_single_sentence(self, token_ids): def add_special_tokens_single_sequence(self, token_ids):
""" """
Adds special tokens to a sequence pair for sequence classification tasks. Adds special tokens to a sequence pair for sequence classification tasks.
An XLNet sequence pair has the following format: A [SEP] B [SEP][CLS] An XLNet sequence pair has the following format: A [SEP] B [SEP][CLS]
...@@ -190,7 +190,7 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -190,7 +190,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
cls = [self.cls_token_id] cls = [self.cls_token_id]
return token_ids + sep + cls return token_ids + sep + cls
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False): def add_special_tokens_sequence_pair(self, token_ids_0, token_ids_1):
""" """
Adds special tokens to a sequence for sequence classification tasks. Adds special tokens to a sequence for sequence classification tasks.
An XLNet sequence has the following format: X [SEP][CLS] An XLNet sequence has the following format: X [SEP][CLS]
...@@ -199,12 +199,6 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -199,12 +199,6 @@ class XLNetTokenizer(PreTrainedTokenizer):
sep = [self.sep_token_id] sep = [self.sep_token_id]
cls = [self.cls_token_id] cls = [self.cls_token_id]
cls_segment_ids = [2] cls_segment_ids = [2]
if output_mask:
return (
token_ids_0 + sep + token_ids_1 + sep + cls,
[0] * len(token_ids_0 + sep) + [1] * len(token_ids_1 + sep) + cls_segment_ids
)
else:
return token_ids_0 + sep + token_ids_1 + sep + cls return token_ids_0 + sep + token_ids_1 + sep + cls
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment