Unverified Commit ce50305e authored by Aymeric Augustin's avatar Aymeric Augustin Committed by GitHub
Browse files

Merge pull request #2270 from aaugustin/remove-python-2

Remove support for Python 2
parents b6ea0f43 1a948d70
...@@ -15,11 +15,10 @@ ...@@ -15,11 +15,10 @@
# limitations under the License. # limitations under the License.
""" PyTorch XLNet model. """ PyTorch XLNet model.
""" """
from __future__ import absolute_import, division, print_function, unicode_literals
import logging import logging
import math import math
import sys
import torch import torch
from torch import nn from torch import nn
...@@ -420,9 +419,7 @@ class XLNetFeedForward(nn.Module): ...@@ -420,9 +419,7 @@ class XLNetFeedForward(nn.Module):
self.layer_1 = nn.Linear(config.d_model, config.d_inner) self.layer_1 = nn.Linear(config.d_model, config.d_inner)
self.layer_2 = nn.Linear(config.d_inner, config.d_model) self.layer_2 = nn.Linear(config.d_inner, config.d_model)
self.dropout = nn.Dropout(config.dropout) self.dropout = nn.Dropout(config.dropout)
if isinstance(config.ff_activation, str) or ( if isinstance(config.ff_activation, str):
sys.version_info[0] == 2 and isinstance(config.ff_activation, unicode) # noqa: F821
):
self.activation_function = ACT2FN[config.ff_activation] self.activation_function = ACT2FN[config.ff_activation]
else: else:
self.activation_function = config.ff_activation self.activation_function = config.ff_activation
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# ============================================================================== # ==============================================================================
"""Functions and classes related to optimization (weight updates).""" """Functions and classes related to optimization (weight updates)."""
from __future__ import absolute_import, division, print_function
import re import re
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import, division, print_function, unicode_literals
import csv import csv
import json import json
...@@ -26,7 +26,6 @@ from os.path import abspath, exists ...@@ -26,7 +26,6 @@ from os.path import abspath, exists
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import six
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, AutoConfig
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
...@@ -939,7 +938,7 @@ def pipeline( ...@@ -939,7 +938,7 @@ def pipeline(
modelcard = config modelcard = config
# Instantiate tokenizer if needed # Instantiate tokenizer if needed
if isinstance(tokenizer, six.string_types): if isinstance(tokenizer, str):
tokenizer = AutoTokenizer.from_pretrained(tokenizer) tokenizer = AutoTokenizer.from_pretrained(tokenizer)
# Instantiate config if needed # Instantiate config if needed
......
...@@ -13,15 +13,13 @@ ...@@ -13,15 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Tokenization classes for ALBERT model.""" """ Tokenization classes for ALBERT model."""
from __future__ import absolute_import, division, print_function, unicode_literals
import logging import logging
import os import os
import unicodedata import unicodedata
from shutil import copyfile from shutil import copyfile
import six
from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils import PreTrainedTokenizer
...@@ -139,9 +137,6 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -139,9 +137,6 @@ class AlbertTokenizer(PreTrainedTokenizer):
outputs = inputs outputs = inputs
outputs = outputs.replace("``", '"').replace("''", '"') outputs = outputs.replace("``", '"').replace("''", '"')
if six.PY2 and isinstance(outputs, str):
outputs = outputs.decode("utf-8")
if not self.keep_accents: if not self.keep_accents:
outputs = unicodedata.normalize("NFKD", outputs) outputs = unicodedata.normalize("NFKD", outputs)
outputs = "".join([c for c in outputs if not unicodedata.combining(c)]) outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
...@@ -150,14 +145,9 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -150,14 +145,9 @@ class AlbertTokenizer(PreTrainedTokenizer):
return outputs return outputs
def _tokenize(self, text, return_unicode=True, sample=False): def _tokenize(self, text, sample=False):
""" Tokenize a string. """ Tokenize a string. """
return_unicode is used only for py2
"""
text = self.preprocess_text(text) text = self.preprocess_text(text)
# note(zhiliny): in some systems, sentencepiece only accepts str for py2
if six.PY2 and isinstance(text, unicode): # noqa: F821
text = text.encode("utf-8")
if not sample: if not sample:
pieces = self.sp_model.EncodeAsPieces(text) pieces = self.sp_model.EncodeAsPieces(text)
...@@ -177,27 +167,15 @@ class AlbertTokenizer(PreTrainedTokenizer): ...@@ -177,27 +167,15 @@ class AlbertTokenizer(PreTrainedTokenizer):
else: else:
new_pieces.append(piece) new_pieces.append(piece)
# note(zhiliny): convert back to unicode for py2
if six.PY2 and return_unicode:
ret_pieces = []
for piece in new_pieces:
if isinstance(piece, str):
piece = piece.decode("utf-8")
ret_pieces.append(piece)
new_pieces = ret_pieces
return new_pieces return new_pieces
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """ """ Converts a token (str) in an id using the vocab. """
return self.sp_model.PieceToId(token) return self.sp_model.PieceToId(token)
def _convert_id_to_token(self, index, return_unicode=True): def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (string/unicode) using the vocab.""" """Converts an index (integer) in a token (str) using the vocab."""
token = self.sp_model.IdToPiece(index) return self.sp_model.IdToPiece(index)
if six.PY2 and return_unicode and isinstance(token, str):
token = token.decode("utf-8")
return token
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string.""" """Converts a sequence of tokens (strings for sub-words) in a single string."""
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" Auto Model class. """ """ Auto Model class. """
from __future__ import absolute_import, division, print_function, unicode_literals
import logging import logging
......
...@@ -14,13 +14,11 @@ ...@@ -14,13 +14,11 @@
# limitations under the License. # limitations under the License.
"""Tokenization classes.""" """Tokenization classes."""
from __future__ import absolute_import, division, print_function, unicode_literals
import collections import collections
import logging import logging
import os import os
import unicodedata import unicodedata
from io import open
from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils import PreTrainedTokenizer
...@@ -203,11 +201,11 @@ class BertTokenizer(PreTrainedTokenizer): ...@@ -203,11 +201,11 @@ class BertTokenizer(PreTrainedTokenizer):
return split_tokens return split_tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """ """ Converts a token (str) in an id using the vocab. """
return self.vocab.get(token, self.vocab.get(self.unk_token)) return self.vocab.get(token, self.vocab.get(self.unk_token))
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (string/unicode) using the vocab.""" """Converts an index (integer) in a token (str) using the vocab."""
return self.ids_to_tokens.get(index, self.unk_token) return self.ids_to_tokens.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
......
...@@ -14,15 +14,12 @@ ...@@ -14,15 +14,12 @@
# limitations under the License. # limitations under the License.
"""Tokenization classes.""" """Tokenization classes."""
from __future__ import absolute_import, division, print_function, unicode_literals
import collections import collections
import logging import logging
import os import os
import unicodedata import unicodedata
import six
from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer, load_vocab from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer, load_vocab
...@@ -195,10 +192,7 @@ class MecabTokenizer(object): ...@@ -195,10 +192,7 @@ class MecabTokenizer(object):
never_split = self.never_split + (never_split if never_split is not None else []) never_split = self.never_split + (never_split if never_split is not None else [])
tokens = [] tokens = []
if six.PY2: mecab_output = self.mecab.parse(text)
mecab_output = self.mecab.parse(text.encode("utf-8")).decode("utf-8")
else:
mecab_output = self.mecab.parse(text)
cursor = 0 cursor = 0
for line in mecab_output.split("\n"): for line in mecab_output.split("\n"):
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
""" Tokenization classes for Camembert model.""" """ Tokenization classes for Camembert model."""
from __future__ import absolute_import, division, print_function, unicode_literals
import logging import logging
import os import os
...@@ -155,7 +155,7 @@ class CamembertTokenizer(PreTrainedTokenizer): ...@@ -155,7 +155,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
return self.sp_model.EncodeAsPieces(text) return self.sp_model.EncodeAsPieces(text)
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """ """ Converts a token (str) in an id using the vocab. """
if token in self.fairseq_tokens_to_ids: if token in self.fairseq_tokens_to_ids:
return self.fairseq_tokens_to_ids[token] return self.fairseq_tokens_to_ids[token]
elif self.sp_model.PieceToId(token) == 0: elif self.sp_model.PieceToId(token) == 0:
...@@ -164,7 +164,7 @@ class CamembertTokenizer(PreTrainedTokenizer): ...@@ -164,7 +164,7 @@ class CamembertTokenizer(PreTrainedTokenizer):
return self.fairseq_offset + self.sp_model.PieceToId(token) return self.fairseq_offset + self.sp_model.PieceToId(token)
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (string/unicode) using the vocab.""" """Converts an index (integer) in a token (str) using the vocab."""
if index in self.fairseq_ids_to_tokens: if index in self.fairseq_ids_to_tokens:
return self.fairseq_ids_to_tokens[index] return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index - self.fairseq_offset) return self.sp_model.IdToPiece(index - self.fairseq_offset)
......
...@@ -13,12 +13,11 @@ ...@@ -13,12 +13,11 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tokenization classes for Salesforce CTRL.""" """Tokenization classes for Salesforce CTRL."""
from __future__ import absolute_import, division, print_function, unicode_literals
import json import json
import logging import logging
import os import os
from io import open
import regex as re import regex as re
...@@ -204,11 +203,11 @@ class CTRLTokenizer(PreTrainedTokenizer): ...@@ -204,11 +203,11 @@ class CTRLTokenizer(PreTrainedTokenizer):
return split_tokens return split_tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """ """ Converts a token (str) in an id using the vocab. """
return self.encoder.get(token, self.encoder.get(self.unk_token)) return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (string/unicode) using the vocab.""" """Converts an index (integer) in a token (str) using the vocab."""
return self.decoder.get(index, self.unk_token) return self.decoder.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
"""Tokenization classes for DistilBERT.""" """Tokenization classes for DistilBERT."""
from __future__ import absolute_import, division, print_function, unicode_literals
import logging import logging
......
...@@ -13,28 +13,18 @@ ...@@ -13,28 +13,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tokenization classes for OpenAI GPT.""" """Tokenization classes for OpenAI GPT."""
from __future__ import absolute_import, division, print_function, unicode_literals
import json import json
import logging import logging
import os import os
import sys from functools import lru_cache
from io import open
import regex as re import regex as re
from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils import PreTrainedTokenizer
try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = { VOCAB_FILES_NAMES = {
...@@ -80,7 +70,6 @@ def bytes_to_unicode(): ...@@ -80,7 +70,6 @@ def bytes_to_unicode():
This is a signficant percentage of your normal, say, 32K bpe vocab. This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings. To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
""" """
_chr = unichr if sys.version_info[0] == 2 else chr # noqa: F821
bs = ( bs = (
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
) )
...@@ -91,7 +80,7 @@ def bytes_to_unicode(): ...@@ -91,7 +80,7 @@ def bytes_to_unicode():
bs.append(b) bs.append(b)
cs.append(2 ** 8 + n) cs.append(2 ** 8 + n)
n += 1 n += 1
cs = [_chr(n) for n in cs] cs = [chr(n) for n in cs]
return dict(zip(bs, cs)) return dict(zip(bs, cs))
...@@ -212,23 +201,18 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -212,23 +201,18 @@ class GPT2Tokenizer(PreTrainedTokenizer):
bpe_tokens = [] bpe_tokens = []
for token in re.findall(self.pat, text): for token in re.findall(self.pat, text):
if sys.version_info[0] == 2: token = "".join(
token = "".join( self.byte_encoder[b] for b in token.encode("utf-8")
self.byte_encoder[ord(b)] for b in token ) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
else:
token = "".join(
self.byte_encoder[b] for b in token.encode("utf-8")
) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
return bpe_tokens return bpe_tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """ """ Converts a token (str) in an id using the vocab. """
return self.encoder.get(token, self.encoder.get(self.unk_token)) return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (string/unicode) using the vocab.""" """Converts an index (integer) in a token (str) using the vocab."""
return self.decoder.get(index) return self.decoder.get(index)
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
......
...@@ -13,13 +13,12 @@ ...@@ -13,13 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tokenization classes for OpenAI GPT.""" """Tokenization classes for OpenAI GPT."""
from __future__ import absolute_import, division, print_function, unicode_literals
import json import json
import logging import logging
import os import os
import re import re
from io import open
from .tokenization_bert import BasicTokenizer from .tokenization_bert import BasicTokenizer
from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils import PreTrainedTokenizer
...@@ -177,7 +176,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer): ...@@ -177,7 +176,7 @@ class OpenAIGPTTokenizer(PreTrainedTokenizer):
return split_tokens return split_tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """ """ Converts a token (str) in an id using the vocab. """
return self.encoder.get(token, self.encoder.get(self.unk_token)) return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
......
...@@ -13,22 +13,13 @@ ...@@ -13,22 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tokenization classes for RoBERTa.""" """Tokenization classes for RoBERTa."""
from __future__ import absolute_import, division, print_function, unicode_literals
import logging import logging
from .tokenization_gpt2 import GPT2Tokenizer from .tokenization_gpt2 import GPT2Tokenizer
try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = { VOCAB_FILES_NAMES = {
......
...@@ -14,15 +14,12 @@ ...@@ -14,15 +14,12 @@
# limitations under the License. # limitations under the License.
""" Tokenization class for model T5.""" """ Tokenization class for model T5."""
from __future__ import absolute_import, division, print_function, unicode_literals
import logging import logging
import os import os
import re import re
from shutil import copyfile from shutil import copyfile
import six
from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils import PreTrainedTokenizer
...@@ -138,41 +135,29 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -138,41 +135,29 @@ class T5Tokenizer(PreTrainedTokenizer):
self.sp_model = spm.SentencePieceProcessor() self.sp_model = spm.SentencePieceProcessor()
self.sp_model.Load(self.vocab_file) self.sp_model.Load(self.vocab_file)
def _tokenize(self, text, return_unicode=True, sample=False): def _tokenize(self, text, sample=False):
""" Take as input a string and return a list of strings (tokens) for words/sub-words """ Take as input a string and return a list of strings (tokens) for words/sub-words
""" """
if not sample: if not sample:
pieces = self.sp_model.EncodeAsPieces(text) pieces = self.sp_model.EncodeAsPieces(text)
else: else:
pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1) pieces = self.sp_model.SampleEncodeAsPieces(text, 64, 0.1)
# convert back to unicode for py2
if six.PY2 and return_unicode:
ret_pieces = []
for piece in pieces:
if isinstance(piece, str):
piece = piece.decode("utf-8")
ret_pieces.append(piece)
pieces = ret_pieces
return pieces return pieces
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """ """ Converts a token (str) in an id using the vocab. """
if token.startswith("<extra_id_"): if token.startswith("<extra_id_"):
match = re.match(r"<extra_id_(\d+)>", token) match = re.match(r"<extra_id_(\d+)>", token)
num = int(match.group(1)) num = int(match.group(1))
return self.vocab_size - num - 1 return self.vocab_size - num - 1
return self.sp_model.piece_to_id(token) return self.sp_model.piece_to_id(token)
def _convert_id_to_token(self, index, return_unicode=True): def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (string/unicode) using the vocab.""" """Converts an index (integer) in a token (str) using the vocab."""
if index < self.sp_model.get_piece_size(): if index < self.sp_model.get_piece_size():
token = self.sp_model.IdToPiece(index) token = self.sp_model.IdToPiece(index)
else: else:
token = "<extra_id_{}>".format(self.vocab_size - 1 - index) token = "<extra_id_{}>".format(self.vocab_size - 1 - index)
if six.PY2 and return_unicode and isinstance(token, str):
token = token.decode("utf-8")
return token return token
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
......
...@@ -16,14 +16,13 @@ ...@@ -16,14 +16,13 @@
""" Tokenization classes for Transformer XL model. """ Tokenization classes for Transformer XL model.
Adapted from https://github.com/kimiyoung/transformer-xl. Adapted from https://github.com/kimiyoung/transformer-xl.
""" """
from __future__ import absolute_import, division, print_function, unicode_literals
import glob import glob
import logging import logging
import os import os
import sys import pickle
from collections import Counter, OrderedDict from collections import Counter, OrderedDict
from io import open
import numpy as np import numpy as np
...@@ -36,11 +35,6 @@ try: ...@@ -36,11 +35,6 @@ try:
except ImportError: except ImportError:
pass pass
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -238,7 +232,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer): ...@@ -238,7 +232,7 @@ class TransfoXLTokenizer(PreTrainedTokenizer):
return self.idx2sym[idx] return self.idx2sym[idx]
def _convert_token_to_id(self, sym): def _convert_token_to_id(self, sym):
""" Converts a token (str/unicode) in an id using the vocab. """ """ Converts a token (str) in an id using the vocab. """
if sym in self.sym2idx: if sym in self.sym2idx:
return self.sym2idx[sym] return self.sym2idx[sym]
else: else:
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tokenization classes for OpenAI GPT.""" """Tokenization classes for OpenAI GPT."""
from __future__ import absolute_import, division, print_function, unicode_literals
import copy import copy
import itertools import itertools
...@@ -21,9 +21,6 @@ import json ...@@ -21,9 +21,6 @@ import json
import logging import logging
import os import os
import re import re
from io import open
import six
from .file_utils import cached_path, hf_bucket_url, is_remote_url, is_tf_available, is_torch_available from .file_utils import cached_path, hf_bucket_url, is_remote_url, is_tf_available, is_torch_available
...@@ -251,11 +248,9 @@ class PreTrainedTokenizer(object): ...@@ -251,11 +248,9 @@ class PreTrainedTokenizer(object):
for key, value in kwargs.items(): for key, value in kwargs.items():
if key in self.SPECIAL_TOKENS_ATTRIBUTES: if key in self.SPECIAL_TOKENS_ATTRIBUTES:
if key == "additional_special_tokens": if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)) and all( assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value # noqa: F821
)
else: else:
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode)) # noqa: F821 assert isinstance(value, str)
setattr(self, key, value) setattr(self, key, value)
@classmethod @classmethod
...@@ -567,7 +562,7 @@ class PreTrainedTokenizer(object): ...@@ -567,7 +562,7 @@ class PreTrainedTokenizer(object):
to_add_tokens = [] to_add_tokens = []
for token in new_tokens: for token in new_tokens:
assert isinstance(token, str) or (six.PY2 and isinstance(token, unicode)) # noqa: F821 assert isinstance(token, str)
if self.init_kwargs.get("do_lower_case", False) and token not in self.all_special_tokens: if self.init_kwargs.get("do_lower_case", False) and token not in self.all_special_tokens:
token = token.lower() token = token.lower()
if ( if (
...@@ -649,12 +644,10 @@ class PreTrainedTokenizer(object): ...@@ -649,12 +644,10 @@ class PreTrainedTokenizer(object):
for key, value in special_tokens_dict.items(): for key, value in special_tokens_dict.items():
assert key in self.SPECIAL_TOKENS_ATTRIBUTES assert key in self.SPECIAL_TOKENS_ATTRIBUTES
if key == "additional_special_tokens": if key == "additional_special_tokens":
assert isinstance(value, (list, tuple)) and all( assert isinstance(value, (list, tuple)) and all(isinstance(t, str) for t in value)
isinstance(t, str) or (six.PY2 and isinstance(t, unicode)) for t in value # noqa: F821
)
added_tokens += self.add_tokens(value) added_tokens += self.add_tokens(value)
else: else:
assert isinstance(value, str) or (six.PY2 and isinstance(value, unicode)) # noqa: F821 assert isinstance(value, str)
added_tokens += self.add_tokens([value]) added_tokens += self.add_tokens([value])
logger.info("Assigning %s to the %s key of the tokenizer", value, key) logger.info("Assigning %s to the %s key of the tokenizer", value, key)
setattr(self, key, value) setattr(self, key, value)
...@@ -740,13 +733,13 @@ class PreTrainedTokenizer(object): ...@@ -740,13 +733,13 @@ class PreTrainedTokenizer(object):
raise NotImplementedError raise NotImplementedError
def convert_tokens_to_ids(self, tokens): def convert_tokens_to_ids(self, tokens):
""" Converts a single token, or a sequence of tokens, (str/unicode) in a single integer id """ Converts a single token, or a sequence of tokens, (str) in a single integer id
(resp. a sequence of ids), using the vocabulary. (resp. a sequence of ids), using the vocabulary.
""" """
if tokens is None: if tokens is None:
return None return None
if isinstance(tokens, str) or (six.PY2 and isinstance(tokens, unicode)): # noqa: F821 if isinstance(tokens, str):
return self._convert_token_to_id_with_added_voc(tokens) return self._convert_token_to_id_with_added_voc(tokens)
ids = [] ids = []
...@@ -901,9 +894,9 @@ class PreTrainedTokenizer(object): ...@@ -901,9 +894,9 @@ class PreTrainedTokenizer(object):
""" """
def get_input_ids(text): def get_input_ids(text):
if isinstance(text, six.string_types): if isinstance(text, str):
return self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) return self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], six.string_types): elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
return self.convert_tokens_to_ids(text) return self.convert_tokens_to_ids(text)
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int): elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
return text return text
...@@ -1297,7 +1290,7 @@ class PreTrainedTokenizer(object): ...@@ -1297,7 +1290,7 @@ class PreTrainedTokenizer(object):
def convert_ids_to_tokens(self, ids, skip_special_tokens=False): def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
""" Converts a single index or a sequence of indices (integers) in a token " """ Converts a single index or a sequence of indices (integers) in a token "
(resp.) a sequence of tokens (str/unicode), using the vocabulary and added tokens. (resp.) a sequence of tokens (str), using the vocabulary and added tokens.
Args: Args:
skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False skip_special_tokens: Don't decode special tokens (self.all_special_tokens). Default: False
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Tokenization classes for XLM.""" """Tokenization classes for XLM."""
from __future__ import absolute_import, division, print_function, unicode_literals
import json import json
import logging import logging
...@@ -21,7 +21,6 @@ import os ...@@ -21,7 +21,6 @@ import os
import re import re
import sys import sys
import unicodedata import unicodedata
from io import open
import sacremoses as sm import sacremoses as sm
...@@ -798,11 +797,11 @@ class XLMTokenizer(PreTrainedTokenizer): ...@@ -798,11 +797,11 @@ class XLMTokenizer(PreTrainedTokenizer):
return split_tokens return split_tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """ """ Converts a token (str) in an id using the vocab. """
return self.encoder.get(token, self.encoder.get(self.unk_token)) return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (string/unicode) using the vocab.""" """Converts an index (integer) in a token (str) using the vocab."""
return self.decoder.get(index, self.unk_token) return self.decoder.get(index, self.unk_token)
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
""" Tokenization classes for XLM-RoBERTa model.""" """ Tokenization classes for XLM-RoBERTa model."""
from __future__ import absolute_import, division, print_function, unicode_literals
import logging import logging
import os import os
...@@ -171,13 +171,13 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): ...@@ -171,13 +171,13 @@ class XLMRobertaTokenizer(PreTrainedTokenizer):
return self.sp_model.EncodeAsPieces(text) return self.sp_model.EncodeAsPieces(text)
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """ """ Converts a token (str) in an id using the vocab. """
if token in self.fairseq_tokens_to_ids: if token in self.fairseq_tokens_to_ids:
return self.fairseq_tokens_to_ids[token] return self.fairseq_tokens_to_ids[token]
return self.sp_model.PieceToId(token) + self.fairseq_offset return self.sp_model.PieceToId(token) + self.fairseq_offset
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (string/unicode) using the vocab.""" """Converts an index (integer) in a token (str) using the vocab."""
if index in self.fairseq_ids_to_tokens: if index in self.fairseq_ids_to_tokens:
return self.fairseq_ids_to_tokens[index] return self.fairseq_ids_to_tokens[index]
return self.sp_model.IdToPiece(index - self.fairseq_offset) return self.sp_model.IdToPiece(index - self.fairseq_offset)
......
...@@ -13,15 +13,13 @@ ...@@ -13,15 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Tokenization classes for XLNet model.""" """ Tokenization classes for XLNet model."""
from __future__ import absolute_import, division, print_function, unicode_literals
import logging import logging
import os import os
import unicodedata import unicodedata
from shutil import copyfile from shutil import copyfile
import six
from .tokenization_utils import PreTrainedTokenizer from .tokenization_utils import PreTrainedTokenizer
...@@ -139,9 +137,6 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -139,9 +137,6 @@ class XLNetTokenizer(PreTrainedTokenizer):
outputs = inputs outputs = inputs
outputs = outputs.replace("``", '"').replace("''", '"') outputs = outputs.replace("``", '"').replace("''", '"')
if six.PY2 and isinstance(outputs, str):
outputs = outputs.decode("utf-8")
if not self.keep_accents: if not self.keep_accents:
outputs = unicodedata.normalize("NFKD", outputs) outputs = unicodedata.normalize("NFKD", outputs)
outputs = "".join([c for c in outputs if not unicodedata.combining(c)]) outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
...@@ -150,14 +145,9 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -150,14 +145,9 @@ class XLNetTokenizer(PreTrainedTokenizer):
return outputs return outputs
def _tokenize(self, text, return_unicode=True, sample=False): def _tokenize(self, text, sample=False):
""" Tokenize a string. """ Tokenize a string. """
return_unicode is used only for py2
"""
text = self.preprocess_text(text) text = self.preprocess_text(text)
# note(zhiliny): in some systems, sentencepiece only accepts str for py2
if six.PY2 and isinstance(text, unicode): # noqa: F821
text = text.encode("utf-8")
if not sample: if not sample:
pieces = self.sp_model.EncodeAsPieces(text) pieces = self.sp_model.EncodeAsPieces(text)
...@@ -177,27 +167,15 @@ class XLNetTokenizer(PreTrainedTokenizer): ...@@ -177,27 +167,15 @@ class XLNetTokenizer(PreTrainedTokenizer):
else: else:
new_pieces.append(piece) new_pieces.append(piece)
# note(zhiliny): convert back to unicode for py2
if six.PY2 and return_unicode:
ret_pieces = []
for piece in new_pieces:
if isinstance(piece, str):
piece = piece.decode("utf-8")
ret_pieces.append(piece)
new_pieces = ret_pieces
return new_pieces return new_pieces
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
""" Converts a token (str/unicode) in an id using the vocab. """ """ Converts a token (str) in an id using the vocab. """
return self.sp_model.PieceToId(token) return self.sp_model.PieceToId(token)
def _convert_id_to_token(self, index, return_unicode=True): def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (string/unicode) using the vocab.""" """Converts an index (integer) in a token (str) using the vocab."""
token = self.sp_model.IdToPiece(index) return self.sp_model.IdToPiece(index)
if six.PY2 and return_unicode and isinstance(token, str):
token = token.decode("utf-8")
return token
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (strings for sub-words) in a single string.""" """Converts a sequence of tokens (strings for sub-words) in a single string."""
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" Finetuning the library models for task XXX.""" """ Finetuning the library models for task XXX."""
from __future__ import absolute_import, division, print_function
import argparse import argparse
import glob import glob
...@@ -156,7 +155,7 @@ def train(args, train_dataset, model, tokenizer): ...@@ -156,7 +155,7 @@ def train(args, train_dataset, model, tokenizer):
tr_loss, logging_loss = 0.0, 0.0 tr_loss, logging_loss = 0.0, 0.0
model.zero_grad() model.zero_grad()
train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0])
set_seed(args) # Added here for reproductibility (even between python 2 and 3) set_seed(args) # Added here for reproductibility
for _ in train_iterator: for _ in train_iterator:
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
for step, batch in enumerate(epoch_iterator): for step, batch in enumerate(epoch_iterator):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment