Commit cc412edd authored by LysandreJik's avatar LysandreJik
Browse files

Supports already existing special tokens

parent 2f259b22
...@@ -321,4 +321,16 @@ class CommonTestCases: ...@@ -321,4 +321,16 @@ class CommonTestCases:
filtered_sequence = [x for x in filtered_sequence if x is not None] filtered_sequence = [x for x in filtered_sequence if x is not None]
assert encoded_sequence == filtered_sequence assert encoded_sequence == filtered_sequence
# Testing with already existing special tokens
if tokenizer.cls_token_id == tokenizer.unk_token_id and tokenizer.cls_token_id == tokenizer.unk_token_id:
tokenizer.add_special_tokens({'cls_token': '</s>', 'sep_token': '<s>'})
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True)
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
sequence_ids_orig = encoded_sequence_dict["sequence_ids"]
sequence_ids = tokenizer.get_sequence_ids(encoded_sequence_w_special, special_tokens_present=True)
assert len(sequence_ids) == len(encoded_sequence_w_special)
print(sequence_ids_orig, sequence_ids)
assert sequence_ids_orig == sequence_ids
...@@ -204,7 +204,7 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -204,7 +204,7 @@ class BertTokenizer(PreTrainedTokenizer):
return cls + token_ids_0 + sep + token_ids_1 + sep return cls + token_ids_0 + sep + token_ids_1 + sep
def get_sequence_ids(self, token_ids_0, token_ids_1=None): def get_sequence_ids(self, token_ids_0, token_ids_1=None, special_tokens_present=False):
""" """
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
...@@ -217,6 +217,10 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -217,6 +217,10 @@ class BertTokenizer(PreTrainedTokenizer):
Returns: Returns:
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token. A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
""" """
if special_tokens_present:
return list(map(lambda x: 0 if x in [self.sep_token_id, self.cls_token_id] else 1, token_ids_0))
if token_ids_1: if token_ids_1:
return [0] + ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0] return [0] + ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0]
else: else:
......
...@@ -100,7 +100,7 @@ class RobertaTokenizer(GPT2Tokenizer): ...@@ -100,7 +100,7 @@ class RobertaTokenizer(GPT2Tokenizer):
cls = [self.cls_token_id] cls = [self.cls_token_id]
return cls + token_ids_0 + sep + sep + token_ids_1 + sep return cls + token_ids_0 + sep + sep + token_ids_1 + sep
def get_sequence_ids(self, token_ids_0, token_ids_1=None): def get_sequence_ids(self, token_ids_0, token_ids_1=None, special_tokens_present=False):
""" """
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
...@@ -113,6 +113,10 @@ class RobertaTokenizer(GPT2Tokenizer): ...@@ -113,6 +113,10 @@ class RobertaTokenizer(GPT2Tokenizer):
Returns: Returns:
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token. A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
""" """
if special_tokens_present:
return list(map(lambda x: 0 if x in [self.sep_token_id, self.cls_token_id] else 1, token_ids_0))
if token_ids_1: if token_ids_1:
return [0] + ([1] * len(token_ids_0)) + [0, 0] + ([1] * len(token_ids_1)) + [0] return [0] + ([1] * len(token_ids_0)) + [0, 0] + ([1] * len(token_ids_1)) + [0]
else: else:
......
...@@ -908,7 +908,7 @@ class PreTrainedTokenizer(object): ...@@ -908,7 +908,7 @@ class PreTrainedTokenizer(object):
logger.warning("This tokenizer does not make use of special tokens. The two sequences have been concatenated.") logger.warning("This tokenizer does not make use of special tokens. The two sequences have been concatenated.")
return token_ids_0 + token_ids_1 return token_ids_0 + token_ids_1
def get_sequence_ids(self, token_ids_0, token_ids_1=None): def get_sequence_ids(self, token_ids_0, token_ids_1=None, special_tokens_present=False):
return [1] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0)) return [1] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0))
def convert_ids_to_tokens(self, ids, skip_special_tokens=False): def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
......
...@@ -770,7 +770,7 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -770,7 +770,7 @@ class XLMTokenizer(PreTrainedTokenizer):
cls = [self.cls_token_id] cls = [self.cls_token_id]
return cls + token_ids_0 + sep + token_ids_1 + sep return cls + token_ids_0 + sep + token_ids_1 + sep
def get_sequence_ids(self, token_ids_0, token_ids_1=None): def get_sequence_ids(self, token_ids_0, token_ids_1=None, special_tokens_present=False):
""" """
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
...@@ -783,6 +783,10 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -783,6 +783,10 @@ class XLMTokenizer(PreTrainedTokenizer):
Returns: Returns:
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token. A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
""" """
if special_tokens_present:
return list(map(lambda x: 0 if x in [self.sep_token_id, self.cls_token_id] else 1, token_ids_0))
if token_ids_1: if token_ids_1:
return [0] + ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0] return [0] + ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0]
else: else:
......
...@@ -200,7 +200,7 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -200,7 +200,7 @@ class XLNetTokenizer(PreTrainedTokenizer):
cls = [self.cls_token_id] cls = [self.cls_token_id]
return token_ids_0 + sep + token_ids_1 + sep + cls return token_ids_0 + sep + token_ids_1 + sep + cls
def get_sequence_ids(self, token_ids_0, token_ids_1=None): def get_sequence_ids(self, token_ids_0, token_ids_1=None, special_tokens_present=False):
""" """
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods. special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
...@@ -213,6 +213,10 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -213,6 +213,10 @@ class XLNetTokenizer(PreTrainedTokenizer):
Returns: Returns:
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token. A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
""" """
if special_tokens_present:
return list(map(lambda x: 0 if x in [self.sep_token_id, self.cls_token_id] else 1, token_ids_0))
if token_ids_1: if token_ids_1:
return ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0, 0] return ([1] * len(token_ids_0)) + [0] + ([1] * len(token_ids_1)) + [0, 0]
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment