search.py 1.99 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
import os
import random
import time

import numpy as np
import torch
import torch.nn as nn

import utils
from config import SearchConfig
from datasets.cifar import get_search_datasets
from model import Model
from nni.nas.pytorch.cdarts import CdartsTrainer

if __name__ == "__main__":
    config = SearchConfig()
    main_proc = not config.distributed or config.local_rank == 0
    if config.distributed:
        torch.cuda.set_device(config.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method=config.dist_url,
                                             rank=config.local_rank, world_size=config.world_size)
    if main_proc:
        os.makedirs(config.output_path, exist_ok=True)
    if config.distributed:
        torch.distributed.barrier()
    logger = utils.get_logger(os.path.join(config.output_path, 'search.log'))
    if main_proc:
        config.print_params(logger.info)
    utils.reset_seed(config.seed)

    loaders, samplers = get_search_datasets(config)
    model_small = Model(config.dataset, 8).cuda()
    if config.share_module:
        model_large = Model(config.dataset, 20, shared_modules=model_small.shared_modules).cuda()
    else:
        model_large = Model(config.dataset, 20).cuda()

    criterion = nn.CrossEntropyLoss()
    trainer = CdartsTrainer(model_small, model_large, criterion, loaders, samplers, logger,
                            config.regular_coeff, config.regular_ratio, config.warmup_epochs, config.fix_head,
                            config.epochs, config.steps_per_epoch, config.loss_alpha, config.loss_T, config.distributed,
                            config.log_frequency, config.grad_clip, config.interactive_type, config.output_path,
                            config.w_lr, config.w_momentum, config.w_weight_decay, config.alpha_lr, config.alpha_weight_decay,
                            config.nasnet_lr, config.local_rank, config.share_module)
    trainer.train()