"vscode:/vscode.git/clone" did not exist on "b54ef78d0c30045bb3f9ecc8b178eab0dfdbeaec"
Unverified Commit 65d74c49 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Add preprocessing step for transfo-xl tokenization to avoid tokenizing words...

Add preprocessing step for transfo-xl tokenization to avoid tokenizing words followed by punction to <unk> (#2987)

* add preprocessing to add space before punctuation for transfo_xl

* improve warning messages

* make style

* compile regex at instantination of tokenizer object
parent a143d947
...@@ -59,7 +59,7 @@ MODEL_CLASSES = { ...@@ -59,7 +59,7 @@ MODEL_CLASSES = {
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia # Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology # in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e # and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
PADDING_TEXT = """ In 1991, the remains of Russian Tsar Nicholas II and his family PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered. (except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia, remainder of the story. 1883 Western Siberia,
...@@ -214,7 +214,9 @@ def main(): ...@@ -214,7 +214,9 @@ def main():
if requires_preprocessing: if requires_preprocessing:
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text) preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
encoded_prompt = tokenizer.encode(preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt") encoded_prompt = tokenizer.encode(
preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", add_space_before_punct_symbol=True
)
else: else:
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt") encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
encoded_prompt = encoded_prompt.to(args.device) encoded_prompt = encoded_prompt.to(args.device)
......
...@@ -22,6 +22,7 @@ import glob ...@@ -22,6 +22,7 @@ import glob
import logging import logging
import os import os
import pickle import pickle
import re
from collections import Counter, OrderedDict from collections import Counter, OrderedDict
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
...@@ -114,6 +115,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -114,6 +115,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
self.delimiter = delimiter self.delimiter = delimiter
self.vocab_file = vocab_file self.vocab_file = vocab_file
self.never_split = never_split self.never_split = never_split
self.punctuation_symbols = '!"#$%&()*+,-./\:;<=>?@[\\]^_`{|}~' # noqa: W605
self.punction_without_space_before_pattern = re.compile(r"[^\s][{}]".format(self.punctuation_symbols))
self.punctuation_with_space_around_pattern = self._compile_space_around_punctuation_pattern()
if pretrained_vocab_file is not None: if pretrained_vocab_file is not None:
# Hack because, honestly this tokenizer was not made to be used # Hack because, honestly this tokenizer was not made to be used
...@@ -126,6 +130,11 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -126,6 +130,11 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
if vocab_file is not None: if vocab_file is not None:
self.build_vocab() self.build_vocab()
def _compile_space_around_punctuation_pattern(self):
look_ahead_for_special_token = "(?=[{}])".format(self.punctuation_symbols)
look_ahead_to_match_all_except_space = "(?=[^\s])" # noqa: W605
return re.compile(r"" + look_ahead_for_special_token + look_ahead_to_match_all_except_space)
def count_file(self, path, verbose=False, add_eos=False): def count_file(self, path, verbose=False, add_eos=False):
if verbose: if verbose:
logger.info("counting file {} ...".format(path)) logger.info("counting file {} ...".format(path))
...@@ -295,6 +304,19 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -295,6 +304,19 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
else: else:
return symbols return symbols
def prepare_for_tokenization(self, text, **kwargs):
# add spaces before punctuation symbols as should be done in transfo-xl
if "add_space_before_punct_symbol" in kwargs and kwargs["add_space_before_punct_symbol"]:
text = self.punctuation_with_space_around_pattern.sub(r" ", text)
elif self.punction_without_space_before_pattern.search(text):
# searches until the first occurence of a punctuation symbol without surrounding spaces
logger.warning(
"You might want to consider setting `add_space_before_punct_symbol=True` as an argument to the `tokenizer.encode()` to avoid tokenizing words with punctuation symbols to the `<unk>` token"
)
return text
class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer): class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer):
def __init__( def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment