"test/vscode:/vscode.git/clone" did not exist on "32ba4d9abab78ffa20c39068547ea0bf111e8057"
config.py 7.93 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import argparse
from functools import partial


def get_parser(name):
    """ make default formatted parser """
    parser = argparse.ArgumentParser(name, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    # print default value always
    parser.add_argument = partial(parser.add_argument, help=' ')
    return parser


class BaseConfig(argparse.Namespace):
    def print_params(self, prtf=print):
        prtf("")
        prtf("Parameters:")
        for attr, value in sorted(vars(self).items()):
            prtf("{}={}".format(attr.upper(), value))
        prtf("")

    def as_markdown(self):
        """ Return configs as markdown format """
        text = "|name|value|  \n|-|-|  \n"
        for attr, value in sorted(vars(self).items()):
            text += "|{}|{}|  \n".format(attr, value)

        return text


class SearchConfig(BaseConfig):
    def build_parser(self):
        parser = get_parser("Search config")
        ########### basic settings ############
        parser.add_argument('--dataset', default='cifar10', choices=['cifar10', 'cifar100', 'imagenet'])
        parser.add_argument('--n_classes', type=int, default=10)
        parser.add_argument('--stem_multiplier', type=int, default=3)
        parser.add_argument('--init_channels', type=int, default=16)
        parser.add_argument('--data_dir', type=str, default='data/cifar', help='cifar dataset')
        parser.add_argument('--output_path', type=str, default='./outputs', help='')
        parser.add_argument('--batch_size', type=int, default=128, help='batch size')
        parser.add_argument('--log_frequency', type=int, default=10, help='print frequency')
        parser.add_argument('--seed', type=int, default=0, help='random seed')
        parser.add_argument('--workers', type=int, default=4, help='# of workers')
        parser.add_argument('--steps_per_epoch', type=int, default=None, help='how many steps per epoch, use None for one pass of dataset')

        ########### learning rate ############
        parser.add_argument('--w_lr', type=float, default=0.05, help='lr for weights')
        parser.add_argument('--w_momentum', type=float, default=0.9, help='momentum for weights')
        parser.add_argument('--w_weight_decay', type=float, default=3e-4, help='weight decay for weights')
        parser.add_argument('--grad_clip', type=float, default=5., help='gradient clipping for weights')
        parser.add_argument('--alpha_lr', type=float, default=6e-4, help='lr for alpha')
        parser.add_argument('--alpha_weight_decay', type=float, default=1e-3, help='weight decay for alpha')
        parser.add_argument('--nasnet_lr', type=float, default=0.1, help='lr of nasnet')

        ########### alternate training ############
        parser.add_argument('--epochs', type=int, default=32, help='# of search epochs')
        parser.add_argument('--warmup_epochs', type=int, default=2, help='# warmup epochs of super model')
        parser.add_argument('--loss_alpha', type=float, default=1, help='loss alpha')
        parser.add_argument('--loss_T', type=float, default=2, help='loss temperature')
        parser.add_argument('--interactive_type', type=str, default='kl', choices=['kl', 'smoothl1'])
        parser.add_argument('--sync_bn', action='store_true', default=False, help='whether to sync bn')
        parser.add_argument('--use_apex', action='store_true', default=False, help='whether to use apex')
        parser.add_argument('--regular_ratio', type=float, default=0.5, help='regular ratio')
        parser.add_argument('--regular_coeff', type=float, default=5, help='regular coefficient')
        parser.add_argument('--fix_head', action='store_true', default=False, help='whether to fix head')
        parser.add_argument('--share_module', action='store_true', default=False, help='whether to share stem and aux head')

        ########### data augument ############
        parser.add_argument('--aux_weight', type=float, default=0.4, help='auxiliary loss weight')
        parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
        parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path prob')
        parser.add_argument('--use_aa', action='store_true', default=False, help='whether to use aa')
        parser.add_argument('--mixup_alpha', default=1., type=float, help='mixup interpolation coefficient (default: 1)')

        ########### distributed ############
        parser.add_argument("--local_rank", default=0, type=int)
        parser.add_argument("--world_size", default=1, type=int)
        parser.add_argument('--dist_url', default='tcp://127.0.0.1:23456', type=str, help='url used to set up distributed training')
        parser.add_argument('--distributed', action='store_true', help='run model distributed mode')

        return parser

    def __init__(self):
        parser = self.build_parser()
        args = parser.parse_args()
        super().__init__(**vars(args))


class RetrainConfig(BaseConfig):
    def build_parser(self):
        parser = get_parser("Retrain config")
        parser.add_argument('--dataset', default="cifar10", choices=['cifar10', 'cifar100', 'imagenet'])
        parser.add_argument('--data_dir', type=str, default='data/cifar', help='cifar dataset')
        parser.add_argument('--output_path', type=str, default='./outputs', help='')
        parser.add_argument("--arc_checkpoint", default="epoch_02.json")
        parser.add_argument('--log_frequency', type=int, default=10, help='print frequency')

        ########### model settings ############
        parser.add_argument('--n_classes', type=int, default=10)
        parser.add_argument('--input_channels', type=int, default=3)
        parser.add_argument('--stem_multiplier', type=int, default=3)
        parser.add_argument('--batch_size', type=int, default=128, help='batch size')
        parser.add_argument('--eval_batch_size', type=int, default=500, help='batch size for validation')
        parser.add_argument('--lr', type=float, default=0.025, help='lr for weights')
        parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
        parser.add_argument('--grad_clip', type=float, default=5., help='gradient clipping for weights')
        parser.add_argument('--weight_decay', type=float, default=5e-4, help='weight decay')
        parser.add_argument('--epochs', type=int, default=600, help='# of training epochs')
        parser.add_argument('--warmup_epochs', type=int, default=5, help='# warmup')
        parser.add_argument('--init_channels', type=int, default=36)
        parser.add_argument('--layers', type=int, default=20, help='# of layers')
        parser.add_argument('--seed', type=int, default=0, help='random seed')
        parser.add_argument('--workers', type=int, default=4, help='# of workers')
        parser.add_argument('--aux_weight', type=float, default=0.4, help='auxiliary loss weight')
        parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
        parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing')
        parser.add_argument('--drop_path_prob', type=float, default=0.3, help='drop path prob')

        ########### data augmentation ############
        parser.add_argument('--use_aa', action='store_true', default=False, help='whether to use aa')
        parser.add_argument('--mixup_alpha', default=1., type=float, help='mixup interpolation coefficient')

        ########### distributed ############
        parser.add_argument("--local_rank", default=0, type=int)
        parser.add_argument("--world_size", default=1, type=int)
        parser.add_argument('--dist_url', default='tcp://127.0.0.1:23456', type=str, help='url used to set up distributed training')
        parser.add_argument('--distributed', action='store_true', help='run model distributed mode')

        return parser

    def __init__(self):
        parser = self.build_parser()
        args = parser.parse_args()
        super().__init__(**vars(args))