"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "0d5a6cd0275b9b52a162f1bea43d4fd006171e1c"
Unverified Commit 192a807b authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[Retiarii] refactor based on the new launch approach (#3185)

parent 80394047
...@@ -9,6 +9,7 @@ import sys ...@@ -9,6 +9,7 @@ import sys
from pathlib import Path from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parents[2])) sys.path.append(str(Path(__file__).resolve().parents[2]))
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import register_module
# Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is # Paper suggests 0.9997 momentum, for TensorFlow. Equivalent PyTorch momentum is
# 1.0 - tensorflow. # 1.0 - tensorflow.
...@@ -109,7 +110,7 @@ def _get_depths(depths, alpha): ...@@ -109,7 +110,7 @@ def _get_depths(depths, alpha):
rather than down. """ rather than down. """
return [_round_to_multiple_of(depth * alpha, 8) for depth in depths] return [_round_to_multiple_of(depth * alpha, 8) for depth in depths]
@register_module()
class MNASNet(nn.Module): class MNASNet(nn.Module):
""" MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This """ MNASNet, as described in https://arxiv.org/pdf/1807.11626.pdf. This
implements the B1 variant of the model. implements the B1 variant of the model.
...@@ -126,8 +127,7 @@ class MNASNet(nn.Module): ...@@ -126,8 +127,7 @@ class MNASNet(nn.Module):
def __init__(self, alpha, depths, convops, kernel_sizes, num_layers, def __init__(self, alpha, depths, convops, kernel_sizes, num_layers,
skips, num_classes=1000, dropout=0.2): skips, num_classes=1000, dropout=0.2):
super(MNASNet, self).__init__(alpha, depths, convops, kernel_sizes, num_layers, super(MNASNet, self).__init__()
skips, num_classes, dropout)
assert alpha > 0.0 assert alpha > 0.0
assert len(depths) == len(convops) == len(kernel_sizes) == len(num_layers) == len(skips) == 7 assert len(depths) == len(convops) == len(kernel_sizes) == len(num_layers) == len(skips) == 7
self.alpha = alpha self.alpha = alpha
......
...@@ -3,20 +3,11 @@ import sys ...@@ -3,20 +3,11 @@ import sys
import torch import torch
from pathlib import Path from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parents[2]))
from nni.retiarii.converter.graph_gen import convert_to_graph
from nni.retiarii.converter.visualize import visualize_model
from nni.retiarii.codegen.pytorch import model_to_pytorch_script
from nni.retiarii import nn
from nni.retiarii.trainer import PyTorchImageClassificationTrainer from nni.retiarii.trainer import PyTorchImageClassificationTrainer
from nni.retiarii.utils import TraceClassArguments
from base_mnasnet import MNASNet from base_mnasnet import MNASNet
from nni.experiment import RetiariiExperiment, RetiariiExpConfig from nni.retiarii.experiment import RetiariiExperiment, RetiariiExeConfig
#from simple_strategy import SimpleStrategy
#from tpe_strategy import TPEStrategy
from nni.retiarii.strategies import TPEStrategy from nni.retiarii.strategies import TPEStrategy
from mutator import BlockMutator from mutator import BlockMutator
...@@ -27,23 +18,13 @@ if __name__ == '__main__': ...@@ -27,23 +18,13 @@ if __name__ == '__main__':
_DEFAULT_KERNEL_SIZES = [3, 3, 5, 5, 3, 5, 3] _DEFAULT_KERNEL_SIZES = [3, 3, 5, 5, 3, 5, 3]
_DEFAULT_NUM_LAYERS = [1, 3, 3, 3, 2, 4, 1] _DEFAULT_NUM_LAYERS = [1, 3, 3, 3, 2, 4, 1]
with TraceClassArguments() as tca: base_model = MNASNet(0.5, _DEFAULT_DEPTHS, _DEFAULT_CONVOPS, _DEFAULT_KERNEL_SIZES,
base_model = MNASNet(0.5, _DEFAULT_DEPTHS, _DEFAULT_CONVOPS, _DEFAULT_KERNEL_SIZES, _DEFAULT_NUM_LAYERS, _DEFAULT_SKIPS)
_DEFAULT_NUM_LAYERS, _DEFAULT_SKIPS) trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="CIFAR10",
trainer = PyTorchImageClassificationTrainer(base_model, dataset_cls="CIFAR10", dataset_kwargs={"root": "data/cifar10", "download": True},
dataset_kwargs={"root": "data/cifar10", "download": True}, dataloader_kwargs={"batch_size": 32},
dataloader_kwargs={"batch_size": 32}, optimizer_kwargs={"lr": 1e-3},
optimizer_kwargs={"lr": 1e-3}, trainer_kwargs={"max_epochs": 1})
trainer_kwargs={"max_epochs": 1})
'''script_module = torch.jit.script(base_model)
model = convert_to_graph(script_module, base_model, tca.recorded_arguments)
code_script = model_to_pytorch_script(model)
print(code_script)
print("Model: ", model)
graph_ir = model._dump()
print(graph_ir)
visualize_model(graph_ir)'''
# new interface # new interface
applied_mutators = [] applied_mutators = []
...@@ -52,11 +33,12 @@ if __name__ == '__main__': ...@@ -52,11 +33,12 @@ if __name__ == '__main__':
simple_startegy = TPEStrategy() simple_startegy = TPEStrategy()
exp = RetiariiExperiment(base_model, trainer, applied_mutators, simple_startegy, tca) exp = RetiariiExperiment(base_model, trainer, applied_mutators, simple_startegy)
exp_config = RetiariiExpConfig.create_template('local') exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnasnet_search' exp_config.experiment_name = 'mnasnet_search'
exp_config.trial_concurrency = 2 exp_config.trial_concurrency = 2
exp_config.max_trial_number = 10 exp_config.max_trial_number = 10
exp_config.training_service.use_active_gpu = False
exp.run(exp_config, 8081, debug=True) exp.run(exp_config, 8081, debug=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment