Commit 7eb02d29 authored by Kai Chen's avatar Kai Chen
Browse files

Merge branch 'dev' into single-stage

parents 20e75c22 01a03aab
import functools
import torch
def assert_tensor_type(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
if not isinstance(args[0].data, torch.Tensor):
raise AttributeError('{} has no attribute {} for type {}'.format(
args[0].__class__.__name__, func.__name__, args[0].datatype))
return func(*args, **kwargs)
return wrapper
class DataContainer(object):
def __init__(self, data, stack=False, padding_value=0, cpu_only=False):
self._data = data
self._cpu_only = cpu_only
self._stack = stack
self._padding_value = padding_value
def __repr__(self):
return '{}({})'.format(self.__class__.__name__, repr(self.data))
@property
def data(self):
return self._data
@property
def datatype(self):
if isinstance(self.data, torch.Tensor):
return self.data.type()
else:
return type(self.data)
@property
def cpu_only(self):
return self._cpu_only
@property
def stack(self):
return self._stack
@property
def padding_value(self):
return self._padding_value
@assert_tensor_type
def size(self, *args, **kwargs):
return self.data.size(*args, **kwargs)
@assert_tensor_type
def dim(self):
return self.data.dim()
......@@ -2,4 +2,4 @@
PYTHON=${PYTHON:-"python"}
$PYTHON -m torch.distributed.launch --nproc_per_node=$2 train.py $1 --launcher pytorch ${@:3}
$PYTHON -m torch.distributed.launch --nproc_per_node=$2 $(dirname "$0")/train.py $1 --launcher pytorch ${@:3}
......@@ -3,9 +3,10 @@ import argparse
import torch
import mmcv
from mmcv.runner import load_checkpoint, parallel_test, obj_from_dict
from mmcv.parallel import scatter, MMDataParallel
from mmdet import datasets
from mmdet.core import scatter, MMDataParallel, results2json, coco_eval
from mmdet.core import results2json, coco_eval
from mmdet.datasets import collate, build_dataloader
from mmdet.models import build_detector, detectors
......@@ -44,17 +45,16 @@ def parse_args():
'--eval',
type=str,
nargs='+',
choices=['proposal', 'bbox', 'segm', 'keypoints'],
choices=['proposal', 'proposal_fast', 'bbox', 'segm', 'keypoints'],
help='eval types')
parser.add_argument('--show', action='store_true', help='show results')
args = parser.parse_args()
return args
args = parse_args()
def main():
args = parse_args()
cfg = mmcv.Config.fromfile(args.config)
cfg.model.pretrained = None
cfg.data.test.test_mode = True
......
......@@ -2,17 +2,18 @@ from __future__ import division
import argparse
import logging
import random
from collections import OrderedDict
import numpy as np
import torch
from mmcv import Config
from mmcv.runner import Runner, obj_from_dict
from mmcv.runner import Runner, obj_from_dict, DistSamplerSeedHook
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmdet import datasets, __version__
from mmdet.core import (init_dist, DistOptimizerHook, DistSamplerSeedHook,
MMDataParallel, MMDistributedDataParallel,
CocoDistEvalRecallHook, CocoDistEvalmAPHook)
from mmdet.core import (init_dist, DistOptimizerHook, CocoDistEvalRecallHook,
CocoDistEvalmAPHook)
from mmdet.datasets import build_dataloader
from mmdet.models import build_detector, RPN
......@@ -55,6 +56,7 @@ def get_logger(log_level):
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
......@@ -89,7 +91,7 @@ def main():
if args.work_dir is not None:
cfg.work_dir = args.work_dir
cfg.gpus = args.gpus
# add mmdet version to checkpoint as meta data
# save mmdet version in checkpoint as meta data
cfg.checkpoint_config.meta = dict(
mmdet_version=__version__, config=cfg.text)
......@@ -103,13 +105,13 @@ def main():
# init distributed environment if necessary
if args.launcher == 'none':
dist = False
logger.info('Disabled distributed training.')
logger.info('Non-distributed training.')
else:
dist = True
init_dist(args.launcher, **cfg.dist_params)
if torch.distributed.get_rank() != 0:
logger.setLevel('ERROR')
logger.info('Enabled distributed training.')
logger.info('Distributed training.')
# prepare data loaders
train_dataset = obj_from_dict(cfg.data.train, datasets)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment