Unverified Commit 59cd3982 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Coding style improvements for pylint and flake8 (#3190)

parent 593a275c
import logging import logging
from typing import * from typing import List
from ..graph import IllegalGraphError, Edge, Graph, Node, Model from ..graph import IllegalGraphError, Edge, Graph, Node, Model
from ..operation import Operation, Cell
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
def model_to_pytorch_script(model: Model, placement=None) -> str:
def model_to_pytorch_script(model: Model, placement = None) -> str:
graphs = [] graphs = []
total_pkgs = set() total_pkgs = set()
for name, cell in model.graphs.items(): for name, cell in model.graphs.items():
import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement = placement) import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement=placement)
graphs.append(graph_code) graphs.append(graph_code)
total_pkgs.update(import_pkgs) total_pkgs.update(import_pkgs)
pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs]) pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs])
return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip() return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip()
def _sorted_incoming_edges(node: Node) -> List[Edge]: def _sorted_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node] edges = [edge for edge in node.graph.edges if edge.tail is node]
_logger.info('sorted_incoming_edges: {}'.format(edges)) _logger.info('sorted_incoming_edges: %s', str(edges))
if not edges: if not edges:
return [] return []
_logger.info(f'all tail_slots are None: {[edge.tail_slot for edge in edges]}') _logger.info('all tail_slots are None: %s', str([edge.tail_slot for edge in edges]))
if all(edge.tail_slot is None for edge in edges): if all(edge.tail_slot is None for edge in edges):
return edges return edges
if all(isinstance(edge.tail_slot, int) for edge in edges): if all(isinstance(edge.tail_slot, int) for edge in edges):
...@@ -32,6 +31,7 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]: ...@@ -32,6 +31,7 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]:
return edges return edges
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name)) raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))
def _format_inputs(node: Node) -> List[str]: def _format_inputs(node: Node) -> List[str]:
edges = _sorted_incoming_edges(node) edges = _sorted_incoming_edges(node)
inputs = [] inputs = []
...@@ -53,6 +53,7 @@ def _format_inputs(node: Node) -> List[str]: ...@@ -53,6 +53,7 @@ def _format_inputs(node: Node) -> List[str]:
inputs.append('{}[{}]'.format(edge.head.name, edge.head_slot)) inputs.append('{}[{}]'.format(edge.head.name, edge.head_slot))
return inputs return inputs
def _remove_prefix(names, graph_name): def _remove_prefix(names, graph_name):
""" """
variables name (full name space) is too long, variables name (full name space) is too long,
...@@ -69,14 +70,14 @@ def _remove_prefix(names, graph_name): ...@@ -69,14 +70,14 @@ def _remove_prefix(names, graph_name):
else: else:
return names[len(graph_name):] if names.startswith(graph_name) else names return names[len(graph_name):] if names.startswith(graph_name) else names
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement = None) -> str:
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str:
nodes = graph.topo_sort() nodes = graph.topo_sort()
# handle module node and function node differently # handle module node and function node differently
# only need to generate code for module here # only need to generate code for module here
import_pkgs = set() import_pkgs = set()
node_codes = [] node_codes = []
placement_codes = []
for node in nodes: for node in nodes:
if node.operation: if node.operation:
pkg_name = node.operation.get_import_pkg() pkg_name = node.operation.get_import_pkg()
......
from .graph_gen import convert_to_graph from .graph_gen import convert_to_graph
from .visualize import visualize_model
\ No newline at end of file
import json_tricks
import logging import logging
import re import re
import torch
from ..graph import Graph, Node, Edge, Model import torch
from ..operation import Cell, Operation
from ..nn.pytorch import Placeholder, LayerChoice, InputChoice
from .op_types import MODULE_EXCEPT_LIST, OpTypeName, BasicOpsPT from ..graph import Graph, Model, Node
from .utils import build_full_name, _convert_name from ..nn.pytorch import InputChoice, LayerChoice, Placeholder
from ..operation import Cell
from .op_types import MODULE_EXCEPT_LIST, BasicOpsPT, OpTypeName
from .utils import _convert_name, build_full_name
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -16,6 +15,7 @@ global_seq = 0 ...@@ -16,6 +15,7 @@ global_seq = 0
global_graph_id = 0 global_graph_id = 0
modules_arg = None modules_arg = None
def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap, ignore_first=False): def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap, ignore_first=False):
""" """
Parameters Parameters
...@@ -76,6 +76,7 @@ def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap, ...@@ -76,6 +76,7 @@ def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap,
new_node_input_idx += 1 new_node_input_idx += 1
def create_prim_constant_node(ir_graph, node, module_name): def create_prim_constant_node(ir_graph, node, module_name):
global global_seq global global_seq
attrs = {} attrs = {}
...@@ -86,14 +87,17 @@ def create_prim_constant_node(ir_graph, node, module_name): ...@@ -86,14 +87,17 @@ def create_prim_constant_node(ir_graph, node, module_name):
node.kind(), attrs) node.kind(), attrs)
return new_node return new_node
def handle_prim_attr_node(node): def handle_prim_attr_node(node):
assert node.hasAttribute('name') assert node.hasAttribute('name')
attrs = {'name': node.s('name'), 'input': node.inputsAt(0).debugName()} attrs = {'name': node.s('name'), 'input': node.inputsAt(0).debugName()}
return node.kind(), attrs return node.kind(), attrs
def _remove_mangle(module_type_str): def _remove_mangle(module_type_str):
return re.sub('\\.___torch_mangle_\\d+', '', module_type_str) return re.sub('\\.___torch_mangle_\\d+', '', module_type_str)
def remove_unconnected_nodes(ir_graph, targeted_type=None): def remove_unconnected_nodes(ir_graph, targeted_type=None):
""" """
Parameters Parameters
...@@ -122,6 +126,7 @@ def remove_unconnected_nodes(ir_graph, targeted_type=None): ...@@ -122,6 +126,7 @@ def remove_unconnected_nodes(ir_graph, targeted_type=None):
for hidden_node in to_removes: for hidden_node in to_removes:
hidden_node.remove() hidden_node.remove()
def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, ir_graph): def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, ir_graph):
""" """
Convert torch script node to our node ir, and build our graph ir Convert torch script node to our node ir, and build our graph ir
...@@ -156,7 +161,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i ...@@ -156,7 +161,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
# TODO: add scope name # TODO: add scope name
ir_graph._add_input(_convert_name(_input.debugName())) ir_graph._add_input(_convert_name(_input.debugName()))
node_index = {} # graph node to graph ir node node_index = {} # graph node to graph ir node
# some node does not have output but it modifies a variable, for example aten::append # some node does not have output but it modifies a variable, for example aten::append
# %17 : Tensor[] = aten::append(%out.1, %16) # %17 : Tensor[] = aten::append(%out.1, %16)
...@@ -248,13 +253,14 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i ...@@ -248,13 +253,14 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
# therefore, we do this check for a module. example below: # therefore, we do this check for a module. example below:
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self) # %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self)
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1) # %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
assert submodule_name in script_module._modules, "submodule_name: {} not in script_module {}".format(submodule_name, script_module._modules.keys()) assert submodule_name in script_module._modules, "submodule_name: {} not in script_module {}".format(
submodule_name, script_module._modules.keys())
submodule_full_name = build_full_name(module_name, submodule_name) submodule_full_name = build_full_name(module_name, submodule_name)
submodule_obj = getattr(module, submodule_name) submodule_obj = getattr(module, submodule_name)
subgraph, sub_m_attrs = convert_module(script_module._modules[submodule_name], subgraph, sub_m_attrs = convert_module(script_module._modules[submodule_name],
submodule_obj, submodule_obj,
submodule_full_name, ir_model) submodule_full_name, ir_model)
else: else:
# %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self) # %8 : __torch__.nni.retiarii.model_apis.nn.___torch_mangle_37.ModuleList = prim::GetAttr[name="cells"](%self)
# %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8) # %10 : __torch__.darts_model.Cell = prim::GetAttr[name="0"](%8)
...@@ -271,7 +277,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i ...@@ -271,7 +277,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
predecessor_obj = getattr(module, predecessor_name) predecessor_obj = getattr(module, predecessor_name)
submodule_obj = getattr(predecessor_obj, submodule_name) submodule_obj = getattr(predecessor_obj, submodule_name)
subgraph, sub_m_attrs = convert_module(script_module._modules[predecessor_name]._modules[submodule_name], subgraph, sub_m_attrs = convert_module(script_module._modules[predecessor_name]._modules[submodule_name],
submodule_obj, submodule_full_name, ir_model) submodule_obj, submodule_full_name, ir_model)
else: else:
raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str())) raise RuntimeError('Unsupported module case: {}'.format(submodule.inputsAt(0).type().str()))
...@@ -329,7 +335,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i ...@@ -329,7 +335,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
node_type, attrs = handle_prim_attr_node(node) node_type, attrs = handle_prim_attr_node(node)
global_seq += 1 global_seq += 1
new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Attr, global_seq), new_node = ir_graph.add_node(build_full_name(module_name, OpTypeName.Attr, global_seq),
node_type, attrs) node_type, attrs)
node_index[node] = new_node node_index[node] = new_node
elif node.kind() == 'prim::min': elif node.kind() == 'prim::min':
print('zql: ', sm_graph) print('zql: ', sm_graph)
...@@ -350,6 +356,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i ...@@ -350,6 +356,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
return node_index return node_index
def merge_aten_slices(ir_graph): def merge_aten_slices(ir_graph):
""" """
if there is aten::slice node, merge the consecutive ones together. if there is aten::slice node, merge the consecutive ones together.
...@@ -367,7 +374,7 @@ def merge_aten_slices(ir_graph): ...@@ -367,7 +374,7 @@ def merge_aten_slices(ir_graph):
break break
if has_slice_node: if has_slice_node:
assert head_slice_nodes assert head_slice_nodes
for head_node in head_slice_nodes: for head_node in head_slice_nodes:
slot = 0 slot = 0
new_slice_node = ir_graph.add_node(build_full_name(head_node.name, 'merged'), OpTypeName.MergedSlice) new_slice_node = ir_graph.add_node(build_full_name(head_node.name, 'merged'), OpTypeName.MergedSlice)
...@@ -391,11 +398,11 @@ def merge_aten_slices(ir_graph): ...@@ -391,11 +398,11 @@ def merge_aten_slices(ir_graph):
slot += 4 slot += 4
ir_graph.hidden_nodes.remove(node) ir_graph.hidden_nodes.remove(node)
node = suc_node node = suc_node
for edge in node.outgoing_edges: for edge in node.outgoing_edges:
edge.head = new_slice_node edge.head = new_slice_node
ir_graph.hidden_nodes.remove(node) ir_graph.hidden_nodes.remove(node)
def refine_graph(ir_graph): def refine_graph(ir_graph):
""" """
...@@ -408,13 +415,14 @@ def refine_graph(ir_graph): ...@@ -408,13 +415,14 @@ def refine_graph(ir_graph):
remove_unconnected_nodes(ir_graph, targeted_type='prim::GetAttr') remove_unconnected_nodes(ir_graph, targeted_type='prim::GetAttr')
merge_aten_slices(ir_graph) merge_aten_slices(ir_graph)
def _handle_layerchoice(module): def _handle_layerchoice(module):
global modules_arg global modules_arg
m_attrs = {} m_attrs = {}
candidates = module.candidate_ops candidates = module.candidate_ops
choices = [] choices = []
for i, cand in enumerate(candidates): for cand in candidates:
assert id(cand) in modules_arg, 'id not exist: {}'.format(id(cand)) assert id(cand) in modules_arg, 'id not exist: {}'.format(id(cand))
assert isinstance(modules_arg[id(cand)], dict) assert isinstance(modules_arg[id(cand)], dict)
cand_type = '__torch__.' + cand.__class__.__module__ + '.' + cand.__class__.__name__ cand_type = '__torch__.' + cand.__class__.__module__ + '.' + cand.__class__.__name__
...@@ -423,6 +431,7 @@ def _handle_layerchoice(module): ...@@ -423,6 +431,7 @@ def _handle_layerchoice(module):
m_attrs['label'] = module.label m_attrs['label'] = module.label
return m_attrs return m_attrs
def _handle_inputchoice(module): def _handle_inputchoice(module):
m_attrs = {} m_attrs = {}
m_attrs['n_chosen'] = module.n_chosen m_attrs['n_chosen'] = module.n_chosen
...@@ -430,6 +439,7 @@ def _handle_inputchoice(module): ...@@ -430,6 +439,7 @@ def _handle_inputchoice(module):
m_attrs['label'] = module.label m_attrs['label'] = module.label
return m_attrs return m_attrs
def convert_module(script_module, module, module_name, ir_model): def convert_module(script_module, module, module_name, ir_model):
""" """
Convert a module to its graph ir (i.e., Graph) along with its input arguments Convert a module to its graph ir (i.e., Graph) along with its input arguments
...@@ -503,10 +513,11 @@ def convert_module(script_module, module, module_name, ir_model): ...@@ -503,10 +513,11 @@ def convert_module(script_module, module, module_name, ir_model):
# TODO: if we parse this module, it means we will create a graph (module class) # TODO: if we parse this module, it means we will create a graph (module class)
# for this module. Then it is not necessary to record this module's arguments # for this module. Then it is not necessary to record this module's arguments
# return ir_graph, modules_arg[id(module)]. # return ir_graph, modules_arg[id(module)].
# That is, we can refactor this part, to allow users to annotate which module # That is, we can refactor this part, to allow users to annotate which module
# should not be parsed further. # should not be parsed further.
return ir_graph, {} return ir_graph, {}
def convert_to_graph(script_module, module, recorded_modules_arg): def convert_to_graph(script_module, module, recorded_modules_arg):
""" """
Convert module to our graph ir, i.e., build a ```Model``` type Convert module to our graph ir, i.e., build a ```Model``` type
......
...@@ -16,6 +16,7 @@ class OpTypeName(str, Enum): ...@@ -16,6 +16,7 @@ class OpTypeName(str, Enum):
Placeholder = 'Placeholder' Placeholder = 'Placeholder'
MergedSlice = 'MergedSlice' MergedSlice = 'MergedSlice'
# deal with aten op # deal with aten op
BasicOpsPT = { BasicOpsPT = {
'aten::mean': 'Mean', 'aten::mean': 'Mean',
...@@ -29,7 +30,7 @@ BasicOpsPT = { ...@@ -29,7 +30,7 @@ BasicOpsPT = {
'aten::size': 'Size', 'aten::size': 'Size',
'aten::view': 'View', 'aten::view': 'View',
'aten::eq': 'Eq', 'aten::eq': 'Eq',
'aten::add_': 'Add_' # %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4) 'aten::add_': 'Add_' # %out.3 : Tensor = aten::add_(%out.1, %connection.1, %4)
} }
BasicOpsTF = {} BasicOpsTF = {}
\ No newline at end of file
...@@ -6,6 +6,7 @@ def build_full_name(prefix, name, seq=None): ...@@ -6,6 +6,7 @@ def build_full_name(prefix, name, seq=None):
else: else:
return '{}__{}{}'.format(prefix, name, str(seq)) return '{}__{}{}'.format(prefix, name, str(seq))
def _convert_name(name: str) -> str: def _convert_name(name: str) -> str:
""" """
Convert the names using separator '.' to valid variable name in code Convert the names using separator '.' to valid variable name in code
......
import graphviz import graphviz
def convert_to_visualize(graph_ir, vgraph): def convert_to_visualize(graph_ir, vgraph):
for name, graph in graph_ir.items(): for name, graph in graph_ir.items():
if name == '_training_config': if name == '_training_config':
...@@ -33,7 +34,8 @@ def convert_to_visualize(graph_ir, vgraph): ...@@ -33,7 +34,8 @@ def convert_to_visualize(graph_ir, vgraph):
dst = cell_node[dst][0] dst = cell_node[dst][0]
subgraph.edge(src, dst) subgraph.edge(src, dst)
def visualize_model(graph_ir): def visualize_model(graph_ir):
vgraph = graphviz.Digraph('G', filename='vgraph', format='jpg') vgraph = graphviz.Digraph('G', filename='vgraph', format='jpg')
convert_to_visualize(graph_ir, vgraph) convert_to_visualize(graph_ir, vgraph)
vgraph.render() vgraph.render()
\ No newline at end of file
import time import time
import os import os
import importlib.util from typing import List
from typing import *
from ..graph import Model, ModelStatus from ..graph import Model, ModelStatus
from .base import BaseExecutionEngine from .base import BaseExecutionEngine
from .cgo_engine import CGOExecutionEngine from .cgo_engine import CGOExecutionEngine
from .interface import * from .interface import AbstractExecutionEngine, WorkerInfo
from .listener import DefaultListener from .listener import DefaultListener
_execution_engine = None _execution_engine = None
......
import logging import logging
from typing import * from typing import Dict, Any, List
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils from .. import codegen, utils
...@@ -61,16 +61,16 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -61,16 +61,16 @@ class BaseExecutionEngine(AbstractExecutionEngine):
def _send_trial_callback(self, paramater: dict) -> None: def _send_trial_callback(self, paramater: dict) -> None:
for listener in self._listeners: for listener in self._listeners:
_logger.warning('resources: {}'.format(listener.resources)) _logger.warning('resources: %s', listener.resources)
if not listener.has_available_resource(): if not listener.has_available_resource():
_logger.warning('There is no available resource, but trial is submitted.') _logger.warning('There is no available resource, but trial is submitted.')
listener.on_resource_used(1) listener.on_resource_used(1)
_logger.warning('on_resource_used: {}'.format(listener.resources)) _logger.warning('on_resource_used: %s', listener.resources)
def _request_trial_jobs_callback(self, num_trials: int) -> None: def _request_trial_jobs_callback(self, num_trials: int) -> None:
for listener in self._listeners: for listener in self._listeners:
listener.on_resource_available(1 * num_trials) listener.on_resource_available(1 * num_trials)
_logger.warning('on_resource_available: {}'.format(listener.resources)) _logger.warning('on_resource_available: %s', listener.resources)
def _trial_end_callback(self, trial_id: int, success: bool) -> None: def _trial_end_callback(self, trial_id: int, success: bool) -> None:
model = self._running_models[trial_id] model = self._running_models[trial_id]
......
import logging import logging
import json from typing import List, Dict, Tuple
from typing import *
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils from .. import codegen, utils
...@@ -12,8 +11,10 @@ from .logical_optimizer.opt_dedup_input import DedupInputOptimizer ...@@ -12,8 +11,10 @@ from .logical_optimizer.opt_dedup_input import DedupInputOptimizer
from .base import BaseGraphData from .base import BaseGraphData
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
class CGOExecutionEngine(AbstractExecutionEngine): class CGOExecutionEngine(AbstractExecutionEngine):
def __init__(self, n_model_per_graph = 4) -> None: def __init__(self, n_model_per_graph=4) -> None:
self._listeners: List[AbstractGraphListener] = [] self._listeners: List[AbstractGraphListener] = []
self._running_models: Dict[int, Model] = dict() self._running_models: Dict[int, Model] = dict()
self.logical_plan_counter = 0 self.logical_plan_counter = 0
...@@ -30,38 +31,37 @@ class CGOExecutionEngine(AbstractExecutionEngine): ...@@ -30,38 +31,37 @@ class CGOExecutionEngine(AbstractExecutionEngine):
advisor.intermediate_metric_callback = self._intermediate_metric_callback advisor.intermediate_metric_callback = self._intermediate_metric_callback
advisor.final_metric_callback = self._final_metric_callback advisor.final_metric_callback = self._final_metric_callback
def add_optimizer(self, opt): def add_optimizer(self, opt):
self._optimizers.append(opt) self._optimizers.append(opt)
def submit_models(self, *models: List[Model]) -> None: def submit_models(self, *models: List[Model]) -> None:
_logger.info(f'{len(models)} Models are submitted') _logger.info('%d models are submitted', len(models))
logical = self._build_logical(models) logical = self._build_logical(models)
for opt in self._optimizers: for opt in self._optimizers:
opt.convert(logical) opt.convert(logical)
phy_models_and_placements = self._assemble(logical) phy_models_and_placements = self._assemble(logical)
for model, placement, grouped_models in phy_models_and_placements: for model, placement, grouped_models in phy_models_and_placements:
data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement), data = BaseGraphData(codegen.model_to_pytorch_script(model, placement=placement),
model.training_config.module, model.training_config.kwargs) model.training_config.module, model.training_config.kwargs)
for m in grouped_models: for m in grouped_models:
self._original_models[m.model_id] = m self._original_models[m.model_id] = m
self._original_model_to_multi_model[m.model_id] = model self._original_model_to_multi_model[m.model_id] = model
self._running_models[send_trial(data.dump())] = model self._running_models[send_trial(data.dump())] = model
# for model in models: # for model in models:
# data = BaseGraphData(codegen.model_to_pytorch_script(model), # data = BaseGraphData(codegen.model_to_pytorch_script(model),
# model.config['trainer_module'], model.config['trainer_kwargs']) # model.config['trainer_module'], model.config['trainer_kwargs'])
# self._running_models[send_trial(data.dump())] = model # self._running_models[send_trial(data.dump())] = model
def _assemble(self, logical_plan : LogicalPlan) -> List[Tuple[Model, PhysicalDevice]]: def _assemble(self, logical_plan: LogicalPlan) -> List[Tuple[Model, PhysicalDevice]]:
# unique_models = set() # unique_models = set()
# for node in logical_plan.graph.nodes: # for node in logical_plan.graph.nodes:
# if node.graph.model not in unique_models: # if node.graph.model not in unique_models:
# unique_models.add(node.graph.model) # unique_models.add(node.graph.model)
# return [m for m in unique_models] # return [m for m in unique_models]
grouped_models : List[Dict[Model, PhysicalDevice]] = AssemblePolicy().group(logical_plan) grouped_models: List[Dict[Model, PhysicalDevice]] = AssemblePolicy().group(logical_plan)
phy_models_and_placements = [] phy_models_and_placements = []
for multi_model in grouped_models: for multi_model in grouped_models:
model, model_placement = logical_plan.assemble(multi_model) model, model_placement = logical_plan.assemble(multi_model)
...@@ -69,7 +69,7 @@ class CGOExecutionEngine(AbstractExecutionEngine): ...@@ -69,7 +69,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
return phy_models_and_placements return phy_models_and_placements
def _build_logical(self, models: List[Model]) -> LogicalPlan: def _build_logical(self, models: List[Model]) -> LogicalPlan:
logical_plan = LogicalPlan(id = self.logical_plan_counter) logical_plan = LogicalPlan(plan_id=self.logical_plan_counter)
for model in models: for model in models:
logical_plan.add_model(model) logical_plan.add_model(model)
self.logical_plan_counter += 1 self.logical_plan_counter += 1
...@@ -108,7 +108,7 @@ class CGOExecutionEngine(AbstractExecutionEngine): ...@@ -108,7 +108,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
for model_id in merged_metrics: for model_id in merged_metrics:
int_model_id = int(model_id) int_model_id = int(model_id)
self._original_models[int_model_id].intermediate_metrics.append(merged_metrics[model_id]) self._original_models[int_model_id].intermediate_metrics.append(merged_metrics[model_id])
#model.intermediate_metrics.append(metrics) # model.intermediate_metrics.append(metrics)
for listener in self._listeners: for listener in self._listeners:
listener.on_intermediate_metric(self._original_models[int_model_id], merged_metrics[model_id]) listener.on_intermediate_metric(self._original_models[int_model_id], merged_metrics[model_id])
...@@ -117,10 +117,9 @@ class CGOExecutionEngine(AbstractExecutionEngine): ...@@ -117,10 +117,9 @@ class CGOExecutionEngine(AbstractExecutionEngine):
for model_id in merged_metrics: for model_id in merged_metrics:
int_model_id = int(model_id) int_model_id = int(model_id)
self._original_models[int_model_id].intermediate_metrics.append(merged_metrics[model_id]) self._original_models[int_model_id].intermediate_metrics.append(merged_metrics[model_id])
#model.intermediate_metrics.append(metrics) # model.intermediate_metrics.append(metrics)
for listener in self._listeners: for listener in self._listeners:
listener.on_metric(self._original_models[int_model_id], merged_metrics[model_id]) listener.on_metric(self._original_models[int_model_id], merged_metrics[model_id])
def query_available_resource(self) -> List[WorkerInfo]: def query_available_resource(self) -> List[WorkerInfo]:
raise NotImplementedError # move the method from listener to here? raise NotImplementedError # move the method from listener to here?
...@@ -141,6 +140,7 @@ class CGOExecutionEngine(AbstractExecutionEngine): ...@@ -141,6 +140,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
trainer_instance = trainer_cls(model_cls(), graph_data.training_kwargs) trainer_instance = trainer_cls(model_cls(), graph_data.training_kwargs)
trainer_instance.fit() trainer_instance.fit()
class AssemblePolicy: class AssemblePolicy:
@staticmethod @staticmethod
def group(logical_plan): def group(logical_plan):
...@@ -148,4 +148,3 @@ class AssemblePolicy: ...@@ -148,4 +148,3 @@ class AssemblePolicy:
for idx, m in enumerate(logical_plan.models): for idx, m in enumerate(logical_plan.models):
group_model[m] = PhysicalDevice('server', f'cuda:{idx}') group_model[m] = PhysicalDevice('server', f'cuda:{idx}')
return [group_model] return [group_model]
\ No newline at end of file
from abc import * from abc import ABC, abstractmethod, abstractclassmethod
from typing import * from typing import Any, NewType, List
from ..graph import Model, MetricData from ..graph import Model, MetricData
......
from typing import * from ..graph import Model, ModelStatus
from .interface import MetricData, AbstractGraphListener
from ..graph import *
from .interface import *
class DefaultListener(AbstractGraphListener): class DefaultListener(AbstractGraphListener):
......
from abc import * from abc import ABC
from typing import *
from .logical_plan import LogicalPlan from .logical_plan import LogicalPlan
class AbstractOptimizer(ABC): class AbstractOptimizer(ABC):
def __init__(self) -> None: def __init__(self) -> None:
pass pass
......
from nni.retiarii.operation import Operation
from nni.retiarii.graph import Model, Graph, Edge, Node, Cell
from typing import *
import logging
from nni.retiarii.operation import _IOPseudoOperation
import copy import copy
from typing import Dict, Tuple, List, Any
from ...graph import Cell, Edge, Graph, Model, Node
from ...operation import Operation, _IOPseudoOperation
class PhysicalDevice: class PhysicalDevice:
...@@ -108,11 +107,11 @@ class OriginNode(AbstractLogicalNode): ...@@ -108,11 +107,11 @@ class OriginNode(AbstractLogicalNode):
class LogicalPlan: class LogicalPlan:
def __init__(self, id=0) -> None: def __init__(self, plan_id=0) -> None:
self.lp_model = Model(_internal=True) self.lp_model = Model(_internal=True)
self.id = id self.id = plan_id
self.logical_graph = LogicalGraph( self.logical_graph = LogicalGraph(
self.lp_model, id, name=f'{id}', _internal=True)._register() self.lp_model, self.id, name=f'{self.id}', _internal=True)._register()
self.lp_model._root_graph_name = self.logical_graph.name self.lp_model._root_graph_name = self.logical_graph.name
self.models = [] self.models = []
...@@ -148,7 +147,7 @@ class LogicalPlan: ...@@ -148,7 +147,7 @@ class LogicalPlan:
phy_model.training_config.kwargs['is_multi_model'] = True phy_model.training_config.kwargs['is_multi_model'] = True
phy_model.training_config.kwargs['model_cls'] = phy_graph.name phy_model.training_config.kwargs['model_cls'] = phy_graph.name
phy_model.training_config.kwargs['model_kwargs'] = [] phy_model.training_config.kwargs['model_kwargs'] = []
#FIXME: allow user to specify # FIXME: allow user to specify
phy_model.training_config.module = 'nni.retiarii.trainer.PyTorchMultiModelTrainer' phy_model.training_config.module = 'nni.retiarii.trainer.PyTorchMultiModelTrainer'
# merge sub-graphs # merge sub-graphs
...@@ -158,10 +157,9 @@ class LogicalPlan: ...@@ -158,10 +157,9 @@ class LogicalPlan:
model.graphs[graph_name]._fork_to( model.graphs[graph_name]._fork_to(
phy_model, name_prefix=f'M_{model.model_id}_') phy_model, name_prefix=f'M_{model.model_id}_')
# When replace logical nodes, merge the training configs when # When replace logical nodes, merge the training configs when
# input/output nodes are replaced. # input/output nodes are replaced.
training_config_slot = {} # Model ID -> Slot ID training_config_slot = {} # Model ID -> Slot ID
input_slot_mapping = {} input_slot_mapping = {}
output_slot_mapping = {} output_slot_mapping = {}
# Replace all logical nodes to executable physical nodes # Replace all logical nodes to executable physical nodes
...@@ -230,7 +228,7 @@ class LogicalPlan: ...@@ -230,7 +228,7 @@ class LogicalPlan:
to_node = copied_op[(edge.head, tail_placement)] to_node = copied_op[(edge.head, tail_placement)]
else: else:
to_operation = Operation.new( to_operation = Operation.new(
'ToDevice', {"device":tail_placement.device}) 'ToDevice', {"device": tail_placement.device})
to_node = Node(phy_graph, phy_model._uid(), to_node = Node(phy_graph, phy_model._uid(),
edge.head.name+"_to_"+edge.tail.name, to_operation)._register() edge.head.name+"_to_"+edge.tail.name, to_operation)._register()
Edge((edge.head, edge.head_slot), Edge((edge.head, edge.head_slot),
...@@ -249,19 +247,18 @@ class LogicalPlan: ...@@ -249,19 +247,18 @@ class LogicalPlan:
if edge.head in input_nodes: if edge.head in input_nodes:
edge.head_slot = input_slot_mapping[edge.head] edge.head_slot = input_slot_mapping[edge.head]
edge.head = phy_graph.input_node edge.head = phy_graph.input_node
# merge all output nodes into one with multiple slots # merge all output nodes into one with multiple slots
output_nodes = [] output_nodes = []
for node in phy_graph.hidden_nodes: for node in phy_graph.hidden_nodes:
if isinstance(node.operation, _IOPseudoOperation) and node.operation.type == '_outputs': if isinstance(node.operation, _IOPseudoOperation) and node.operation.type == '_outputs':
output_nodes.append(node) output_nodes.append(node)
for edge in phy_graph.edges: for edge in phy_graph.edges:
if edge.tail in output_nodes: if edge.tail in output_nodes:
edge.tail_slot = output_slot_mapping[edge.tail] edge.tail_slot = output_slot_mapping[edge.tail]
edge.tail = phy_graph.output_node edge.tail = phy_graph.output_node
for node in input_nodes: for node in input_nodes:
node.remove() node.remove()
for node in output_nodes: for node in output_nodes:
......
from .base_optimizer import BaseOptimizer
from .logical_plan import LogicalPlan
class BatchingOptimizer(BaseOptimizer):
def __init__(self) -> None:
pass
def convert(self, logical_plan: LogicalPlan) -> None:
pass
from .interface import AbstractOptimizer from typing import List, Dict, Tuple
from .logical_plan import LogicalPlan, AbstractLogicalNode, LogicalGraph, OriginNode, PhysicalDevice
from nni.retiarii import Graph, Node, Model
from typing import *
from nni.retiarii.operation import _IOPseudoOperation
from ...graph import Graph, Model, Node
from .interface import AbstractOptimizer
from .logical_plan import (AbstractLogicalNode, LogicalGraph, LogicalPlan,
OriginNode, PhysicalDevice)
_supported_training_modules = ['nni.retiarii.trainer.PyTorchImageClassificationTrainer'] _supported_training_modules = ['nni.retiarii.trainer.PyTorchImageClassificationTrainer']
class DedupInputNode(AbstractLogicalNode): class DedupInputNode(AbstractLogicalNode):
def __init__(self, logical_graph : LogicalGraph, id : int, \ def __init__(self, logical_graph: LogicalGraph, node_id: int,
nodes_to_dedup : List[Node], _internal=False): nodes_to_dedup: List[Node], _internal=False):
super().__init__(logical_graph, id, \ super().__init__(logical_graph, node_id,
"Dedup_"+nodes_to_dedup[0].name, \ "Dedup_"+nodes_to_dedup[0].name,
nodes_to_dedup[0].operation) nodes_to_dedup[0].operation)
self.origin_nodes : List[OriginNode] = nodes_to_dedup.copy() self.origin_nodes: List[OriginNode] = nodes_to_dedup.copy()
def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) -> Tuple[Node, PhysicalDevice]: def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) -> Tuple[Node, PhysicalDevice]:
for node in self.origin_nodes: for node in self.origin_nodes:
if node.original_graph.model in multi_model_placement: if node.original_graph.model in multi_model_placement:
new_node = Node(node.original_graph, node.id, \ new_node = Node(node.original_graph, node.id,
f'M_{node.original_graph.model.model_id}_{node.name}', \ f'M_{node.original_graph.model.model_id}_{node.name}',
node.operation) node.operation)
return new_node, multi_model_placement[node.original_graph.model] return new_node, multi_model_placement[node.original_graph.model]
raise ValueError(f'DedupInputNode {self.name} does not contain nodes from multi_model') raise ValueError(f'DedupInputNode {self.name} does not contain nodes from multi_model')
def _fork_to(self, graph: Graph): def _fork_to(self, graph: Graph):
DedupInputNode(graph, self.id, self.origin_nodes)._register() DedupInputNode(graph, self.id, self.origin_nodes)._register()
def __repr__(self) -> str: def __repr__(self) -> str:
return f'DedupNode(id={self.id}, name={self.name}, \ return f'DedupNode(id={self.id}, name={self.name}, \
len(nodes_to_dedup)={len(self.origin_nodes)}' len(nodes_to_dedup)={len(self.origin_nodes)}'
...@@ -35,6 +36,7 @@ class DedupInputNode(AbstractLogicalNode): ...@@ -35,6 +36,7 @@ class DedupInputNode(AbstractLogicalNode):
class DedupInputOptimizer(AbstractOptimizer): class DedupInputOptimizer(AbstractOptimizer):
def __init__(self) -> None: def __init__(self) -> None:
pass pass
def _check_deduplicate_by_node(self, root_node, node_to_check): def _check_deduplicate_by_node(self, root_node, node_to_check):
if root_node == node_to_check: if root_node == node_to_check:
return True return True
...@@ -50,13 +52,12 @@ class DedupInputOptimizer(AbstractOptimizer): ...@@ -50,13 +52,12 @@ class DedupInputOptimizer(AbstractOptimizer):
return False return False
else: else:
return False return False
def convert(self, logical_plan: LogicalPlan) -> None: def convert(self, logical_plan: LogicalPlan) -> None:
nodes_to_skip = set() nodes_to_skip = set()
while True: # repeat until the logical_graph converges while True: # repeat until the logical_graph converges
input_nodes = logical_plan.logical_graph.get_nodes_by_type("_inputs") input_nodes = logical_plan.logical_graph.get_nodes_by_type("_inputs")
#_PseudoOperation(type_name="_inputs")) # _PseudoOperation(type_name="_inputs"))
root_node = None root_node = None
for node in input_nodes: for node in input_nodes:
if node in nodes_to_skip: if node in nodes_to_skip:
...@@ -64,21 +65,21 @@ class DedupInputOptimizer(AbstractOptimizer): ...@@ -64,21 +65,21 @@ class DedupInputOptimizer(AbstractOptimizer):
root_node = node root_node = node
break break
if root_node == None: if root_node == None:
break # end of convert break # end of convert
else: else:
nodes_to_dedup = [] nodes_to_dedup = []
for node in input_nodes: for node in input_nodes:
if node in nodes_to_skip: if node in nodes_to_skip:
continue continue
if self._check_deduplicate_by_node(root_node, node): if self._check_deduplicate_by_node(root_node, node):
nodes_to_dedup.append(node) nodes_to_dedup.append(node)
assert(len(nodes_to_dedup) >= 1) assert(len(nodes_to_dedup) >= 1)
if len(nodes_to_dedup) == 1: if len(nodes_to_dedup) == 1:
assert(nodes_to_dedup[0] == root_node) assert(nodes_to_dedup[0] == root_node)
nodes_to_skip.add(root_node) nodes_to_skip.add(root_node)
else: else:
dedup_node = DedupInputNode(logical_plan.logical_graph, \ dedup_node = DedupInputNode(logical_plan.logical_graph,
logical_plan.lp_model._uid(), nodes_to_dedup)._register() logical_plan.lp_model._uid(), nodes_to_dedup)._register()
for edge in logical_plan.logical_graph.edges: for edge in logical_plan.logical_graph.edges:
if edge.head in nodes_to_dedup: if edge.head in nodes_to_dedup:
edge.head = dedup_node edge.head = dedup_node
......
from .base_optimizer import BaseOptimizer
from .logical_plan import LogicalPlan
class WeightSharingOptimizer(BaseOptimizer):
def __init__(self) -> None:
pass
def convert(self, logical_plan: LogicalPlan) -> None:
pass
import dataclasses
import logging import logging
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from subprocess import Popen
from threading import Thread from threading import Thread
from typing import Any, List, Optional from typing import Any, Optional
from ..experiment import Experiment, TrainingServiceConfig from ..experiment import Experiment, TrainingServiceConfig, launcher, rest
from ..experiment import launcher, rest
from ..experiment.config.base import ConfigBase, PathLike from ..experiment.config.base import ConfigBase, PathLike
from ..experiment.config import util from ..experiment.config import util
from ..experiment.pipe import Pipe
from .graph import Model
from .utils import get_records from .utils import get_records
from .integration import RetiariiAdvisor from .integration import RetiariiAdvisor
from .converter.graph_gen import convert_to_graph from .converter import convert_to_graph
from .mutator import LayerChoiceMutator, InputChoiceMutator from .mutator import Mutator, LayerChoiceMutator, InputChoiceMutator
from .trainer.interface import BaseTrainer
from .strategies.strategy import BaseStrategy
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@dataclass(init=False) @dataclass(init=False)
class RetiariiExeConfig(ConfigBase): class RetiariiExeConfig(ConfigBase):
experiment_name: Optional[str] = None experiment_name: Optional[str] = None
search_space: Any = '' # TODO: remove search_space: Any = '' # TODO: remove
trial_command: str = 'python3 -m nni.retiarii.trial_entry' trial_command: str = 'python3 -m nni.retiarii.trial_entry'
trial_code_directory: PathLike = '.' trial_code_directory: PathLike = '.'
trial_concurrency: int trial_concurrency: int
...@@ -52,6 +56,7 @@ class RetiariiExeConfig(ConfigBase): ...@@ -52,6 +56,7 @@ class RetiariiExeConfig(ConfigBase):
def _validation_rules(self): def _validation_rules(self):
return _validation_rules return _validation_rules
_canonical_rules = { _canonical_rules = {
'trial_code_directory': util.canonical_path, 'trial_code_directory': util.canonical_path,
'max_experiment_duration': lambda value: f'{util.parse_time(value)}s' if value is not None else None, 'max_experiment_duration': lambda value: f'{util.parse_time(value)}s' if value is not None else None,
...@@ -70,8 +75,8 @@ _validation_rules = { ...@@ -70,8 +75,8 @@ _validation_rules = {
class RetiariiExperiment(Experiment): class RetiariiExperiment(Experiment):
def __init__(self, base_model: 'nn.Module', trainer: 'BaseTrainer', def __init__(self, base_model: Model, trainer: BaseTrainer,
applied_mutators: List['Mutator'], strategy: 'BaseStrategy'): applied_mutators: Mutator, strategy: BaseStrategy):
self.config: RetiariiExeConfig = None self.config: RetiariiExeConfig = None
self.port: Optional[int] = None self.port: Optional[int] = None
...@@ -139,7 +144,7 @@ class RetiariiExperiment(Experiment): ...@@ -139,7 +144,7 @@ class RetiariiExperiment(Experiment):
debug debug
Whether to start in debug mode. Whether to start in debug mode.
""" """
# FIXME: # FIXME:
if debug: if debug:
logging.getLogger('nni').setLevel(logging.DEBUG) logging.getLogger('nni').setLevel(logging.DEBUG)
...@@ -189,4 +194,4 @@ class RetiariiExperiment(Experiment): ...@@ -189,4 +194,4 @@ class RetiariiExperiment(Experiment):
if self.port is None: if self.port is None:
raise RuntimeError('Experiment is not running') raise RuntimeError('Experiment is not running')
resp = rest.get(self.port, '/check-status') resp = rest.get(self.port, '/check-status')
return resp['status'] return resp['status']
\ No newline at end of file
...@@ -5,7 +5,6 @@ Model representation. ...@@ -5,7 +5,6 @@ Model representation.
import copy import copy
from enum import Enum from enum import Enum
import json import json
from collections import defaultdict
from typing import (Any, Dict, List, Optional, Tuple, Union, overload) from typing import (Any, Dict, List, Optional, Tuple, Union, overload)
from .operation import Cell, Operation, _IOPseudoOperation from .operation import Cell, Operation, _IOPseudoOperation
...@@ -329,12 +328,12 @@ class Graph: ...@@ -329,12 +328,12 @@ class Graph:
Returns nodes whose operation is specified typed. Returns nodes whose operation is specified typed.
""" """
return [node for node in self.hidden_nodes if node.operation.type == operation_type] return [node for node in self.hidden_nodes if node.operation.type == operation_type]
def get_node_by_id(self, id: int) -> Optional['Node']: def get_node_by_id(self, node_id: int) -> Optional['Node']:
""" """
Returns the node which has specified name; or returns `None` if no node has this name. Returns the node which has specified name; or returns `None` if no node has this name.
""" """
found = [node for node in self.nodes if node.id == id] found = [node for node in self.nodes if node.id == node_id]
return found[0] if found else None return found[0] if found else None
def get_nodes_by_label(self, label: str) -> List['Node']: def get_nodes_by_label(self, label: str) -> List['Node']:
...@@ -365,7 +364,8 @@ class Graph: ...@@ -365,7 +364,8 @@ class Graph:
curr_nodes.append(successor) curr_nodes.append(successor)
for key in node_to_fanin: for key in node_to_fanin:
assert node_to_fanin[key] == 0, '{}, fanin: {}, predecessor: {}, edges: {}, fanin: {}, keys: {}'.format(key, assert node_to_fanin[key] == 0, '{}, fanin: {}, predecessor: {}, edges: {}, fanin: {}, keys: {}'.format(
key,
node_to_fanin[key], node_to_fanin[key],
key.predecessors[0], key.predecessors[0],
self.edges, self.edges,
...@@ -587,6 +587,7 @@ class Node: ...@@ -587,6 +587,7 @@ class Node:
ret['label'] = self.label ret['label'] = self.label
return ret return ret
class Edge: class Edge:
""" """
A tensor, or "data flow", between two nodes. A tensor, or "data flow", between two nodes.
......
import logging import logging
import threading from typing import Any, Callable
from typing import *
import json_tricks import json_tricks
import nni import nni
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import send, CommandType from nni.runtime.protocol import CommandType, send
from nni.utils import MetricType from nni.utils import MetricType
from . import utils
from .graph import MetricData from .graph import MetricData
_logger = logging.getLogger('nni.msg_dispatcher_base') _logger = logging.getLogger('nni.msg_dispatcher_base')
...@@ -44,6 +41,7 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -44,6 +41,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
final_metric_callback final_metric_callback
""" """
def __init__(self): def __init__(self):
super(RetiariiAdvisor, self).__init__() super(RetiariiAdvisor, self).__init__()
register_advisor(self) # register the current advisor as the "global only" advisor register_advisor(self) # register the current advisor as the "global only" advisor
...@@ -88,28 +86,28 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -88,28 +86,28 @@ class RetiariiAdvisor(MsgDispatcherBase):
'parameters': parameters, 'parameters': parameters,
'parameter_source': 'algorithm' 'parameter_source': 'algorithm'
} }
_logger.info('New trial sent: {}'.format(new_trial)) _logger.info('New trial sent: %s', new_trial)
send(CommandType.NewTrialJob, json_tricks.dumps(new_trial)) send(CommandType.NewTrialJob, json_tricks.dumps(new_trial))
if self.send_trial_callback is not None: if self.send_trial_callback is not None:
self.send_trial_callback(parameters) # pylint: disable=not-callable self.send_trial_callback(parameters) # pylint: disable=not-callable
return self.parameters_count return self.parameters_count
def handle_request_trial_jobs(self, num_trials): def handle_request_trial_jobs(self, num_trials):
_logger.info('Request trial jobs: {}'.format(num_trials)) _logger.info('Request trial jobs: %s', num_trials)
if self.request_trial_jobs_callback is not None: if self.request_trial_jobs_callback is not None:
self.request_trial_jobs_callback(num_trials) # pylint: disable=not-callable self.request_trial_jobs_callback(num_trials) # pylint: disable=not-callable
def handle_update_search_space(self, data): def handle_update_search_space(self, data):
_logger.info('Received search space: {}'.format(data)) _logger.info('Received search space: %s', data)
self.search_space = data self.search_space = data
def handle_trial_end(self, data): def handle_trial_end(self, data):
_logger.info('Trial end: {}'.format(data)) # do nothing _logger.info('Trial end: %s', data)
self.trial_end_callback(json_tricks.loads(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable self.trial_end_callback(json_tricks.loads(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
data['event'] == 'SUCCEEDED') data['event'] == 'SUCCEEDED')
def handle_report_metric_data(self, data): def handle_report_metric_data(self, data):
_logger.info('Metric reported: {}'.format(data)) _logger.info('Metric reported: %s', data)
if data['type'] == MetricType.REQUEST_PARAMETER: if data['type'] == MetricType.REQUEST_PARAMETER:
raise ValueError('Request parameter not supported') raise ValueError('Request parameter not supported')
elif data['type'] == MetricType.PERIODICAL: elif data['type'] == MetricType.PERIODICAL:
......
...@@ -13,6 +13,7 @@ class Sampler: ...@@ -13,6 +13,7 @@ class Sampler:
""" """
Handles `Mutator.choice()` calls. Handles `Mutator.choice()` calls.
""" """
def choice(self, candidates: List[Choice], mutator: 'Mutator', model: Model, index: int) -> Choice: def choice(self, candidates: List[Choice], mutator: 'Mutator', model: Model, index: int) -> Choice:
raise NotImplementedError() raise NotImplementedError()
...@@ -35,6 +36,7 @@ class Mutator: ...@@ -35,6 +36,7 @@ class Mutator:
For certain mutator subclasses, strategy or sampler can use `Mutator.dry_run()` to predict choice candidates. For certain mutator subclasses, strategy or sampler can use `Mutator.dry_run()` to predict choice candidates.
# Method names are open for discussion. # Method names are open for discussion.
""" """
def __init__(self, sampler: Optional[Sampler] = None): def __init__(self, sampler: Optional[Sampler] = None):
self.sampler: Optional[Sampler] = sampler self.sampler: Optional[Sampler] = sampler
self._cur_model: Optional[Model] = None self._cur_model: Optional[Model] = None
...@@ -77,7 +79,6 @@ class Mutator: ...@@ -77,7 +79,6 @@ class Mutator:
self.sampler = sampler_backup self.sampler = sampler_backup
return recorder.recorded_candidates, new_model return recorder.recorded_candidates, new_model
def mutate(self, model: Model) -> None: def mutate(self, model: Model) -> None:
""" """
Abstract method to be implemented by subclass. Abstract method to be implemented by subclass.
...@@ -105,6 +106,7 @@ class _RecorderSampler(Sampler): ...@@ -105,6 +106,7 @@ class _RecorderSampler(Sampler):
# the following is for inline mutation # the following is for inline mutation
class LayerChoiceMutator(Mutator): class LayerChoiceMutator(Mutator):
def __init__(self, node_name: str, candidates: List): def __init__(self, node_name: str, candidates: List):
super().__init__() super().__init__()
...@@ -118,6 +120,7 @@ class LayerChoiceMutator(Mutator): ...@@ -118,6 +120,7 @@ class LayerChoiceMutator(Mutator):
chosen_cand = self.candidates[chosen_index] chosen_cand = self.candidates[chosen_index]
target.update_operation(chosen_cand['type'], chosen_cand['parameters']) target.update_operation(chosen_cand['type'], chosen_cand['parameters'])
class InputChoiceMutator(Mutator): class InputChoiceMutator(Mutator):
def __init__(self, node_name: str, n_chosen: int): def __init__(self, node_name: str, n_chosen: int):
super().__init__() super().__init__()
...@@ -129,4 +132,4 @@ class InputChoiceMutator(Mutator): ...@@ -129,4 +132,4 @@ class InputChoiceMutator(Mutator):
candidates = [i for i in range(self.n_chosen)] candidates = [i for i in range(self.n_chosen)]
chosen = self.choice(candidates) chosen = self.choice(candidates)
target.update_operation('__torch__.nni.retiarii.nn.pytorch.nn.ChosenInputs', target.update_operation('__torch__.nni.retiarii.nn.pytorch.nn.ChosenInputs',
{'chosen': chosen}) {'chosen': chosen})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment