Unverified Commit 24f16a33 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Remove python2 compability code (#2033)

* remove sys.version_info == 2

* remove sys.version_info < 3

* remove from __future__ imports
parent 42b8d462
from __future__ import print_function
import datetime import datetime
import os import os
import time import time
import sys
import torch import torch
import torch.utils.data import torch.utils.data
...@@ -141,12 +139,9 @@ def load_data(traindir, valdir, cache_dataset, distributed): ...@@ -141,12 +139,9 @@ def load_data(traindir, valdir, cache_dataset, distributed):
def main(args): def main(args):
if args.apex: if args.apex and amp is None:
if sys.version_info < (3, 0): raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
raise RuntimeError("Apex currently only supports Python 3. Aborting.") "to enable mixed-precision training.")
if amp is None:
raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
"to enable mixed-precision training.")
if args.output_dir: if args.output_dir:
utils.mkdir(args.output_dir) utils.mkdir(args.output_dir)
......
from __future__ import print_function
import datetime import datetime
import os import os
import time import time
......
from __future__ import print_function
from collections import defaultdict, deque from collections import defaultdict, deque
import datetime import datetime
import time import time
......
from __future__ import print_function
from collections import defaultdict, deque from collections import defaultdict, deque
import datetime import datetime
import pickle import pickle
......
from __future__ import print_function
from collections import defaultdict, deque from collections import defaultdict, deque
import datetime import datetime
import math import math
......
from __future__ import print_function
import datetime import datetime
import os import os
import time import time
import sys
import torch import torch
import torch.utils.data import torch.utils.data
from torch.utils.data.dataloader import default_collate from torch.utils.data.dataloader import default_collate
...@@ -95,12 +92,9 @@ def collate_fn(batch): ...@@ -95,12 +92,9 @@ def collate_fn(batch):
def main(args): def main(args):
if args.apex: if args.apex and amp is None:
if sys.version_info < (3, 0): raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
raise RuntimeError("Apex currently only supports Python 3. Aborting.") "to enable mixed-precision training.")
if amp is None:
raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
"to enable mixed-precision training.")
if args.output_dir: if args.output_dir:
utils.mkdir(args.output_dir) utils.mkdir(args.output_dir)
......
from __future__ import print_function
from collections import defaultdict, deque from collections import defaultdict, deque
import datetime import datetime
import time import time
......
from __future__ import print_function
import os import os
import io import io
import re import re
......
import os import os
import sys
import contextlib import contextlib
import tarfile import tarfile
import json import json
...@@ -7,12 +6,7 @@ import numpy as np ...@@ -7,12 +6,7 @@ import numpy as np
import PIL import PIL
import torch import torch
from common_utils import get_tmp_dir from common_utils import get_tmp_dir
import pickle
PYTHON2 = sys.version_info[0] == 2
if PYTHON2:
import cPickle as pickle
else:
import pickle
@contextlib.contextmanager @contextlib.contextmanager
......
from __future__ import division
import torch import torch
from torch import Tensor from torch import Tensor
import torchvision.transforms as transforms import torchvision.transforms as transforms
......
from __future__ import division
import math import math
import unittest import unittest
......
from __future__ import division
import os import os
import mock import mock
import torch import torch
......
from __future__ import division
import torch import torch
import torchvision.transforms._transforms_video as transforms import torchvision.transforms._transforms_video as transforms
from torchvision.transforms import Compose from torchvision.transforms import Compose
......
import collections import collections
import math import math
import os import os
import sys
import time import time
import unittest import unittest
from fractions import Fraction from fractions import Fraction
...@@ -22,10 +21,7 @@ except ImportError: ...@@ -22,10 +21,7 @@ except ImportError:
av = None av = None
if sys.version_info < (3,): from urllib.error import URLError
from urllib2 import URLError
else:
from urllib.error import URLError
VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos") VIDEO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "videos")
......
from __future__ import print_function
from PIL import Image from PIL import Image
import os import os
import os.path import os.path
......
from __future__ import print_function
from PIL import Image from PIL import Image
import os import os
import os.path import os.path
import numpy as np import numpy as np
import sys import pickle
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
from .vision import VisionDataset from .vision import VisionDataset
from .utils import check_integrity, download_and_extract_archive from .utils import check_integrity, download_and_extract_archive
...@@ -79,10 +73,7 @@ class CIFAR10(VisionDataset): ...@@ -79,10 +73,7 @@ class CIFAR10(VisionDataset):
for file_name, checksum in downloaded_list: for file_name, checksum in downloaded_list:
file_path = os.path.join(self.root, self.base_folder, file_name) file_path = os.path.join(self.root, self.base_folder, file_name)
with open(file_path, 'rb') as f: with open(file_path, 'rb') as f:
if sys.version_info[0] == 2: entry = pickle.load(f, encoding='latin1')
entry = pickle.load(f)
else:
entry = pickle.load(f, encoding='latin1')
self.data.append(entry['data']) self.data.append(entry['data'])
if 'labels' in entry: if 'labels' in entry:
self.targets.extend(entry['labels']) self.targets.extend(entry['labels'])
...@@ -100,10 +91,7 @@ class CIFAR10(VisionDataset): ...@@ -100,10 +91,7 @@ class CIFAR10(VisionDataset):
raise RuntimeError('Dataset metadata file not found or corrupted.' + raise RuntimeError('Dataset metadata file not found or corrupted.' +
' You can use download=True to download it') ' You can use download=True to download it')
with open(path, 'rb') as infile: with open(path, 'rb') as infile:
if sys.version_info[0] == 2: data = pickle.load(infile, encoding='latin1')
data = pickle.load(infile)
else:
data = pickle.load(infile, encoding='latin1')
self.classes = data[self.meta['key']] self.classes = data[self.meta['key']]
self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)} self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
......
...@@ -11,11 +11,7 @@ if sys.version_info < (3, 3): ...@@ -11,11 +11,7 @@ if sys.version_info < (3, 3):
else: else:
from collections.abc import Iterable from collections.abc import Iterable
if sys.version_info[0] == 2: import pickle
import cPickle as pickle
else:
import pickle
from .utils import verify_str_arg, iterable_to_str from .utils import verify_str_arg, iterable_to_str
......
from __future__ import print_function
from .vision import VisionDataset from .vision import VisionDataset
import warnings import warnings
from PIL import Image from PIL import Image
......
from __future__ import print_function
from PIL import Image from PIL import Image
from os.path import join from os.path import join
import os import os
......
from __future__ import print_function
from PIL import Image from PIL import Image
import os import os
import os.path import os.path
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment