Commit 60b2a7a3 authored by Yuge Zhang's avatar Yuge Zhang
Browse files

Merge branch 'dev-retiarii' of https://github.com/microsoft/nni into dev-retiarii

parents d6791c2b bcb7633e
""" """
Classes related to Graph IR, except `Operation`. Model representation.
""" """
from __future__ import annotations
import copy import copy
import json
from enum import Enum from enum import Enum
from typing import * import json
from typing import (Any, Dict, List, Optional, Tuple, overload)
from .operation import Cell, Operation, _PseudoOperation from .operation import Cell, Operation, _PseudoOperation
__all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'IllegalGraphError', 'MetricData'] __all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'IllegalGraphError', 'MetricData']
MetricData = NewType('MetricData', Any) MetricData = Any
""" """
Graph metrics like loss, accuracy, etc. Graph metrics like loss, accuracy, etc.
Maybe we can assume this is a single float number for first iteration. # Maybe we can assume this is a single float number for first iteration.
""" """
...@@ -36,7 +34,7 @@ class TrainingConfig: ...@@ -36,7 +34,7 @@ class TrainingConfig:
Trainer keyword arguments Trainer keyword arguments
""" """
def __init__(self, module: str, kwargs: Dict[str, any]): def __init__(self, module: str, kwargs: Dict[str, Any]):
self.module = module self.module = module
self.kwargs = kwargs self.kwargs = kwargs
...@@ -44,7 +42,7 @@ class TrainingConfig: ...@@ -44,7 +42,7 @@ class TrainingConfig:
return f'TrainingConfig(module={self.module}, kwargs={self.kwargs})' return f'TrainingConfig(module={self.module}, kwargs={self.kwargs})'
@staticmethod @staticmethod
def _load(ir: Any) -> TrainingConfig: def _load(ir: Any) -> 'TrainingConfig':
return TrainingConfig(ir['module'], ir.get('kwargs', {})) return TrainingConfig(ir['module'], ir.get('kwargs', {}))
def _dump(self) -> Any: def _dump(self) -> Any:
...@@ -56,15 +54,14 @@ class TrainingConfig: ...@@ -56,15 +54,14 @@ class TrainingConfig:
class Model: class Model:
""" """
Top-level structure of graph IR. Represents a neural network model.
In execution engine's perspective, this is a trainable neural network model.
In mutator's perspective, this is a sandbox for a round of mutation.
Once a round of mutation starts, a sandbox is created and all mutating operations will happen inside. During mutation, one `Model` object is created for each trainable snapshot.
When mutation is complete, the sandbox will be frozen to a trainable model. For example, consider a mutator that insert a node at an edge for each iteration.
Then the strategy will submit model to execution engine for training. In one iteration, the mutator invokes 4 primitives: add node, remove edge, add edge to head, add edge to tail.
The model will record its metrics once trained. These 4 primitives operates in one `Model` object.
When they are all done the model will be set to "frozen" (trainable) status and be submitted to execution engine.
And then a new iteration starts, and a new `Model` object is created by forking last model.
Attributes Attributes
---------- ----------
...@@ -104,17 +101,17 @@ class Model: ...@@ -104,17 +101,17 @@ class Model:
self.metric: Optional[MetricData] = None self.metric: Optional[MetricData] = None
self.intermediate_metrics: List[MetricData] = [] self.intermediate_metrics: List[MetricData] = []
self._last_uid: int = 0 self._last_uid: int = 0 # FIXME: this should be global, not model-wise
def __repr__(self): def __repr__(self):
return f'Model(model_id={self.model_id}, status={self.status}, graphs={list(self.graphs.keys())}, ' + \ return f'Model(model_id={self.model_id}, status={self.status}, graphs={list(self.graphs.keys())}, ' + \
f'training_config={self.training_config}, metric={self.metric}, intermediate_metrics={self.intermediate_metrics})' f'training_config={self.training_config}, metric={self.metric}, intermediate_metrics={self.intermediate_metrics})'
@property @property
def root_graph(self) -> Graph: def root_graph(self) -> 'Graph':
return self.graphs[self._root_graph_name] return self.graphs[self._root_graph_name]
def fork(self) -> Model: def fork(self) -> 'Model':
""" """
Create a new model which has same topology, names, and IDs to current one. Create a new model which has same topology, names, and IDs to current one.
...@@ -136,17 +133,17 @@ class Model: ...@@ -136,17 +133,17 @@ class Model:
return self._last_uid return self._last_uid
@staticmethod @staticmethod
def _load(ir: Any) -> Model: def _load(ir: Any) -> 'Model':
model = Model(_internal=True) model = Model(_internal=True)
for graph_name, graph_data in ir.items(): for graph_name, graph_data in ir.items():
if graph_name != '_training_config': if graph_name != '_training_config':
Graph._load(model, graph_name, graph_data)._register() Graph._load(model, graph_name, graph_data)._register()
#model.training_config = TrainingConfig._load(ir['_training_config']) model.training_config = TrainingConfig._load(ir['_training_config'])
return model return model
def _dump(self) -> Any: def _dump(self) -> Any:
ret = {name: graph._dump() for name, graph in self.graphs.items()} ret = {name: graph._dump() for name, graph in self.graphs.items()}
#ret['_training_config'] = self.training_config._dump() ret['_training_config'] = self.training_config._dump()
return ret return ret
...@@ -227,41 +224,45 @@ class Graph: ...@@ -227,41 +224,45 @@ class Graph:
f'output_names={self.output_names}, num_hidden_nodes={len(self.hidden_nodes)}, num_edges={len(self.edges)})' f'output_names={self.output_names}, num_hidden_nodes={len(self.hidden_nodes)}, num_edges={len(self.edges)})'
@property @property
def nodes(self) -> List[Node]: def nodes(self) -> List['Node']:
return [self.input_node, self.output_node] + self.hidden_nodes return [self.input_node, self.output_node] + self.hidden_nodes
# mutation # mutation
def add_node(self, type: Union[Operation, str], **parameters) -> Node: @overload
if isinstance(type, Operation): def add_node(self, operation: Operation) -> 'Node': ...
assert not parameters @overload
op = type def add_node(self, type_name: str, parameters: Dict[str, Any] = {}) -> 'Node': ...
def add_node(self, operation_or_type, parameters={}):
if isinstance(operation_or_type, Operation):
op = operation_or_type
else: else:
op = Operation.new(type, **parameters) op = Operation.new(operation_or_type, parameters)
return Node(self, self.model._uid(), None, op, _internal=True)._register() return Node(self, self.model._uid(), None, op, _internal=True)._register()
# mutation # mutation
def add_edge(self, head: Tuple[Node, Optional[int]], tail: Tuple[Node, Optional[int]]) -> Edge: def add_edge(self, head: Tuple['Node', Optional[int]], tail: Tuple['Node', Optional[int]]) -> 'Edge':
assert head[0].graph is self and tail[0].graph is self assert head[0].graph is self and tail[0].graph is self
return Edge(head, tail)._register() return Edge(head, tail)._register()
def get_node_by_name(self, name: str) -> Optional[Node]: def get_node_by_name(self, name: str) -> Optional['Node']:
""" """
Returns the node which has specified name; or returns `None` if no node has this name. Returns the node which has specified name; or returns `None` if no node has this name.
""" """
found = [node for node in self.nodes if node.name == name] found = [node for node in self.nodes if node.name == name]
return found[0] if found else None return found[0] if found else None
def get_nodes_by_type(self, operation_type: str) -> List[Node]: def get_nodes_by_type(self, operation_type: str) -> List['Node']:
""" """
Returns nodes whose operation is specified typed. Returns nodes whose operation is specified typed.
""" """
return [node for node in self.hidden_nodes if node.operation.type == operation_type] return [node for node in self.hidden_nodes if node.operation.type == operation_type]
def topo_sort(self) -> List[Node]: # TODO def topo_sort(self) -> List['Node']: # TODO
... ...
def fork(self) -> Graph: def fork(self) -> 'Graph':
""" """
Fork the model and returns corresponding graph in new model. Fork the model and returns corresponding graph in new model.
This shortcut might be helpful because many algorithms only cares about "stem" subgraph instead of whole model. This shortcut might be helpful because many algorithms only cares about "stem" subgraph instead of whole model.
...@@ -271,7 +272,7 @@ class Graph: ...@@ -271,7 +272,7 @@ class Graph:
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
return self is other return self is other
def _fork_to(self, model: Model) -> Graph: def _fork_to(self, model: Model) -> 'Graph':
new_graph = Graph(model, self.id, self.name, _internal=True)._register() new_graph = Graph(model, self.id, self.name, _internal=True)._register()
new_graph.input_names = self.input_names new_graph.input_names = self.input_names
new_graph.output_names = self.output_names new_graph.output_names = self.output_names
...@@ -288,7 +289,7 @@ class Graph: ...@@ -288,7 +289,7 @@ class Graph:
return new_graph return new_graph
def _copy(self) -> Graph: def _copy(self) -> 'Graph':
# Copy this graph inside the model. # Copy this graph inside the model.
# The new graph will have identical topology, but its nodes' name and ID will be different. # The new graph will have identical topology, but its nodes' name and ID will be different.
new_graph = Graph(self.model, self.model._uid(), _internal=True)._register() new_graph = Graph(self.model, self.model._uid(), _internal=True)._register()
...@@ -308,12 +309,12 @@ class Graph: ...@@ -308,12 +309,12 @@ class Graph:
return new_graph return new_graph
def _register(self) -> Graph: def _register(self) -> 'Graph':
self.model.graphs[self.name] = self self.model.graphs[self.name] = self
return self return self
@staticmethod @staticmethod
def _load(model: Model, name: str, ir: Any) -> Graph: def _load(model: Model, name: str, ir: Any) -> 'Graph':
graph = Graph(model, model._uid(), name, _internal=True) graph = Graph(model, model._uid(), name, _internal=True)
graph.input_names = ir.get('inputs') graph.input_names = ir.get('inputs')
graph.output_names = ir.get('outputs') graph.output_names = ir.get('outputs')
...@@ -381,19 +382,19 @@ class Node: ...@@ -381,19 +382,19 @@ class Node:
return f'Node(id={self.id}, name={self.name}, operation={self.operation})' return f'Node(id={self.id}, name={self.name}, operation={self.operation})'
@property @property
def predecessors(self) -> List[Node]: def predecessors(self) -> List['Node']:
return sorted(set(edge.head for edge in self.incoming_edges), key=(lambda node: node.id)) return sorted(set(edge.head for edge in self.incoming_edges), key=(lambda node: node.id))
@property @property
def successors(self) -> List[Node]: def successors(self) -> List['Node']:
return sorted(set(edge.tail for edge in self.outgoing_edges), key=(lambda node: node.id)) return sorted(set(edge.tail for edge in self.outgoing_edges), key=(lambda node: node.id))
@property @property
def incoming_edges(self) -> List[Edge]: def incoming_edges(self) -> List['Edge']:
return [edge for edge in self.graph.edges if edge.tail is self] return [edge for edge in self.graph.edges if edge.tail is self]
@property @property
def outgoing_edges(self) -> List[Edge]: def outgoing_edges(self) -> List['Edge']:
return [edge for edge in self.graph.edges if edge.head is self] return [edge for edge in self.graph.edges if edge.head is self]
@property @property
...@@ -403,12 +404,16 @@ class Node: ...@@ -403,12 +404,16 @@ class Node:
# mutation # mutation
def update_operation(self, type: Union[Operation, str], **parameters) -> None: @overload
if isinstance(type, Operation): def update_operation(self, operation: Operation) -> None: ...
assert not parameters @overload
self.operation = type def update_operation(self, type_name: str, parameters: Dict[str, Any] = {}) -> None: ...
def update_operation(self, operation_or_type, parameters={}):
if isinstance(operation_or_type, Operation):
self.operation = operation_or_type
else: else:
self.operation = Operation.new(type, **parameters) self.operation = Operation.new(operation_or_type, parameters)
# mutation # mutation
def remove(self) -> None: def remove(self) -> None:
...@@ -422,26 +427,29 @@ class Node: ...@@ -422,26 +427,29 @@ class Node:
Duplicate the cell template and let this node reference to newly created copy. Duplicate the cell template and let this node reference to newly created copy.
""" """
new_cell = self.cell._copy()._register() new_cell = self.cell._copy()._register()
self.operation = Operation.new('_cell', cell=new_cell.name) self.operation = Cell(new_cell.name)
return new_cell return new_cell
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
return self is other return self is other
def _register(self) -> Node: def _register(self) -> 'Node':
self.graph.hidden_nodes.append(self) self.graph.hidden_nodes.append(self)
return self return self
@staticmethod @staticmethod
def _load(graph: Graph, name: str, ir: Any) -> Node: def _load(graph: Graph, name: str, ir: Any) -> 'Node':
ir = dict(ir) if ir['type'] == '_cell':
if 'type' not in ir and 'cell' in ir: op = Cell(ir['cell'], ir.get('parameters', {}))
ir['type'] = '_cell' else:
op = Operation.new(**ir) op = Operation.new(ir['type'], ir.get('parameters', {}))
return Node(graph, graph.model._uid(), name, op) return Node(graph, graph.model._uid(), name, op)
def _dump(self) -> Any: def _dump(self) -> Any:
return {'type': self.operation.type, **self.operation.parameters} ret = {'type': self.operation.type, 'parameters': self.operation.parameters}
if isinstance(self.operation, Cell):
ret['cell'] = self.operation.cell_name
return ret
class Edge: class Edge:
...@@ -499,14 +507,15 @@ class Edge: ...@@ -499,14 +507,15 @@ class Edge:
def remove(self) -> None: def remove(self) -> None:
self.graph.edges.remove(self) self.graph.edges.remove(self)
def _register(self) -> Edge: def _register(self) -> 'Edge':
self.graph.edges.append(self) self.graph.edges.append(self)
return self return self
@staticmethod @staticmethod
def _load(graph: Graph, ir: Any) -> Edge: def _load(graph: Graph, ir: Any) -> 'Edge':
head = graph.get_node_by_name(ir['head'][0]) head = graph.get_node_by_name(ir['head'][0])
tail = graph.get_node_by_name(ir['tail'][0]) tail = graph.get_node_by_name(ir['tail'][0])
assert head is not None and tail is not None
return Edge((head, ir['head'][1]), (tail, ir['tail'][1]), _internal=True) return Edge((head, ir['head'][1]), (tail, ir['tail'][1]), _internal=True)
def _dump(self) -> Any: def _dump(self) -> Any:
......
from __future__ import annotations from typing import (Any, Iterable, List, Optional)
from typing import *
from .graph import * from .graph import Model
__all__ = ['Sampler', 'Mutator'] __all__ = ['Sampler', 'Mutator']
Choice = NewType('Choice', Any) Choice = Any
class Sampler: class Sampler:
""" """
Handles `Mutator.choice()` calls. Handles `Mutator.choice()` calls.
""" """
def choice(self, candidates: List[Choice], mutator: Mutator, model: Model, index: int) -> Choice: def choice(self, candidates: List[Choice], mutator: 'Mutator', model: Model, index: int) -> Choice:
raise NotImplementedError() raise NotImplementedError()
def mutation_start(self, mutator: Mutator, model: Model) -> None: def mutation_start(self, mutator: 'Mutator', model: Model) -> None:
pass pass
def mutation_end(self, mutator: Mutator, model: Model) -> None: def mutation_end(self, mutator: 'Mutator', model: Model) -> None:
pass pass
...@@ -44,11 +44,12 @@ class Mutator: ...@@ -44,11 +44,12 @@ class Mutator:
self._cur_model: Optional[Model] = None self._cur_model: Optional[Model] = None
self._cur_choice_idx: Optional[int] = None self._cur_choice_idx: Optional[int] = None
def bind_sampler(self, sampler: Sampler) -> Mutator: def bind_sampler(self, sampler: Sampler) -> 'Mutator':
""" """
Set the sampler which will handle `Mutator.choice` calls. Set the sampler which will handle `Mutator.choice` calls.
""" """
self.sampler = sampler self.sampler = sampler
return self
def apply(self, model: Model) -> Model: def apply(self, model: Model) -> Model:
""" """
...@@ -57,6 +58,7 @@ class Mutator: ...@@ -57,6 +58,7 @@ class Mutator:
The model will be copied before mutation and the original model will not be modified. The model will be copied before mutation and the original model will not be modified.
""" """
assert self.sampler is not None
copy = model.fork() copy = model.fork()
self._cur_model = copy self._cur_model = copy
self._cur_choice_idx = 0 self._cur_choice_idx = 0
...@@ -93,6 +95,7 @@ class Mutator: ...@@ -93,6 +95,7 @@ class Mutator:
""" """
Ask sampler to make a choice. Ask sampler to make a choice.
""" """
assert self.sampler is not None and self._cur_model is not None and self._cur_choice_idx is not None
ret = self.sampler.choice(list(candidates), self, self._cur_model, self._cur_choice_idx) ret = self.sampler.choice(list(candidates), self, self._cur_model, self._cur_choice_idx)
self._cur_choice_idx += 1 self._cur_choice_idx += 1
return ret return ret
......
from __future__ import annotations from typing import (Any, Dict)
from enum import Enum
from typing import *
from . import debug_configs from . import debug_configs
__all__ = ['Operation', 'Cell']
class Operation: class Operation:
""" """
...@@ -24,13 +24,9 @@ class Operation: ...@@ -24,13 +24,9 @@ class Operation:
Arbitrary key-value parameters (e.g. kernel_size). Arbitrary key-value parameters (e.g. kernel_size).
""" """
def __init__( def __init__(self, type_name: str, parameters: Dict[str, Any], _internal: bool = False):
self, assert _internal, '`Operation()` is private, use `Operation.new()` instead'
type: str, self.type: str = type_name
parameters: Dict[str, Any],
_internal_access: bool = False):
assert _internal_access, '`Operation()` is private, use `Operation.new()` instead'
self.type: str = type
self.parameters: Dict[str, Any] = parameters self.parameters: Dict[str, Any] = parameters
def to_init_code(self, field: str) -> str: def to_init_code(self, field: str) -> str:
...@@ -47,19 +43,19 @@ class Operation: ...@@ -47,19 +43,19 @@ class Operation:
return True return True
@staticmethod @staticmethod
def new(type: str, **parameters: Any) -> Operation: def new(type_name: str, parameters: Dict[str, Any] = {}) -> 'Operation':
if type == '_cell': if type_name == '_cell':
return Cell(parameters['cell']) return Cell(parameters['cell'])
else: else:
if debug_configs.framework.lower() in ('torch', 'pytorch'): if debug_configs.framework.lower() in ('torch', 'pytorch'):
from .operation_def import torch_op_def from .operation_def import torch_op_def # pylint: disable=unused-import
cls = PyTorchOperation._find_subclass(type) cls = PyTorchOperation._find_subclass(type)
elif debug_configs.framework.lower() in ('tf', 'tensorflow'): elif debug_configs.framework.lower() in ('tf', 'tensorflow'):
from .operation_def import tf_op_def from .operation_def import tf_op_def # pylint: disable=unused-import
cls = TensorFlowOperation._find_subclass(type) cls = TensorFlowOperation._find_subclass(type)
else: else:
raise ValueError(f'Unsupported framework: {debug_configs.framework}') raise ValueError(f'Unsupported framework: {debug_configs.framework}')
return cls(type, parameters, _internal_access=True) return cls(type_name, parameters, _internal=True)
@classmethod @classmethod
def _find_subclass(cls, subclass_name): def _find_subclass(cls, subclass_name):
...@@ -120,12 +116,13 @@ class Cell(Operation): ...@@ -120,12 +116,13 @@ class Cell(Operation):
framework framework
No real usage. Exists for compatibility with base class. No real usage. Exists for compatibility with base class.
""" """
def __init__(self, cell_name: str): def __init__(self, cell_name: str, parameters: Dict[str, Any] = {}):
self.type = '_cell' self.type = '_cell'
self.parameters = {'cell': cell_name} self.cell_name = cell_name
self.parameters = parameters
def to_init_code(self, field: str) -> str: def _to_class_name(self):
return f'self.{field} = {self.parameters["cell"]}()' return self.cell_name
class _PseudoOperation(Operation): class _PseudoOperation(Operation):
......
from ..operation import TensorFlowOperation from ..operation import TensorFlowOperation
class Conv2D(TensorFlowOperation): class Conv2D(TensorFlowOperation):
def to_init_code(self, field): def __init__(self, type_name, parameters, _internal):
parameters = {'padding': 'same', **parameters} if 'padding' not in parameters:
super().__init__(type, parameters, _internal_access) parameters['padding'] = 'same'
super().__init__(type_name, parameters, _internal)
from ..operation import PyTorchOperation from ..operation import PyTorchOperation
class relu(PyTorchOperation):
def to_init_code(self, field):
return ''
def to_forward_code(self, field, output, *inputs) -> str:
assert len(inputs) == 1
return f'{output} = nn.functional.relu({inputs[0]})'
class Flatten(PyTorchOperation): class Flatten(PyTorchOperation):
def to_init_code(self, field): def to_init_code(self, field):
......
...@@ -4,10 +4,10 @@ ...@@ -4,10 +4,10 @@
"outputs": ["metric"], "outputs": ["metric"],
"nodes": { "nodes": {
"stem": {"cell": "stem"}, "stem": {"type": "_cell", "cell": "stem"},
"flatten": {"type": "Flatten"}, "flatten": {"type": "Flatten"},
"fc1": {"type": "Dense", "units": 1024, "activation": "relu"}, "fc1": {"type": "Dense", "parameters": {"units": 1024, "activation": "relu"}},
"fc2": {"type": "Dense", "units": 10}, "fc2": {"type": "Dense", "parameters": {"units": 10}},
"softmax": {"type": "Softmax"} "softmax": {"type": "Softmax"}
}, },
...@@ -23,10 +23,10 @@ ...@@ -23,10 +23,10 @@
"stem": { "stem": {
"nodes": { "nodes": {
"conv1": {"type": "Conv2D", "filters": 32, "kernel_size": 5, "activation": "relu"}, "conv1": {"type": "Conv2D", "parameters": {"filters": 32, "kernel_size": 5, "activation": "relu"}},
"pool1": {"type": "MaxPool2D", "pool_size": 2}, "pool1": {"type": "MaxPool2D", "parameters": {"pool_size": 2}},
"conv2": {"type": "Conv2D", "filters": 64, "kernel_size": 5, "activation": "relu"}, "conv2": {"type": "Conv2D", "parameters": {"filters": 64, "kernel_size": 5, "activation": "relu"}},
"pool2": {"type": "MaxPool2D", "pool_size": 2} "pool2": {"type": "MaxPool2D", "parameters": {"pool_size": 2}}
}, },
"edges": [ "edges": [
...@@ -36,5 +36,10 @@ ...@@ -36,5 +36,10 @@
{"head": ["conv2", null], "tail": ["pool2", null]}, {"head": ["conv2", null], "tail": ["pool2", null]},
{"head": ["pool2", null], "tail": ["_outputs", 0]} {"head": ["pool2", null], "tail": ["_outputs", 0]}
] ]
},
"_training_config": {
"module": "_debug_no_trainer",
"kwargs": {}
} }
} }
...@@ -23,13 +23,19 @@ def _test_file(json_path): ...@@ -23,13 +23,19 @@ def _test_file(json_path):
# add default values to JSON, so we can compare with `==` # add default values to JSON, so we can compare with `==`
for graph_name, graph in orig_ir.items(): for graph_name, graph in orig_ir.items():
if graph_name == '_training_config':
continue
if 'inputs' not in graph: if 'inputs' not in graph:
graph['inputs'] = None graph['inputs'] = None
if 'outputs' not in graph: if 'outputs' not in graph:
graph['outputs'] = None graph['outputs'] = None
for node_name, node in graph['nodes'].items(): for node_name, node in graph['nodes'].items():
if 'type' not in node and 'cell' in node: if 'parameters' not in node:
node['type'] = '_cell' node['parameters'] = {}
# debug output
#json.dump(orig_ir, open('_orig.json', 'w'), indent=4)
#json.dump(dump_ir, open('_dump.json', 'w'), indent=4)
assert orig_ir == dump_ir assert orig_ir == dump_ir
......
...@@ -8,6 +8,10 @@ from nni.retiarii import * ...@@ -8,6 +8,10 @@ from nni.retiarii import *
import nni.retiarii.debug_configs import nni.retiarii.debug_configs
nni.retiarii.debug_configs.framework = 'tensorflow' nni.retiarii.debug_configs.framework = 'tensorflow'
max_pool = Operation.new('MaxPool2D', {'pool_size': 2})
avg_pool = Operation.new('AveragePooling2D', {'pool_size': 2})
global_pool = Operation.new('GlobalAveragePooling2D')
class DebugSampler(Sampler): class DebugSampler(Sampler):
def __init__(self): def __init__(self):
...@@ -22,9 +26,6 @@ class DebugSampler(Sampler): ...@@ -22,9 +26,6 @@ class DebugSampler(Sampler):
class DebugMutator(Mutator): class DebugMutator(Mutator):
def mutate(self, model): def mutate(self, model):
max_pool = Operation.new('MaxPool2D', pool_size = 2)
avg_pool = Operation.new('AveragePooling2D', pool_size=2)
global_pool = Operation.new('GlobalAveragePooling2D')
ops = [max_pool, avg_pool, global_pool] ops = [max_pool, avg_pool, global_pool]
pool1 = model.graphs['stem'].get_node_by_name('pool1') pool1 = model.graphs['stem'].get_node_by_name('pool1')
...@@ -67,10 +68,6 @@ def _get_pools(model): ...@@ -67,10 +68,6 @@ def _get_pools(model):
return pool1, pool2 return pool1, pool2
max_pool = Operation.new(type='MaxPool2D', pool_size=2)
avg_pool = Operation.new(type='AveragePooling2D', pool_size=2)
global_pool = Operation.new(type='GlobalAveragePooling2D')
if __name__ == '__main__': if __name__ == '__main__':
test_dry_run() test_dry_run()
test_mutation() test_mutation()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment