Unverified Commit 192a807b authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[Retiarii] refactor based on the new launch approach (#3185)

parent 80394047
......@@ -9,6 +9,7 @@ import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parents[2]))
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import register_module
# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
# 1.0 - tensorflow.
......@@ -109,7 +110,7 @@ def _get_depths(depths, alpha):
rather than down. """
return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]
@register_module()
class MNASNet(nn.Module):
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
implements the B1 variant of the model.
......@@ -126,8 +127,7 @@ class MNASNet(nn.Module):
def __init__(self, alpha, depths, convops, kernel_sizes, num_layers,
skips, num_classes=1000, dropout=0.2):
super(MNASNet, self).__init__(alpha, depths, convops, kernel_sizes, num_layers,
skips, num_classes, dropout)
super(MNASNet, self).__init__()
assert alpha > 0.0
assert len(depths) == len(convops) == len(kernel_sizes) == len(num_layers) == len(skips) == 7
self.alpha = alpha
......
......@@ -3,20 +3,11 @@ import sys
import torch
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parents[2]))
from nni.retiarii.converter.graph_gen import convert_to_graph
from nni.retiarii.converter.visualize import visualize_model
from nni.retiarii.codegen.pytorch import model_to_pytorch_script
from nni.retiarii import nn
from nni.retiarii.trainer import PyTorchImageClassificationTrainer
from nni.retiarii.utils import TraceClassArguments
from base_mnasnet import MNASNet
from nni.experiment import RetiariiExperiment, RetiariiExpConfig
from nni.retiarii.experiment import RetiariiExperiment, RetiariiExeConfig
#from simple_strategy import SimpleStrategy
#from tpe_strategy import TPEStrategy
from nni.retiarii.strategies import TPEStrategy
from mutator import BlockMutator
......@@ -27,7 +18,6 @@ if __name__ == '__main__':
_DEFAULT_KERNEL_SIZES = [3, 3, 5, 5, 3, 5, 3]
_DEFAULT_NUM_LAYERS = [1, 3, 3, 3, 2, 4, 1]
with TraceClassArguments() as tca:
base_model = MNASNet(0.5, _DEFAULT_DEPTHS, _DEFAULT_CONVOPS, _DEFAULT_KERNEL_SIZES,
_DEFAULT_NUM_LAYERS, _DEFAULT_SKIPS)
trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="CIFAR10",
......@@ -36,15 +26,6 @@ if __name__ == '__main__':
optimizer_kwargs={"lr": 1e-3},
trainer_kwargs={"max_epochs": 1})
'''script_module = torch.jit.script(base_model)
model = convert_to_graph(script_module, base_model, tca.recorded_arguments)
code_script = model_to_pytorch_script(model)
print(code_script)
print("Model: ", model)
graph_ir = model._dump()
print(graph_ir)
visualize_model(graph_ir)'''
# new interface
applied_mutators = []
applied_mutators.append(BlockMutator('mutable_0'))
......@@ -52,11 +33,12 @@ if __name__ == '__main__':
simple_startegy = TPEStrategy()
exp = RetiariiExperiment(base_model, trainer, applied_mutators, simple_startegy, tca)
exp = RetiariiExperiment(base_model, trainer, applied_mutators, simple_startegy)
exp_config = RetiariiExpConfig.create_template('local')
exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnasnet_search'
exp_config.trial_concurrency = 2
exp_config.max_trial_number = 10
exp_config.training_service.use_active_gpu = False
exp.run(exp_config, 8081, debug=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment