search.py 3.59 KB
Newer Older
qianyj's avatar
qianyj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import os
import random
from argparse import ArgumentParser
from itertools import cycle

import numpy as np
import torch
import torch.nn as nn

from nni.algorithms.nas.pytorch.enas import EnasMutator, EnasTrainer
from nni.nas.pytorch.callbacks import LRSchedulerCallback

from dataloader import read_data_sst
from model import Model
from utils import accuracy


logger = logging.getLogger("nni.textnas")


class TextNASTrainer(EnasTrainer):
    def __init__(self, *args, train_loader=None, valid_loader=None, test_loader=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.test_loader = test_loader

    def init_dataloader(self):
        pass


if __name__ == "__main__":
    parser = ArgumentParser("textnas")
    parser.add_argument("--batch-size", default=128, type=int)
    parser.add_argument("--log-frequency", default=50, type=int)
    parser.add_argument("--seed", default=1234, type=int)
    parser.add_argument("--epochs", default=10, type=int)
    parser.add_argument("--lr", default=5e-3, type=float)
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    np.random.seed(args.seed)
    random.seed(args.seed)
    torch.backends.cudnn.deterministic = True

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    train_dataset, valid_dataset, test_dataset, embedding = read_data_sst("data")
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True)
    valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=args.batch_size, num_workers=4, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, num_workers=4)
    train_loader, valid_loader = cycle(train_loader), cycle(valid_loader)
    model = Model(embedding)

    mutator = EnasMutator(model, temperature=None, tanh_constant=None, entropy_reduction="mean")

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, eps=1e-3, weight_decay=2e-6)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-5)

    trainer = TextNASTrainer(model,
                             loss=criterion,
                             metrics=lambda output, target: {"acc": accuracy(output, target)},
                             reward_function=accuracy,
                             optimizer=optimizer,
                             callbacks=[LRSchedulerCallback(lr_scheduler)],
                             batch_size=args.batch_size,
                             num_epochs=args.epochs,
                             dataset_train=None,
                             dataset_valid=None,
                             train_loader=train_loader,
                             valid_loader=valid_loader,
                             test_loader=test_loader,
                             log_frequency=args.log_frequency,
                             mutator=mutator,
                             mutator_lr=2e-3,
                             mutator_steps=500,
                             mutator_steps_aggregate=1,
                             child_steps=3000,
                             baseline_decay=0.99,
                             test_arc_per_epoch=10)
    trainer.train()
    os.makedirs("checkpoints", exist_ok=True)
    for i in range(20):
        trainer.export(os.path.join("checkpoints", "architecture_%02d.json" % i))