sweep_ooo_bi_transformer_big.py 4.44 KB
Newer Older
Myle Ott's avatar
Myle Ott committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#!/usr/bin/env python

try:
    import sweep_chronos as sweep
    from sweep_chronos import hyperparam
except:
    import sweep
    from sweep import hyperparam

from sweep_lm_data import set_data_based_on_shortname


def get_grid(args):
    grid = []

    enable_ooo = True

    max_update = 250000
    num_data_loaders = 4
    ddp_backend = 'no_c10d'

    # 64 x 16GB Volta
    arch = 'bi_transformer_lm_big'
    max_sentences = 16
    update_freq = 2
    peak_lr = 2.5e-4
    lr_scheduler = 'inverse_sqrt'
    #lr_scheduler = 'polynomial'
    #grid += [hyperparam('--decoder-learned-pos', save_dir_key=lambda val: 'learnpos')]

    max_tokens = 550 * max_sentences

    set_data_based_on_shortname(args)

    # batch size
    grid += [
        hyperparam('--tokens-per-sample', 512, save_dir_key=lambda val: f'st{val}'),
        hyperparam('--max-sentences', max_sentences, save_dir_key=lambda val: f'ms{val}'),
        hyperparam('--max-tokens', max_tokens, save_dir_key=lambda val: f'mt{val}'),
        hyperparam('--update-freq', update_freq, save_dir_key=lambda val: f'uf{val}'),
    ]

    # task settings
    grid += [
        hyperparam('--task', 'odd_one_out_lm'),
    ]

    # model settings
    grid += [
        hyperparam('--arch', arch, save_dir_key=lambda val: val),
        hyperparam('--activation-fn', 'gelu_fast', save_dir_key=lambda val: val),
        hyperparam('--share-decoder-input-output-embed'),
    ]

    # regularization
    grid += [
        hyperparam('--dropout', 0.1, save_dir_key=lambda val: f'dr{val}'),
        hyperparam('--attention-dropout', 0.0, save_dir_key=lambda val: f'atdr{val}'),
        hyperparam('--weight-decay', 0.01, save_dir_key=lambda val: f'wd{val}'),
    ]

    # optimization settings
    grid += [
        hyperparam('--optimizer', 'adam', save_dir_key=lambda val: val),
        hyperparam('--adam-betas', '(0.9, 0.999)', save_dir_key=lambda val: 'beta2_999'),
        hyperparam('--clip-norm', 0.0, save_dir_key=lambda val: f'clip{val}'),
    ]

    # lr scheduler
    if lr_scheduler == 'inverse_sqrt':
        grid += [
            hyperparam('--lr-scheduler', 'inverse_sqrt'),
            hyperparam('--lr', peak_lr, save_dir_key=lambda val: f'lr{val}'),
            hyperparam('--warmup-init-lr', 0),
            hyperparam('--warmup-updates', 16000, save_dir_key=lambda val: f'warm{val}'),
        ]
    elif lr_scheduler == 'polynomial':
        grid += [
            hyperparam('--lr-scheduler', 'polynomial_decay'),
            hyperparam('--lr', peak_lr, save_dir_key=lambda val: f'lr{val}'),
            hyperparam('--total-num-update', max_update),
            hyperparam('--warmup-updates', 16000, save_dir_key=lambda val: f'warm{val}'),
        ]

    if enable_ooo:
        grid += [
            hyperparam('--criterion', 'odd_one_out'),
            hyperparam('--ooo-weight', [1.0], save_dir_key=lambda val: str(val)),
            hyperparam('--short-item-prob', 0.1, save_dir_key=lambda val: f'short{val}'),
        ]

    # FP16 + distributed settings
    grid += [
        hyperparam('--ddp-backend', ddp_backend),

        hyperparam('--fp16', save_dir_key=lambda val: 'fp16'),
        #hyperparam('--memory-efficient-fp16', save_dir_key=lambda val: 'me_fp16'),
        hyperparam('--fp16-init-scale', 4),
        hyperparam('--threshold-loss-scale', 1),
        hyperparam('--fp16-scale-window', 128),
    ]

    # data loading settings
    grid += [
        hyperparam('--dataset-impl', 'mmap'),
        hyperparam('--num-workers', num_data_loaders),
    ]

    # validation and checkpoint settings
    grid += [
        hyperparam('--save-interval-updates', 2000),
        hyperparam('--no-epoch-checkpoints'),
        hyperparam('--max-update', max_update, save_dir_key=lambda val: f'mu{val}'),
    ]

    # logging settings
    grid += [
        hyperparam('--log-format', 'json'),
        hyperparam('--log-interval', 25),
    ]

    # random seed
    grid += [
        hyperparam('--seed', [1], save_dir_key=lambda val: f'seed{val}'),
    ]

    if args.local:
        grid += [
            hyperparam('--log-format', 'json'),
            hyperparam('--log-interval', 1),
        ]

    return grid


def postprocess_hyperparams(args, config):
    """Postprocess a given hyperparameter configuration."""
    # if config['--seq-beam'].current_value <= 8:
    #    config['--max-tokens'].current_value = 400
    # else:
    #    config['--max-tokens'].current_value = 300
    pass


if __name__ == '__main__':
    sweep.main(get_grid, postprocess_hyperparams)