"research/object_detection/exporter.py" did not exist on "a4944a57ad2811e1f6a7a87589a9fc8a776e8d3c"
train.py 5.23 KB
Newer Older
pangjm's avatar
pangjm committed
1
from __future__ import division
Kai Chen's avatar
Kai Chen committed
2

pangjm's avatar
pangjm committed
3
import argparse
4
import logging
Kai Chen's avatar
Kai Chen committed
5
from collections import OrderedDict
pangjm's avatar
pangjm committed
6

Kai Chen's avatar
Kai Chen committed
7
import numpy as np
pangjm's avatar
pangjm committed
8
9
import torch
from mmcv import Config
Kai Chen's avatar
Kai Chen committed
10
11
from mmcv.torchpack import Runner, obj_from_dict

12
from mmdet import datasets, __version__
13
from mmdet.core import (init_dist, DistOptimizerHook, DistSamplerSeedHook,
Kai Chen's avatar
Kai Chen committed
14
                        MMDataParallel, MMDistributedDataParallel,
15
                        CocoDistEvalRecallHook, CocoDistEvalmAPHook)
Kai Chen's avatar
Kai Chen committed
16
from mmdet.datasets.loader import build_dataloader
Kai Chen's avatar
Kai Chen committed
17
from mmdet.models import build_detector, RPN
Kai Chen's avatar
Kai Chen committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39


def parse_losses(losses):
    log_vars = OrderedDict()
    for loss_name, loss_value in losses.items():
        if isinstance(loss_value, torch.Tensor):
            log_vars[loss_name] = loss_value.mean()
        elif isinstance(loss_value, list):
            log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
        else:
            raise TypeError(
                '{} is not a tensor or list of tensors'.format(loss_name))

    loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)

    log_vars['loss'] = loss
    for name in log_vars:
        log_vars[name] = log_vars[name].item()

    return loss, log_vars


Kai Chen's avatar
Kai Chen committed
40
def batch_processor(model, data, train_mode):
Kai Chen's avatar
Kai Chen committed
41
42
43
44
    losses = model(**data)
    loss, log_vars = parse_losses(losses)

    outputs = dict(
45
        loss=loss, log_vars=log_vars, num_samples=len(data['img'].data))
Kai Chen's avatar
Kai Chen committed
46
47

    return outputs
pangjm's avatar
pangjm committed
48
49


50
51
52
53
54
55
56
def get_logger(log_level):
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(message)s', level=log_level)
    logger = logging.getLogger()
    return logger


Kai Chen's avatar
Kai Chen committed
57
58
59
60
61
62
def set_random_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


pangjm's avatar
pangjm committed
63
def parse_args():
Kai Chen's avatar
Kai Chen committed
64
    parser = argparse.ArgumentParser(description='Train a detector')
pangjm's avatar
pangjm committed
65
    parser.add_argument('config', help='train config file path')
66
    parser.add_argument('--work_dir', help='the dir to save logs and models')
pangjm's avatar
pangjm committed
67
    parser.add_argument(
Kai Chen's avatar
Kai Chen committed
68
69
70
71
        '--validate',
        action='store_true',
        help='whether to add a validate phase')
    parser.add_argument(
72
        '--gpus', type=int, default=1, help='number of gpus to use')
Kai Chen's avatar
Kai Chen committed
73
    parser.add_argument('--seed', type=int, help='random seed')
74
75
76
77
78
79
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
pangjm's avatar
pangjm committed
80
81
82
83
84
85
    args = parser.parse_args()

    return args


def main():
86
87
    args = parse_args()

Kai Chen's avatar
Kai Chen committed
88
    cfg = Config.fromfile(args.config)
89
90
91
    if args.work_dir is not None:
        cfg.work_dir = args.work_dir
    cfg.gpus = args.gpus
92
93
    # add mmdet version to checkpoint as meta data
    cfg.checkpoint_config.meta = dict(mmdet_version=__version__)
94
95

    logger = get_logger(cfg.log_level)
Kai Chen's avatar
Kai Chen committed
96

Kai Chen's avatar
Kai Chen committed
97
98
99
100
101
    # set random seed if specified
    if args.seed is not None:
        logger.info('Set random seed to {}'.format(args.seed))
        set_random_seed(args.seed)

Kai Chen's avatar
Kai Chen committed
102
    # init distributed environment if necessary
103
104
    if args.launcher == 'none':
        dist = False
105
        logger.info('Disabled distributed training.')
106
107
    else:
        dist = True
Kai Chen's avatar
Kai Chen committed
108
        init_dist(args.launcher, **cfg.dist_params)
109
110
111
        if torch.distributed.get_rank() != 0:
            logger.setLevel('ERROR')
        logger.info('Enabled distributed training.')
pangjm's avatar
pangjm committed
112

Kai Chen's avatar
Kai Chen committed
113
114
115
    # prepare data loaders
    train_dataset = obj_from_dict(cfg.data.train, datasets)
    data_loaders = [
116
117
        build_dataloader(train_dataset, cfg.data.imgs_per_gpu,
                         cfg.data.workers_per_gpu, cfg.gpus, dist)
Kai Chen's avatar
Kai Chen committed
118
    ]
pangjm's avatar
pangjm committed
119
120
121
122
123
    if args.validate:
        val_dataset = obj_from_dict(cfg.data.test, datasets)
        data_loaders.append(
            build_dataloader(val_dataset, cfg.data.imgs_per_gpu,
                             cfg.data.workers_per_gpu, cfg.gpus, dist))
pangjm's avatar
pangjm committed
124
125

    # build model
Kai Chen's avatar
Kai Chen committed
126
127
    model = build_detector(
        cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
128
    if dist:
129
        model = MMDistributedDataParallel(model.cuda())
pangjm's avatar
pangjm committed
130
    else:
131
        model = MMDataParallel(model, device_ids=range(cfg.gpus)).cuda()
pangjm's avatar
pangjm committed
132

Kai Chen's avatar
Kai Chen committed
133
    # build runner
pangjm's avatar
pangjm committed
134
135
    runner = Runner(model, batch_processor, cfg.optimizer, cfg.work_dir,
                    cfg.log_level)
pangjm's avatar
pangjm committed
136
137
138

    if args.validate:
        val_dataset = obj_from_dict(cfg.data.test, datasets)
pangjm's avatar
pangjm committed
139
140
141
        data_loaders.append(
            build_dataloader(val_dataset, cfg.data.imgs_per_gpu,
                             cfg.data.workers_per_gpu, cfg.gpus, dist))
pangjm's avatar
pangjm committed
142

Kai Chen's avatar
Kai Chen committed
143
144
    # register hooks
    optimizer_config = DistOptimizerHook(
145
        **cfg.optimizer_config) if dist else cfg.optimizer_config
Kai Chen's avatar
Kai Chen committed
146
    runner.register_training_hooks(cfg.lr_config, optimizer_config,
pangjm's avatar
pangjm committed
147
                                   cfg.checkpoint_config, cfg.log_config)
148
    if dist:
pangjm's avatar
pangjm committed
149
        runner.register_hook(DistSamplerSeedHook())
Kai Chen's avatar
Kai Chen committed
150
151
        # register eval hooks
        if isinstance(model.module, RPN):
152
            runner.register_hook(CocoDistEvalRecallHook(cfg.data.val))
Kai Chen's avatar
Kai Chen committed
153
154
        elif cfg.data.val.type == 'CocoDataset':
            runner.register_hook(CocoDistEvalmAPHook(cfg.data.val))
Kai Chen's avatar
Kai Chen committed
155

pangjm's avatar
pangjm committed
156
157
158
159
    if cfg.resume_from:
        runner.resume(cfg.resume_from)
    elif cfg.load_from:
        runner.load_checkpoint(cfg.load_from)
Kai Chen's avatar
Kai Chen committed
160
    runner.run(data_loaders, cfg.workflow, cfg.total_epochs)
pangjm's avatar
pangjm committed
161
162


Kai Chen's avatar
Kai Chen committed
163
if __name__ == '__main__':
pangjm's avatar
pangjm committed
164
    main()