Unverified Commit 59cd3982 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Coding style improvements for pylint and flake8 (#3190)

parent 593a275c
import logging
from typing import *
from typing import List
from ..graph import IllegalGraphError, Edge, Graph, Node, Model
from ..operation import Operation, Cell
_logger = logging.getLogger(__name__)
def model_to_pytorch_script(model: Model, placement = None) -> str:
def model_to_pytorch_script(model: Model, placement=None) -> str:
graphs = []
total_pkgs = set()
for name, cell in model.graphs.items():
import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement = placement)
import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement=placement)
graphs.append(graph_code)
total_pkgs.update(import_pkgs)
pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs])
return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip()
def _sorted_incoming_edges(node: Node) -> List[Edge]:
edges = [edge for edge in node.graph.edges if edge.tail is node]
_logger.info('sorted_incoming_edges: {}'.format(edges))
_logger.info('sorted_incoming_edges: %s', str(edges))
if not edges:
return []
_logger.info(f'all tail_slots are None: {[edge.tail_slot for edge in edges]}')
_logger.info('all tail_slots are None: %s', str([edge.tail_slot for edge in edges]))
if all(edge.tail_slot is None for edge in edges):
return edges
if all(isinstance(edge.tail_slot, int) for edge in edges):
......@@ -32,6 +31,7 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]:
return edges
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))
def _format_inputs(node: Node) -> List[str]:
edges = _sorted_incoming_edges(node)
inputs = []
......@@ -53,6 +53,7 @@ def _format_inputs(node: Node) -> List[str]:
inputs.append('{}[{}]'.format(edge.head.name, edge.head_slot))
return inputs
def _remove_prefix(names, graph_name):
"""
variables name (full name space) is too long,
......@@ -69,14 +70,14 @@ def _remove_prefix(names, graph_name):
else:
return names[len(graph_name):] if names.startswith(graph_name) else names
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement = None) -> str:
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str:
nodes = graph.topo_sort()
# handle module node and function node differently
# only need to generate code for module here
import_pkgs = set()
node_codes = []
placement_codes = []
for node in nodes:
if node.operation:
pkg_name = node.operation.get_import_pkg()
......
from .graph_gen import convert_to_graph
from .visualize import visualize_model
\ No newline at end of file
import json_tricks
import logging
import re
import torch
from ..graph import Graph, Node, Edge, Model
from ..operation import Cell, Operation
from ..nn.pytorch import Placeholder, LayerChoice, InputChoice
import torch
from .op_types import MODULE_EXCEPT_LIST, OpTypeName, BasicOpsPT
from .utils import build_full_name, _convert_name
from ..graph import Graph, Model, Node
from ..nn.pytorch import InputChoice, LayerChoice, Placeholder
from ..operation import Cell
from .op_types import MODULE_EXCEPT_LIST, BasicOpsPT, OpTypeName
from .utils import _convert_name, build_full_name
_logger = logging.getLogger(__name__)
......@@ -16,6 +15,7 @@ global_seq = 0
global_graph_id = 0
modules_arg = None
def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap, ignore_first=False):
"""
Parameters
......@@ -76,6 +76,7 @@ def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap,
new_node_input_idx += 1
def create_prim_constant_node(ir_graph, node, module_name):
global global_seq
attrs = {}
......@@ -86,14 +87,17 @@ def create_prim_constant_node(ir_graph, node, module_name):
node.kind(), attrs)
return new_node
def handle_prim_attr_node(node):
assert node.hasAttribute('name')
attrs = {'name': node.s('name'), 'input': node.inputsAt(0).debugName()}
return node.kind(), attrs
def _remove_mangle(module_type_str):
return re.sub('\\.___torch_mangle_\\d+', '', module_type_str)
def remove_unconnected_nodes(ir_graph, targeted_type=None):
"""
Parameters
......@@ -122,6 +126,7 @@ def remove_unconnected_nodes(ir_graph, targeted_type=None):
for hidden_node in to_removes:
hidden_node.remove()
def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, ir_graph):
"""
Convert torch script node to our node ir, and build our graph ir
......@@ -248,7 +253,8 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
# therefore, we do this check for a module. example below:
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self)
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
assert submodule_name in script_module._modules, "submodule_name: {} not in script_module {}".format(submodule_name, script_module._modules.keys())
assert submodule_name in script_module._modules, "submodule_name: {} not in script_module {}".format(
submodule_name, script_module._modules.keys())
submodule_full_name = build_full_name(module_name, submodule_name)
submodule_obj = getattr(module, submodule_name)
......@@ -350,6 +356,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
return node_index
def merge_aten_slices(ir_graph):
"""
if there is aten::slice node, merge the consecutive ones together.
......@@ -408,13 +415,14 @@ def refine_graph(ir_graph):
remove_unconnected_nodes(ir_graph, targeted_type='prim::GetAttr')
merge_aten_slices(ir_graph)
def _handle_layerchoice(module):
global modules_arg
m_attrs = {}
candidates = module.candidate_ops
choices = []
for i, cand in enumerate(candidates):
for cand in candidates:
assert id(cand) in modules_arg, 'id not exist: {}'.format(id(cand))
assert isinstance(modules_arg[id(cand)], dict)
cand_type = '__torch__.' + cand.__class__.__module__ + '.' + cand.__class__.__name__
......@@ -423,6 +431,7 @@ def _handle_layerchoice(module):
m_attrs['label'] = module.label
return m_attrs
def _handle_inputchoice(module):
m_attrs = {}
m_attrs['n_chosen'] = module.n_chosen
......@@ -430,6 +439,7 @@ def _handle_inputchoice(module):
m_attrs['label'] = module.label
return m_attrs
def convert_module(script_module, module, module_name, ir_model):
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
......@@ -507,6 +517,7 @@ def convert_module(script_module, module, module_name, ir_model):
# should not be parsed further.
return ir_graph, {}
def convert_to_graph(script_module, module, recorded_modules_arg):
"""
Convert module to our graph ir, i.e., build a ```Model``` type
......
......@@ -16,6 +16,7 @@ class OpTypeName(str, Enum):
Placeholder = 'Placeholder'
MergedSlice = 'MergedSlice'
# deal with aten op
BasicOpsPT = {
'aten::mean': 'Mean',
......
......@@ -6,6 +6,7 @@ def build_full_name(prefix, name, seq=None):
else:
return '{}__{}{}'.format(prefix, name, str(seq))
def _convert_name(name: str) -> str:
"""
Convert the names using separator '.' to valid variable name in code
......
import graphviz
def convert_to_visualize(graph_ir, vgraph):
for name, graph in graph_ir.items():
if name == '_training_config':
......@@ -33,6 +34,7 @@ def convert_to_visualize(graph_ir, vgraph):
dst = cell_node[dst][0]
subgraph.edge(src, dst)
def visualize_model(graph_ir):
vgraph = graphviz.Digraph('G', filename='vgraph', format='jpg')
convert_to_visualize(graph_ir, vgraph)
......
import time
import os
import importlib.util
from typing import *
from typing import List
from ..graph import Model, ModelStatus
from .base import BaseExecutionEngine
from .cgo_engine import CGOExecutionEngine
from .interface import *
from .interface import AbstractExecutionEngine, WorkerInfo
from .listener import DefaultListener
_execution_engine = None
......
import logging
from typing import *
from typing import Dict, Any, List
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils
......@@ -61,16 +61,16 @@ class BaseExecutionEngine(AbstractExecutionEngine):
def _send_trial_callback(self, paramater: dict) -> None:
for listener in self._listeners:
_logger.warning('resources: {}'.format(listener.resources))
_logger.warning('resources: %s', listener.resources)
if not listener.has_available_resource():
_logger.warning('There is no available resource, but trial is submitted.')
listener.on_resource_used(1)
_logger.warning('on_resource_used: {}'.format(listener.resources))
_logger.warning('on_resource_used: %s', listener.resources)
def _request_trial_jobs_callback(self, num_trials: int) -> None:
for listener in self._listeners:
listener.on_resource_available(1 * num_trials)
_logger.warning('on_resource_available: {}'.format(listener.resources))
_logger.warning('on_resource_available: %s', listener.resources)
def _trial_end_callback(self, trial_id: int, success: bool) -> None:
model = self._running_models[trial_id]
......
import logging
import json
from typing import *
from typing import List, Dict, Tuple
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils
......@@ -12,8 +11,10 @@ from .logical_optimizer.opt_dedup_input import DedupInputOptimizer
from .base import BaseGraphData
_logger = logging.getLogger(__name__)
class CGOExecutionEngine(AbstractExecutionEngine):
def __init__(self, n_model_per_graph = 4) -> None:
def __init__(self, n_model_per_graph=4) -> None:
self._listeners: List[AbstractGraphListener] = []
self._running_models: Dict[int, Model] = dict()
self.logical_plan_counter = 0
......@@ -30,12 +31,11 @@ class CGOExecutionEngine(AbstractExecutionEngine):
advisor.intermediate_metric_callback = self._intermediate_metric_callback
advisor.final_metric_callback = self._final_metric_callback
def add_optimizer(self, opt):
self._optimizers.append(opt)
def submit_models(self, *models: List[Model]) -> None:
_logger.info(f'{len(models)} Models are submitted')
_logger.info('%d models are submitted', len(models))
logical = self._build_logical(models)
for opt in self._optimizers:
......@@ -55,13 +55,13 @@ class CGOExecutionEngine(AbstractExecutionEngine):
# model.config['trainer_module'], model.config['trainer_kwargs'])
# self._running_models[send_trial(data.dump())] = model
def _assemble(self, logical_plan : LogicalPlan) -> List[Tuple[Model, PhysicalDevice]]:
def _assemble(self, logical_plan: LogicalPlan) -> List[Tuple[Model, PhysicalDevice]]:
# unique_models = set()
# for node in logical_plan.graph.nodes:
# if node.graph.model not in unique_models:
# unique_models.add(node.graph.model)
# return [m for m in unique_models]
grouped_models : List[Dict[Model, PhysicalDevice]] = AssemblePolicy().group(logical_plan)
grouped_models: List[Dict[Model, PhysicalDevice]] = AssemblePolicy().group(logical_plan)
phy_models_and_placements = []
for multi_model in grouped_models:
model, model_placement = logical_plan.assemble(multi_model)
......@@ -69,7 +69,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
return phy_models_and_placements
def _build_logical(self, models: List[Model]) -> LogicalPlan:
logical_plan = LogicalPlan(id = self.logical_plan_counter)
logical_plan = LogicalPlan(plan_id=self.logical_plan_counter)
for model in models:
logical_plan.add_model(model)
self.logical_plan_counter += 1
......@@ -108,7 +108,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
for model_id in merged_metrics:
int_model_id = int(model_id)
self._original_models[int_model_id].intermediate_metrics.append(merged_metrics[model_id])
#model.intermediate_metrics.append(metrics)
# model.intermediate_metrics.append(metrics)
for listener in self._listeners:
listener.on_intermediate_metric(self._original_models[int_model_id], merged_metrics[model_id])
......@@ -117,11 +117,10 @@ class CGOExecutionEngine(AbstractExecutionEngine):
for model_id in merged_metrics:
int_model_id = int(model_id)
self._original_models[int_model_id].intermediate_metrics.append(merged_metrics[model_id])
#model.intermediate_metrics.append(metrics)
# model.intermediate_metrics.append(metrics)
for listener in self._listeners:
listener.on_metric(self._original_models[int_model_id], merged_metrics[model_id])
def query_available_resource(self) -> List[WorkerInfo]:
raise NotImplementedError # move the method from listener to here?
......@@ -141,6 +140,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
trainer_instance = trainer_cls(model_cls(), graph_data.training_kwargs)
trainer_instance.fit()
class AssemblePolicy:
@staticmethod
def group(logical_plan):
......@@ -148,4 +148,3 @@ class AssemblePolicy:
for idx, m in enumerate(logical_plan.models):
group_model[m] = PhysicalDevice('server', f'cuda:{idx}')
return [group_model]
\ No newline at end of file
from abc import *
from typing import *
from abc import ABC, abstractmethod, abstractclassmethod
from typing import Any, NewType, List
from ..graph import Model, MetricData
......
from typing import *
from ..graph import *
from .interface import *
from ..graph import Model, ModelStatus
from .interface import MetricData, AbstractGraphListener
class DefaultListener(AbstractGraphListener):
......
from abc import *
from typing import *
from abc import ABC
from .logical_plan import LogicalPlan
class AbstractOptimizer(ABC):
def __init__(self) -> None:
pass
......
from nni.retiarii.operation import Operation
from nni.retiarii.graph import Model, Graph, Edge, Node, Cell
from typing import *
import logging
from nni.retiarii.operation import _IOPseudoOperation
import copy
from typing import Dict, Tuple, List, Any
from ...graph import Cell, Edge, Graph, Model, Node
from ...operation import Operation, _IOPseudoOperation
class PhysicalDevice:
......@@ -108,11 +107,11 @@ class OriginNode(AbstractLogicalNode):
class LogicalPlan:
def __init__(self, id=0) -> None:
def __init__(self, plan_id=0) -> None:
self.lp_model = Model(_internal=True)
self.id = id
self.id = plan_id
self.logical_graph = LogicalGraph(
self.lp_model, id, name=f'{id}', _internal=True)._register()
self.lp_model, self.id, name=f'{self.id}', _internal=True)._register()
self.lp_model._root_graph_name = self.logical_graph.name
self.models = []
......@@ -148,7 +147,7 @@ class LogicalPlan:
phy_model.training_config.kwargs['is_multi_model'] = True
phy_model.training_config.kwargs['model_cls'] = phy_graph.name
phy_model.training_config.kwargs['model_kwargs'] = []
#FIXME: allow user to specify
# FIXME: allow user to specify
phy_model.training_config.module = 'nni.retiarii.trainer.PyTorchMultiModelTrainer'
# merge sub-graphs
......@@ -158,7 +157,6 @@ class LogicalPlan:
model.graphs[graph_name]._fork_to(
phy_model, name_prefix=f'M_{model.model_id}_')
# When replace logical nodes, merge the training configs when
# input/output nodes are replaced.
training_config_slot = {} # Model ID -> Slot ID
......@@ -230,7 +228,7 @@ class LogicalPlan:
to_node = copied_op[(edge.head, tail_placement)]
else:
to_operation = Operation.new(
'ToDevice', {"device":tail_placement.device})
'ToDevice', {"device": tail_placement.device})
to_node = Node(phy_graph, phy_model._uid(),
edge.head.name+"_to_"+edge.tail.name, to_operation)._register()
Edge((edge.head, edge.head_slot),
......@@ -250,7 +248,6 @@ class LogicalPlan:
edge.head_slot = input_slot_mapping[edge.head]
edge.head = phy_graph.input_node
# merge all output nodes into one with multiple slots
output_nodes = []
for node in phy_graph.hidden_nodes:
......
from .base_optimizer import BaseOptimizer
from .logical_plan import LogicalPlan
class BatchingOptimizer(BaseOptimizer):
def __init__(self) -> None:
pass
def convert(self, logical_plan: LogicalPlan) -> None:
pass
from .interface import AbstractOptimizer
from .logical_plan import LogicalPlan, AbstractLogicalNode, LogicalGraph, OriginNode, PhysicalDevice
from nni.retiarii import Graph, Node, Model
from typing import *
from nni.retiarii.operation import _IOPseudoOperation
from typing import List, Dict, Tuple
from ...graph import Graph, Model, Node
from .interface import AbstractOptimizer
from .logical_plan import (AbstractLogicalNode, LogicalGraph, LogicalPlan,
OriginNode, PhysicalDevice)
_supported_training_modules = ['nni.retiarii.trainer.PyTorchImageClassificationTrainer']
class DedupInputNode(AbstractLogicalNode):
def __init__(self, logical_graph : LogicalGraph, id : int, \
nodes_to_dedup : List[Node], _internal=False):
super().__init__(logical_graph, id, \
"Dedup_"+nodes_to_dedup[0].name, \
def __init__(self, logical_graph: LogicalGraph, node_id: int,
nodes_to_dedup: List[Node], _internal=False):
super().__init__(logical_graph, node_id,
"Dedup_"+nodes_to_dedup[0].name,
nodes_to_dedup[0].operation)
self.origin_nodes : List[OriginNode] = nodes_to_dedup.copy()
self.origin_nodes: List[OriginNode] = nodes_to_dedup.copy()
def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) -> Tuple[Node, PhysicalDevice]:
for node in self.origin_nodes:
if node.original_graph.model in multi_model_placement:
new_node = Node(node.original_graph, node.id, \
f'M_{node.original_graph.model.model_id}_{node.name}', \
new_node = Node(node.original_graph, node.id,
f'M_{node.original_graph.model.model_id}_{node.name}',
node.operation)
return new_node, multi_model_placement[node.original_graph.model]
raise ValueError(f'DedupInputNode {self.name} does not contain nodes from multi_model')
......@@ -26,7 +28,6 @@ class DedupInputNode(AbstractLogicalNode):
def _fork_to(self, graph: Graph):
DedupInputNode(graph, self.id, self.origin_nodes)._register()
def __repr__(self) -> str:
return f'DedupNode(id={self.id}, name={self.name}, \
len(nodes_to_dedup)={len(self.origin_nodes)}'
......@@ -35,6 +36,7 @@ class DedupInputNode(AbstractLogicalNode):
class DedupInputOptimizer(AbstractOptimizer):
def __init__(self) -> None:
pass
def _check_deduplicate_by_node(self, root_node, node_to_check):
if root_node == node_to_check:
return True
......@@ -51,12 +53,11 @@ class DedupInputOptimizer(AbstractOptimizer):
else:
return False
def convert(self, logical_plan: LogicalPlan) -> None:
nodes_to_skip = set()
while True: # repeat until the logical_graph converges
input_nodes = logical_plan.logical_graph.get_nodes_by_type("_inputs")
#_PseudoOperation(type_name="_inputs"))
# _PseudoOperation(type_name="_inputs"))
root_node = None
for node in input_nodes:
if node in nodes_to_skip:
......@@ -77,7 +78,7 @@ class DedupInputOptimizer(AbstractOptimizer):
assert(nodes_to_dedup[0] == root_node)
nodes_to_skip.add(root_node)
else:
dedup_node = DedupInputNode(logical_plan.logical_graph, \
dedup_node = DedupInputNode(logical_plan.logical_graph,
logical_plan.lp_model._uid(), nodes_to_dedup)._register()
for edge in logical_plan.logical_graph.edges:
if edge.head in nodes_to_dedup:
......
from .base_optimizer import BaseOptimizer
from .logical_plan import LogicalPlan
class WeightSharingOptimizer(BaseOptimizer):
def __init__(self) -> None:
pass
def convert(self, logical_plan: LogicalPlan) -> None:
pass
import dataclasses
import logging
import time
from dataclasses import dataclass
from pathlib import Path
from subprocess import Popen
from threading import Thread
from typing import Any, List, Optional
from typing import Any, Optional
from ..experiment import Experiment, TrainingServiceConfig
from ..experiment import launcher, rest
from ..experiment import Experiment, TrainingServiceConfig, launcher, rest
from ..experiment.config.base import ConfigBase, PathLike
from ..experiment.config import util
from ..experiment.pipe import Pipe
from .graph import Model
from .utils import get_records
from .integration import RetiariiAdvisor
from .converter.graph_gen import convert_to_graph
from .mutator import LayerChoiceMutator, InputChoiceMutator
from .converter import convert_to_graph
from .mutator import Mutator, LayerChoiceMutator, InputChoiceMutator
from .trainer.interface import BaseTrainer
from .strategies.strategy import BaseStrategy
_logger = logging.getLogger(__name__)
@dataclass(init=False)
class RetiariiExeConfig(ConfigBase):
experiment_name: Optional[str] = None
......@@ -52,6 +56,7 @@ class RetiariiExeConfig(ConfigBase):
def _validation_rules(self):
return _validation_rules
_canonical_rules = {
'trial_code_directory': util.canonical_path,
'max_experiment_duration': lambda value: f'{util.parse_time(value)}s' if value is not None else None,
......@@ -70,8 +75,8 @@ _validation_rules = {
class RetiariiExperiment(Experiment):
def __init__(self, base_model: 'nn.Module', trainer: 'BaseTrainer',
applied_mutators: List['Mutator'], strategy: 'BaseStrategy'):
def __init__(self, base_model: Model, trainer: BaseTrainer,
applied_mutators: Mutator, strategy: BaseStrategy):
self.config: RetiariiExeConfig = None
self.port: Optional[int] = None
......
......@@ -5,7 +5,6 @@ Model representation.
import copy
from enum import Enum
import json
from collections import defaultdict
from typing import (Any, Dict, List, Optional, Tuple, Union, overload)
from .operation import Cell, Operation, _IOPseudoOperation
......@@ -330,11 +329,11 @@ class Graph:
"""
return [node for node in self.hidden_nodes if node.operation.type == operation_type]
def get_node_by_id(self, id: int) -> Optional['Node']:
def get_node_by_id(self, node_id: int) -> Optional['Node']:
"""
Returns the node which has specified name; or returns `None` if no node has this name.
"""
found = [node for node in self.nodes if node.id == id]
found = [node for node in self.nodes if node.id == node_id]
return found[0] if found else None
def get_nodes_by_label(self, label: str) -> List['Node']:
......@@ -365,7 +364,8 @@ class Graph:
curr_nodes.append(successor)
for key in node_to_fanin:
assert node_to_fanin[key] == 0, '{}, fanin: {}, predecessor: {}, edges: {}, fanin: {}, keys: {}'.format(key,
assert node_to_fanin[key] == 0, '{}, fanin: {}, predecessor: {}, edges: {}, fanin: {}, keys: {}'.format(
key,
node_to_fanin[key],
key.predecessors[0],
self.edges,
......@@ -587,6 +587,7 @@ class Node:
ret['label'] = self.label
return ret
class Edge:
"""
A tensor, or "data flow", between two nodes.
......
import logging
import threading
from typing import *
from typing import Any, Callable
import json_tricks
import nni
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import send, CommandType
from nni.runtime.protocol import CommandType, send
from nni.utils import MetricType
from . import utils
from .graph import MetricData
_logger = logging.getLogger('nni.msg_dispatcher_base')
......@@ -44,6 +41,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
final_metric_callback
"""
def __init__(self):
super(RetiariiAdvisor, self).__init__()
register_advisor(self) # register the current advisor as the "global only" advisor
......@@ -88,28 +86,28 @@ class RetiariiAdvisor(MsgDispatcherBase):
'parameters': parameters,
'parameter_source': 'algorithm'
}
_logger.info('New trial sent: {}'.format(new_trial))
_logger.info('New trial sent: %s', new_trial)
send(CommandType.NewTrialJob, json_tricks.dumps(new_trial))
if self.send_trial_callback is not None:
self.send_trial_callback(parameters) # pylint: disable=not-callable
return self.parameters_count
def handle_request_trial_jobs(self, num_trials):
_logger.info('Request trial jobs: {}'.format(num_trials))
_logger.info('Request trial jobs: %s', num_trials)
if self.request_trial_jobs_callback is not None:
self.request_trial_jobs_callback(num_trials) # pylint: disable=not-callable
def handle_update_search_space(self, data):
_logger.info('Received search space: {}'.format(data))
_logger.info('Received search space: %s', data)
self.search_space = data
def handle_trial_end(self, data):
_logger.info('Trial end: {}'.format(data)) # do nothing
_logger.info('Trial end: %s', data)
self.trial_end_callback(json_tricks.loads(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
data['event'] == 'SUCCEEDED')
def handle_report_metric_data(self, data):
_logger.info('Metric reported: {}'.format(data))
_logger.info('Metric reported: %s', data)
if data['type'] == MetricType.REQUEST_PARAMETER:
raise ValueError('Request parameter not supported')
elif data['type'] == MetricType.PERIODICAL:
......
......@@ -13,6 +13,7 @@ class Sampler:
"""
Handles `Mutator.choice()` calls.
"""
def choice(self, candidates: List[Choice], mutator: 'Mutator', model: Model, index: int) -> Choice:
raise NotImplementedError()
......@@ -35,6 +36,7 @@ class Mutator:
For certain mutator subclasses, strategy or sampler can use `Mutator.dry_run()` to predict choice candidates.
# Method names are open for discussion.
"""
def __init__(self, sampler: Optional[Sampler] = None):
self.sampler: Optional[Sampler] = sampler
self._cur_model: Optional[Model] = None
......@@ -77,7 +79,6 @@ class Mutator:
self.sampler = sampler_backup
return recorder.recorded_candidates, new_model
def mutate(self, model: Model) -> None:
"""
Abstract method to be implemented by subclass.
......@@ -105,6 +106,7 @@ class _RecorderSampler(Sampler):
# the following is for inline mutation
class LayerChoiceMutator(Mutator):
def __init__(self, node_name: str, candidates: List):
super().__init__()
......@@ -118,6 +120,7 @@ class LayerChoiceMutator(Mutator):
chosen_cand = self.candidates[chosen_index]
target.update_operation(chosen_cand['type'], chosen_cand['parameters'])
class InputChoiceMutator(Mutator):
def __init__(self, node_name: str, n_chosen: int):
super().__init__()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment