"docs/vscode:/vscode.git/clone" did not exist on "d28aa8a9cced3158e724585d5e6839947ca5c449"
Unverified Commit 2f3c3951 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Pure-python execution engine (#3605)

parent 122b5b89
Advanced Tutorial
=================
This document includes two parts. The first part explains the design decision of ``@basic_unit`` and ``serializer``. The second part is the tutorial of how to write a model space with mutators.
Pure-python execution engine (experimental)
-------------------------------------------
If you are experiencing issues with TorchScript, or the generated model code by Retiarii, there is another execution engine called Pure-python execution engine which doesn't need the code-graph conversion. This should generally not affect models and strategies in most cases, but customized mutation might not be supported.
This will come as the default execution engine in future version of Retiarii.
Two steps are needed to enable this engine now.
1. Add ``@nni.retiarii.model_wrapper`` decorator outside the whole PyTorch model.
2. Add ``config.execution_engine = 'py'`` to ``RetiariiExeConfig``.
``@basic_unit`` and ``serializer``
----------------------------------
......
......@@ -5,4 +5,4 @@ from .operation import Operation
from .graph import *
from .execution import *
from .mutator import *
from .serializer import basic_unit, json_dump, json_dumps, json_load, json_loads, serialize, serialize_cls
from .serializer import basic_unit, json_dump, json_dumps, json_load, json_loads, serialize, serialize_cls, model_wrapper
......@@ -15,19 +15,18 @@ __all__ = ['get_execution_engine', 'get_and_register_default_listener',
'list_models', 'submit_models', 'wait_models', 'query_available_resources',
'set_execution_engine', 'is_stopped_exec', 'budget_exhausted']
def set_execution_engine(engine) -> None:
def set_execution_engine(engine: AbstractExecutionEngine) -> None:
global _execution_engine
if _execution_engine is None:
_execution_engine = engine
else:
raise RuntimeError('execution engine is already set')
raise RuntimeError('Execution engine is already set.')
def get_execution_engine() -> AbstractExecutionEngine:
"""
Currently we assume the default execution engine is BaseExecutionEngine.
"""
global _execution_engine
assert _execution_engine is not None, 'You need to set execution engine, before using it.'
return _execution_engine
......
......@@ -5,7 +5,7 @@ import logging
import os
import random
import string
from typing import Dict, Iterable, List
from typing import Any, Dict, Iterable, List
from .interface import AbstractExecutionEngine, AbstractGraphListener
from .. import codegen, utils
......@@ -59,7 +59,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
def submit_models(self, *models: Model) -> None:
for model in models:
data = BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator)
data = self.pack_model_data(model)
self._running_models[send_trial(data.dump())] = model
self._history.append(model)
......@@ -108,6 +108,10 @@ class BaseExecutionEngine(AbstractExecutionEngine):
advisor = get_advisor()
return advisor.stopping
@classmethod
def pack_model_data(cls, model: Model) -> Any:
return BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator)
@classmethod
def trial_execute_graph(cls) -> None:
"""
......
from typing import Dict, Any, List
from ..graph import Evaluator, Model
from ..integration_api import receive_trial_parameters
from ..utils import ContextStack, import_, get_importable_name
from .base import BaseExecutionEngine
class PythonGraphData:
def __init__(self, class_name: str, init_parameters: Dict[str, Any],
mutation: Dict[str, Any], evaluator: Evaluator) -> None:
self.class_name = class_name
self.init_parameters = init_parameters
self.mutation = mutation
self.evaluator = evaluator
def dump(self) -> dict:
return {
'class_name': self.class_name,
'init_parameters': self.init_parameters,
'mutation': self.mutation,
'evaluator': self.evaluator
}
@staticmethod
def load(data) -> 'PythonGraphData':
return PythonGraphData(data['class_name'], data['init_parameters'], data['mutation'], data['evaluator'])
class PurePythonExecutionEngine(BaseExecutionEngine):
@classmethod
def pack_model_data(cls, model: Model) -> Any:
mutation = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model.history}
graph_data = PythonGraphData(get_importable_name(model.python_class, relocate_module=True),
model.python_init_params, mutation, model.evaluator)
return graph_data
@classmethod
def trial_execute_graph(cls) -> None:
graph_data = PythonGraphData.load(receive_trial_parameters())
class _model(import_(graph_data.class_name)):
def __init__(self):
super().__init__(**graph_data.init_parameters)
with ContextStack('fixed', graph_data.mutation):
graph_data.evaluator._execute(_model)
def _unpack_if_only_one(ele: List[Any]):
if len(ele) == 1:
return ele[0]
return ele
......@@ -28,11 +28,11 @@ from nni.tools.nnictl.command_utils import kill_command
from ..codegen import model_to_pytorch_script
from ..converter import convert_to_graph
from ..execution import list_models
from ..execution import list_models, set_execution_engine
from ..graph import Model, Evaluator
from ..integration import RetiariiAdvisor
from ..mutator import Mutator
from ..nn.pytorch.mutator import process_inline_mutation
from ..nn.pytorch.mutator import process_inline_mutation, extract_mutation_from_pt_module
from ..strategy import BaseStrategy
from ..oneshot.interface import BaseOneShotTrainer
......@@ -43,7 +43,7 @@ _logger = logging.getLogger(__name__)
class RetiariiExeConfig(ConfigBase):
experiment_name: Optional[str] = None
search_space: Any = '' # TODO: remove
trial_command: str = 'python3 -m nni.retiarii.trial_entry'
trial_command: str = '_reserved'
trial_code_directory: PathLike = '.'
trial_concurrency: int
trial_gpu_number: int = 0
......@@ -55,21 +55,26 @@ class RetiariiExeConfig(ConfigBase):
experiment_working_directory: PathLike = '~/nni-experiments'
# remove configuration of tuner/assessor/advisor
training_service: TrainingServiceConfig
execution_engine: str = 'base'
def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if training_service_platform is not None:
assert 'training_service' not in kwargs
self.training_service = util.training_service_config_factory(platform = training_service_platform)
self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry base'
def __setattr__(self, key, value):
fixed_attrs = {'search_space': '',
'trial_command': 'python3 -m nni.retiarii.trial_entry'}
'trial_command': '_reserved'}
if key in fixed_attrs and fixed_attrs[key] != value:
raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!')
# 'trial_code_directory' is handled differently because the path will be converted to absolute path by us
if key == 'trial_code_directory' and not (value == Path('.') or os.path.isabs(value)):
raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!')
if key == 'execution_engine':
assert value in ['base', 'py', 'cgo'], f'The specified execution engine "{value}" is not supported.'
self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry ' + value
self.__dict__[key] = value
def validate(self, initialized_tuner: bool = False) -> None:
......@@ -100,23 +105,27 @@ _validation_rules = {
'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
}
def preprocess_model(base_model, trainer, applied_mutators):
def preprocess_model(base_model, trainer, applied_mutators, full_ir=True):
# TODO: this logic might need to be refactored into execution engine
if full_ir:
try:
script_module = torch.jit.script(base_model)
except Exception as e:
_logger.error('Your base model cannot be parsed by torch.jit.script, please fix the following error:')
raise e
base_model_ir = convert_to_graph(script_module, base_model)
base_model_ir.evaluator = trainer
# handle inline mutations
mutators = process_inline_mutation(base_model_ir)
if mutators is not None and applied_mutators:
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, '
'do not use mutators when you use LayerChoice/InputChoice')
if mutators is not None:
applied_mutators = mutators
return base_model_ir, applied_mutators
else:
base_model_ir, mutators = extract_mutation_from_pt_module(base_model)
base_model_ir.evaluator = trainer
if mutators is not None and applied_mutators:
raise RuntimeError('Have not supported mixed usage of LayerChoice/InputChoice and mutators, '
'do not use mutators when you use LayerChoice/InputChoice')
if mutators is not None:
applied_mutators = mutators
return base_model_ir, applied_mutators
def debug_mutated_model(base_model, trainer, applied_mutators):
"""
......@@ -160,7 +169,8 @@ class RetiariiExperiment(Experiment):
self._pipe: Optional[Pipe] = None
def _start_strategy(self):
base_model_ir, self.applied_mutators = preprocess_model(self.base_model, self.trainer, self.applied_mutators)
base_model_ir, self.applied_mutators = preprocess_model(
self.base_model, self.trainer, self.applied_mutators, full_ir=self.config.execution_engine != 'py')
_logger.info('Start strategy...')
self.strategy.run(base_model_ir, self.applied_mutators)
......@@ -182,6 +192,18 @@ class RetiariiExperiment(Experiment):
"""
atexit.register(self.stop)
# we will probably need a execution engine factory to make this clean and elegant
if self.config.execution_engine == 'base':
from ..execution.base import BaseExecutionEngine
engine = BaseExecutionEngine()
elif self.config.execution_engine == 'cgo':
from ..execution.cgo_engine import CGOExecutionEngine
engine = CGOExecutionEngine()
elif self.config.execution_engine == 'py':
from ..execution.python import PurePythonExecutionEngine
engine = PurePythonExecutionEngine()
set_execution_engine(engine)
self.id = management.generate_experiment_id()
if self.config.experiment_working_directory is not None:
......
......@@ -9,12 +9,12 @@ import abc
import copy
import json
from enum import Enum
from typing import (Any, Dict, Iterable, List, Optional, Tuple, Union, overload)
from typing import (Any, Dict, Iterable, List, Optional, Tuple, Type, Union, overload)
from .operation import Cell, Operation, _IOPseudoOperation
from .utils import get_importable_name, import_, uid
__all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'IllegalGraphError', 'MetricData']
__all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'Mutation', 'IllegalGraphError', 'MetricData']
MetricData = Any
......@@ -80,6 +80,10 @@ class Model:
Attributes
----------
python_class
Python class that base model is converted from.
python_init_params
Initialization parameters of python class.
status
See `ModelStatus`.
root_graph
......@@ -102,6 +106,8 @@ class Model:
def __init__(self, _internal=False):
assert _internal, '`Model()` is private, use `model.fork()` instead'
self.model_id: int = uid('model')
self.python_class: Optional[Type] = None
self.python_init_params: Optional[Dict[str, Any]] = None
self.status: ModelStatus = ModelStatus.Mutating
......@@ -116,7 +122,8 @@ class Model:
def __repr__(self):
return f'Model(model_id={self.model_id}, status={self.status}, graphs={list(self.graphs.keys())}, ' + \
f'evaluator={self.evaluator}, metric={self.metric}, intermediate_metrics={self.intermediate_metrics})'
f'evaluator={self.evaluator}, metric={self.metric}, intermediate_metrics={self.intermediate_metrics}, ' + \
f'python_class={self.python_class})'
@property
def root_graph(self) -> 'Graph':
......@@ -133,9 +140,12 @@ class Model:
"""
new_model = Model(_internal=True)
new_model._root_graph_name = self._root_graph_name
new_model.python_class = self.python_class
new_model.python_init_params = self.python_init_params
new_model.graphs = {name: graph._fork_to(new_model) for name, graph in self.graphs.items()}
new_model.evaluator = copy.deepcopy(self.evaluator) # TODO this may be a problem when evaluator is large
new_model.history = self.history + [self]
new_model.history = [*self.history]
# Note: the history is not updated. It will be updated when the model is changed, that is in mutator.
return new_model
@staticmethod
......@@ -167,8 +177,8 @@ class Model:
def get_nodes_by_label(self, label: str) -> List['Node']:
"""
Traverse all the nodes to find the matched node(s) with the given name.
There could be multiple nodes with the same name. Name space name can uniquely
Traverse all the nodes to find the matched node(s) with the given label.
There could be multiple nodes with the same label. Name space name can uniquely
identify a graph or node.
NOTE: the implementation does not support the class abstration
......@@ -493,6 +503,8 @@ class Node:
If two models have nodes with same ID, they are semantically the same node.
name
Mnemonic name. It should have an one-to-one mapping with ID.
label
Optional. If two nodes have the same label, they are considered same by the mutator.
operation
...
cell
......@@ -515,7 +527,7 @@ class Node:
# TODO: the operation is likely to be considered editable by end-user and it will be hard to debug
# maybe we should copy it here or make Operation class immutable, in next release
self.operation: Operation = operation
self.label: str = None
self.label: Optional[str] = None
def __repr__(self):
return f'Node(id={self.id}, name={self.name}, label={self.label}, operation={self.operation})'
......@@ -673,6 +685,37 @@ class Edge:
}
class Mutation:
"""
An execution of mutation, which consists of four parts: a mutator, a list of decisions (choices),
the model that it comes from, and the model that it becomes.
In general cases, the mutation logs are not reliable and should not be replayed as the mutators can
be arbitrarily complex. However, for inline mutations, the labels correspond to mutator labels here,
this can be useful for metadata visualization and python execution mode.
Attributes
----------
mutator
Mutator.
samples
Decisions/choices.
from_
Model that is comes from.
to
Model that it becomes.
"""
def __init__(self, mutator: 'Mutator', samples: List[Any], from_: Model, to: Model): # noqa: F821
self.mutator: 'Mutator' = mutator # noqa: F821
self.samples: List[Any] = samples
self.from_: Model = from_
self.to: Model = to
def __repr__(self):
return f'Edge(mutator={self.mutator}, samples={self.samples}, from={self.from_}, to={self.to})'
class IllegalGraphError(ValueError):
def __init__(self, graph, *args):
self._debug_dump_graph(graph)
......
......@@ -2,7 +2,6 @@
# Licensed under the MIT license.
import logging
import os
from typing import Any, Callable
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
......@@ -10,9 +9,6 @@ from nni.runtime.protocol import CommandType, send
from nni.utils import MetricType
from .graph import MetricData
from .execution.base import BaseExecutionEngine
from .execution.cgo_engine import CGOExecutionEngine
from .execution.api import set_execution_engine
from .integration_api import register_advisor
from .serializer import json_dumps, json_loads
......@@ -62,15 +58,6 @@ class RetiariiAdvisor(MsgDispatcherBase):
self.parameters_count = 0
engine = self._create_execution_engine()
set_execution_engine(engine)
def _create_execution_engine(self):
if os.environ.get('CGO') == 'true':
return CGOExecutionEngine()
else:
return BaseExecutionEngine()
def handle_initialize(self, data):
"""callback for initializing the advisor
Parameters
......
......@@ -3,7 +3,7 @@
from typing import (Any, Iterable, List, Optional)
from .graph import Model
from .graph import Model, Mutation, ModelStatus
__all__ = ['Sampler', 'Mutator']
......@@ -40,10 +40,13 @@ class Mutator:
and then use `Mutator.apply()` to mutate model.
For certain mutator subclasses, strategy or sampler can use `Mutator.dry_run()` to predict choice candidates.
# Method names are open for discussion.
If mutator has a label, in most cases, it means that this mutator is applied to nodes with this label.
"""
def __init__(self, sampler: Optional[Sampler] = None):
def __init__(self, sampler: Optional[Sampler] = None, label: Optional[str] = None):
self.sampler: Optional[Sampler] = sampler
self.label: Optional[str] = label
self._cur_model: Optional[Model] = None
self._cur_choice_idx: Optional[int] = None
......@@ -64,9 +67,12 @@ class Mutator:
copy = model.fork()
self._cur_model = copy
self._cur_choice_idx = 0
self._cur_samples = []
self.sampler.mutation_start(self, copy)
self.mutate(copy)
self.sampler.mutation_end(self, copy)
copy.history.append(Mutation(self, self._cur_samples, model, copy))
copy.status = ModelStatus.Frozen
self._cur_model = None
self._cur_choice_idx = None
return copy
......@@ -97,6 +103,7 @@ class Mutator:
"""
assert self.sampler is not None and self._cur_model is not None and self._cur_choice_idx is not None
ret = self.sampler.choice(list(candidates), self, self._cur_model, self._cur_choice_idx)
self._cur_samples.append(ret)
self._cur_choice_idx += 1
return ret
......
......@@ -4,18 +4,32 @@
import copy
import warnings
from collections import OrderedDict
from typing import Any, List, Union, Dict
from typing import Any, List, Union, Dict, Optional
import torch
import torch.nn as nn
from ...serializer import Translatable, basic_unit
from ...utils import uid
from ...utils import uid, get_current_context
__all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'Placeholder', 'ChosenInputs']
def _generate_new_label(label: Optional[str]):
if label is None:
return '_mutation_' + str(uid('mutation'))
return label
def _get_fixed_value(label: str):
ret = get_current_context('fixed')
try:
return ret[_generate_new_label(label)]
except KeyError:
raise KeyError(f'Fixed context with {label} not found. Existing values are: {ret}')
class LayerChoice(nn.Module):
"""
Layer choice selects one of the ``candidates``, then apply it on inputs and return results.
......@@ -55,6 +69,16 @@ class LayerChoice(nn.Module):
``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet.
"""
def __new__(cls, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: str = None, **kwargs):
try:
chosen = _get_fixed_value(label)
if isinstance(candidates, list):
return candidates[int(chosen)]
else:
return candidates[chosen]
except AssertionError:
return super().__new__(cls)
def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: str = None, **kwargs):
super(LayerChoice, self).__init__()
if 'key' in kwargs:
......@@ -65,7 +89,7 @@ class LayerChoice(nn.Module):
if 'reduction' in kwargs:
warnings.warn(f'"reduction" is deprecated. Ignoring...')
self.candidates = candidates
self._label = label if label is not None else f'layerchoice_{uid()}'
self._label = _generate_new_label(label)
self.names = []
if isinstance(candidates, OrderedDict):
......@@ -163,6 +187,12 @@ class InputChoice(nn.Module):
Identifier of the input choice.
"""
def __new__(cls, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: str = None, **kwargs):
try:
return ChosenInputs(_get_fixed_value(label), reduction=reduction)
except AssertionError:
return super().__new__(cls)
def __init__(self, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: str = None, **kwargs):
super(InputChoice, self).__init__()
if 'key' in kwargs:
......@@ -176,7 +206,7 @@ class InputChoice(nn.Module):
self.n_chosen = n_chosen
self.reduction = reduction
assert self.reduction in ['mean', 'concat', 'sum', 'none']
self._label = label if label is not None else f'inputchoice_{uid()}'
self._label = _generate_new_label(label)
@property
def key(self):
......@@ -265,10 +295,16 @@ class ValueChoice(Translatable, nn.Module):
Identifier of the value choice.
"""
def __new__(cls, candidates: List[Any], label: str = None):
try:
return _get_fixed_value(label)
except AssertionError:
return super().__new__(cls)
def __init__(self, candidates: List[Any], label: str = None):
super().__init__()
self.candidates = candidates
self._label = label if label is not None else f'valuechoice_{uid()}'
self._label = _generate_new_label(label)
self._accessor = []
@property
......@@ -297,6 +333,14 @@ class ValueChoice(Translatable, nn.Module):
raise KeyError(''.join([f'[{a}]' for a in self._accessor]) + f' does not work on {value}')
return v
def __copy__(self):
return self
def __deepcopy__(self, memo):
new_item = ValueChoice(self.candidates, self.label)
new_item._accessor = [*self._accessor]
return new_item
def __getitem__(self, item):
"""
Get a sub-element of value choice.
......@@ -331,9 +375,9 @@ class ChosenInputs(nn.Module):
The already-chosen version of InputChoice.
"""
def __init__(self, chosen: List[int], reduction: str):
def __init__(self, chosen: Union[List[int], int], reduction: str):
super().__init__()
self.chosen = chosen
self.chosen = chosen if isinstance(chosen, list) else [chosen]
self.reduction = reduction
def forward(self, candidate_inputs):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import inspect
from typing import Any, List, Optional, Tuple
import torch.nn as nn
from ...mutator import Mutator
from ...graph import Cell, Model, Node
from .api import ValueChoice
from ...graph import Cell, Graph, Model, ModelStatus, Node
from ...utils import uid
from .api import LayerChoice, InputChoice, ValueChoice, Placeholder
class LayerChoiceMutator(Mutator):
......@@ -40,7 +44,7 @@ class InputChoiceMutator(Mutator):
def mutate(self, model):
n_candidates = self.nodes[0].operation.parameters['n_candidates']
n_chosen = self.nodes[0].operation.parameters['n_chosen']
n_chosen = self.nodes[0].operation.parameters['n_chosen']
candidates = list(range(n_candidates))
chosen = [self.choice(candidates) for _ in range(n_chosen)]
for node in self.nodes:
......@@ -116,12 +120,96 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
mutator = LayerChoiceMutator(node_list)
applied_mutators.append(mutator)
if applied_mutators:
return applied_mutators
return None
# The following are written for pure-python mode
class ManyChooseManyMutator(Mutator):
"""
Choose based on labels. Will not affect the model itself.
"""
def __init__(self, label: Optional[str]):
super().__init__(label=label)
@staticmethod
def candidates(node):
if 'n_candidates' in node.operation.parameters:
return list(range(node.operation.parameters['n_candidates']))
else:
return node.operation.parameters['candidates']
@staticmethod
def number_of_chosen(node):
if 'n_chosen' in node.operation.parameters:
return node.operation.parameters['n_chosen']
return 1
def mutate(self, model: Model):
# this mutate does not have any effect, but it is recorded in the mutation history
for node in model.get_nodes_by_label(self.label):
for _ in range(self.number_of_chosen(node)):
self.choice(self.candidates(node))
break
def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Optional[List[Mutator]]]:
model = Model(_internal=True)
graph = Graph(model, uid(), '_model', _internal=True)._register()
model.python_class = pytorch_model.__class__
if len(inspect.signature(model.python_class.__init__).parameters) > 1:
if not hasattr(pytorch_model, '_init_parameters'):
raise ValueError('Please annotate the model with @serialize decorator in python execution mode '
'if your model has init parameters.')
model.python_init_params = pytorch_model._init_parameters
else:
model.python_init_params = {}
for name, module in pytorch_model.named_modules():
# tricky case: value choice that serves as parameters are stored in _init_parameters
if hasattr(module, '_init_parameters'):
for key, value in module._init_parameters.items():
if isinstance(value, ValueChoice):
node = graph.add_node(name + '.init.' + key, 'ValueChoice', {'candidates': value.candidates})
node.label = value.label
if isinstance(module, (LayerChoice, InputChoice, ValueChoice)):
# TODO: check the label of module and warn if it's auto-generated
pass
if isinstance(module, LayerChoice):
node = graph.add_node(name, 'LayerChoice', {'candidates': module.names})
node.label = module.label
if isinstance(module, InputChoice):
node = graph.add_node(name, 'InputChoice',
{'n_candidates': module.n_candidates, 'n_chosen': module.n_chosen})
node.label = module.label
if isinstance(module, ValueChoice):
node = graph.add_node(name, 'ValueChoice', {'candidates': module.candidates})
node.label = module.label
if isinstance(module, Placeholder):
raise NotImplementedError('Placeholder is not supported in python execution mode.')
model.status = ModelStatus.Frozen
if not graph.hidden_nodes:
return model, None
mutators = []
for nodes in _group_by_label_and_type(graph.hidden_nodes):
assert _is_all_equal(map(lambda n: n.operation.type, nodes)), \
f'Node with label "{nodes[0].label}" does not all have the same type.'
assert _is_all_equal(map(lambda n: n.operation.parameters, nodes)), \
f'Node with label "{nodes[0].label}" does not agree on parameters.'
mutators.append(ManyChooseManyMutator(nodes[0].label))
return model, mutators
# utility functions
def _is_all_equal(lst):
last = None
for x in lst:
......@@ -131,6 +219,16 @@ def _is_all_equal(lst):
return True
def _group_by_label_and_type(nodes: List[Node]) -> List[List[Node]]:
result = {}
for node in nodes:
key = (node.label, node.operation.type)
if key not in result:
result[key] = []
result[key].append(node)
return list(result.values())
def _group_by_label(nodes: List[Node]) -> List[List[Node]]:
result = {}
for node in nodes:
......
......@@ -9,7 +9,7 @@ from typing import Any
import json_tricks
from .utils import get_importable_name, get_module_name, import_
from .utils import get_importable_name, get_module_name, import_, reset_uid
def get_init_parameters_or_fail(obj, silently=False):
......@@ -83,9 +83,11 @@ class Translatable(abc.ABC):
pass
def _create_wrapper_cls(cls, store_init_parameters=True):
def _create_wrapper_cls(cls, store_init_parameters=True, reset_mutation_uid=False):
class wrapper(cls):
def __init__(self, *args, **kwargs):
if reset_mutation_uid:
reset_uid('mutation')
if store_init_parameters:
argname_list = list(inspect.signature(cls.__init__).parameters.keys())[1:]
full_args = {}
......@@ -149,3 +151,15 @@ def basic_unit(cls):
import torch.nn as nn
assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.'
return serialize_cls(cls)
def model_wrapper(cls):
"""
Wrap the model if you are using pure-python execution engine.
The wrapper serves two purposes:
1. Capture the init parameters of python class so that it can be re-instantiated in another process.
2. Reset uid in `mutation` namespace so that each model counts from zero. Can be useful in unittest and other multi-model scenarios.
"""
return _create_wrapper_cls(cls, reset_mutation_uid=True)
......@@ -6,13 +6,20 @@ Entrypoint for trials.
Assuming execution engine is BaseExecutionEngine.
"""
import os
import argparse
from .execution.base import BaseExecutionEngine
from .execution.cgo_engine import CGOExecutionEngine
if __name__ == '__main__':
if os.environ.get('CGO') == 'true':
CGOExecutionEngine.trial_execute_graph()
else:
BaseExecutionEngine.trial_execute_graph()
parser = argparse.ArgumentParser()
parser.add_argument('exec', choices=['base', 'py', 'cgo'])
args = parser.parse_args()
if args.exec == 'base':
from .execution.base import BaseExecutionEngine
engine = BaseExecutionEngine
elif args.exec == 'cgo':
from .execution.cgo_engine import CGOExecutionEngine
engine = CGOExecutionEngine
elif args.exec == 'py':
from .execution.python import PurePythonExecutionEngine
engine = PurePythonExecutionEngine
engine.trial_execute_graph()
......@@ -4,7 +4,7 @@
import inspect
import warnings
from collections import defaultdict
from typing import Any
from typing import Any, List, Dict
from pathlib import Path
......@@ -31,6 +31,10 @@ def uid(namespace: str = 'default') -> int:
return _last_uid[namespace]
def reset_uid(namespace: str = 'default') -> None:
_last_uid[namespace] = 0
def get_module_name(cls_or_func):
module_name = cls_or_func.__module__
if module_name == '__main__':
......@@ -61,3 +65,42 @@ def get_module_name(cls_or_func):
def get_importable_name(cls, relocate_module=False):
module_name = get_module_name(cls) if relocate_module else cls.__module__
return module_name + '.' + cls.__name__
class ContextStack:
"""
This is to maintain a globally-accessible context envinronment that is visible to everywhere.
Use ``with ContextStack(namespace, value):`` to initiate, and use ``get_current_context(namespace)`` to
get the corresponding value in the namespace.
"""
_stack: Dict[str, List[Any]] = defaultdict(list)
def __init__(self, key: str, value: Any):
self.key = key
self.value = value
def __enter__(self):
self.push(self.key, self.value)
return self
def __exit__(self, *args, **kwargs):
self.pop(self.key)
@classmethod
def push(cls, key: str, value: Any):
cls._stack[key].append(value)
@classmethod
def pop(cls, key: str) -> None:
cls._stack[key].pop()
@classmethod
def top(cls, key: str) -> Any:
assert cls._stack[key], 'Context is empty.'
return cls._stack[key][-1]
def get_current_context(key: str) -> Any:
return ContextStack.top(key)
......@@ -3,22 +3,24 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch
class _model(nn.Module):
def __init__(self):
super().__init__()
self.stem = stem()
self.fc1 = nn.Linear(1024, 256)
self.fc2 = nn.Linear(256, 10)
self.flatten = torch.nn.Flatten()
self.fc1 = torch.nn.Linear(out_features=256, in_features=1024)
self.fc2 = torch.nn.Linear(out_features=10, in_features=256)
self.softmax = torch.nn.Softmax()
def forward(self, image):
stem = self.stem(image)
flatten = stem.view(stem.size(0), -1)
flatten = self.flatten(stem)
fc1 = self.fc1(flatten)
fc2 = self.fc2(fc1)
softmax = F.softmax(fc2, -1)
softmax = self.softmax(fc2)
return softmax
......@@ -26,10 +28,10 @@ class _model(nn.Module):
class stem(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(out_channels=32, in_channels=1, kernel_size=5)
self.pool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
self.pool2 = nn.MaxPool2d(kernel_size=2)
self.conv1 = torch.nn.Conv2d(out_channels=32, in_channels=1, kernel_size=5)
self.pool1 = torch.nn.MaxPool2d(kernel_size=2)
self.conv2 = torch.nn.Conv2d(out_channels=64, in_channels=32, kernel_size=5)
self.pool2 = torch.nn.MaxPool2d(kernel_size=2)
def forward(self, *_inputs):
conv1 = self.conv1(_inputs[0])
......
......@@ -5,10 +5,10 @@
"nodes": {
"stem": {"operation": {"type": "_cell", "cell_name": "stem"}},
"flatten": {"operation": {"type": "Flatten"}},
"fc1": {"operation": {"type": "Dense", "parameters": {"out_features": 256, "in_features": 1024}}},
"fc2": {"operation": {"type": "Dense", "parameters": {"out_features": 10, "in_features": 256}}},
"softmax": {"operation": {"type": "Softmax"}}
"flatten": {"operation": {"type": "__torch__.torch.nn.Flatten"}},
"fc1": {"operation": {"type": "__torch__.torch.nn.Linear", "parameters": {"out_features": 256, "in_features": 1024}}},
"fc2": {"operation": {"type": "__torch__.torch.nn.Linear", "parameters": {"out_features": 10, "in_features": 256}}},
"softmax": {"operation": {"type": "__torch__.torch.nn.Softmax"}}
},
"edges": [
......@@ -23,10 +23,10 @@
"stem": {
"nodes": {
"conv1": {"operation": {"type": "__torch__.Conv2d", "parameters": {"out_channels": 32, "in_channels": 1, "kernel_size": 5}}},
"pool1": {"operation": {"type": "__torch__.MaxPool2d", "parameters": {"kernel_size": 2}}},
"conv2": {"operation": {"type": "__torch__.Conv2d", "parameters": {"out_channels": 64, "in_channels": 32, "kernel_size": 5}}},
"pool2": {"operation": {"type": "__torch__.MaxPool2d", "parameters": {"kernel_size": 2}}}
"conv1": {"operation": {"type": "__torch__.torch.nn.Conv2d", "parameters": {"out_channels": 32, "in_channels": 1, "kernel_size": 5}}},
"pool1": {"operation": {"type": "__torch__.torch.nn.MaxPool2d", "parameters": {"kernel_size": 2}}},
"conv2": {"operation": {"type": "__torch__.torch.nn.Conv2d", "parameters": {"out_channels": 64, "in_channels": 32, "kernel_size": 5}}},
"pool2": {"operation": {"type": "__torch__.torch.nn.MaxPool2d", "parameters": {"kernel_size": 2}}}
},
"edges": [
......@@ -36,26 +36,5 @@
{"head": ["conv2", null], "tail": ["pool2", null]},
{"head": ["pool2", null], "tail": ["_outputs", 0]}
]
},
"_evaluator": {
"module": "nni.retiarii.trainer.PyTorchImageClassificationTrainer",
"kwargs": {
"dataset_cls": "MNIST",
"dataset_kwargs": {
"root": "data/mnist",
"download": true
},
"dataloader_kwargs": {
"batch_size": 32
},
"optimizer_cls" : "SGD",
"optimizer_kwargs": {
"lr": 1e-3
},
"trainer_kwargs": {
"max_epochs": 1
}
}
}
}
import json
import os
import sys
import threading
import unittest
from pathlib import Path
import nni
import nni.retiarii
from nni.retiarii import Model, submit_models
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.integration import RetiariiAdvisor, register_advisor
from nni.retiarii.evaluator.pytorch import PyTorchImageClassificationTrainer
from nni.retiarii.utils import import_
from nni.retiarii.execution import set_execution_engine
from nni.retiarii.execution.base import BaseExecutionEngine
from nni.retiarii.execution.python import PurePythonExecutionEngine
from nni.retiarii.integration import RetiariiAdvisor
@unittest.skip('Skipped in this version')
class CodeGenTest(unittest.TestCase):
def test_mnist_example_pytorch(self):
with open('mnist_pytorch.json') as f:
class EngineTest(unittest.TestCase):
def test_codegen(self):
with open(self.enclosing_dir / 'mnist_pytorch.json') as f:
model = Model._load(json.load(f))
script = model_to_pytorch_script(model)
with open('debug_mnist_pytorch.py') as f:
with open(self.enclosing_dir / 'debug_mnist_pytorch.py') as f:
reference_script = f.read()
self.assertEqual(script.strip(), reference_script.strip())
def test_base_execution_engine(self):
advisor = RetiariiAdvisor()
set_execution_engine(BaseExecutionEngine())
with open(self.enclosing_dir / 'mnist_pytorch.json') as f:
model = Model._load(json.load(f))
submit_models(model, model)
@unittest.skip('Skipped in this version')
class TrainerTest(unittest.TestCase):
def test_trainer(self):
sys.path.insert(0, Path(__file__).parent.as_posix())
Model = import_('debug_mnist_pytorch._model')
trainer = PyTorchImageClassificationTrainer(
Model(),
dataset_kwargs={'root': (Path(__file__).parent / 'data' / 'mnist').as_posix(), 'download': True},
dataloader_kwargs={'batch_size': 32},
optimizer_kwargs={'lr': 1e-3},
trainer_kwargs={'max_epochs': 1}
)
trainer.fit()
@unittest.skip('Skipped in this version')
class EngineTest(unittest.TestCase):
advisor.stopping = True
advisor.default_worker.join()
advisor.assessor_worker.join()
def test_submit_models(self):
os.makedirs('generated', exist_ok=True)
from nni.runtime import protocol
protocol._out_file = open(Path(__file__).parent / 'generated/debug_protocol_out_file.py', 'wb')
def test_py_execution_engine(self):
advisor = RetiariiAdvisor()
with open('mnist_pytorch.json') as f:
model = Model._load(json.load(f))
set_execution_engine(PurePythonExecutionEngine())
model = Model._load({
'_model': {
'inputs': None,
'outputs': None,
'nodes': {
'layerchoice_1': {
'operation': {'type': 'LayerChoice', 'parameters': {'candidates': ['0', '1']}}
}
},
'edges': []
}
})
model.python_class = object
submit_models(model, model)
advisor.stopping = True
advisor.default_worker.join()
advisor.assessor_worker.join()
def test_execution_engine(self):
pass
def setUp(self) -> None:
self.enclosing_dir = Path(__file__).parent
os.makedirs(self.enclosing_dir / 'generated', exist_ok=True)
from nni.runtime import protocol
protocol._out_file = open(self.enclosing_dir / 'generated/debug_protocol_out_file.py', 'wb')
def tearDown(self) -> None:
from nni.runtime import protocol
protocol._out_file.close()
nni.retiarii.execution.api._execution_engine = None
nni.retiarii.integration_api._advisor = None
......@@ -8,7 +8,10 @@ import torch.nn.functional as F
from nni.retiarii import Sampler, basic_unit
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.nn.pytorch.mutator import process_inline_mutation
from nni.retiarii.execution.python import _unpack_if_only_one
from nni.retiarii.nn.pytorch.mutator import process_inline_mutation, extract_mutation_from_pt_module
from nni.retiarii.serializer import model_wrapper
from nni.retiarii.utils import ContextStack
class EnumerateSampler(Sampler):
......@@ -44,7 +47,7 @@ class MutableConv(nn.Module):
return self.conv2(x)
class TestHighLevelAPI(unittest.TestCase):
class GraphIR(unittest.TestCase):
def _convert_to_ir(self, model):
script_module = torch.jit.script(model)
......@@ -56,7 +59,19 @@ class TestHighLevelAPI(unittest.TestCase):
exec(model_code + '\n\nconverted_model = _model()', exec_vars)
return exec_vars['converted_model']
def _get_model_with_mutators(self, pytorch_model):
model = self._convert_to_ir(pytorch_model)
mutators = process_inline_mutation(model)
return model, mutators
def get_serializer(self):
def dummy(cls):
return cls
return dummy
def test_layer_choice(self):
@self.get_serializer()
class Net(nn.Module):
def __init__(self):
super().__init__()
......@@ -68,8 +83,7 @@ class TestHighLevelAPI(unittest.TestCase):
def forward(self, x):
return self.module(x)
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
......@@ -80,6 +94,7 @@ class TestHighLevelAPI(unittest.TestCase):
torch.Size([1, 5, 3, 3]))
def test_input_choice(self):
@self.get_serializer()
class Net(nn.Module):
def __init__(self):
super().__init__()
......@@ -92,8 +107,7 @@ class TestHighLevelAPI(unittest.TestCase):
x2 = self.conv2(x)
return self.input([x1, x2])
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
......@@ -104,6 +118,7 @@ class TestHighLevelAPI(unittest.TestCase):
torch.Size([1, 5, 3, 3]))
def test_chosen_inputs(self):
@self.get_serializer()
class Net(nn.Module):
def __init__(self, reduction):
super().__init__()
......@@ -117,8 +132,7 @@ class TestHighLevelAPI(unittest.TestCase):
return self.input([x1, x2])
for reduction in ['none', 'sum', 'mean', 'concat']:
model = self._convert_to_ir(Net(reduction))
mutators = process_inline_mutation(model)
model, mutators = self._get_model_with_mutators(Net(reduction))
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model = mutator.apply(model)
......@@ -133,6 +147,7 @@ class TestHighLevelAPI(unittest.TestCase):
self.assertEqual(result.size(), torch.Size([1, 3, 3, 3]))
def test_value_choice(self):
@self.get_serializer()
class Net(nn.Module):
def __init__(self):
super().__init__()
......@@ -142,8 +157,7 @@ class TestHighLevelAPI(unittest.TestCase):
def forward(self, x):
return self.conv(x, self.index())
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
......@@ -154,6 +168,7 @@ class TestHighLevelAPI(unittest.TestCase):
torch.Size([1, 5, 3, 3]))
def test_value_choice_as_parameter(self):
@self.get_serializer()
class Net(nn.Module):
def __init__(self):
super().__init__()
......@@ -162,8 +177,7 @@ class TestHighLevelAPI(unittest.TestCase):
def forward(self, x):
return self.conv(x)
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
......@@ -174,6 +188,7 @@ class TestHighLevelAPI(unittest.TestCase):
torch.Size([1, 5, 1, 1]))
def test_value_choice_as_parameter(self):
@self.get_serializer()
class Net(nn.Module):
def __init__(self):
super().__init__()
......@@ -182,8 +197,7 @@ class TestHighLevelAPI(unittest.TestCase):
def forward(self, x):
return self.conv(x)
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
......@@ -194,6 +208,7 @@ class TestHighLevelAPI(unittest.TestCase):
torch.Size([1, 5, 1, 1]))
def test_value_choice_as_parameter(self):
@self.get_serializer()
class Net(nn.Module):
def __init__(self):
super().__init__()
......@@ -202,8 +217,7 @@ class TestHighLevelAPI(unittest.TestCase):
def forward(self, x):
return self.conv(x)
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 2)
mutators[0].bind_sampler(EnumerateSampler())
mutators[1].bind_sampler(EnumerateSampler())
......@@ -214,6 +228,7 @@ class TestHighLevelAPI(unittest.TestCase):
torch.Size([1, 8, 1, 1]))
def test_value_choice_as_parameter_shared(self):
@self.get_serializer()
class Net(nn.Module):
def __init__(self):
super().__init__()
......@@ -223,8 +238,7 @@ class TestHighLevelAPI(unittest.TestCase):
def forward(self, x):
return self.conv1(x) + self.conv2(x)
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
......@@ -235,6 +249,7 @@ class TestHighLevelAPI(unittest.TestCase):
torch.Size([1, 8, 5, 5]))
def test_value_choice_in_functional(self):
@self.get_serializer()
class Net(nn.Module):
def __init__(self):
super().__init__()
......@@ -243,8 +258,7 @@ class TestHighLevelAPI(unittest.TestCase):
def forward(self, x):
return F.dropout(x, self.dropout_rate())
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
......@@ -254,6 +268,7 @@ class TestHighLevelAPI(unittest.TestCase):
self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
def test_value_choice_in_layer_choice(self):
@self.get_serializer()
class Net(nn.Module):
def __init__(self):
super().__init__()
......@@ -265,8 +280,7 @@ class TestHighLevelAPI(unittest.TestCase):
def forward(self, x):
return self.linear(x)
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 3)
sz_counter = Counter()
sampler = RandomSampler()
......@@ -278,6 +292,7 @@ class TestHighLevelAPI(unittest.TestCase):
self.assertEqual(len(sz_counter), 4)
def test_shared(self):
@self.get_serializer()
class Net(nn.Module):
def __init__(self, shared=True):
super().__init__()
......@@ -294,16 +309,14 @@ class TestHighLevelAPI(unittest.TestCase):
def forward(self, x):
return self.module1(x) + self.module2(x)
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1)
sampler = RandomSampler()
mutator = mutators[0].bind_sampler(sampler)
self.assertEqual(self._get_converted_pytorch_model(mutator.apply(model))(torch.randn(1, 3, 3, 3)).size(0), 1)
self.assertEqual(sampler.counter, 1)
model = self._convert_to_ir(Net(shared=False))
mutators = process_inline_mutation(model)
model, mutators = self._get_model_with_mutators(Net(shared=False))
self.assertEqual(len(mutators), 2)
sampler = RandomSampler()
# repeat test. Expectation: sometimes succeeds, sometimes fails.
......@@ -321,6 +334,7 @@ class TestHighLevelAPI(unittest.TestCase):
self.assertLess(failed_count, 30)
def test_valuechoice_access(self):
@self.get_serializer()
class Net(nn.Module):
def __init__(self):
super().__init__()
......@@ -330,8 +344,7 @@ class TestHighLevelAPI(unittest.TestCase):
def forward(self, x):
return self.conv(x)
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1)
mutators[0].bind_sampler(EnumerateSampler())
input = torch.randn(1, 3, 5, 5)
......@@ -340,6 +353,7 @@ class TestHighLevelAPI(unittest.TestCase):
self.assertEqual(self._get_converted_pytorch_model(mutators[0].apply(model))(input).size(),
torch.Size([1, 8, 1, 1]))
@self.get_serializer()
class Net2(nn.Module):
def __init__(self):
super().__init__()
......@@ -354,14 +368,14 @@ class TestHighLevelAPI(unittest.TestCase):
x = self.conv(x)
return self.conv1(torch.cat((x, x), 1))
model = self._convert_to_ir(Net2())
mutators = process_inline_mutation(model)
model, mutators = self._get_model_with_mutators(Net2())
self.assertEqual(len(mutators), 1)
mutators[0].bind_sampler(EnumerateSampler())
input = torch.randn(1, 3, 5, 5)
self._get_converted_pytorch_model(mutators[0].apply(model))(input)
def test_valuechoice_access_functional(self):
@self.get_serializer()
class Net(nn.Module):
def __init__(self):
super().__init__()
......@@ -370,8 +384,7 @@ class TestHighLevelAPI(unittest.TestCase):
def forward(self, x):
return F.dropout(x, self.dropout_rate()[0])
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
......@@ -381,6 +394,7 @@ class TestHighLevelAPI(unittest.TestCase):
self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
def test_valuechoice_access_functional_expression(self):
@self.get_serializer()
class Net(nn.Module):
def __init__(self):
super().__init__()
......@@ -391,8 +405,7 @@ class TestHighLevelAPI(unittest.TestCase):
# ValueError: dropout probability has to be between 0 and 1, but got 1.05
return F.dropout(x, self.dropout_rate()[0] - .1)
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
......@@ -400,3 +413,29 @@ class TestHighLevelAPI(unittest.TestCase):
self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3))
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3]))
self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
class Python(GraphIR):
def _get_converted_pytorch_model(self, model_ir):
mutation = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model_ir.history}
with ContextStack('fixed', mutation):
model = model_ir.python_class(**model_ir.python_init_params)
return model
def _get_model_with_mutators(self, pytorch_model):
return extract_mutation_from_pt_module(pytorch_model)
def get_serializer(self):
return model_wrapper
@unittest.skip
def test_value_choice(self): ...
@unittest.skip
def test_value_choice_in_functional(self): ...
@unittest.skip
def test_valuechoice_access_functional(self): ...
@unittest.skip
def test_valuechoice_access_functional_expression(self): ...
......@@ -60,7 +60,14 @@ def test_mutation():
model2 = mutator.apply(model1)
assert _get_pools(model2) == (global_pool, max_pool)
assert model2.history == [model0, model1]
assert len(model2.history) == 2
assert model2.history[0].from_ == model0
assert model2.history[0].to == model1
assert model2.history[1].from_ == model1
assert model2.history[1].to == model2
assert model2.history[0].mutator == mutator
assert model2.history[1].mutator == mutator
assert _get_pools(model0) == (max_pool, max_pool)
assert _get_pools(model1) == (avg_pool, global_pool)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment