scratch.py 5.45 KB
Newer Older
Yuge Zhang's avatar
Yuge Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import argparse
import logging
import random

import numpy as np
import torch
import torch.nn as nn
from dataloader import get_imagenet_iter_dali
from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.pytorch.utils import AverageMeterGroup
from torch.utils.tensorboard import SummaryWriter

from network import ShuffleNetV2OneShot
from utils import CrossEntropyLabelSmooth, accuracy

logger = logging.getLogger("nni.spos.scratch")


def train(epoch, model, criterion, optimizer, loader, writer, args):
    model.train()
    meters = AverageMeterGroup()
    cur_lr = optimizer.param_groups[0]["lr"]

    for step, (x, y) in enumerate(loader):
        cur_step = len(loader) * epoch + step
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()

        metrics = accuracy(logits, y)
        metrics["loss"] = loss.item()
        meters.update(metrics)

        writer.add_scalar("lr", cur_lr, global_step=cur_step)
        writer.add_scalar("loss/train", loss.item(), global_step=cur_step)
        writer.add_scalar("acc1/train", metrics["acc1"], global_step=cur_step)
        writer.add_scalar("acc5/train", metrics["acc5"], global_step=cur_step)

        if step % args.log_frequency == 0 or step + 1 == len(loader):
            logger.info("Epoch [%d/%d] Step [%d/%d]  %s", epoch + 1,
                        args.epochs, step + 1, len(loader), meters)

    logger.info("Epoch %d training summary: %s", epoch + 1, meters)


def validate(epoch, model, criterion, loader, writer, args):
    model.eval()
    meters = AverageMeterGroup()
    with torch.no_grad():
        for step, (x, y) in enumerate(loader):
            logits = model(x)
            loss = criterion(logits, y)
            metrics = accuracy(logits, y)
            metrics["loss"] = loss.item()
            meters.update(metrics)

            if step % args.log_frequency == 0 or step + 1 == len(loader):
                logger.info("Epoch [%d/%d] Validation Step [%d/%d]  %s", epoch + 1,
                            args.epochs, step + 1, len(loader), meters)

    writer.add_scalar("loss/test", meters.loss.avg, global_step=epoch)
    writer.add_scalar("acc1/test", meters.acc1.avg, global_step=epoch)
    writer.add_scalar("acc5/test", meters.acc5.avg, global_step=epoch)

    logger.info("Epoch %d validation: top1 = %f, top5 = %f", epoch + 1, meters.acc1.avg, meters.acc5.avg)


if __name__ == "__main__":
    parser = argparse.ArgumentParser("SPOS Training From Scratch")
    parser.add_argument("--imagenet-dir", type=str, default="./data/imagenet")
    parser.add_argument("--tb-dir", type=str, default="runs")
    parser.add_argument("--architecture", type=str, default="architecture_final.json")
    parser.add_argument("--workers", type=int, default=12)
    parser.add_argument("--batch-size", type=int, default=1024)
    parser.add_argument("--epochs", type=int, default=240)
    parser.add_argument("--learning-rate", type=float, default=0.5)
    parser.add_argument("--momentum", type=float, default=0.9)
    parser.add_argument("--weight-decay", type=float, default=4E-5)
    parser.add_argument("--label-smooth", type=float, default=0.1)
    parser.add_argument("--log-frequency", type=int, default=10)
    parser.add_argument("--lr-decay", type=str, default="linear")
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--spos-preprocessing", default=False, action="store_true")
    parser.add_argument("--label-smoothing", type=float, default=0.1)

    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    model = ShuffleNetV2OneShot()
    model.cuda()
    apply_fixed_architecture(model, args.architecture)
    if torch.cuda.device_count() > 1:  # exclude last gpu, saving for data preprocessing on gpu
        model = nn.DataParallel(model, device_ids=list(range(0, torch.cuda.device_count() - 1)))
    criterion = CrossEntropyLabelSmooth(1000, args.label_smoothing)
    optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate,
                                momentum=args.momentum, weight_decay=args.weight_decay)
    if args.lr_decay == "linear":
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,
                                                      lambda step: (1.0 - step / args.epochs)
                                                      if step <= args.epochs else 0,
                                                      last_epoch=-1)
    elif args.lr_decay == "cosine":
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, 1E-3)
    else:
        raise ValueError("'%s' not supported." % args.lr_decay)
    writer = SummaryWriter(log_dir=args.tb_dir)

    train_loader = get_imagenet_iter_dali("train", args.imagenet_dir, args.batch_size, args.workers,
                                          spos_preprocessing=args.spos_preprocessing)
    val_loader = get_imagenet_iter_dali("val", args.imagenet_dir, args.batch_size, args.workers,
                                        spos_preprocessing=args.spos_preprocessing)

    for epoch in range(args.epochs):
        train(epoch, model, criterion, optimizer, train_loader, writer, args)
        validate(epoch, model, criterion, val_loader, writer, args)
        scheduler.step()

    writer.close()