Commit 436ce072 authored by Shijie Wu's avatar Shijie Wu
Browse files

Tokenization behave the same as original XLM proprocessing for most languages...

Tokenization behave the same as original XLM proprocessing for most languages except zh, ja and th; Change API to allow specifying language in `tokenize`
parent df9d6eff
......@@ -20,8 +20,11 @@ import json
import logging
import os
import re
import unicodedata
from io import open
import sacremoses as sm
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_bert import BasicTokenizer
......@@ -95,6 +98,93 @@ def text_standardize(text):
text = re.sub(r'[^\S\n]+', ' ', text)
return text.strip()
def lowercase_and_remove_accent(text):
"""
Lowercase and strips accents from a piece of text based on
https://github.com/facebookresearch/XLM/blob/master/tools/lowercase_and_remove_accent.py
"""
text = text.lower()
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output).lower()
def replace_unicode_punct(text):
'''
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/replace-unicode-punctuation.perl
'''
text = text.replace(',', ',')
text = text.replace('。 *', '. ')
text = text.replace('、', ',')
text = text.replace('”', '"')
text = text.replace('“', '"')
text = text.replace('∶', ':')
text = text.replace(':', ':')
text = text.replace('?', '?')
text = text.replace('《', '"')
text = text.replace('》', '"')
text = text.replace(')', ')')
text = text.replace('!', '!')
text = text.replace('(', '(')
text = text.replace(';', ';')
text = text.replace('1', '"')
text = text.replace('」', '"')
text = text.replace('「', '"')
text = text.replace('0', '0')
text = text.replace('3', '3')
text = text.replace('2', '2')
text = text.replace('5', '5')
text = text.replace('6', '6')
text = text.replace('9', '9')
text = text.replace('7', '7')
text = text.replace('8', '8')
text = text.replace('4', '4')
text = re.sub(r'.\s*', '. ', text)
text = text.replace('~', '~')
text = text.replace('’', '\'')
text = text.replace('…', '...')
text = text.replace('━', '-')
text = text.replace('〈', '<')
text = text.replace('〉', '>')
text = text.replace('【', '[')
text = text.replace('】', ']')
text = text.replace('%', '%')
return text
def remove_non_printing_char(text):
'''
Port of https://github.com/moses-smt/mosesdecoder/blob/master/scripts/tokenizer/remove-non-printing-char.perl
'''
output = []
for char in text:
cat = unicodedata.category(char)
if cat.startswith('C'):
continue
output.append(char)
return "".join(output)
def romanian_preprocessing(text):
'''Sennrich's WMT16 scripts for Romanian preprocessing, used by model `xlm-mlm-enro-1024`'''
# https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/normalise-romanian.py
text = text.replace("\u015e", "\u0218").replace("\u015f", "\u0219")
text = text.replace("\u0162", "\u021a").replace("\u0163", "\u021b")
# https://github.com/rsennrich/wmt16-scripts/blob/master/preprocess/remove-diacritics.py
text = text.replace("\u0218", "S").replace("\u0219", "s") #s-comma
text = text.replace("\u021a", "T").replace("\u021b", "t") #t-comma
text = text.replace("\u0102", "A").replace("\u0103", "a")
text = text.replace("\u00C2", "A").replace("\u00E2", "a")
text = text.replace("\u00CE", "I").replace("\u00EE", "i")
return text
class XLMTokenizer(PreTrainedTokenizer):
"""
BPE tokenizer for XLM, adapted from OpenAI BPE tokenizer. Peculiarities:
......@@ -122,16 +212,14 @@ class XLMTokenizer(PreTrainedTokenizer):
cls_token=cls_token, mask_token=mask_token,
additional_special_tokens=additional_special_tokens,
**kwargs)
try:
import ftfy
from spacy.lang.en import English
_nlp = English()
self.nlp = _nlp.Defaults.create_tokenizer(_nlp)
self.fix_text = ftfy.fix_text
except ImportError:
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
self.nlp = BasicTokenizer(do_lower_case=True)
self.fix_text = None
# cache of sm.MosesPunctNormalizer instance
self.cache_moses_punct_normalizer = dict()
# cache of sm.MosesTokenizer instance
self.cache_moses_tokenizer = dict()
self.lang_with_custom_tokenizer = set(['zh', 'th', 'ja'])
# True for current supported model (v1.2.0), False for XLM-17 & 100
self.do_lowercase_and_remove_accent = True
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.decoder = {v:k for k,v in self.encoder.items()}
......@@ -140,6 +228,28 @@ class XLMTokenizer(PreTrainedTokenizer):
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
def moses_punct_norm(self, text, lang):
if lang not in self.cache_moses_punct_normalizer:
punct_normalizer = sm.MosesPunctNormalizer(lang=lang)
self.cache_moses_punct_normalizer[lang] = punct_normalizer
else:
punct_normalizer = self.cache_moses_punct_normalizer[lang]
return punct_normalizer.normalize(text)
def moses_tokenize(self, text, lang):
if lang not in self.cache_moses_tokenizer:
moses_tokenizer = sm.MosesTokenizer(lang=lang)
self.cache_moses_tokenizer[lang] = moses_tokenizer
else:
moses_tokenizer = self.cache_moses_tokenizer[lang]
return moses_tokenizer.tokenize(text, return_str=False, escape=False)
def moses_pipeline(self, text, lang):
text = replace_unicode_punct(text)
text = self.moses_punct_norm(text, lang)
text = remove_non_printing_char(text)
return text
@property
def vocab_size(self):
return len(self.encoder)
......@@ -187,19 +297,21 @@ class XLMTokenizer(PreTrainedTokenizer):
self.cache[token] = word
return word
def _tokenize(self, text):
def _tokenize(self, text, lang='en'):
""" Tokenize a string. """
split_tokens = []
if self.fix_text is None:
# Using BERT's BasicTokenizer
text = self.nlp.tokenize(text)
if self.do_lowercase_and_remove_accent:
text = lowercase_and_remove_accent(text)
if lang not in self.lang_with_custom_tokenizer:
text = self.moses_pipeline(text, lang=lang)
# TODO: make sure we are using `xlm-mlm-enro-1024`, since XLM-100 doesn't have this step
if lang == 'ro':
text = romanian_preprocessing(text)
text = self.moses_tokenize(text, lang=lang)
for token in text:
split_tokens.extend([t for t in self.bpe(token).split(' ')])
else:
# Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
text = self.nlp(text_standardize(self.fix_text(text)))
for token in text:
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
raise ValueError
return split_tokens
def _convert_token_to_id(self, token):
......
......@@ -9,4 +9,6 @@ requests
# For OpenAI GPT
regex
# For XLNet
sentencepiece
\ No newline at end of file
sentencepiece
# For XLM
sacremoses
\ No newline at end of file
......@@ -55,7 +55,8 @@ setup(
'requests',
'tqdm',
'regex',
'sentencepiece'],
'sentencepiece',
'sacremoses'],
entry_points={
'console_scripts': [
"pytorch_transformers=pytorch_transformers.__main__:main",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment