Commit 798b3b38 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Remove sys.version_info[0] == 2 or 3.

parent 8af25b16
...@@ -24,7 +24,6 @@ import glob ...@@ -24,7 +24,6 @@ import glob
import logging import logging
import os import os
import random import random
import sys
import numpy as np import numpy as np
import torch import torch
...@@ -104,12 +103,7 @@ class InputFeatures(object): ...@@ -104,12 +103,7 @@ class InputFeatures(object):
def read_swag_examples(input_file, is_training=True): def read_swag_examples(input_file, is_training=True):
with open(input_file, "r", encoding="utf-8") as f: with open(input_file, "r", encoding="utf-8") as f:
reader = csv.reader(f) lines = list(csv.reader(f))
lines = []
for line in reader:
if sys.version_info[0] == 2:
line = list(unicode(cell, "utf-8") for cell in line) # noqa: F821
lines.append(line)
if is_training and lines[0][-1] != "label": if is_training and lines[0][-1] != "label":
raise ValueError("For training, the input file must contain a label column.") raise ValueError("For training, the input file must contain a label column.")
......
...@@ -21,7 +21,6 @@ import glob ...@@ -21,7 +21,6 @@ import glob
import json import json
import logging import logging
import os import os
import sys
from io import open from io import open
from typing import List from typing import List
...@@ -179,13 +178,7 @@ class SwagProcessor(DataProcessor): ...@@ -179,13 +178,7 @@ class SwagProcessor(DataProcessor):
def _read_csv(self, input_file): def _read_csv(self, input_file):
with open(input_file, "r", encoding="utf-8") as f: with open(input_file, "r", encoding="utf-8") as f:
reader = csv.reader(f) return list(csv.reader(f))
lines = []
for line in reader:
if sys.version_info[0] == 2:
line = list(unicode(cell, "utf-8") for cell in line) # noqa: F821
lines.append(line)
return lines
def _create_examples(self, lines: List[List[str]], type: str): def _create_examples(self, lines: List[List[str]], type: str):
"""Creates examples for the training and dev sets.""" """Creates examples for the training and dev sets."""
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import argparse import argparse
import logging import logging
import os import os
import pickle
import sys import sys
from io import open from io import open
...@@ -34,12 +35,6 @@ from transformers import ( ...@@ -34,12 +35,6 @@ from transformers import (
from transformers.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES from transformers.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
# We do this to be able to load python 2 datasets pickles # We do this to be able to load python 2 datasets pickles
......
...@@ -18,7 +18,6 @@ import copy ...@@ -18,7 +18,6 @@ import copy
import csv import csv
import json import json
import logging import logging
import sys
from ...file_utils import is_tf_available, is_torch_available from ...file_utils import is_tf_available, is_torch_available
...@@ -98,13 +97,7 @@ class DataProcessor(object): ...@@ -98,13 +97,7 @@ class DataProcessor(object):
def _read_tsv(cls, input_file, quotechar=None): def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
with open(input_file, "r", encoding="utf-8-sig") as f: with open(input_file, "r", encoding="utf-8-sig") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) return list(csv.reader(f, delimiter="\t", quotechar=quotechar))
lines = []
for line in reader:
if sys.version_info[0] == 2:
line = list(unicode(cell, "utf-8") for cell in line) # noqa: F821
lines.append(line)
return lines
class SingleSentenceClassificationProcessor(DataProcessor): class SingleSentenceClassificationProcessor(DataProcessor):
......
...@@ -166,7 +166,7 @@ def filename_to_url(filename, cache_dir=None): ...@@ -166,7 +166,7 @@ def filename_to_url(filename, cache_dir=None):
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE cache_dir = TRANSFORMERS_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path): if isinstance(cache_dir, Path):
cache_dir = str(cache_dir) cache_dir = str(cache_dir)
cache_path = os.path.join(cache_dir, filename) cache_path = os.path.join(cache_dir, filename)
...@@ -201,9 +201,9 @@ def cached_path( ...@@ -201,9 +201,9 @@ def cached_path(
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE cache_dir = TRANSFORMERS_CACHE
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path): if isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename) url_or_filename = str(url_or_filename)
if sys.version_info[0] == 3 and isinstance(cache_dir, Path): if isinstance(cache_dir, Path):
cache_dir = str(cache_dir) cache_dir = str(cache_dir)
if is_remote_url(url_or_filename): if is_remote_url(url_or_filename):
...@@ -314,9 +314,7 @@ def get_from_cache( ...@@ -314,9 +314,7 @@ def get_from_cache(
""" """
if cache_dir is None: if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE cache_dir = TRANSFORMERS_CACHE
if sys.version_info[0] == 3 and isinstance(cache_dir, Path): if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if sys.version_info[0] == 2 and not isinstance(cache_dir, str):
cache_dir = str(cache_dir) cache_dir = str(cache_dir)
if not os.path.exists(cache_dir): if not os.path.exists(cache_dir):
...@@ -335,8 +333,6 @@ def get_from_cache( ...@@ -335,8 +333,6 @@ def get_from_cache(
except (EnvironmentError, requests.exceptions.Timeout): except (EnvironmentError, requests.exceptions.Timeout):
etag = None etag = None
if sys.version_info[0] == 2 and etag is not None:
etag = etag.decode("utf-8")
filename = url_to_filename(url, etag) filename = url_to_filename(url, etag)
# get cache path to put the file # get cache path to put the file
...@@ -400,9 +396,6 @@ def get_from_cache( ...@@ -400,9 +396,6 @@ def get_from_cache(
meta = {"url": url, "etag": etag} meta = {"url": url, "etag": etag}
meta_path = cache_path + ".json" meta_path = cache_path + ".json"
with open(meta_path, "w") as meta_file: with open(meta_path, "w") as meta_file:
output_string = json.dumps(meta) json.dump(meta, meta_file)
if sys.version_info[0] == 2 and isinstance(output_string, str):
output_string = unicode(output_string, "utf-8") # noqa: F821
meta_file.write(output_string)
return cache_path return cache_path
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
import logging import logging
import math import math
import os import os
import sys
import torch import torch
from torch import nn from torch import nn
...@@ -338,9 +337,7 @@ class BertIntermediate(nn.Module): ...@@ -338,9 +337,7 @@ class BertIntermediate(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BertIntermediate, self).__init__() super(BertIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str) or ( if isinstance(config.hidden_act, str):
sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) # noqa: F821
):
self.intermediate_act_fn = ACT2FN[config.hidden_act] self.intermediate_act_fn = ACT2FN[config.hidden_act]
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
...@@ -460,9 +457,7 @@ class BertPredictionHeadTransform(nn.Module): ...@@ -460,9 +457,7 @@ class BertPredictionHeadTransform(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BertPredictionHeadTransform, self).__init__() super(BertPredictionHeadTransform, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str) or ( if isinstance(config.hidden_act, str):
sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) # noqa: F821
):
self.transform_act_fn = ACT2FN[config.hidden_act] self.transform_act_fn = ACT2FN[config.hidden_act]
else: else:
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import logging import logging
import sys
import tensorflow as tf import tensorflow as tf
...@@ -311,9 +310,7 @@ class TFAlbertLayer(tf.keras.layers.Layer): ...@@ -311,9 +310,7 @@ class TFAlbertLayer(tf.keras.layers.Layer):
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="ffn" config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="ffn"
) )
if isinstance(config.hidden_act, str) or ( if isinstance(config.hidden_act, str):
sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) # noqa: F821
):
self.activation = ACT2FN[config.hidden_act] self.activation = ACT2FN[config.hidden_act]
else: else:
self.activation = config.hidden_act self.activation = config.hidden_act
...@@ -454,9 +451,7 @@ class TFAlbertMLMHead(tf.keras.layers.Layer): ...@@ -454,9 +451,7 @@ class TFAlbertMLMHead(tf.keras.layers.Layer):
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
config.embedding_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" config.embedding_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
) )
if isinstance(config.hidden_act, str) or ( if isinstance(config.hidden_act, str):
sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) # noqa: F821
):
self.activation = ACT2FN[config.hidden_act] self.activation = ACT2FN[config.hidden_act]
else: else:
self.activation = config.hidden_act self.activation = config.hidden_act
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import logging import logging
import sys
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -310,9 +309,7 @@ class TFBertIntermediate(tf.keras.layers.Layer): ...@@ -310,9 +309,7 @@ class TFBertIntermediate(tf.keras.layers.Layer):
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
) )
if isinstance(config.hidden_act, str) or ( if isinstance(config.hidden_act, str):
sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) # noqa: F821
):
self.intermediate_act_fn = ACT2FN[config.hidden_act] self.intermediate_act_fn = ACT2FN[config.hidden_act]
else: else:
self.intermediate_act_fn = config.hidden_act self.intermediate_act_fn = config.hidden_act
...@@ -417,9 +414,7 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer): ...@@ -417,9 +414,7 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
) )
if isinstance(config.hidden_act, str) or ( if isinstance(config.hidden_act, str):
sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode) # noqa: F821
):
self.transform_act_fn = ACT2FN[config.hidden_act] self.transform_act_fn = ACT2FN[config.hidden_act]
else: else:
self.transform_act_fn = config.hidden_act self.transform_act_fn = config.hidden_act
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
import logging import logging
import sys
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -290,9 +289,7 @@ class TFXLNetFeedForward(tf.keras.layers.Layer): ...@@ -290,9 +289,7 @@ class TFXLNetFeedForward(tf.keras.layers.Layer):
config.d_model, kernel_initializer=get_initializer(config.initializer_range), name="layer_2" config.d_model, kernel_initializer=get_initializer(config.initializer_range), name="layer_2"
) )
self.dropout = tf.keras.layers.Dropout(config.dropout) self.dropout = tf.keras.layers.Dropout(config.dropout)
if isinstance(config.ff_activation, str) or ( if isinstance(config.ff_activation, str):
sys.version_info[0] == 2 and isinstance(config.ff_activation, unicode) # noqa: F821
):
self.activation_function = ACT2FN[config.ff_activation] self.activation_function = ACT2FN[config.ff_activation]
else: else:
self.activation_function = config.ff_activation self.activation_function = config.ff_activation
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
import logging import logging
import math import math
import sys
import torch import torch
from torch import nn from torch import nn
...@@ -420,9 +419,7 @@ class XLNetFeedForward(nn.Module): ...@@ -420,9 +419,7 @@ class XLNetFeedForward(nn.Module):
self.layer_1 = nn.Linear(config.d_model, config.d_inner) self.layer_1 = nn.Linear(config.d_model, config.d_inner)
self.layer_2 = nn.Linear(config.d_inner, config.d_model) self.layer_2 = nn.Linear(config.d_inner, config.d_model)
self.dropout = nn.Dropout(config.dropout) self.dropout = nn.Dropout(config.dropout)
if isinstance(config.ff_activation, str) or ( if isinstance(config.ff_activation, str):
sys.version_info[0] == 2 and isinstance(config.ff_activation, unicode) # noqa: F821
):
self.activation_function = ACT2FN[config.ff_activation] self.activation_function = ACT2FN[config.ff_activation]
else: else:
self.activation_function = config.ff_activation self.activation_function = config.ff_activation
......
...@@ -18,7 +18,6 @@ ...@@ -18,7 +18,6 @@
import json import json
import logging import logging
import os import os
import sys
from io import open from io import open
import regex as re import regex as re
...@@ -80,7 +79,6 @@ def bytes_to_unicode(): ...@@ -80,7 +79,6 @@ def bytes_to_unicode():
This is a signficant percentage of your normal, say, 32K bpe vocab. This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings. To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
""" """
_chr = unichr if sys.version_info[0] == 2 else chr # noqa: F821
bs = ( bs = (
list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
) )
...@@ -91,7 +89,7 @@ def bytes_to_unicode(): ...@@ -91,7 +89,7 @@ def bytes_to_unicode():
bs.append(b) bs.append(b)
cs.append(2 ** 8 + n) cs.append(2 ** 8 + n)
n += 1 n += 1
cs = [_chr(n) for n in cs] cs = [chr(n) for n in cs]
return dict(zip(bs, cs)) return dict(zip(bs, cs))
...@@ -212,14 +210,9 @@ class GPT2Tokenizer(PreTrainedTokenizer): ...@@ -212,14 +210,9 @@ class GPT2Tokenizer(PreTrainedTokenizer):
bpe_tokens = [] bpe_tokens = []
for token in re.findall(self.pat, text): for token in re.findall(self.pat, text):
if sys.version_info[0] == 2: token = "".join(
token = "".join( self.byte_encoder[b] for b in token.encode("utf-8")
self.byte_encoder[ord(b)] for b in token ) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
else:
token = "".join(
self.byte_encoder[b] for b in token.encode("utf-8")
) # Maps all our bytes to unicode strings, avoiding controle tokens of the BPE (spaces in our case)
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" ")) bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
return bpe_tokens return bpe_tokens
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
import glob import glob
import logging import logging
import os import os
import sys import pickle
from collections import Counter, OrderedDict from collections import Counter, OrderedDict
from io import open from io import open
...@@ -36,11 +36,6 @@ try: ...@@ -36,11 +36,6 @@ try:
except ImportError: except ImportError:
pass pass
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -16,8 +16,7 @@ ...@@ -16,8 +16,7 @@
import json import json
import os import os
import tempfile
from .test_tokenization_common import TemporaryDirectory
class ConfigTester(object): class ConfigTester(object):
...@@ -42,7 +41,7 @@ class ConfigTester(object): ...@@ -42,7 +41,7 @@ class ConfigTester(object):
def create_and_test_config_to_json_file(self): def create_and_test_config_to_json_file(self):
config_first = self.config_class(**self.inputs_dict) config_first = self.config_class(**self.inputs_dict)
with TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
json_file_path = os.path.join(tmpdirname, "config.json") json_file_path = os.path.join(tmpdirname, "config.json")
config_first.to_json_file(json_file_path) config_first.to_json_file(json_file_path)
config_second = self.config_class.from_json_file(json_file_path) config_second = self.config_class.from_json_file(json_file_path)
...@@ -52,7 +51,7 @@ class ConfigTester(object): ...@@ -52,7 +51,7 @@ class ConfigTester(object):
def create_and_test_config_from_and_save_pretrained(self): def create_and_test_config_from_and_save_pretrained(self):
config_first = self.config_class(**self.inputs_dict) config_first = self.config_class(**self.inputs_dict)
with TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
config_first.save_pretrained(tmpdirname) config_first.save_pretrained(tmpdirname)
config_second = self.config_class.from_pretrained(tmpdirname) config_second = self.config_class.from_pretrained(tmpdirname)
......
...@@ -16,12 +16,11 @@ ...@@ -16,12 +16,11 @@
import json import json
import os import os
import tempfile
import unittest import unittest
from transformers.modelcard import ModelCard from transformers.modelcard import ModelCard
from .test_tokenization_common import TemporaryDirectory
class ModelCardTester(unittest.TestCase): class ModelCardTester(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -65,7 +64,7 @@ class ModelCardTester(unittest.TestCase): ...@@ -65,7 +64,7 @@ class ModelCardTester(unittest.TestCase):
def test_model_card_to_json_file(self): def test_model_card_to_json_file(self):
model_card_first = ModelCard.from_dict(self.inputs_dict) model_card_first = ModelCard.from_dict(self.inputs_dict)
with TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
filename = os.path.join(tmpdirname, "modelcard.json") filename = os.path.join(tmpdirname, "modelcard.json")
model_card_first.to_json_file(filename) model_card_first.to_json_file(filename)
model_card_second = ModelCard.from_json_file(filename) model_card_second = ModelCard.from_json_file(filename)
...@@ -75,7 +74,7 @@ class ModelCardTester(unittest.TestCase): ...@@ -75,7 +74,7 @@ class ModelCardTester(unittest.TestCase):
def test_model_card_from_and_save_pretrained(self): def test_model_card_from_and_save_pretrained(self):
model_card_first = ModelCard.from_dict(self.inputs_dict) model_card_first = ModelCard.from_dict(self.inputs_dict)
with TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model_card_first.save_pretrained(tmpdirname) model_card_first.save_pretrained(tmpdirname)
model_card_second = ModelCard.from_pretrained(tmpdirname) model_card_second = ModelCard.from_pretrained(tmpdirname)
......
...@@ -19,8 +19,6 @@ import json ...@@ -19,8 +19,6 @@ import json
import logging import logging
import os.path import os.path
import random import random
import shutil
import sys
import tempfile import tempfile
import unittest import unittest
import uuid import uuid
...@@ -43,23 +41,6 @@ if is_torch_available(): ...@@ -43,23 +41,6 @@ if is_torch_available():
BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
) )
if sys.version_info[0] == 2:
class TemporaryDirectory(object):
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
def __enter__(self):
self.name = tempfile.mkdtemp()
return self.name
def __exit__(self, exc_type, exc_value, traceback):
shutil.rmtree(self.name)
else:
TemporaryDirectory = tempfile.TemporaryDirectory
unicode = str
def _config_zero_init(config): def _config_zero_init(config):
configs_no_init = copy.deepcopy(config) configs_no_init = copy.deepcopy(config)
...@@ -92,7 +73,7 @@ class ModelTesterMixin: ...@@ -92,7 +73,7 @@ class ModelTesterMixin:
out_2 = outputs[0].numpy() out_2 = outputs[0].numpy()
out_2[np.isnan(out_2)] = 0 out_2[np.isnan(out_2)] = 0
with TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname) model = model_class.from_pretrained(tmpdirname)
model.to(torch_device) model.to(torch_device)
...@@ -238,7 +219,7 @@ class ModelTesterMixin: ...@@ -238,7 +219,7 @@ class ModelTesterMixin:
except RuntimeError: except RuntimeError:
self.fail("Couldn't trace module.") self.fail("Couldn't trace module.")
with TemporaryDirectory() as tmp_dir_name: with tempfile.TemporaryDirectory() as tmp_dir_name:
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt") pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
try: try:
...@@ -366,7 +347,7 @@ class ModelTesterMixin: ...@@ -366,7 +347,7 @@ class ModelTesterMixin:
heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), -1: [0]} heads_to_prune = {0: list(range(1, self.model_tester.num_attention_heads)), -1: [0]}
model.prune_heads(heads_to_prune) model.prune_heads(heads_to_prune)
with TemporaryDirectory() as temp_dir_name: with tempfile.TemporaryDirectory() as temp_dir_name:
model.save_pretrained(temp_dir_name) model.save_pretrained(temp_dir_name)
model = model_class.from_pretrained(temp_dir_name) model = model_class.from_pretrained(temp_dir_name)
model.to(torch_device) model.to(torch_device)
...@@ -435,7 +416,7 @@ class ModelTesterMixin: ...@@ -435,7 +416,7 @@ class ModelTesterMixin:
self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads) self.assertEqual(attentions[2].shape[-3], self.model_tester.num_attention_heads)
self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads) self.assertEqual(attentions[3].shape[-3], self.model_tester.num_attention_heads)
with TemporaryDirectory() as temp_dir_name: with tempfile.TemporaryDirectory() as temp_dir_name:
model.save_pretrained(temp_dir_name) model.save_pretrained(temp_dir_name)
model = model_class.from_pretrained(temp_dir_name) model = model_class.from_pretrained(temp_dir_name)
model.to(torch_device) model.to(torch_device)
......
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
import copy import copy
import os import os
import random import random
import shutil
import sys
import tempfile import tempfile
from transformers import is_tf_available, is_torch_available from transformers import is_tf_available, is_torch_available
...@@ -32,23 +30,6 @@ if is_tf_available(): ...@@ -32,23 +30,6 @@ if is_tf_available():
# from transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP # from transformers.modeling_bert import BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
if sys.version_info[0] == 2:
class TemporaryDirectory(object):
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
def __enter__(self):
self.name = tempfile.mkdtemp()
return self.name
def __exit__(self, exc_type, exc_value, traceback):
shutil.rmtree(self.name)
else:
TemporaryDirectory = tempfile.TemporaryDirectory
unicode = str
def _config_zero_init(config): def _config_zero_init(config):
configs_no_init = copy.deepcopy(config) configs_no_init = copy.deepcopy(config)
...@@ -87,7 +68,7 @@ class TFModelTesterMixin: ...@@ -87,7 +68,7 @@ class TFModelTesterMixin:
model = model_class(config) model = model_class(config)
outputs = model(inputs_dict) outputs = model(inputs_dict)
with TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname) model = model_class.from_pretrained(tmpdirname)
after_outputs = model(inputs_dict) after_outputs = model(inputs_dict)
...@@ -137,7 +118,7 @@ class TFModelTesterMixin: ...@@ -137,7 +118,7 @@ class TFModelTesterMixin:
self.assertLessEqual(max_diff, 2e-2) self.assertLessEqual(max_diff, 2e-2)
# Check we can load pt model in tf and vice-versa with checkpoint => model functions # Check we can load pt model in tf and vice-versa with checkpoint => model functions
with TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin") pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
torch.save(pt_model.state_dict(), pt_checkpoint_path) torch.save(pt_model.state_dict(), pt_checkpoint_path)
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path) tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
...@@ -180,7 +161,7 @@ class TFModelTesterMixin: ...@@ -180,7 +161,7 @@ class TFModelTesterMixin:
model = model_class(config) model = model_class(config)
# Let's load it from the disk to be sure we can use pretrained weights # Let's load it from the disk to be sure we can use pretrained weights
with TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
outputs = model(inputs_dict) # build the model outputs = model(inputs_dict) # build the model
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname)
model = model_class.from_pretrained(tmpdirname) model = model_class.from_pretrained(tmpdirname)
......
...@@ -15,11 +15,11 @@ ...@@ -15,11 +15,11 @@
import os import os
import tempfile
import unittest import unittest
from transformers import is_torch_available from transformers import is_torch_available
from .test_tokenization_common import TemporaryDirectory
from .utils import require_torch from .utils import require_torch
...@@ -50,7 +50,7 @@ def unwrap_and_save_reload_schedule(scheduler, num_steps=10): ...@@ -50,7 +50,7 @@ def unwrap_and_save_reload_schedule(scheduler, num_steps=10):
scheduler.step() scheduler.step()
lrs.append(scheduler.get_lr()) lrs.append(scheduler.get_lr())
if step == num_steps // 2: if step == num_steps // 2:
with TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
file_name = os.path.join(tmpdirname, "schedule.bin") file_name = os.path.join(tmpdirname, "schedule.bin")
torch.save(scheduler.state_dict(), file_name) torch.save(scheduler.state_dict(), file_name)
......
...@@ -15,33 +15,12 @@ ...@@ -15,33 +15,12 @@
import os import os
import pickle
import shutil import shutil
import sys
import tempfile import tempfile
from io import open from io import open
if sys.version_info[0] == 2:
import cPickle as pickle
class TemporaryDirectory(object):
"""Context manager for tempfile.mkdtemp() so it's usable with "with" statement."""
def __enter__(self):
self.name = tempfile.mkdtemp()
return self.name
def __exit__(self, exc_type, exc_value, traceback):
shutil.rmtree(self.name)
else:
import pickle
TemporaryDirectory = tempfile.TemporaryDirectory
unicode = str
class TokenizerTesterMixin: class TokenizerTesterMixin:
tokenizer_class = None tokenizer_class = None
...@@ -90,7 +69,7 @@ class TokenizerTesterMixin: ...@@ -90,7 +69,7 @@ class TokenizerTesterMixin:
before_tokens = tokenizer.encode("He is very happy, UNwant\u00E9d,running", add_special_tokens=False) before_tokens = tokenizer.encode("He is very happy, UNwant\u00E9d,running", add_special_tokens=False)
with TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
tokenizer.save_pretrained(tmpdirname) tokenizer.save_pretrained(tmpdirname)
tokenizer = self.tokenizer_class.from_pretrained(tmpdirname) tokenizer = self.tokenizer_class.from_pretrained(tmpdirname)
...@@ -108,7 +87,7 @@ class TokenizerTesterMixin: ...@@ -108,7 +87,7 @@ class TokenizerTesterMixin:
text = "Munich and Berlin are nice cities" text = "Munich and Berlin are nice cities"
subwords = tokenizer.tokenize(text) subwords = tokenizer.tokenize(text)
with TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
filename = os.path.join(tmpdirname, "tokenizer.bin") filename = os.path.join(tmpdirname, "tokenizer.bin")
with open(filename, "wb") as handle: with open(filename, "wb") as handle:
...@@ -246,7 +225,7 @@ class TokenizerTesterMixin: ...@@ -246,7 +225,7 @@ class TokenizerTesterMixin:
self.assertEqual(text_2, output_text) self.assertEqual(text_2, output_text)
self.assertNotEqual(len(tokens_2), 0) self.assertNotEqual(len(tokens_2), 0)
self.assertIsInstance(text_2, (str, unicode)) self.assertIsInstance(text_2, str)
def test_encode_decode_with_spaces(self): def test_encode_decode_with_spaces(self):
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
...@@ -268,9 +247,6 @@ class TokenizerTesterMixin: ...@@ -268,9 +247,6 @@ class TokenizerTesterMixin:
self.assertListEqual(weights_list, weights_list_2) self.assertListEqual(weights_list, weights_list_2)
def test_mask_output(self): def test_mask_output(self):
if sys.version_info <= (3, 0):
return
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
if tokenizer.build_inputs_with_special_tokens.__qualname__.split(".")[0] != "PreTrainedTokenizer": if tokenizer.build_inputs_with_special_tokens.__qualname__.split(".")[0] != "PreTrainedTokenizer":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment