"libai/evaluation/bleu_evaluator.py" did not exist on "fd158e88e82c3fa848017c62a7eccb49a5c64f78"
test.py 1.14 KB
Newer Older
1
2
3
4
5
6
7
import json
import os
import sys
import torch
from pathlib import Path

from nni.retiarii.experiment import RetiariiExperiment, RetiariiExeConfig
8
from nni.retiarii.strategies import TPEStrategy, RandomStrategy
9
10
11
12
13
14
15
16
17
18
19
20
from nni.retiarii.trainer import PyTorchImageClassificationTrainer

from darts_model import CNN

if __name__ == '__main__':
    base_model = CNN(32, 3, 16, 10, 8)
    trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="CIFAR10",
            dataset_kwargs={"root": "data/cifar10", "download": True},
            dataloader_kwargs={"batch_size": 32},
            optimizer_kwargs={"lr": 1e-3},
            trainer_kwargs={"max_epochs": 1})

21
22
    #simple_startegy = TPEStrategy()
    simple_startegy = RandomStrategy()
23
24
25
26
27
28
29
30
31
32
33
34

    exp = RetiariiExperiment(base_model, trainer, [], simple_startegy)

    exp_config = RetiariiExeConfig('local')
    exp_config.experiment_name = 'darts_search'
    exp_config.trial_concurrency = 2
    exp_config.max_trial_number = 10
    exp_config.trial_gpu_number = 1
    exp_config.training_service.use_active_gpu = True
    exp_config.training_service.gpu_indices = [1, 2]

    exp.run(exp_config, 8081, debug=True)