Unverified Commit 468917ca authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

Merge pull request #3155 from microsoft/dev-retiarii

[Do NOT Squash] Merge retiarii dev branch to master
parents f8424a9f d5a551c8
import json
import os
import sys
import torch
from pathlib import Path
from nni.retiarii.experiment import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.strategies import TPEStrategy
from nni.retiarii.trainer import PyTorchImageClassificationTrainer
from darts_model import CNN
if __name__ == '__main__':
base_model = CNN(32, 3, 16, 10, 8)
trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="CIFAR10",
dataset_kwargs={"root": "data/cifar10", "download": True},
dataloader_kwargs={"batch_size": 32},
optimizer_kwargs={"lr": 1e-3},
trainer_kwargs={"max_epochs": 1})
simple_startegy = TPEStrategy()
exp = RetiariiExperiment(base_model, trainer, [], simple_startegy)
exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'darts_search'
exp_config.trial_concurrency = 2
exp_config.max_trial_number = 10
exp_config.trial_gpu_number = 1
exp_config.training_service.use_active_gpu = True
exp_config.training_service.gpu_indices = [1, 2]
exp.run(exp_config, 8081, debug=True)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
import json
import logging
import random
import os
from nni.retiarii import Model, submit_models, wait_models
from nni.retiarii.strategy import BaseStrategy
from nni.retiarii import Sampler
_logger = logging.getLogger(__name__)
class RandomSampler(Sampler):
def choice(self, candidates, mutator, model, index):
return random.choice(candidates)
class SimpleStrategy(BaseStrategy):
def __init__(self):
self.name = ''
def run(self, base_model, applied_mutators, trainer):
try:
_logger.info('stargety start...')
while True:
model = base_model
_logger.info('apply mutators...')
_logger.info('mutators: {}'.format(applied_mutators))
random_sampler = RandomSampler()
for mutator in applied_mutators:
_logger.info('mutate model...')
mutator.bind_sampler(random_sampler)
model = mutator.apply(model)
# get and apply training approach
_logger.info('apply training approach...')
model.apply_trainer(trainer['modulename'], trainer['args'])
# run models
submit_models(model)
wait_models(model)
_logger.info('Strategy says:', model.metric)
except Exception as e:
_logger.error(logging.exception('message'))
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment