test.py 1.83 KB
Newer Older
1
2
3
4
5
6
import json
import os
import sys
import torch
from pathlib import Path

7
import nni.retiarii.evaluator.pytorch.lightning as pl
8
import nni.retiarii.strategy as strategy
9
from nni.retiarii import serialize
10
11
12
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from torchvision import transforms
from torchvision.datasets import CIFAR10
13
14
15
16
17
18

from darts_model import CNN

if __name__ == '__main__':
    base_model = CNN(32, 3, 16, 10, 8)

19
20
21
22
23
24
25
26
27
28
29
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

30
31
    train_dataset = serialize(CIFAR10, root='data/cifar10', train=True, download=True, transform=train_transform)
    test_dataset = serialize(CIFAR10, root='data/cifar10', train=False, download=True, transform=valid_transform)
32
33
    trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
                                val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
34
35
                                max_epochs=1, limit_train_batches=0.2,
                                progress_bar_refresh_rate=0)
36

37
    simple_strategy = strategy.Random()
38

39
    exp = RetiariiExperiment(base_model, trainer, [], simple_strategy)
40
41
42
43
44
45
46
47
48

    exp_config = RetiariiExeConfig('local')
    exp_config.experiment_name = 'darts_search'
    exp_config.trial_concurrency = 2
    exp_config.max_trial_number = 10
    exp_config.trial_gpu_number = 1
    exp_config.training_service.use_active_gpu = True
    exp_config.training_service.gpu_indices = [1, 2]

49
    exp.run(exp_config, 8081)