test_training_service.py 2.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import json
from nni.common.device import GPUDevice
import os
import sys
import torch
from pathlib import Path

import nni.retiarii.evaluator.pytorch.lightning as pl
import nni.retiarii.strategy as strategy
from nni.experiment import RemoteMachineConfig
from nni.retiarii import serialize
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from torchvision import transforms
from torchvision.datasets import CIFAR10

from darts_model import CNN

if __name__ == '__main__':
    base_model = CNN(32, 3, 16, 10, 8)

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    train_dataset = serialize(CIFAR10, root='data/cifar10', train=True, download=True, transform=train_transform)
    test_dataset = serialize(CIFAR10, root='data/cifar10', train=False, download=True, transform=valid_transform)
    trainer = pl.Classification(train_dataloader=pl.DataLoader(train_dataset, batch_size=100),
                                val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
                                max_epochs=1, limit_train_batches=0.2)

    simple_strategy = strategy.Random()

    exp = RetiariiExperiment(base_model, trainer, [], simple_strategy)

    exp_config = RetiariiExeConfig('remote')
    exp_config.experiment_name = 'darts_search'
    exp_config.trial_concurrency = 2
    exp_config.max_trial_number = 10
    exp_config.trial_gpu_number = 1
    exp_config.training_service.use_active_gpu = True
    exp_config.training_service.reuse_mode = True
    exp_config.training_service.gpu_indices = [0, 1, 2]

    rm_conf = RemoteMachineConfig()
    rm_conf.host = '127.0.0.1'
    rm_conf.user = 'xxx'
    rm_conf.password = 'xxx'
    rm_conf.port = 22
    rm_conf.python_path = '/home/xxx/py38/bin'
    rm_conf.gpu_indices = [0, 1, 2]
    rm_conf.use_active_gpu = True
    rm_conf.max_trial_number_per_gpu = 3
    
    exp_config.training_service.machine_list = [rm_conf]
    exp_config.execution_engine = 'py'

    exp.run(exp_config, 8081)