Commit e391d473 authored by LysandreJik's avatar LysandreJik
Browse files

Tokenizers' encode function can output binary masks

parent 0d1dad6d
...@@ -194,14 +194,20 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -194,14 +194,20 @@ class BertTokenizer(PreTrainedTokenizer):
""" """
return [self.cls_token_id] + token_ids + [self.sep_token_id] return [self.cls_token_id] + token_ids + [self.sep_token_id]
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False):
""" """
Adds special tokens to a sequence pair for sequence classification tasks. Adds special tokens to a sequence pair for sequence classification tasks.
A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP] A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP]
""" """
sep = [self.sep_token_id] sep = [self.sep_token_id]
cls = [self.cls_token_id] cls = [self.cls_token_id]
return cls + token_ids_0 + sep + token_ids_1 + sep if output_mask:
return (
cls + token_ids_0 + sep + token_ids_1 + sep,
[0] * len(cls + token_ids_0 + sep) + [1] * len(token_ids_1 + sep)
)
else:
return cls + token_ids_0 + sep + token_ids_1 + sep
def save_vocabulary(self, vocab_path): def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a directory or file.""" """Save the tokenizer vocabulary to a directory or file."""
......
...@@ -88,11 +88,17 @@ class RobertaTokenizer(GPT2Tokenizer): ...@@ -88,11 +88,17 @@ class RobertaTokenizer(GPT2Tokenizer):
""" """
return [self.cls_token_id] + token_ids + [self.sep_token_id] return [self.cls_token_id] + token_ids + [self.sep_token_id]
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False):
""" """
Adds special tokens to a sequence pair for sequence classification tasks. Adds special tokens to a sequence pair for sequence classification tasks.
A RoBERTa sequence pair has the following format: <s> A </s></s> B </s> A RoBERTa sequence pair has the following format: <s> A </s></s> B </s>
""" """
sep = [self.sep_token_id] sep = [self.sep_token_id]
cls = [self.cls_token_id] cls = [self.cls_token_id]
return cls + token_ids_0 + sep + sep + token_ids_1 + sep if output_mask:
return (
cls + token_ids_0 + sep + sep + token_ids_1 + sep,
[0] * len(cls + token_ids_0 + sep) + [1] * len(sep + token_ids_1 + sep)
)
else:
return cls + token_ids_0 + sep + sep + token_ids_1 + sep
...@@ -663,7 +663,7 @@ class PreTrainedTokenizer(object): ...@@ -663,7 +663,7 @@ class PreTrainedTokenizer(object):
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
raise NotImplementedError raise NotImplementedError
def encode(self, text, text_pair=None, add_special_tokens=False, **kwargs): def encode(self, text, text_pair=None, add_special_tokens=False, output_mask=False, **kwargs):
""" """
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
...@@ -674,6 +674,8 @@ class PreTrainedTokenizer(object): ...@@ -674,6 +674,8 @@ class PreTrainedTokenizer(object):
text_pair: Optional second sequence to be encoded. text_pair: Optional second sequence to be encoded.
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
to their model. to their model.
output_mask: if set to ``True``, returns the text pair corresponding mask with 0 for the first sequence,
and 1 for the second.
**kwargs: passed to the `self.tokenize()` method **kwargs: passed to the `self.tokenize()` method
""" """
if text_pair is None: if text_pair is None:
...@@ -686,7 +688,7 @@ class PreTrainedTokenizer(object): ...@@ -686,7 +688,7 @@ class PreTrainedTokenizer(object):
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
if add_special_tokens: if add_special_tokens:
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens) return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens, output_mask)
else: else:
return first_sentence_tokens, second_sentence_tokens return first_sentence_tokens, second_sentence_tokens
...@@ -694,7 +696,7 @@ class PreTrainedTokenizer(object): ...@@ -694,7 +696,7 @@ class PreTrainedTokenizer(object):
logger.warning("This tokenizer does not make use of special tokens. The sequence has been returned with no modification.") logger.warning("This tokenizer does not make use of special tokens. The sequence has been returned with no modification.")
return token_ids return token_ids
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False):
logger.warning("This tokenizer does not make use of special tokens. The two sequences have been concatenated.") logger.warning("This tokenizer does not make use of special tokens. The two sequences have been concatenated.")
return token_ids_0 + token_ids_1 return token_ids_0 + token_ids_1
......
...@@ -761,14 +761,21 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -761,14 +761,21 @@ class XLMTokenizer(PreTrainedTokenizer):
""" """
return [self.cls_token_id] + token_ids + [self.sep_token_id] return [self.cls_token_id] + token_ids + [self.sep_token_id]
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False):
""" """
Adds special tokens to a sequence pair for sequence classification tasks. Adds special tokens to a sequence pair for sequence classification tasks.
An XLM sequence pair has the following format: [CLS] A [SEP] B [SEP] An XLM sequence pair has the following format: [CLS] A [SEP] B [SEP]
""" """
sep = [self.sep_token_id] sep = [self.sep_token_id]
cls = [self.cls_token_id] cls = [self.cls_token_id]
return cls + token_ids_0 + sep + token_ids_1 + sep
if output_mask:
return (
cls + token_ids_0 + sep + token_ids_1 + sep,
[0] * len(cls + token_ids_0 + sep) + [1] * len(token_ids_1 + sep)
)
else:
return cls + token_ids_0 + sep + token_ids_1 + sep
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
"""Save the tokenizer vocabulary and merge files to a directory.""" """Save the tokenizer vocabulary and merge files to a directory."""
......
...@@ -190,14 +190,21 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -190,14 +190,21 @@ class XLNetTokenizer(PreTrainedTokenizer):
cls = [self.cls_token_id] cls = [self.cls_token_id]
return token_ids + sep + cls return token_ids + sep + cls
def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1): def add_special_tokens_sentences_pair(self, token_ids_0, token_ids_1, output_mask=False):
""" """
Adds special tokens to a sequence for sequence classification tasks. Adds special tokens to a sequence for sequence classification tasks.
An XLNet sequence has the following format: X [SEP][CLS] An XLNet sequence has the following format: X [SEP][CLS]
""" """
sep = [self.sep_token_id] sep = [self.sep_token_id]
cls = [self.cls_token_id] cls = [self.cls_token_id]
return token_ids_0 + sep + token_ids_1 + sep + cls if output_mask:
return (
token_ids_0 + sep + token_ids_1 + sep + cls,
[0] * len(token_ids_0 + sep) + [1] * len(token_ids_1 + sep + cls)
)
else:
return token_ids_0 + sep + token_ids_1 + sep + cls
def save_vocabulary(self, save_directory): def save_vocabulary(self, save_directory):
""" Save the sentencepiece vocabulary (copy original file) and special tokens file """ Save the sentencepiece vocabulary (copy original file) and special tokens file
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment