Unverified Commit cb276b41 authored by RafaelWO's avatar RafaelWO Committed by GitHub
Browse files

Transformer-XL: Improved tokenization with sacremoses (#6322)



* Improved tokenization with sacremoses

 * The TransfoXLTokenizer is now using sacremoses for tokenization
 * Added tokenization of comma-separated and floating point numbers.
 * Removed prepare_for_tokenization() from tokenization_transfo_xl.py because punctuation is handled by sacremoses
 * Added corresponding tests
 * Removed test comapring TransfoXLTokenizer and TransfoXLTokenizerFast
 * Added deprecation warning to TransfoXLTokenizerFast

* isort change
Co-authored-by: default avatarTeven <teven.lescao@gmail.com>
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent 930153e7
...@@ -22,11 +22,13 @@ import glob ...@@ -22,11 +22,13 @@ import glob
import os import os
import pickle import pickle
import re import re
import warnings
from collections import Counter, OrderedDict from collections import Counter, OrderedDict
from typing import Optional from typing import List, Optional
import numpy as np import numpy as np
import sacremoses as sm
from tokenizers import Tokenizer from tokenizers import Tokenizer
from tokenizers.implementations import BaseTokenizer from tokenizers.implementations import BaseTokenizer
from tokenizers.models import WordLevel from tokenizers.models import WordLevel
...@@ -70,6 +72,47 @@ PRETRAINED_CORPUS_ARCHIVE_MAP = { ...@@ -70,6 +72,47 @@ PRETRAINED_CORPUS_ARCHIVE_MAP = {
} }
CORPUS_NAME = "corpus.bin" CORPUS_NAME = "corpus.bin"
MATCH_NUMBERS = r"(?<=\d)[,.](?=\d)", r" @\g<0>@ "
DETOKENIZE_NUMBERS = [(r" @\,@ ", r","), (r" @\.@ ", r".")]
def tokenize_numbers(text_array: List[str]) -> List[str]:
"""
Splits large comma-separated numbers and floating point values.
This is done by replacing commas with ' @,@ ' and dots with ' @.@ '.
Args:
text_array: An already tokenized text as list
Returns:
A list of strings with tokenized numbers
Example::
>>> tokenize_numbers(["$", "5,000", "1.73", "m"])
["$", "5", "@,@", "000", "1", "@.@", "73", "m"]
"""
tokenized = []
for i in range(len(text_array)):
reg, sub = MATCH_NUMBERS
replaced = re.sub(reg, sub, text_array[i]).split()
tokenized.extend(replaced)
return tokenized
def detokenize_numbers(text: str) -> str:
"""
Inverts the operation of `tokenize_numbers`.
This is replacing ' @,@ ' and ' @.@' by ',' and '.'.
Args:
text: A string where the number should be detokenized
Returns:
A detokenized string
Example::
>>> detokenize_numbers("$ 5 @,@ 000 1 @.@ 73 m")
"$ 5,000 1.73 m"
"""
for reg, sub in DETOKENIZE_NUMBERS:
text = re.sub(reg, sub, text)
return text
class TransfoXLTokenizer(PreTrainedTokenizer): class TransfoXLTokenizer(PreTrainedTokenizer):
""" """
...@@ -97,6 +140,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -97,6 +140,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
unk_token="<unk>", unk_token="<unk>",
eos_token="<eos>", eos_token="<eos>",
additional_special_tokens=["<formula>"], additional_special_tokens=["<formula>"],
language="en",
**kwargs **kwargs
): ):
super().__init__( super().__init__(
...@@ -118,6 +162,10 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -118,6 +162,10 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
self.punctuation_symbols = '!"#$%&()*+,-./\\:;<=>?@[\\]^_`{|}~' self.punctuation_symbols = '!"#$%&()*+,-./\\:;<=>?@[\\]^_`{|}~'
self.punction_without_space_before_pattern = re.compile(r"[^\s][{}]".format(self.punctuation_symbols)) self.punction_without_space_before_pattern = re.compile(r"[^\s][{}]".format(self.punctuation_symbols))
self.punctuation_with_space_around_pattern = self._compile_space_around_punctuation_pattern() self.punctuation_with_space_around_pattern = self._compile_space_around_punctuation_pattern()
self.language = language
self.moses_punct_normalizer = sm.MosesPunctNormalizer(language)
self.moses_tokenizer = sm.MosesTokenizer(language)
self.moses_detokenizer = sm.MosesDetokenizer(language)
try: try:
if pretrained_vocab_file is not None: if pretrained_vocab_file is not None:
...@@ -300,6 +348,34 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -300,6 +348,34 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
del self.added_tokens_decoder[old_index] del self.added_tokens_decoder[old_index]
del self.added_tokens_encoder[token] del self.added_tokens_encoder[token]
def moses_punct_norm(self, text):
return self.moses_punct_normalizer.normalize(text)
def moses_tokenize(self, text):
return self.moses_tokenizer.tokenize(
text, aggressive_dash_splits=True, return_str=False, escape=False, protected_patterns=self.never_split
)
def moses_pipeline(self, text: str) -> List[str]:
"""
Does basic tokenization using :class:`sacremoses.MosesPunctNormalizer` and :class:`sacremoses.MosesTokenizer`
with `aggressive_dash_splits=True` (see :func:`sacremoses.tokenize.MosesTokenizer.tokenize`).
Additionally, large comma-separated numbers and floating point values are split.
E.g. "23,000 people are 1.80m tall" -> "23 @,@ 000 people are 1 @.@ 80m tall".
Args:
text: Text to be tokenized
Returns:
A list of tokenized strings
Example::
>>> tokenizer = TransfoXLTokenizer.from_pretrained("transfo-xl-wt103")
>>> tokenizer.moses_pipeline("23,000 people are 1.80 m tall")
['23', '@,@', '000', 'people', 'are', '1', '@.@', '80', 'm', 'tall']
"""
text = self.moses_punct_norm(text)
text = self.moses_tokenize(text)
text = tokenize_numbers(text)
return text
def _convert_id_to_token(self, idx): def _convert_id_to_token(self, idx):
"""Converts an id in a token (BPE) using the vocab.""" """Converts an id in a token (BPE) using the vocab."""
assert 0 <= idx < len(self), "Index {} out of vocabulary range".format(idx) assert 0 <= idx < len(self), "Index {} out of vocabulary range".format(idx)
...@@ -323,9 +399,12 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -323,9 +399,12 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
raise ValueError("Token not in vocabulary and no <unk> token in vocabulary for replacement") raise ValueError("Token not in vocabulary and no <unk> token in vocabulary for replacement")
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """ """
out_string = " ".join(tokens).strip() Converts a sequence of tokens (string) in a single string.
return out_string Additionally, the split numbers are converted back into it's original form.
"""
out_string = self.moses_detokenizer.detokenize(tokens)
return detokenize_numbers(out_string).strip()
def convert_to_tensor(self, symbols): def convert_to_tensor(self, symbols):
return torch.LongTensor(self.convert_tokens_to_ids(symbols)) return torch.LongTensor(self.convert_tokens_to_ids(symbols))
...@@ -347,7 +426,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -347,7 +426,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
if self.delimiter == "": if self.delimiter == "":
symbols = line symbols = line
else: else:
symbols = line.split(self.delimiter) symbols = self.moses_pipeline(line)
if add_double_eos: # lm1b if add_double_eos: # lm1b
return ["<S>"] + symbols + ["<S>"] return ["<S>"] + symbols + ["<S>"]
...@@ -356,19 +435,6 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -356,19 +435,6 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
else: else:
return symbols return symbols
def prepare_for_tokenization(self, text, is_pretokenized=False, **kwargs):
# add spaces before punctuation symbols as should be done in transfo-xl
add_space_before_punct_symbol = kwargs.pop("add_space_before_punct_symbol", False)
if add_space_before_punct_symbol:
text = self.punctuation_with_space_around_pattern.sub(r" ", text)
elif self.punction_without_space_before_pattern.search(text):
# searches until the first occurence of a punctuation symbol without surrounding spaces
logger.warning(
"You might want to consider setting `add_space_before_punct_symbol=True` as an argument to the `tokenizer.encode()` to avoid tokenizing words with punctuation symbols to the `<unk>` token"
)
return (text, kwargs)
class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer): class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer):
def __init__( def __init__(
...@@ -484,6 +550,11 @@ class TransfoXLTokenizerFast(PreTrainedTokenizerFast): ...@@ -484,6 +550,11 @@ class TransfoXLTokenizerFast(PreTrainedTokenizerFast):
**kwargs, **kwargs,
) )
warnings.warn(
"The class `TransfoXLTokenizerFast` is deprecated and will be removed in a future version. Please use `TransfoXLTokenizer` with it's enhanced tokenization instead.",
FutureWarning,
)
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
logger.warning( logger.warning(
"Please note you will not be able to load the vocabulary in" "Please note you will not be able to load the vocabulary in"
......
...@@ -12,14 +12,12 @@ from transformers import ( ...@@ -12,14 +12,12 @@ from transformers import (
OpenAIGPTTokenizer, OpenAIGPTTokenizer,
PreTrainedTokenizer, PreTrainedTokenizer,
RobertaTokenizer, RobertaTokenizer,
TransfoXLTokenizer,
is_torch_available, is_torch_available,
) )
from transformers.testing_utils import get_tests_dir, require_torch from transformers.testing_utils import get_tests_dir, require_torch
from transformers.tokenization_distilbert import DistilBertTokenizerFast from transformers.tokenization_distilbert import DistilBertTokenizerFast
from transformers.tokenization_openai import OpenAIGPTTokenizerFast from transformers.tokenization_openai import OpenAIGPTTokenizerFast
from transformers.tokenization_roberta import RobertaTokenizerFast from transformers.tokenization_roberta import RobertaTokenizerFast
from transformers.tokenization_transfo_xl import TransfoXLTokenizerFast
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -895,17 +893,3 @@ class NoPaddingTokenFastTokenizerMatchingTest(CommonFastTokenizerTest): ...@@ -895,17 +893,3 @@ class NoPaddingTokenFastTokenizerMatchingTest(CommonFastTokenizerTest):
max_length=max_length, max_length=max_length,
padding="max_length", padding="max_length",
) )
class TransfoXLFastTokenizerTest(NoPaddingTokenFastTokenizerMatchingTest):
TOKENIZERS_CLASSES = frozenset(
[Tokenizer("TransfoXL", TransfoXLTokenizerFast, TransfoXLTokenizer, "pretrained_vocab_file", None, None)]
)
@require_torch
def test_all_tokenizers(self):
super().test_all_tokenizers()
@require_torch
def test_pretokenized_tokenizers(self):
super().test_pretokenized_tokenizers()
...@@ -83,6 +83,44 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -83,6 +83,44 @@ class TransfoXLTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer.tokenize(" \tHeLLo ! how \n Are yoU ? "), ["HeLLo", "!", "how", "Are", "yoU", "?"] tokenizer.tokenize(" \tHeLLo ! how \n Are yoU ? "), ["HeLLo", "!", "how", "Are", "yoU", "?"]
) )
def test_full_tokenizer_moses_numbers(self):
tokenizer = TransfoXLTokenizer(lower_case=False)
text_in = "Hello (bracket) and side-scrolled [and] Henry's $5,000 with 3.34 m. What's up!?"
tokens_out = [
"Hello",
"(",
"bracket",
")",
"and",
"side",
"@-@",
"scrolled",
"[",
"and",
"]",
"Henry",
"'s",
"$",
"5",
"@,@",
"000",
"with",
"3",
"@.@",
"34",
"m",
".",
"What",
"'s",
"up",
"!",
"?",
]
self.assertListEqual(tokenizer.tokenize(text_in), tokens_out)
self.assertEqual(tokenizer.convert_tokens_to_string(tokens_out), text_in)
def test_move_added_token(self): def test_move_added_token(self):
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
original_len = len(tokenizer) original_len = len(tokenizer)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment