Commit 9ce36e3e authored by samvelyan's avatar samvelyan
Browse files

Re-implemented tokenize() iteratively in PreTrainedTokenizer.

parent aaedfc35
......@@ -428,7 +428,7 @@ class PreTrainedTokenizer(object):
Parameters:
special_tokens_dict: dict of string. Keys should be in the list of predefined special attributes: [``bos_token``, ``eos_token``, ``unk_token``, ``sep_token``, ``pad_token``, ``cls_token``, ``mask_token``, ``additional_special_tokens``].
Tokens are only added if they are not already in the vocabulary (tested by checking if the tokenizer assign the index of the ``unk_token`` to them).
Returns:
......@@ -472,15 +472,45 @@ class PreTrainedTokenizer(object):
Take care of added tokens.
"""
def split_on_token(tok, text):
result = []
split_text = text.split(tok)
for i, sub_text in enumerate(split_text):
sub_text = sub_text.strip()
if i == 0 and not sub_text:
result += [tok]
elif i == len(split_text) - 1:
if sub_text:
result += [sub_text]
else:
pass
else:
if sub_text:
result += [sub_text]
result += [tok]
return result
def split_on_tokens(tok_list, text):
if not text:
return []
if not tok_list:
return self._tokenize(text, **kwargs)
tok = tok_list[0]
split_text = text.split(tok)
return sum((split_on_tokens(tok_list[1:], sub_text.strip()) + [tok] \
for sub_text in split_text), [])[:-1]
tokenized_text = []
text_list = [text]
for tok in tok_list:
tokenized_text = []
for sub_text in text_list:
if sub_text not in self.added_tokens_encoder \
and sub_text not in self.all_special_tokens:
tokenized_text += split_on_token(tok, sub_text)
else:
tokenized_text += [sub_text]
text_list = tokenized_text
return sum((self._tokenize(token, **kwargs) if token not \
in self.added_tokens_encoder and token not in self.all_special_tokens \
else [token] for token in tokenized_text), [])
added_tokens = list(self.added_tokens_encoder.keys()) + self.all_special_tokens
tokenized_text = split_on_tokens(added_tokens, text)
......@@ -522,7 +552,7 @@ class PreTrainedTokenizer(object):
def encode(self, text):
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
Same doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
"""
return self.convert_tokens_to_ids(self.tokenize(text))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment