train_cifar10.py 5.29 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
2
import logging
import os
Kai Chen's avatar
Kai Chen committed
3
4
5
from argparse import ArgumentParser
from collections import OrderedDict

Kai Chen's avatar
Kai Chen committed
6
import resnet_cifar
Kai Chen's avatar
Kai Chen committed
7
import torch
Kai Chen's avatar
Kai Chen committed
8
9
import torch.distributed as dist
import torch.multiprocessing as mp
Kai Chen's avatar
Kai Chen committed
10
import torch.nn.functional as F
Kai Chen's avatar
Kai Chen committed
11
12
13
from torch.nn.parallel import DataParallel, DistributedDataParallel
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
Kai Chen's avatar
Kai Chen committed
14
15
from torchvision import datasets, transforms

Kai Chen's avatar
Kai Chen committed
16
17
from mmcv import Config
from mmcv.runner import DistSamplerSeedHook, Runner
Kai Chen's avatar
Kai Chen committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50


def accuracy(output, target, topk=(1, )):
    """Computes the precision@k for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def batch_processor(model, data, train_mode):
    img, label = data
    label = label.cuda(non_blocking=True)
    pred = model(img)
    loss = F.cross_entropy(pred, label)
    acc_top1, acc_top5 = accuracy(pred, label, topk=(1, 5))
    log_vars = OrderedDict()
    log_vars['loss'] = loss.item()
    log_vars['acc_top1'] = acc_top1.item()
    log_vars['acc_top5'] = acc_top5.item()
    outputs = dict(loss=loss, log_vars=log_vars, num_samples=img.size(0))
    return outputs


Kai Chen's avatar
Kai Chen committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def get_logger(log_level):
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(message)s', level=log_level)
    logger = logging.getLogger()
    return logger


def init_dist(backend='nccl', **kwargs):
    if mp.get_start_method(allow_none=True) is None:
        mp.set_start_method('spawn')
    rank = int(os.environ['RANK'])
    num_gpus = torch.cuda.device_count()
    torch.cuda.set_device(rank % num_gpus)
    dist.init_process_group(backend=backend, **kwargs)


Kai Chen's avatar
Kai Chen committed
67
68
69
def parse_args():
    parser = ArgumentParser(description='Train CIFAR-10 classification')
    parser.add_argument('config', help='train config file path')
Kai Chen's avatar
Kai Chen committed
70
71
72
73
74
75
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
Kai Chen's avatar
Kai Chen committed
76
77
78
79
80
    return parser.parse_args()


def main():
    args = parse_args()
Kai Chen's avatar
Kai Chen committed
81

Kai Chen's avatar
Kai Chen committed
82
83
    cfg = Config.fromfile(args.config)

Kai Chen's avatar
Kai Chen committed
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    logger = get_logger(cfg.log_level)

    # init distributed environment if necessary
    if args.launcher == 'none':
        dist = False
        logger.info('Disabled distributed training.')
    else:
        dist = True
        init_dist(**cfg.dist_params)
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()
        if rank != 0:
            logger.setLevel('ERROR')
        logger.info('Enabled distributed training.')

    # build datasets and dataloaders
Kai Chen's avatar
Kai Chen committed
100
101
102
103
104
105
106
107
108
109
110
111
    normalize = transforms.Normalize(mean=cfg.mean, std=cfg.std)
    train_dataset = datasets.CIFAR10(
        root=cfg.data_root,
        train=True,
        transform=transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    val_dataset = datasets.CIFAR10(
        root=cfg.data_root,
Kai Chen's avatar
Kai Chen committed
112
        train=False,
Kai Chen's avatar
Kai Chen committed
113
114
115
116
        transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ]))
Kai Chen's avatar
Kai Chen committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
    if dist:
        num_workers = cfg.data_workers
        assert cfg.batch_size % world_size == 0
        batch_size = cfg.batch_size // world_size
        train_sampler = DistributedSampler(train_dataset, world_size, rank)
        val_sampler = DistributedSampler(val_dataset, world_size, rank)
        shuffle = False
    else:
        num_workers = cfg.data_workers * len(cfg.gpus)
        batch_size = cfg.batch_size
        train_sampler = None
        val_sampler = None
        shuffle = True
    train_loader = DataLoader(
Kai Chen's avatar
Kai Chen committed
131
        train_dataset,
Kai Chen's avatar
Kai Chen committed
132
133
134
135
136
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=train_sampler,
        num_workers=num_workers)
    val_loader = DataLoader(
Kai Chen's avatar
Kai Chen committed
137
        val_dataset,
Kai Chen's avatar
Kai Chen committed
138
        batch_size=batch_size,
Kai Chen's avatar
Kai Chen committed
139
        shuffle=False,
Kai Chen's avatar
Kai Chen committed
140
141
        sampler=val_sampler,
        num_workers=num_workers)
Kai Chen's avatar
Kai Chen committed
142

Kai Chen's avatar
Kai Chen committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    # build model
    model = getattr(resnet_cifar, cfg.model)()
    if dist:
        model = DistributedDataParallel(
            model.cuda(), device_ids=[torch.cuda.current_device()])
    else:
        model = DataParallel(model, device_ids=cfg.gpus).cuda()

    # build runner and register hooks
    runner = Runner(
        model,
        batch_processor,
        cfg.optimizer,
        cfg.work_dir,
        log_level=cfg.log_level)
    runner.register_training_hooks(
        lr_config=cfg.lr_config,
        optimizer_config=cfg.optimizer_config,
        checkpoint_config=cfg.checkpoint_config,
        log_config=cfg.log_config)
    if dist:
        runner.register_hook(DistSamplerSeedHook())

    # load param (if necessary) and run
Kai Chen's avatar
Kai Chen committed
167
168
169
170
171
    if cfg.get('resume_from') is not None:
        runner.resume(cfg.resume_from)
    elif cfg.get('load_from') is not None:
        runner.load_checkpoint(cfg.load_from)

Kai Chen's avatar
Kai Chen committed
172
    runner.run([train_loader, val_loader], cfg.workflow, cfg.total_epochs)
Kai Chen's avatar
Kai Chen committed
173
174
175
176


if __name__ == '__main__':
    main()