"...git@developer.sourcefind.cn:OpenDAS/autoawq_kernels.git" did not exist on "9f21782542d13cd05c2c9c257a7ad0ea87b97cc0"
Commit f24a228a authored by Lysandre's avatar Lysandre
Browse files

Speed up tokenization process

parent c8ed1c82
...@@ -116,7 +116,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -116,7 +116,7 @@ def squad_convert_examples_to_features(examples, tokenizer, max_seq_length,
unique_id = 1000000000 unique_id = 1000000000
features = [] features = []
for (example_index, example) in enumerate(tqdm(examples)): for (example_index, example) in enumerate(tqdm(examples, desc="Converting examples to features")):
if is_training and not example.is_impossible: if is_training and not example.is_impossible:
# Get start and end position # Get start and end position
start_position = example.start_position start_position = example.start_position
......
...@@ -637,9 +637,11 @@ class PreTrainedTokenizer(object): ...@@ -637,9 +637,11 @@ class PreTrainedTokenizer(object):
text: The sequence to be encoded. text: The sequence to be encoded.
**kwargs: passed to the child `self.tokenize()` method **kwargs: passed to the child `self.tokenize()` method
""" """
all_special_tokens = self.all_special_tokens
def lowercase_text(t): def lowercase_text(t):
# convert non-special tokens to lowercase # convert non-special tokens to lowercase
escaped_special_toks = [re.escape(s_tok) for s_tok in self.all_special_tokens] escaped_special_toks = [re.escape(s_tok) for s_tok in all_special_tokens]
pattern = r'(^' + r'|'.join(escaped_special_toks) + r')|' + \ pattern = r'(^' + r'|'.join(escaped_special_toks) + r')|' + \
r'(.+?)' r'(.+?)'
return re.sub( return re.sub(
...@@ -680,17 +682,17 @@ class PreTrainedTokenizer(object): ...@@ -680,17 +682,17 @@ class PreTrainedTokenizer(object):
tokenized_text = [] tokenized_text = []
for sub_text in text_list: for sub_text in text_list:
if sub_text not in self.added_tokens_encoder \ if sub_text not in self.added_tokens_encoder \
and sub_text not in self.all_special_tokens: and sub_text not in all_special_tokens:
tokenized_text += split_on_token(tok, sub_text) tokenized_text += split_on_token(tok, sub_text)
else: else:
tokenized_text += [sub_text] tokenized_text += [sub_text]
text_list = tokenized_text text_list = tokenized_text
return list(itertools.chain.from_iterable((self._tokenize(token, **kwargs) if token not \ return list(itertools.chain.from_iterable((self._tokenize(token, **kwargs) if token not \
in self.added_tokens_encoder and token not in self.all_special_tokens \ in self.added_tokens_encoder and token not in all_special_tokens \
else [token] for token in tokenized_text))) else [token] for token in tokenized_text)))
added_tokens = list(self.added_tokens_encoder.keys()) + self.all_special_tokens added_tokens = list(self.added_tokens_encoder.keys()) + all_special_tokens
tokenized_text = split_on_tokens(added_tokens, text) tokenized_text = split_on_tokens(added_tokens, text)
return tokenized_text return tokenized_text
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment