Unverified Commit 65d74c49 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Add preprocessing step for transfo-xl tokenization to avoid tokenizing words...

Add preprocessing step for transfo-xl tokenization to avoid tokenizing words followed by punction to <unk> (#2987)

* add preprocessing to add space before punctuation for transfo_xl

* improve warning messages

* make style

* compile regex at instantination of tokenizer object
parent a143d947
......@@ -59,7 +59,7 @@ MODEL_CLASSES = {
# Padding text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
PADDING_TEXT = """ In 1991, the remains of Russian Tsar Nicholas II and his family
PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family
(except for Alexei and Maria) are discovered.
The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
remainder of the story. 1883 Western Siberia,
......@@ -214,7 +214,9 @@ def main():
if requires_preprocessing:
prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type)
preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text)
encoded_prompt = tokenizer.encode(preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt")
encoded_prompt = tokenizer.encode(
preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", add_space_before_punct_symbol=True
)
else:
encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False, return_tensors="pt")
encoded_prompt = encoded_prompt.to(args.device)
......
......@@ -22,6 +22,7 @@ import glob
import logging
import os
import pickle
import re
from collections import Counter, OrderedDict
from typing import List, Optional, Tuple, Union
......@@ -114,6 +115,9 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
self.delimiter = delimiter
self.vocab_file = vocab_file
self.never_split = never_split
self.punctuation_symbols = '!"#$%&()*+,-./\:;<=>?@[\\]^_`{|}~' # noqa: W605
self.punction_without_space_before_pattern = re.compile(r"[^\s][{}]".format(self.punctuation_symbols))
self.punctuation_with_space_around_pattern = self._compile_space_around_punctuation_pattern()
if pretrained_vocab_file is not None:
# Hack because, honestly this tokenizer was not made to be used
......@@ -126,6 +130,11 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
if vocab_file is not None:
self.build_vocab()
def _compile_space_around_punctuation_pattern(self):
look_ahead_for_special_token = "(?=[{}])".format(self.punctuation_symbols)
look_ahead_to_match_all_except_space = "(?=[^\s])" # noqa: W605
return re.compile(r"" + look_ahead_for_special_token + look_ahead_to_match_all_except_space)
def count_file(self, path, verbose=False, add_eos=False):
if verbose:
logger.info("counting file {} ...".format(path))
......@@ -295,6 +304,19 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
else:
return symbols
def prepare_for_tokenization(self, text, **kwargs):
# add spaces before punctuation symbols as should be done in transfo-xl
if "add_space_before_punct_symbol" in kwargs and kwargs["add_space_before_punct_symbol"]:
text = self.punctuation_with_space_around_pattern.sub(r" ", text)
elif self.punction_without_space_before_pattern.search(text):
# searches until the first occurence of a punctuation symbol without surrounding spaces
logger.warning(
"You might want to consider setting `add_space_before_punct_symbol=True` as an argument to the `tokenizer.encode()` to avoid tokenizing words with punctuation symbols to the `<unk>` token"
)
return text
class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer):
def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment