Unverified Commit df2aab9b authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Merge pull request #16 from myownskyW7/dev

add high level api
parents 3c51dcc4 724abbca
from .env import init_dist, get_root_logger, set_random_seed
from .train import train_detector
from .inference import inference_detector
__all__ = [
'init_dist', 'get_root_logger', 'set_random_seed', 'train_detector',
'inference_detector'
]
import logging
import os
import random
import numpy as np
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from mmcv.runner import get_dist_info
def init_dist(launcher, backend='nccl', **kwargs):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
_init_dist_pytorch(backend, **kwargs)
elif launcher == 'mpi':
_init_dist_mpi(backend, **kwargs)
elif launcher == 'slurm':
_init_dist_slurm(backend, **kwargs)
else:
raise ValueError('Invalid launcher type: {}'.format(launcher))
def _init_dist_pytorch(backend, **kwargs):
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_mpi(backend, **kwargs):
raise NotImplementedError
def _init_dist_slurm(backend, **kwargs):
raise NotImplementedError
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_root_logger(log_level=logging.INFO):
logger = logging.getLogger()
if not logger.hasHandlers():
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s',
level=log_level)
rank, _ = get_dist_info()
if rank != 0:
logger.setLevel('ERROR')
return logger
import mmcv
import numpy as np
import torch
from mmdet.datasets import to_tensor
from mmdet.datasets.transforms import ImageTransform
from mmdet.core import get_classes
def _prepare_data(img, img_transform, cfg, device):
ori_shape = img.shape
img, img_shape, pad_shape, scale_factor = img_transform(
img, scale=cfg.data.test.img_scale)
img = to_tensor(img).to(device).unsqueeze(0)
img_meta = [
dict(
ori_shape=ori_shape,
img_shape=img_shape,
pad_shape=pad_shape,
scale_factor=scale_factor,
flip=False)
]
return dict(img=[img], img_meta=[img_meta])
def inference_detector(model, imgs, cfg, device='cuda:0'):
imgs = imgs if isinstance(imgs, list) else [imgs]
img_transform = ImageTransform(
size_divisor=cfg.data.test.size_divisor, **cfg.img_norm_cfg)
model = model.to(device)
model.eval()
for img in imgs:
img = mmcv.imread(img)
data = _prepare_data(img, img_transform, cfg, device)
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
yield result
def show_result(img, result, dataset='coco', score_thr=0.3):
class_names = get_classes(dataset)
labels = [
np.full(bbox.shape[0], i, dtype=np.int32)
for i, bbox in enumerate(result)
]
labels = np.concatenate(labels)
bboxes = np.vstack(result)
mmcv.imshow_det_bboxes(
img.copy(),
bboxes,
labels,
class_names=class_names,
score_thr=score_thr)
from __future__ import division
from collections import OrderedDict
import torch
from mmcv.runner import Runner, DistSamplerSeedHook
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmdet.core import (DistOptimizerHook, CocoDistEvalRecallHook,
CocoDistEvalmAPHook)
from mmdet.datasets import build_dataloader
from mmdet.models import RPN
from .env import get_root_logger
def parse_losses(losses):
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(
'{} is not a tensor or list of tensors'.format(loss_name))
loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)
log_vars['loss'] = loss
for name in log_vars:
log_vars[name] = log_vars[name].item()
return loss, log_vars
def batch_processor(model, data, train_mode):
losses = model(**data)
loss, log_vars = parse_losses(losses)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
return outputs
def train_detector(model,
dataset,
cfg,
distributed=False,
validate=False,
logger=None):
if logger is None:
logger = get_root_logger(cfg.log_level)
# start training
if distributed:
_dist_train(model, dataset, cfg, validate=validate)
else:
_non_dist_train(model, dataset, cfg, validate=validate)
def _dist_train(model, dataset, cfg, validate=False):
# prepare data loaders
data_loaders = [
build_dataloader(
dataset,
cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu,
dist=True)
]
# put model on gpus
model = MMDistributedDataParallel(model.cuda())
# build runner
runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir,
cfg.log_level)
# register hooks
optimizer_config = DistOptimizerHook(**cfg.optimizer_config)
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config)
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
if validate:
if isinstance(model.module, RPN):
runner.register_hook(CocoDistEvalRecallHook(cfg.data.val))
elif cfg.data.val.type == 'CocoDataset':
runner.register_hook(CocoDistEvalmAPHook(cfg.data.val))
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
def _non_dist_train(model, dataset, cfg, validate=False):
# prepare data loaders
data_loaders = [
build_dataloader(
dataset,
cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu,
cfg.gpus,
dist=False)
]
# put model on gpus
model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
# build runner
runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir,
cfg.log_level)
runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
cfg.checkpoint_config, cfg.log_config)
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
from .dist_utils import init_dist, allreduce_grads, DistOptimizerHook from .dist_utils import allreduce_grads, DistOptimizerHook
from .misc import tensor2imgs, unmap, multi_apply from .misc import tensor2imgs, unmap, multi_apply
__all__ = [ __all__ = [
'init_dist', 'allreduce_grads', 'DistOptimizerHook', 'tensor2imgs', 'allreduce_grads', 'DistOptimizerHook', 'tensor2imgs', 'unmap',
'unmap', 'multi_apply' 'multi_apply'
] ]
import os
from collections import OrderedDict from collections import OrderedDict
import torch
import torch.multiprocessing as mp
import torch.distributed as dist import torch.distributed as dist
from torch._utils import (_flatten_dense_tensors, _unflatten_dense_tensors, from torch._utils import (_flatten_dense_tensors, _unflatten_dense_tensors,
_take_tensors) _take_tensors)
from mmcv.runner import OptimizerHook from mmcv.runner import OptimizerHook
def init_dist(launcher, backend='nccl', **kwargs):
if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn')
if launcher == 'pytorch':
_init_dist_pytorch(backend, **kwargs)
elif launcher == 'mpi':
_init_dist_mpi(backend, **kwargs)
elif launcher == 'slurm':
_init_dist_slurm(backend, **kwargs)
else:
raise ValueError('Invalid launcher type: {}'.format(launcher))
def _init_dist_pytorch(backend, **kwargs):
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK'])
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
dist.init_process_group(backend=backend, **kwargs)
def _init_dist_mpi(backend, **kwargs):
raise NotImplementedError
def _init_dist_slurm(backend, **kwargs):
raise NotImplementedError
def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
if bucket_size_mb > 0: if bucket_size_mb > 0:
bucket_size_bytes = bucket_size_mb * 1024 * 1024 bucket_size_bytes = bucket_size_mb * 1024 * 1024
......
...@@ -15,7 +15,7 @@ resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) ...@@ -15,7 +15,7 @@ resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
def build_dataloader(dataset, def build_dataloader(dataset,
imgs_per_gpu, imgs_per_gpu,
workers_per_gpu, workers_per_gpu,
num_gpus, num_gpus=1,
dist=True, dist=True,
**kwargs): **kwargs):
if dist: if dist:
......
from __future__ import division from __future__ import division
import argparse import argparse
import logging
import random
from collections import OrderedDict
import numpy as np
import torch
from mmcv import Config from mmcv import Config
from mmcv.runner import Runner, obj_from_dict, DistSamplerSeedHook from mmcv.runner import obj_from_dict
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmdet import datasets, __version__ from mmdet import datasets, __version__
from mmdet.core import (init_dist, DistOptimizerHook, CocoDistEvalRecallHook, from mmdet.api import (train_detector, init_dist, get_root_logger,
CocoDistEvalmAPHook) set_random_seed)
from mmdet.datasets import build_dataloader from mmdet.models import build_detector
from mmdet.models import build_detector, RPN
def parse_losses(losses):
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(
'{} is not a tensor or list of tensors'.format(loss_name))
loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)
log_vars['loss'] = loss
for name in log_vars:
log_vars[name] = log_vars[name].item()
return loss, log_vars
def batch_processor(model, data, train_mode):
losses = model(**data)
loss, log_vars = parse_losses(losses)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
return outputs
def get_logger(log_level):
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=log_level)
logger = logging.getLogger()
return logger
def set_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def parse_args(): def parse_args():
...@@ -69,10 +17,14 @@ def parse_args(): ...@@ -69,10 +17,14 @@ def parse_args():
parser.add_argument( parser.add_argument(
'--validate', '--validate',
action='store_true', action='store_true',
help='whether to add a validate phase') help='whether to evaluate the checkpoint during training')
parser.add_argument( parser.add_argument(
'--gpus', type=int, default=1, help='number of gpus to use') '--gpus',
parser.add_argument('--seed', type=int, help='random seed') type=int,
default=1,
help='number of gpus to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument( parser.add_argument(
'--launcher', '--launcher',
choices=['none', 'pytorch', 'slurm', 'mpi'], choices=['none', 'pytorch', 'slurm', 'mpi'],
...@@ -88,69 +40,41 @@ def main(): ...@@ -88,69 +40,41 @@ def main():
args = parse_args() args = parse_args()
cfg = Config.fromfile(args.config) cfg = Config.fromfile(args.config)
# update configs according to CLI args
if args.work_dir is not None: if args.work_dir is not None:
cfg.work_dir = args.work_dir cfg.work_dir = args.work_dir
cfg.gpus = args.gpus cfg.gpus = args.gpus
# save mmdet version in checkpoint as meta data if cfg.checkpoint_config is not None:
# save mmdet version in checkpoints as meta data
cfg.checkpoint_config.meta = dict( cfg.checkpoint_config.meta = dict(
mmdet_version=__version__, config=cfg.text) mmdet_version=__version__, config=cfg.text)
logger = get_logger(cfg.log_level) # init distributed env first, since logger depends on the dist info.
# set random seed if specified
if args.seed is not None:
logger.info('Set random seed to {}'.format(args.seed))
set_random_seed(args.seed)
# init distributed environment if necessary
if args.launcher == 'none': if args.launcher == 'none':
dist = False distributed = False
logger.info('Non-distributed training.')
else: else:
dist = True distributed = True
init_dist(args.launcher, **cfg.dist_params) init_dist(args.launcher, **cfg.dist_params)
if torch.distributed.get_rank() != 0:
logger.setLevel('ERROR')
logger.info('Distributed training.')
# prepare data loaders # init logger before other steps
train_dataset = obj_from_dict(cfg.data.train, datasets) logger = get_root_logger(cfg.log_level)
data_loaders = [ logger.info('Distributed training: {}'.format(distributed))
build_dataloader(train_dataset, cfg.data.imgs_per_gpu,
cfg.data.workers_per_gpu, cfg.gpus, dist) # set random seeds
] if args.seed is not None:
logger.info('Set random seed to {}'.format(args.seed))
set_random_seed(args.seed)
# build model
model = build_detector( model = build_detector(
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg) cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
if dist: train_dataset = obj_from_dict(cfg.data.train, datasets)
model = MMDistributedDataParallel(model.cuda()) train_detector(
else: model,
model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda() train_dataset,
cfg,
# build runner distributed=distributed,
runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir, validate=args.validate,
cfg.log_level) logger=logger)
# register hooks
optimizer_config = DistOptimizerHook(
**cfg.optimizer_config) if dist else cfg.optimizer_config
runner.register_training_hooks(cfg.lr_config, optimizer_config,
cfg.checkpoint_config, cfg.log_config)
if dist:
runner.register_hook(DistSamplerSeedHook())
# register eval hooks
if args.validate:
if isinstance(model.module, RPN):
runner.register_hook(CocoDistEvalRecallHook(cfg.data.val))
elif cfg.data.val.type == 'CocoDataset':
runner.register_hook(CocoDistEvalmAPHook(cfg.data.val))
if cfg.resume_from:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment