Unverified Commit 5ac8b622 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #1205 from maru0kun/patch-2

Fix typo
parents ed717635 5c6cac10
...@@ -55,6 +55,22 @@ class CommonTestCases: ...@@ -55,6 +55,22 @@ class CommonTestCases:
def get_input_output_texts(self): def get_input_output_texts(self):
raise NotImplementedError raise NotImplementedError
def test_tokenizers_common_properties(self):
tokenizer = self.get_tokenizer()
attributes_list = ["bos_token", "eos_token", "unk_token", "sep_token",
"pad_token", "cls_token", "mask_token"]
for attr in attributes_list:
self.assertTrue(hasattr(tokenizer, attr))
self.assertTrue(hasattr(tokenizer, attr + "_id"))
self.assertTrue(hasattr(tokenizer, "additional_special_tokens"))
self.assertTrue(hasattr(tokenizer, 'additional_special_tokens_ids'))
attributes_list = ["max_len", "init_inputs", "init_kwargs", "added_tokens_encoder",
"added_tokens_decoder"]
for attr in attributes_list:
self.assertTrue(hasattr(tokenizer, attr))
def test_save_and_load_tokenizer(self): def test_save_and_load_tokenizer(self):
# safety check on max_len default value so we are sure the test works # safety check on max_len default value so we are sure the test works
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
......
...@@ -162,58 +162,42 @@ class PreTrainedTokenizer(object): ...@@ -162,58 +162,42 @@ class PreTrainedTokenizer(object):
@property @property
def bos_token_id(self): def bos_token_id(self):
""" Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """ """ Id of the beginning of sentence token in the vocabulary. Log an error if used while not having been set. """
if self._bos_token is None: return self.convert_tokens_to_ids(self.bos_token)
logger.error("Using bos_token, but it is not set yet.")
return self.convert_tokens_to_ids(self._bos_token)
@property @property
def eos_token_id(self): def eos_token_id(self):
""" Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """ """ Id of the end of sentence token in the vocabulary. Log an error if used while not having been set. """
if self._eos_token is None: return self.convert_tokens_to_ids(self.eos_token)
logger.error("Using eos_token, but it is not set yet.")
return self.convert_tokens_to_ids(self._eos_token)
@property @property
def unk_token_is(self): def unk_token_id(self):
""" Id of the unknown token in the vocabulary. Log an error if used while not having been set. """ """ Id of the unknown token in the vocabulary. Log an error if used while not having been set. """
if self._unk_token is None: return self.convert_tokens_to_ids(self.unk_token)
logger.error("Using unk_token, but it is not set yet.")
return self.convert_tokens_to_ids(self._unk_token)
@property @property
def sep_token_id(self): def sep_token_id(self):
""" Id of the separation token in the vocabulary. E.g. separate context and query in an input sequence. Log an error if used while not having been set. """ """ Id of the separation token in the vocabulary. E.g. separate context and query in an input sequence. Log an error if used while not having been set. """
if self._sep_token is None: return self.convert_tokens_to_ids(self.sep_token)
logger.error("Using sep_token, but it is not set yet.")
return self.convert_tokens_to_ids(self._sep_token)
@property @property
def pad_token_id(self): def pad_token_id(self):
""" Id of the padding token in the vocabulary. Log an error if used while not having been set. """ """ Id of the padding token in the vocabulary. Log an error if used while not having been set. """
if self._pad_token is None: return self.convert_tokens_to_ids(self.pad_token)
logger.error("Using pad_token, but it is not set yet.")
return self.convert_tokens_to_ids(self._pad_token)
@property @property
def cls_token_id(self): def cls_token_id(self):
""" Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """ """ Id of the classification token in the vocabulary. E.g. to extract a summary of an input sequence leveraging self-attention along the full depth of the model. Log an error if used while not having been set. """
if self._cls_token is None: return self.convert_tokens_to_ids(self.cls_token)
logger.error("Using cls_token, but it is not set yet.")
return self.convert_tokens_to_ids(self._cls_token)
@property @property
def mask_token_id(self): def mask_token_id(self):
""" Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """ """ Id of the mask token in the vocabulary. E.g. when training a model with masked-language modeling. Log an error if used while not having been set. """
if self._mask_token is None: return self.convert_tokens_to_ids(self.mask_token)
logger.error("Using mask_token, but it is not set yet.")
return self.convert_tokens_to_ids(self._mask_token)
@property @property
def additional_special_tokens_ids(self): def additional_special_tokens_ids(self):
""" Ids of all the additional special tokens in the vocabulary (list of integers). Log an error if used while not having been set. """ """ Ids of all the additional special tokens in the vocabulary (list of integers). Log an error if used while not having been set. """
if self._additional_special_tokens is None: return self.convert_tokens_to_ids(self.additional_special_tokens)
logger.error("Using additional_special_tokens, but it is not set yet.")
return self.convert_tokens_to_ids(self._additional_special_tokens)
def __init__(self, max_len=None, **kwargs): def __init__(self, max_len=None, **kwargs):
self._bos_token = None self._bos_token = None
...@@ -653,6 +637,9 @@ class PreTrainedTokenizer(object): ...@@ -653,6 +637,9 @@ class PreTrainedTokenizer(object):
""" Converts a single token, or a sequence of tokens, (str/unicode) in a single integer id """ Converts a single token, or a sequence of tokens, (str/unicode) in a single integer id
(resp. a sequence of ids), using the vocabulary. (resp. a sequence of ids), using the vocabulary.
""" """
if tokens is None:
return None
if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)): if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)):
return self._convert_token_to_id_with_added_voc(tokens) return self._convert_token_to_id_with_added_voc(tokens)
...@@ -666,6 +653,9 @@ class PreTrainedTokenizer(object): ...@@ -666,6 +653,9 @@ class PreTrainedTokenizer(object):
return ids return ids
def _convert_token_to_id_with_added_voc(self, token): def _convert_token_to_id_with_added_voc(self, token):
if token is None:
return None
if token in self.added_tokens_encoder: if token in self.added_tokens_encoder:
return self.added_tokens_encoder[token] return self.added_tokens_encoder[token]
return self._convert_token_to_id(token) return self._convert_token_to_id(token)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment