Commit 7eb02d29 authored by Kai Chen's avatar Kai Chen
Browse files

Merge branch 'dev' into single-stage

parents 20e75c22 01a03aab
import functools
import torch
def assert_tensor_type(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not isinstance(args[0].data, torch.Tensor):
raise AttributeError('{} has no attribute {} for type {}'.format(
args[0].__class__.__name__, func.__name__, args[0].datatype))
return func(*args, **kwargs)
return wrapper
class DataContainer(object):
def __init__(self, data, stack=False, padding_value=0, cpu_only=False):
self._data = data
self._cpu_only = cpu_only
self._stack = stack
self._padding_value = padding_value
def __repr__(self):
return '{}({})'.format(self.__class__.__name__, repr(self.data))
@property
def data(self):
return self._data
@property
def datatype(self):
if isinstance(self.data, torch.Tensor):
return self.data.type()
else:
return type(self.data)
@property
def cpu_only(self):
return self._cpu_only
@property
def stack(self):
return self._stack
@property
def padding_value(self):
return self._padding_value
@assert_tensor_type
def size(self, *args, **kwargs):
return self.data.size(*args, **kwargs)
@assert_tensor_type
def dim(self):
return self.data.dim()
...@@ -2,4 +2,4 @@ ...@@ -2,4 +2,4 @@
PYTHON=${PYTHON:-"python"} PYTHON=${PYTHON:-"python"}
$PYTHON -m torch.distributed.launch --nproc_per_node=$2 train.py $1 --launcher pytorch ${@:3} $PYTHON -m torch.distributed.launch --nproc_per_node=$2 $(dirname "$0")/train.py $1 --launcher pytorch ${@:3}
...@@ -3,9 +3,10 @@ import argparse ...@@ -3,9 +3,10 @@ import argparse
import torch import torch
import mmcv import mmcv
from mmcv.runner import load_checkpoint, parallel_test, obj_from_dict from mmcv.runner import load_checkpoint, parallel_test, obj_from_dict
from mmcv.parallel import scatter, MMDataParallel
from mmdet import datasets from mmdet import datasets
from mmdet.core import scatter, MMDataParallel, results2json, coco_eval from mmdet.core import results2json, coco_eval
from mmdet.datasets import collate, build_dataloader from mmdet.datasets import collate, build_dataloader
from mmdet.models import build_detector, detectors from mmdet.models import build_detector, detectors
...@@ -44,17 +45,16 @@ def parse_args(): ...@@ -44,17 +45,16 @@ def parse_args():
'--eval', '--eval',
type=str, type=str,
nargs='+', nargs='+',
choices=['proposal', 'bbox', 'segm', 'keypoints'], choices=['proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints'],
help='eval types') help='eval types')
parser.add_argument('--show', action='store_true', help='show results') parser.add_argument('--show', action='store_true', help='show results')
args = parser.parse_args() args = parser.parse_args()
return args return args
args = parse_args()
def main(): def main():
args = parse_args()
cfg = mmcv.Config.fromfile(args.config) cfg = mmcv.Config.fromfile(args.config)
cfg.model.pretrained = None cfg.model.pretrained = None
cfg.data.test.test_mode = True cfg.data.test.test_mode = True
......
...@@ -2,17 +2,18 @@ from __future__ import division ...@@ -2,17 +2,18 @@ from __future__ import division
import argparse import argparse
import logging import logging
import random
from collections import OrderedDict from collections import OrderedDict
import numpy as np import numpy as np
import torch import torch
from mmcv import Config from mmcv import Config
from mmcv.runner import Runner, obj_from_dict from mmcv.runner import Runner, obj_from_dict, DistSamplerSeedHook
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmdet import datasets, __version__ from mmdet import datasets, __version__
from mmdet.core import (init_dist, DistOptimizerHook, DistSamplerSeedHook, from mmdet.core import (init_dist, DistOptimizerHook, CocoDistEvalRecallHook,
MMDataParallel, MMDistributedDataParallel, CocoDistEvalmAPHook)
CocoDistEvalRecallHook, CocoDistEvalmAPHook)
from mmdet.datasets import build_dataloader from mmdet.datasets import build_dataloader
from mmdet.models import build_detector, RPN from mmdet.models import build_detector, RPN
...@@ -55,6 +56,7 @@ def get_logger(log_level): ...@@ -55,6 +56,7 @@ def get_logger(log_level):
def set_random_seed(seed): def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
...@@ -89,7 +91,7 @@ def main(): ...@@ -89,7 +91,7 @@ def main():
if args.work_dir is not None: if args.work_dir is not None:
cfg.work_dir = args.work_dir cfg.work_dir = args.work_dir
cfg.gpus = args.gpus cfg.gpus = args.gpus
# add mmdet version to checkpoint as meta data # save mmdet version in checkpoint as meta data
cfg.checkpoint_config.meta = dict( cfg.checkpoint_config.meta = dict(
mmdet_version=__version__, config=cfg.text) mmdet_version=__version__, config=cfg.text)
...@@ -103,13 +105,13 @@ def main(): ...@@ -103,13 +105,13 @@ def main():
# init distributed environment if necessary # init distributed environment if necessary
if args.launcher == 'none': if args.launcher == 'none':
dist = False dist = False
logger.info('Disabled distributed training.') logger.info('Non-distributed training.')
else: else:
dist = True dist = True
init_dist(args.launcher, **cfg.dist_params) init_dist(args.launcher, **cfg.dist_params)
if torch.distributed.get_rank() != 0: if torch.distributed.get_rank() != 0:
logger.setLevel('ERROR') logger.setLevel('ERROR')
logger.info('Enabled distributed training.') logger.info('Distributed training.')
# prepare data loaders # prepare data loaders
train_dataset = obj_from_dict(cfg.data.train, datasets) train_dataset = obj_from_dict(cfg.data.train, datasets)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment