"driver/driver.hip.cpp" did not exist on "1de6fd07535833877019634a95eafd329406be4c"
Unverified Commit 1a3c019a authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Bug fix of mutating architectures and hparams simultaneously (#4739)

parent 84b9c9b2
...@@ -144,6 +144,10 @@ def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, dummy_ ...@@ -144,6 +144,10 @@ def preprocess_model(base_model, trainer, applied_mutators, full_ir=True, dummy_
'do not use mutators when you use LayerChoice/InputChoice') 'do not use mutators when you use LayerChoice/InputChoice')
if mutators is not None: if mutators is not None:
applied_mutators = mutators applied_mutators = mutators
# Add mutations on evaluators
applied_mutators += process_evaluator_mutations(trainer, applied_mutators)
return base_model_ir, applied_mutators return base_model_ir, applied_mutators
...@@ -203,7 +207,6 @@ class RetiariiExperiment(Experiment): ...@@ -203,7 +207,6 @@ class RetiariiExperiment(Experiment):
full_ir=self.config.execution_engine not in ['py', 'benchmark'], full_ir=self.config.execution_engine not in ['py', 'benchmark'],
dummy_input=self.config.dummy_input dummy_input=self.config.dummy_input
) )
self.applied_mutators += process_evaluator_mutations(self.trainer, self.applied_mutators)
_logger.info('Start strategy...') _logger.info('Start strategy...')
search_space = dry_run_for_formatted_search_space(base_model_ir, self.applied_mutators) search_space = dry_run_for_formatted_search_space(base_model_ir, self.applied_mutators)
......
...@@ -755,7 +755,7 @@ class ValueChoice(ValueChoiceX, Mutable): ...@@ -755,7 +755,7 @@ class ValueChoice(ValueChoiceX, Mutable):
(i.e., modules in ``nni.retiarii.nn.pytorch`` and user-defined modules decorated with ``@basic_unit``). (i.e., modules in ``nni.retiarii.nn.pytorch`` and user-defined modules decorated with ``@basic_unit``).
* Used as input arguments of evaluator (*new in v2.7*). * Used as input arguments of evaluator (*new in v2.7*).
It can be used in parameters of operators: :: It can be used in parameters of operators (i.e., a sub-module of the model): ::
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
...@@ -765,7 +765,8 @@ class ValueChoice(ValueChoiceX, Mutable): ...@@ -765,7 +765,8 @@ class ValueChoice(ValueChoiceX, Mutable):
def forward(self, x): def forward(self, x):
return self.conv(x) return self.conv(x)
Or evaluator: :: Or evaluator (only if the evaluator is :doc:`traceable </nas/serialization>`, e.g.,
:class:`FunctionalEvaluator <nni.retiarii.evaluator.FunctionalEvaluator>`): ::
def train_and_evaluate(model_cls, learning_rate): def train_and_evaluate(model_cls, learning_rate):
... ...
......
...@@ -2,8 +2,7 @@ ...@@ -2,8 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import inspect import inspect
from collections import defaultdict from typing import Any, List, Optional, Tuple, Dict, Iterator
from typing import Any, List, Optional, Tuple, Dict
import torch.nn as nn import torch.nn as nn
...@@ -361,26 +360,41 @@ class EvaluatorValueChoiceMutator(Mutator): ...@@ -361,26 +360,41 @@ class EvaluatorValueChoiceMutator(Mutator):
# works in the same way as `ParameterChoiceMutator` # works in the same way as `ParameterChoiceMutator`
# we only need one such mutator for one model/evaluator # we only need one such mutator for one model/evaluator
def mutate(self, model: Model): def _mutate_traceable_object(self, obj: Any, value_choice_decisions: Dict[str, Any]) -> Any:
# make a copy to mutate the evaluator if not is_traceable(obj):
model.evaluator = model.evaluator.trace_copy() return obj
value_choice_decisions = {} if not any(isinstance(value, ValueChoiceX) for value in obj.trace_kwargs.values()):
for mutation in model.history: # No valuechoice, not interesting
if isinstance(mutation.mutator, EvaluatorValueChoiceLeafMutator): return obj
value_choice_decisions[mutation.mutator.label] = mutation.samples[0]
# Make a copy
obj = obj.trace_copy()
result = {} result = {}
# for each argument that is a composition of value choice # For each argument that is a composition of value choice
# we find all the leaf-value-choice in the mutation # we find all the leaf-value-choice in the mutation
# and compute the final result # and compute the final result
for key, param in model.evaluator.trace_kwargs.items(): for key, param in obj.trace_kwargs.items():
if isinstance(param, ValueChoiceX): if isinstance(param, ValueChoiceX):
leaf_node_values = [value_choice_decisions[choice.label] for choice in param.inner_choices()] leaf_node_values = [value_choice_decisions[choice.label] for choice in param.inner_choices()]
result[key] = param.evaluate(leaf_node_values) result[key] = param.evaluate(leaf_node_values)
elif is_traceable(param):
# Recursively
result[key] = self._mutate_traceable_object(param, value_choice_decisions)
obj.trace_kwargs.update(result)
return obj
def mutate(self, model: Model):
value_choice_decisions = {}
for mutation in model.history:
if isinstance(mutation.mutator, EvaluatorValueChoiceLeafMutator):
value_choice_decisions[mutation.mutator.label] = mutation.samples[0]
model.evaluator.trace_kwargs.update(result) model.evaluator = self._mutate_traceable_object(model.evaluator, value_choice_decisions)
def process_evaluator_mutations(evaluator: Evaluator, existing_mutators: List[Mutator]) -> List[Mutator]: def process_evaluator_mutations(evaluator: Evaluator, existing_mutators: List[Mutator]) -> List[Mutator]:
...@@ -389,27 +403,25 @@ def process_evaluator_mutations(evaluator: Evaluator, existing_mutators: List[Mu ...@@ -389,27 +403,25 @@ def process_evaluator_mutations(evaluator: Evaluator, existing_mutators: List[Mu
if not is_traceable(evaluator): if not is_traceable(evaluator):
return [] return []
mutator_candidates = {} mutator_candidates = {}
mutator_keys = defaultdict(list) for param in _expand_nested_trace_kwargs(evaluator):
for key, param in evaluator.trace_kwargs.items():
if isinstance(param, ValueChoiceX): if isinstance(param, ValueChoiceX):
for choice in param.inner_choices(): for choice in param.inner_choices():
# merge duplicate labels # merge duplicate labels
for mutator in existing_mutators: for mutator in existing_mutators:
if mutator.name == choice.label: if mutator.label == choice.label:
raise ValueError( raise ValueError(
f'Found duplicated labels “{choice.label}”. When two value choices have the same name, ' f'Found duplicated labels “{choice.label}”. When two value choices have the same name, '
'they would share choices. However, sharing choices between model and evaluator is not yet supported.' 'they would share choices. However, sharing choices between model and evaluator is not supported.'
) )
if choice.label in mutator_candidates and mutator_candidates[choice.label] != choice.candidates: if choice.label in mutator_candidates and mutator_candidates[choice.label] != choice.candidates:
raise ValueError( raise ValueError(
f'Duplicate labels for evaluator ValueChoice {choice.label}. They should share choices.' f'Duplicate labels for evaluator ValueChoice {choice.label}. They should share choices.'
f'But their candidate list is not equal: {mutator_candidates[choice.label][1]} vs. {choice.candidates}' f'But their candidate list is not equal: {mutator_candidates[choice.label][1]} vs. {choice.candidates}'
) )
mutator_keys[choice.label].append(key)
mutator_candidates[choice.label] = choice.candidates mutator_candidates[choice.label] = choice.candidates
mutators = [] mutators = []
for label in mutator_keys: for label, candidates in mutator_candidates.items():
mutators.append(EvaluatorValueChoiceLeafMutator(mutator_candidates[label], label)) mutators.append(EvaluatorValueChoiceLeafMutator(candidates, label))
if mutators: if mutators:
# one last mutator to actually apply the mutations # one last mutator to actually apply the mutations
mutators.append(EvaluatorValueChoiceMutator()) mutators.append(EvaluatorValueChoiceMutator())
...@@ -446,3 +458,15 @@ def _group_by_label(nodes: List[Node]) -> List[List[Node]]: ...@@ -446,3 +458,15 @@ def _group_by_label(nodes: List[Node]) -> List[List[Node]]:
result[label] = [] result[label] = []
result[label].append(node) result[label].append(node)
return list(result.values()) return list(result.values())
def _expand_nested_trace_kwargs(obj: Any) -> Iterator[Any]:
# Get items from `trace_kwargs`.
# If some item is traceable itself, get items recursively.
if not is_traceable(obj):
return
for param in obj.trace_kwargs.values():
yield param
yield from _expand_nested_trace_kwargs(param)
...@@ -5,6 +5,7 @@ from collections import Counter ...@@ -5,6 +5,7 @@ from collections import Counter
import pytest import pytest
import nni
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -13,6 +14,7 @@ from nni.retiarii.converter import convert_to_graph ...@@ -13,6 +14,7 @@ from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.evaluator import FunctionalEvaluator from nni.retiarii.evaluator import FunctionalEvaluator
from nni.retiarii.execution.utils import _unpack_if_only_one from nni.retiarii.execution.utils import _unpack_if_only_one
from nni.retiarii.experiment.pytorch import preprocess_model
from nni.retiarii.graph import Model from nni.retiarii.graph import Model
from nni.retiarii.nn.pytorch.api import ValueChoice from nni.retiarii.nn.pytorch.api import ValueChoice
from nni.retiarii.nn.pytorch.mutator import process_evaluator_mutations, process_inline_mutation, extract_mutation_from_pt_module from nni.retiarii.nn.pytorch.mutator import process_evaluator_mutations, process_inline_mutation, extract_mutation_from_pt_module
...@@ -68,6 +70,8 @@ class GraphIR(unittest.TestCase): ...@@ -68,6 +70,8 @@ class GraphIR(unittest.TestCase):
value_choice_incr = 1 value_choice_incr = 1
# graph engine has an extra mutator to apply the depth choice to nodes # graph engine has an extra mutator to apply the depth choice to nodes
repeat_incr = 1 repeat_incr = 1
# graph engine parse the model into graph
graph_engine = True
def _convert_to_ir(self, model): def _convert_to_ir(self, model):
script_module = torch.jit.script(model) script_module = torch.jit.script(model)
...@@ -565,6 +569,48 @@ class GraphIR(unittest.TestCase): ...@@ -565,6 +569,48 @@ class GraphIR(unittest.TestCase):
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
self._get_model_with_mutators(Net()) self._get_model_with_mutators(Net())
def test_valuechoice_hybrid_arch_hparams(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 5, kernel_size=nn.ValueChoice([3, 5]))
def forward(self, x):
return self.conv(x)
def foo():
pass
evaluator = FunctionalEvaluator(foo, t=1, x=ValueChoice([1, 2]), y=ValueChoice([3, 4]))
model, mutators = preprocess_model(Net(), evaluator, [], full_ir=self.graph_engine)
samplers = [EnumerateSampler() for _ in range(len(mutators))]
model1 = _apply_all_mutators(model, mutators, samplers)
model2 = _apply_all_mutators(model, mutators, samplers)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 5, 3, 3]))
self.assertEqual(model1.evaluator.trace_kwargs['x'], 1)
self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 5, 1, 1]))
self.assertEqual(model2.evaluator.trace_kwargs['y'], 4)
def test_valuechoice_hybrid_arch_hparams_conflict_label(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 5, kernel_size=nn.ValueChoice([3, 5], label='123'))
def forward(self, x):
return self.conv(x)
def foo():
pass
evaluator = FunctionalEvaluator(foo, t=1, x=ValueChoice([3, 5], label='123'))
with pytest.raises(ValueError, match='share'):
preprocess_model(Net(), evaluator, [], full_ir=self.graph_engine)
def test_repeat(self): def test_repeat(self):
class AddOne(nn.Module): class AddOne(nn.Module):
def forward(self, x): def forward(self, x):
...@@ -848,6 +894,7 @@ class Python(GraphIR): ...@@ -848,6 +894,7 @@ class Python(GraphIR):
# Python engine doesn't have the extra mutator # Python engine doesn't have the extra mutator
value_choice_incr = 0 value_choice_incr = 0
repeat_incr = 0 repeat_incr = 0
graph_engine = False
def _get_converted_pytorch_model(self, model_ir): def _get_converted_pytorch_model(self, model_ir):
mutation = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model_ir.history} mutation = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model_ir.history}
...@@ -1136,6 +1183,30 @@ class Shared(unittest.TestCase): ...@@ -1136,6 +1183,30 @@ class Shared(unittest.TestCase):
model = _apply_all_mutators(init_model, mutators, sampler) model = _apply_all_mutators(init_model, mutators, sampler)
assert (model.evaluator.trace_kwargs['x'], model.evaluator.trace_kwargs['y']) in [(1, 2), (3, 4)] assert (model.evaluator.trace_kwargs['x'], model.evaluator.trace_kwargs['y']) in [(1, 2), (3, 4)]
def test_valuechoice_in_evaluator_nested(self):
@nni.trace
class FooClass:
def __init__(self, a):
self.a = a
obj = FooClass(ValueChoice([1, 2, 3], label='t'))
def foo():
pass
evaluator = FunctionalEvaluator(foo, t=obj, v=ValueChoice([1, 2, 3], label='t') + ValueChoice([10, 20, 30]))
mutators = process_evaluator_mutations(evaluator, [])
assert len(mutators) == 3
init_model = Model(_internal=True)
init_model.evaluator = evaluator
samplers = [RandomSampler() for _ in range(3)]
for _ in range(10):
model = _apply_all_mutators(init_model, mutators, samplers)
a, v = model.evaluator.trace_kwargs['t'].trace_kwargs['a'], model.evaluator.trace_kwargs['v']
assert v % 10 == a
assert a in [1, 2, 3]
assert v // 10 in [1, 2, 3]
def test_retiarii_nn_import(self): def test_retiarii_nn_import(self):
dummy = torch.zeros(1, 16, 32, 24) dummy = torch.zeros(1, 16, 32, 24)
nn.init.uniform_(dummy) nn.init.uniform_(dummy)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment