Commit 8fda532c authored by thomwolf's avatar thomwolf
Browse files

fix python 2 sentencepiece tokenization

parent ba10065c
......@@ -18,7 +18,8 @@ import os
import unittest
import pytest
from transformers.tokenization_t5 import (T5Tokenizer, SPIECE_UNDERLINE)
from transformers.tokenization_t5 import (T5Tokenizer)
from transformers.tokenization_xlnet import SPIECE_UNDERLINE
from .tokenization_tests_commons import CommonTestCases
......@@ -33,7 +34,7 @@ class T5TokenizationTest(CommonTestCases.CommonTokenizerTester):
super(T5TokenizationTest, self).setUp()
# We have a SentencePiece fixture for testing
tokenizer = T5Tokenizer(SAMPLE_VOCAB, keep_accents=True)
tokenizer = T5Tokenizer(SAMPLE_VOCAB)
tokenizer.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
......@@ -45,7 +46,7 @@ class T5TokenizationTest(CommonTestCases.CommonTokenizerTester):
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = T5Tokenizer(SAMPLE_VOCAB, keep_accents=True)
tokenizer = T5Tokenizer(SAMPLE_VOCAB)
tokens = tokenizer.tokenize(u'This is a test')
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
......
......@@ -18,6 +18,7 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import logging
import os
import six
from shutil import copyfile
from .tokenization_utils import PreTrainedTokenizer
......@@ -96,18 +97,35 @@ class T5Tokenizer(PreTrainedTokenizer):
self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(self.vocab_file)
def _tokenize(self, text):
def _tokenize(self, text, return_unicode=True, sample=False):
""" Take as input a string and return a list of strings (tokens) for words/sub-words
"""
return self.sp_model.EncodeAsPieces(text)
if not sample:
pieces = self.sp_model.EncodeAsPieces(text)
else:
pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
# convert back to unicode for py2
if six.PY2 and return_unicode:
ret_pieces = []
for piece in pieces:
if isinstance(piece, str):
piece = piece.decode('utf-8')
ret_pieces.append(piece)
pieces = ret_pieces
return pieces
def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """
return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index):
def _convert_id_to_token(self, index, return_unicode=True):
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
return self.sp_model.id_to_piece(index)
token = self.sp_model.IdToPiece(index)
if six.PY2 and return_unicode and isinstance(token, str):
token = token.decode('utf-8')
return token
def convert_tokens_to_string(self, tokens):
""" Converts a sequence of tokens (string) in a single string. """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment