train.py 6.33 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from __future__ import absolute_import, division, print_function

import argparse
import logging
import os
import torch
import torchvision

import numpy as np

from datasets import PFLDDatasets
from lib.builder import search_space
from lib.ops import PRIMITIVES
from lib.trainer import PFLDTrainer
from lib.utils import PFLDLoss
from nni.algorithms.nas.pytorch.fbnet import LookUpTable, NASConfig
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def main(args):
    """ The main function for supernet pre-training and subnet fine-tuning. """
    logging.basicConfig(
        format="[%(asctime)s] [p%(process)s] [%(pathname)s\
            :%(lineno)d] [%(levelname)s] %(message)s",
        level=logging.INFO,
        handlers=[
            logging.FileHandler(args.log_file, mode="w"),
            logging.StreamHandler(),
        ],
    )

    # print the information of arguments
    for arg in vars(args):
        s = arg + ": " + str(getattr(args, arg))
        logging.info(s)

    # for 106 landmarks
    num_points = 106
    # list of device ids, and the number of workers for data loading
    device_ids = [int(id) for id in args.dev_id.split(",")]
    dev_num = len(device_ids)
    num_workers = 4 * dev_num

    # random seed
    manual_seed = 1
    np.random.seed(manual_seed)
    torch.manual_seed(manual_seed)
    torch.cuda.manual_seed_all(manual_seed)

    # import supernet for block-wise DNAS pre-training
    from lib.supernet import PFLDInference, AuxiliaryNet

    # the configuration for training control
    nas_config = NASConfig(
60
61
        perf_metric=args.perf_metric,
        lut_load=args.lut_load,
62
63
64
65
66
67
        model_dir=args.snapshot,
        nas_lr=args.theta_lr,
        mode=args.mode,
        alpha=args.alpha,
        beta=args.beta,
        search_space=search_space,
68
        start_epoch=args.start_epoch,
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
    )
    # look-up table with information of search space, flops per block, etc.
    lookup_table = LookUpTable(config=nas_config, primitives=PRIMITIVES)

    # create supernet
    pfld_backbone = PFLDInference(lookup_table, num_points)
    # the auxiliary-net of PFLD to predict the pose angle
    auxiliarynet = AuxiliaryNet()

    # main task loss
    criterion = PFLDLoss()

    # optimizer for weight train
    if args.opt == "adam":
        optimizer = torch.optim.AdamW(
            [
                {"params": pfld_backbone.parameters()},
                {"params": auxiliarynet.parameters()},
            ],
            lr=args.base_lr,
            weight_decay=args.weight_decay,
        )
    elif args.opt == "rms":
        optimizer = torch.optim.RMSprop(
            [
                {"params": pfld_backbone.parameters()},
                {"params": auxiliarynet.parameters()},
            ],
            lr=args.base_lr,
            momentum=0.0,
            weight_decay=args.weight_decay,
        )

    # data argmentation and dataloader
    transform = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor()]
    )
    # the landmark dataset with 106 points is default used
    train_dataset = PFLDDatasets(
        os.path.join(args.data_root, "train_data/list.txt"),
        transform,
        data_root=args.data_root,
        img_size=args.img_size,
    )
    dataloader = DataLoader(
        train_dataset,
        batch_size=args.train_batchsize,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=False,
    )

    val_dataset = PFLDDatasets(
        os.path.join(args.data_root, "test_data/list.txt"),
        transform,
        data_root=args.data_root,
        img_size=args.img_size,
    )
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=args.val_batchsize,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    # create the trainer, then search/finetune
    trainer = PFLDTrainer(
        pfld_backbone,
        auxiliarynet,
        optimizer,
        criterion,
        device,
        device_ids,
        nas_config,
        lookup_table,
        dataloader,
        val_dataloader,
        n_epochs=args.end_epoch,
        logger=logging,
    )
    trainer.train()


def parse_args():
155
156
157
158
159
160
161
162
163
    def str2bool(s):
        if isinstance(s, bool):
            return s
        if s.lower() in ('yes', 'true', 't', 'y', '1'):
            return True
        if s.lower() in ('no', 'false', 'f', 'n', '0'):
            return False
        raise argparse.ArgumentTypeError('Boolean value expected.')

164
165
166
167
168
169
170
171
172
173
174
175
176
    """ Parse the user arguments. """
    parser = argparse.ArgumentParser(description="FBNet for PFLD")
    parser.add_argument("--dev_id", dest="dev_id", default="0", type=str)
    parser.add_argument("--opt", default="rms", type=str)
    parser.add_argument("--base_lr", default=0.0001, type=int)
    parser.add_argument("--weight-decay", "--wd", default=1e-6, type=float)
    parser.add_argument("--img_size", default=112, type=int)
    parser.add_argument("--theta-lr", "--tlr", default=0.01, type=float)
    parser.add_argument(
        "--mode", default="mul", type=str, choices=["mul", "add"]
    )
    parser.add_argument("--alpha", default=0.25, type=float)
    parser.add_argument("--beta", default=0.6, type=float)
177
    parser.add_argument("--start_epoch", default=50, type=int)
178
179
180
181
182
183
184
185
186
187
    parser.add_argument("--end_epoch", default=300, type=int)
    parser.add_argument(
        "--snapshot", default="models", type=str, metavar="PATH"
    )
    parser.add_argument("--log_file", default="train.log", type=str)
    parser.add_argument(
        "--data_root", default="/dataset", type=str, metavar="PATH"
    )
    parser.add_argument("--train_batchsize", default=256, type=int)
    parser.add_argument("--val_batchsize", default=128, type=int)
188
189
190
191
192
193
194
195
196
197
    parser.add_argument(
        "--perf_metric", default="flops", type=str, choices=["flops", "latency"]
    )
    parser.add_argument(
        "--lut_load", type=str2bool, default=False
    )
    parser.add_argument(
        "--lut_load_format", default="json", type=str, choices=["json", "numpy"]
    )

198
199
200
201
202
203
204
205
206
207
    args = parser.parse_args()
    args.snapshot = os.path.join(args.snapshot, 'supernet')
    args.log_file = os.path.join(args.snapshot, "{}.log".format('supernet'))
    os.makedirs(args.snapshot, exist_ok=True)
    return args


if __name__ == "__main__":
    args = parse_args()
    main(args)