Unverified Commit eae81c1e authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Add yapf and isort to travis (#96)

* add yapf and isort to travis

* minor formatting

* remove blank lines

* skip unit tests for progressbar when python2

* update travis to ubuntu 16.04

* use a newer version ffmpeg

* add -y to add-apt-repository
parent 3e1e297d
dist: trusty dist: xenial
sudo: required sudo: required
language: python language: python
before_install: before_install:
- sudo add-apt-repository -y ppa:mc3man/trusty-media - sudo add-apt-repository -y ppa:mc3man/xerus-media
- sudo apt-get update - sudo apt-get update
- sudo apt-get install -y ffmpeg - sudo apt-get install -y ffmpeg
install: install:
- pip install opencv-python pyyaml codecov flake8 Cython - pip install Cython opencv-python pyyaml codecov flake8 yapf isort
cache: cache:
pip: true pip: true
...@@ -19,11 +19,14 @@ env: ...@@ -19,11 +19,14 @@ env:
python: python:
- "2.7" - "2.7"
- "3.4"
- "3.5" - "3.5"
- "3.6" - "3.6"
- "3.7"
before_script: flake8 before_script:
- flake8
- isort -rc --diff mmcv/ tests/ examples/
- yapf -r -d mmcv/ tests/ examples/
script: coverage run --source=mmcv setup.py test script: coverage run --source=mmcv setup.py test
......
...@@ -3,18 +3,18 @@ import os ...@@ -3,18 +3,18 @@ import os
from argparse import ArgumentParser from argparse import ArgumentParser
from collections import OrderedDict from collections import OrderedDict
import resnet_cifar
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn.functional as F import torch.nn.functional as F
from mmcv import Config
from mmcv.runner import Runner, DistSamplerSeedHook
from torch.nn.parallel import DataParallel, DistributedDataParallel from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from torchvision import datasets, transforms from torchvision import datasets, transforms
import resnet_cifar from mmcv import Config
from mmcv.runner import DistSamplerSeedHook, Runner
def accuracy(output, target, topk=(1, )): def accuracy(output, target, topk=(1, )):
......
...@@ -50,7 +50,7 @@ def dequantize(arr, min_val, max_val, levels, dtype=np.float64): ...@@ -50,7 +50,7 @@ def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
'min_val ({}) must be smaller than max_val ({})'.format( 'min_val ({}) must be smaller than max_val ({})'.format(
min_val, max_val)) min_val, max_val))
dequantized_arr = (arr + 0.5).astype(dtype) * ( dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
max_val - min_val) / levels + min_val min_val) / levels + min_val
return dequantized_arr return dequantized_arr
...@@ -3,8 +3,8 @@ import logging ...@@ -3,8 +3,8 @@ import logging
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
from .weight_init import constant_init, kaiming_init
from ..runner import load_checkpoint from ..runner import load_checkpoint
from .weight_init import constant_init, kaiming_init
def conv3x3(in_planes, out_planes, stride=1, dilation=1): def conv3x3(in_planes, out_planes, stride=1, dilation=1):
......
...@@ -2,8 +2,8 @@ import logging ...@@ -2,8 +2,8 @@ import logging
import torch.nn as nn import torch.nn as nn
from .weight_init import constant_init, normal_init, kaiming_init
from ..runner import load_checkpoint from ..runner import load_checkpoint
from .weight_init import constant_init, kaiming_init, normal_init
def conv3x3(in_planes, out_planes, dilation=1): def conv3x3(in_planes, out_planes, dilation=1):
......
import yaml import yaml
try: try:
from yaml import CLoader as Loader, CDumper as Dumper from yaml import CLoader as Loader, CDumper as Dumper
except ImportError: except ImportError:
from yaml import Loader, Dumper from yaml import Loader, Dumper
from .base import BaseFileHandler from .base import BaseFileHandler # isort:skip
class YamlHandler(BaseFileHandler): class YamlHandler(BaseFileHandler):
......
from ..utils import is_list_of, is_str
from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
from ..utils import is_str, is_list_of
file_handlers = { file_handlers = {
'json': JsonHandler(), 'json': JsonHandler(),
......
...@@ -3,8 +3,8 @@ import os.path as osp ...@@ -3,8 +3,8 @@ import os.path as osp
import cv2 import cv2
import numpy as np import numpy as np
from mmcv.utils import is_str, check_file_exist, mkdir_or_exist
from mmcv.opencv_info import USE_OPENCV2 from mmcv.opencv_info import USE_OPENCV2
from mmcv.utils import check_file_exist, is_str, mkdir_or_exist
if not USE_OPENCV2: if not USE_OPENCV2:
from cv2 import IMREAD_COLOR, IMREAD_GRAYSCALE, IMREAD_UNCHANGED from cv2 import IMREAD_COLOR, IMREAD_GRAYSCALE, IMREAD_UNCHANGED
......
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch._utils import (_flatten_dense_tensors, _unflatten_dense_tensors, from torch._utils import (_flatten_dense_tensors, _take_tensors,
_take_tensors) _unflatten_dense_tensors)
from .scatter_gather import scatter_kwargs from .scatter_gather import scatter_kwargs
......
...@@ -6,14 +6,13 @@ import warnings ...@@ -6,14 +6,13 @@ import warnings
from collections import OrderedDict from collections import OrderedDict
from importlib import import_module from importlib import import_module
import mmcv
import torch import torch
import torchvision import torchvision
from torch.utils import model_zoo from torch.utils import model_zoo
import mmcv
from .utils import get_dist_info from .utils import get_dist_info
open_mmlab_model_urls = { open_mmlab_model_urls = {
'vgg16_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/vgg16_caffe-292e1171.pth', # noqa: E501 'vgg16_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/vgg16_caffe-292e1171.pth', # noqa: E501
'resnet50_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_caffe-788b5fa3.pth', # noqa: E501 'resnet50_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_caffe-788b5fa3.pth', # noqa: E501
......
from .hook import Hook
from ..utils import master_only from ..utils import master_only
from .hook import Hook
class CheckpointHook(Hook): class CheckpointHook(Hook):
......
from __future__ import print_function from __future__ import print_function
import logging import logging
import os import os
import os.path as osp import os.path as osp
...@@ -10,8 +9,8 @@ from threading import Thread ...@@ -10,8 +9,8 @@ from threading import Thread
import requests import requests
from six.moves.queue import Empty, Queue from six.moves.queue import Empty, Queue
from ...utils import get_host_info, master_only
from .base import LoggerHook from .base import LoggerHook
from ...utils import master_only, get_host_info
class PaviClient(object): class PaviClient(object):
......
import os.path as osp import os.path as osp
from .base import LoggerHook
from ...utils import master_only from ...utils import master_only
from .base import LoggerHook
class TensorboardLoggerHook(LoggerHook): class TensorboardLoggerHook(LoggerHook):
......
...@@ -46,8 +46,9 @@ class TextLoggerHook(LoggerHook): ...@@ -46,8 +46,9 @@ class TextLoggerHook(LoggerHook):
log_dict['time'], log_dict['data_time'])) log_dict['time'], log_dict['data_time']))
log_str += 'memory: {}, '.format(log_dict['memory']) log_str += 'memory: {}, '.format(log_dict['memory'])
else: else:
log_str = 'Epoch({}) [{}][{}]\t'.format( log_str = 'Epoch({}) [{}][{}]\t'.format(log_dict['mode'],
log_dict['mode'], log_dict['epoch'] - 1, log_dict['iter']) log_dict['epoch'] - 1,
log_dict['iter'])
log_items = [] log_items = []
for name, val in log_dict.items(): for name, val in log_dict.items():
# TODO: resolve this hack # TODO: resolve this hack
......
from __future__ import division from __future__ import division
from math import cos, pi from math import cos, pi
from .hook import Hook from .hook import Hook
......
...@@ -2,14 +2,14 @@ import logging ...@@ -2,14 +2,14 @@ import logging
import os.path as osp import os.path as osp
import time import time
import mmcv
import torch import torch
import mmcv
from . import hooks from . import hooks
from .log_buffer import LogBuffer
from .hooks import (Hook, LrUpdaterHook, CheckpointHook, IterTimerHook,
OptimizerHook, lr_updater)
from .checkpoint import load_checkpoint, save_checkpoint from .checkpoint import load_checkpoint, save_checkpoint
from .hooks import (CheckpointHook, Hook, IterTimerHook, LrUpdaterHook,
OptimizerHook, lr_updater)
from .log_buffer import LogBuffer
from .priority import get_priority from .priority import get_priority
from .utils import get_dist_info, get_host_info, get_time_str, obj_from_dict from .utils import get_dist_info, get_host_info, get_time_str, obj_from_dict
...@@ -139,8 +139,8 @@ class Runner(object): ...@@ -139,8 +139,8 @@ class Runner(object):
<class 'torch.optim.sgd.SGD'> <class 'torch.optim.sgd.SGD'>
""" """
if isinstance(optimizer, dict): if isinstance(optimizer, dict):
optimizer = obj_from_dict( optimizer = obj_from_dict(optimizer, torch.optim,
optimizer, torch.optim, dict(params=self.model.parameters())) dict(params=self.model.parameters()))
elif not isinstance(optimizer, torch.optim.Optimizer): elif not isinstance(optimizer, torch.optim.Optimizer):
raise TypeError( raise TypeError(
'optimizer must be either an Optimizer object or a dict, ' 'optimizer must be either an Optimizer object or a dict, '
......
...@@ -4,10 +4,11 @@ import time ...@@ -4,10 +4,11 @@ import time
from getpass import getuser from getpass import getuser
from socket import gethostname from socket import gethostname
import mmcv
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import mmcv
def get_host_info(): def get_host_info():
return '{}@{}'.format(getuser(), gethostname()) return '{}@{}'.format(getuser(), gethostname())
......
...@@ -3,6 +3,9 @@ import functools ...@@ -3,6 +3,9 @@ import functools
import itertools import itertools
import subprocess import subprocess
from importlib import import_module from importlib import import_module
import six
# ABCs from collections will be deprecated in python 3.8+, # ABCs from collections will be deprecated in python 3.8+,
# while collections.abc is not available in python 2.7 # while collections.abc is not available in python 2.7
try: try:
...@@ -10,8 +13,6 @@ try: ...@@ -10,8 +13,6 @@ try:
except ImportError: except ImportError:
import collections as collections_abc import collections as collections_abc
import six
def is_str(x): def is_str(x):
"""Whether the input is an string instance.""" """Whether the input is an string instance."""
......
...@@ -11,8 +11,8 @@ class ProgressBar(object): ...@@ -11,8 +11,8 @@ class ProgressBar(object):
def __init__(self, task_num=0, bar_width=50, start=True): def __init__(self, task_num=0, bar_width=50, start=True):
self.task_num = task_num self.task_num = task_num
max_bar_width = self._get_max_bar_width() max_bar_width = self._get_max_bar_width()
self.bar_width = (bar_width self.bar_width = (
if bar_width <= max_bar_width else max_bar_width) bar_width if bar_width <= max_bar_width else max_bar_width)
self.completed = 0 self.completed = 0
if start: if start:
self.start() self.start()
......
from .io import Cache, VideoReader, frames2video from .io import Cache, VideoReader, frames2video
from .processing import convert_video, resize_video, cut_video, concat_video from .processing import convert_video, resize_video, cut_video, concat_video
from .optflow import (flowread, flowwrite, quantize_flow, from .optflow import (flowread, flowwrite, quantize_flow, dequantize_flow,
dequantize_flow, flow_warp) flow_warp)
__all__ = [ __all__ = [
'Cache', 'VideoReader', 'frames2video', 'convert_video', 'resize_video', 'Cache', 'VideoReader', 'frames2video', 'convert_video', 'resize_video',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment