Unverified Commit d5bc32ce authored by Philip May's avatar Philip May Committed by GitHub
Browse files

Add strip_accents to basic BertTokenizer. (#6280)

* Add strip_accents to basic tokenizer

* Add tests for strip_accents.

* fix style with black

* Fix strip_accents test

* empty commit to trigger CI

* Improved strip_accents check

* Add code quality with is not False
parent 31da35cc
......@@ -154,6 +154,9 @@ class BertTokenizer(PreTrainedTokenizer):
Whether to tokenize Chinese characters.
This should likely be deactivated for Japanese:
see: https://github.com/huggingface/transformers/issues/328
strip_accents: (:obj:`bool`, `optional`, defaults to :obj:`None`):
Whether to strip all accents. If this option is not specified (ie == None),
then it will be determined by the value for `lowercase` (as in the original Bert).
"""
vocab_files_names = VOCAB_FILES_NAMES
......@@ -173,6 +176,7 @@ class BertTokenizer(PreTrainedTokenizer):
cls_token="[CLS]",
mask_token="[MASK]",
tokenize_chinese_chars=True,
strip_accents=None,
**kwargs
):
super().__init__(
......@@ -194,7 +198,10 @@ class BertTokenizer(PreTrainedTokenizer):
self.do_basic_tokenize = do_basic_tokenize
if do_basic_tokenize:
self.basic_tokenizer = BasicTokenizer(
do_lower_case=do_lower_case, never_split=never_split, tokenize_chinese_chars=tokenize_chinese_chars
do_lower_case=do_lower_case,
never_split=never_split,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab, unk_token=self.unk_token)
......@@ -351,7 +358,7 @@ class BertTokenizer(PreTrainedTokenizer):
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True):
def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None):
""" Constructs a BasicTokenizer.
Args:
......@@ -364,12 +371,16 @@ class BasicTokenizer(object):
Whether to tokenize Chinese characters.
This should likely be deactivated for Japanese:
see: https://github.com/huggingface/pytorch-pretrained-BERT/issues/328
**strip_accents**: (`optional`) boolean (default None)
Whether to strip all accents. If this option is not specified (ie == None),
then it will be determined by the value for `lowercase` (as in the original Bert).
"""
if never_split is None:
never_split = []
self.do_lower_case = do_lower_case
self.never_split = set(never_split)
self.tokenize_chinese_chars = tokenize_chinese_chars
self.strip_accents = strip_accents
def tokenize(self, text, never_split=None):
""" Basic Tokenization of a piece of text.
......@@ -395,9 +406,13 @@ class BasicTokenizer(object):
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case and token not in never_split:
token = token.lower()
token = self._run_strip_accents(token)
if token not in never_split:
if self.do_lower_case:
token = token.lower()
if self.strip_accents is not False:
token = self._run_strip_accents(token)
elif self.strip_accents:
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token, never_split))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
......
......@@ -130,6 +130,30 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
def test_basic_tokenizer_lower_strip_accents_false(self):
tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=False)
self.assertListEqual(
tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["hällo", "!", "how", "are", "you", "?"]
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["h\u00E9llo"])
def test_basic_tokenizer_lower_strip_accents_true(self):
tokenizer = BasicTokenizer(do_lower_case=True, strip_accents=True)
self.assertListEqual(
tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["hallo", "!", "how", "are", "you", "?"]
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
def test_basic_tokenizer_lower_strip_accents_default(self):
tokenizer = BasicTokenizer(do_lower_case=True)
self.assertListEqual(
tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["hallo", "!", "how", "are", "you", "?"]
)
self.assertListEqual(tokenizer.tokenize("H\u00E9llo"), ["hello"])
def test_basic_tokenizer_no_lower(self):
tokenizer = BasicTokenizer(do_lower_case=False)
......@@ -137,6 +161,20 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer.tokenize(" \tHeLLo!how \n Are yoU? "), ["HeLLo", "!", "how", "Are", "yoU", "?"]
)
def test_basic_tokenizer_no_lower_strip_accents_false(self):
tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=False)
self.assertListEqual(
tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["HäLLo", "!", "how", "Are", "yoU", "?"]
)
def test_basic_tokenizer_no_lower_strip_accents_true(self):
tokenizer = BasicTokenizer(do_lower_case=False, strip_accents=True)
self.assertListEqual(
tokenizer.tokenize(" \tHäLLo!how \n Are yoU? "), ["HaLLo", "!", "how", "Are", "yoU", "?"]
)
def test_basic_tokenizer_respects_never_split_tokens(self):
tokenizer = BasicTokenizer(do_lower_case=False, never_split=["[UNK]"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment