Commit c946bb51 authored by thomwolf's avatar thomwolf
Browse files

fix xlnet tokenizer and python2

parent 18107563
...@@ -241,7 +241,7 @@ class XLNetTokenizer(object): ...@@ -241,7 +241,7 @@ class XLNetTokenizer(object):
) )
return ids return ids
def convert_ids_to_tokens(self, ids, skip_special_tokens=False): def convert_ids_to_tokens(self, ids, return_unicode=True, skip_special_tokens=False):
"""Converts a sequence of ids in tokens.""" """Converts a sequence of ids in tokens."""
tokens = [] tokens = []
for i in ids: for i in ids:
...@@ -250,6 +250,14 @@ class XLNetTokenizer(object): ...@@ -250,6 +250,14 @@ class XLNetTokenizer(object):
tokens.append(self.special_tokens_decoder[i]) tokens.append(self.special_tokens_decoder[i])
else: else:
tokens.append(self.sp_model.IdToPiece(i)) tokens.append(self.sp_model.IdToPiece(i))
if six.PY2 and return_unicode:
ret_pieces = []
for piece in tokens:
if isinstance(piece, str):
piece = piece.decode('utf-8')
ret_pieces.append(piece)
tokens = ret_pieces
return tokens return tokens
def encode(self, text, sample=False): def encode(self, text, sample=False):
......
...@@ -33,23 +33,24 @@ class XLNetTokenizationTest(unittest.TestCase): ...@@ -33,23 +33,24 @@ class XLNetTokenizationTest(unittest.TestCase):
def test_full_tokenizer(self): def test_full_tokenizer(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB) tokenizer = XLNetTokenizer(SAMPLE_VOCAB)
tokens = tokenizer.tokenize('This is a test') tokens = tokenizer.tokenize(u'This is a test')
self.assertListEqual(tokens, ['▁This', '▁is', '▁a', '▁t', 'est']) self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
self.assertListEqual( self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382]) tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
vocab_path = "/tmp/" vocab_path = u"/tmp/"
vocab_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path) vocab_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path)
tokenizer = tokenizer.from_pretrained(vocab_path, tokenizer = tokenizer.from_pretrained(vocab_path,
keep_accents=True) keep_accents=True)
os.remove(vocab_file) os.remove(vocab_file)
os.remove(special_tokens_file) os.remove(special_tokens_file)
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.") tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + 'I', SPIECE_UNDERLINE + 'was', SPIECE_UNDERLINE + 'b', 'or', 'n', SPIECE_UNDERLINE + 'in', SPIECE_UNDERLINE + '', self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
'9', '2', '0', '0', '0', ',', SPIECE_UNDERLINE + 'and', SPIECE_UNDERLINE + 'this', u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
SPIECE_UNDERLINE + 'is', SPIECE_UNDERLINE + 'f', 'al', 's', 'é', '.']) u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's', u'é', u'.'])
ids = tokenizer.convert_tokens_to_ids(tokens) ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual( self.assertListEqual(
ids, [8, 21, 84, 55, 24, 19, 7, 0, ids, [8, 21, 84, 55, 24, 19, 7, 0,
...@@ -57,10 +58,12 @@ class XLNetTokenizationTest(unittest.TestCase): ...@@ -57,10 +58,12 @@ class XLNetTokenizationTest(unittest.TestCase):
46, 72, 80, 6, 0, 4]) 46, 72, 80, 6, 0, 4])
back_tokens = tokenizer.convert_ids_to_tokens(ids) back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + 'I', SPIECE_UNDERLINE + 'was', SPIECE_UNDERLINE + 'b', 'or', 'n', SPIECE_UNDERLINE + 'in', self.assertListEqual(back_tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
SPIECE_UNDERLINE + '', '<unk>', '2', '0', '0', '0', ',', u'or', u'n', SPIECE_UNDERLINE + u'in',
SPIECE_UNDERLINE + 'and', SPIECE_UNDERLINE + 'this', SPIECE_UNDERLINE + 'is', SPIECE_UNDERLINE + 'f', 'al', 's', SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',',
'<unk>', '.']) SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
u'<unk>', u'.'])
@pytest.mark.slow @pytest.mark.slow
def test_tokenizer_from_pretrained(self): def test_tokenizer_from_pretrained(self):
...@@ -73,17 +76,19 @@ class XLNetTokenizationTest(unittest.TestCase): ...@@ -73,17 +76,19 @@ class XLNetTokenizationTest(unittest.TestCase):
def test_tokenizer_lower(self): def test_tokenizer_lower(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True) tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=True)
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + '', 'i', SPIECE_UNDERLINE + 'was', SPIECE_UNDERLINE + 'b', 'or', 'n', SPIECE_UNDERLINE + 'in', SPIECE_UNDERLINE + '', self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'', u'i', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
'9', '2', '0', '0', '0', ',', SPIECE_UNDERLINE + 'and', SPIECE_UNDERLINE + 'this', u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
SPIECE_UNDERLINE + 'is', SPIECE_UNDERLINE + 'f', 'al', 'se', '.']) u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["▁he", "ll", "o"]) SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.'])
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), [u"▁he", u"ll", u"o"])
def test_tokenizer_no_lower(self): def test_tokenizer_no_lower(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False) tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False)
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.") tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + 'I', SPIECE_UNDERLINE + 'was', SPIECE_UNDERLINE + 'b', 'or', 'n', SPIECE_UNDERLINE + 'in', SPIECE_UNDERLINE + '', self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', u'or',
'9', '2', '0', '0', '0', ',', SPIECE_UNDERLINE + 'and', SPIECE_UNDERLINE + 'this', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
SPIECE_UNDERLINE + 'is', SPIECE_UNDERLINE + 'f', 'al', 'se', '.']) u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.'])
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment