"...composable_kernel_onnxruntime.git" did not exist on "823657ed120144943b7db87c07fe3e647128db56"
Unverified Commit d165905d authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[Retiarii] end2end (#3122)

parent 7d1acfbd
...@@ -339,13 +339,21 @@ class Graph: ...@@ -339,13 +339,21 @@ class Graph:
while curr_nodes: while curr_nodes:
curr_node = curr_nodes.pop(0) curr_node = curr_nodes.pop(0)
sorted_nodes.append(curr_node) sorted_nodes.append(curr_node)
for successor in curr_node.successors: # use successor_slots because a node may connect to another node multiple times
# to different slots
for successor_slot in curr_node.successor_slots:
successor = successor_slot[0]
node_to_fanin[successor] -= 1 node_to_fanin[successor] -= 1
if node_to_fanin[successor] == 0: if node_to_fanin[successor] == 0:
curr_nodes.append(successor) curr_nodes.append(successor)
for key in node_to_fanin: for key in node_to_fanin:
assert node_to_fanin[key] == 0 assert node_to_fanin[key] == 0, '{}, fanin: {}, predecessor: {}, edges: {}, fanin: {}, keys: {}'.format(key,
node_to_fanin[key],
key.predecessors[0],
self.edges,
node_to_fanin.values(),
node_to_fanin.keys())
return sorted_nodes return sorted_nodes
...@@ -485,6 +493,10 @@ class Node: ...@@ -485,6 +493,10 @@ class Node:
def successors(self) -> List['Node']: def successors(self) -> List['Node']:
return sorted(set(edge.tail for edge in self.outgoing_edges), key=(lambda node: node.id)) return sorted(set(edge.tail for edge in self.outgoing_edges), key=(lambda node: node.id))
@property
def successor_slots(self) -> List[Tuple['Node', Union[int, None]]]:
return set((edge.tail, edge.tail_slot) for edge in self.outgoing_edges)
@property @property
def incoming_edges(self) -> List['Edge']: def incoming_edges(self) -> List['Edge']:
return [edge for edge in self.graph.edges if edge.tail is self] return [edge for edge in self.graph.edges if edge.tail is self]
......
...@@ -44,7 +44,7 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -44,7 +44,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
final_metric_callback final_metric_callback
""" """
def __init__(self, strategy: Union[str, Callable]): def __init__(self):
super(RetiariiAdvisor, self).__init__() super(RetiariiAdvisor, self).__init__()
register_advisor(self) # register the current advisor as the "global only" advisor register_advisor(self) # register the current advisor as the "global only" advisor
self.search_space = None self.search_space = None
...@@ -55,11 +55,7 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -55,11 +55,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
self.intermediate_metric_callback: Callable[[int, MetricData], None] = None self.intermediate_metric_callback: Callable[[int, MetricData], None] = None
self.final_metric_callback: Callable[[int, MetricData], None] = None self.final_metric_callback: Callable[[int, MetricData], None] = None
self.strategy = utils.import_(strategy) if isinstance(strategy, str) else strategy
self.parameters_count = 0 self.parameters_count = 0
_logger.info('Starting strategy...')
threading.Thread(target=self.strategy).start()
_logger.info('Strategy started!')
def handle_initialize(self, data): def handle_initialize(self, data):
"""callback for initializing the advisor """callback for initializing the advisor
...@@ -125,6 +121,7 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -125,6 +121,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
@staticmethod @staticmethod
def _process_value(value) -> Any: # hopefully a float def _process_value(value) -> Any: # hopefully a float
value = json_tricks.loads(value)
if isinstance(value, dict): if isinstance(value, dict):
if 'default' in value: if 'default' in value:
return value['default'] return value['default']
......
...@@ -73,9 +73,9 @@ class Mutator: ...@@ -73,9 +73,9 @@ class Mutator:
sampler_backup = self.sampler sampler_backup = self.sampler
recorder = _RecorderSampler() recorder = _RecorderSampler()
self.sampler = recorder self.sampler = recorder
self.apply(model) new_model = self.apply(model)
self.sampler = sampler_backup self.sampler = sampler_backup
return recorder.recorded_candidates return recorder.recorded_candidates, new_model
def mutate(self, model: Model) -> None: def mutate(self, model: Model) -> None:
......
import inspect import inspect
import logging import logging
import torch
import torch.nn as nn import torch.nn as nn
from typing import (Any, Tuple, List, Optional)
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO) _logger.setLevel(logging.INFO)
#consoleHandler = logging.StreamHandler()
#consoleHandler.setLevel(logging.INFO)
#_logger.addHandler(consoleHandler)
_records = None _records = None
def enable_record_args(): def enable_record_args():
...@@ -26,6 +23,42 @@ def get_records(): ...@@ -26,6 +23,42 @@ def get_records():
global _records global _records
return _records return _records
def add_record(name, value):
global _records
if _records is not None:
assert name not in _records, '{} already in _records'.format(name)
_records[name] = value
class LayerChoice(nn.Module):
def __init__(self, candidate_ops: List, label: str = None):
super(LayerChoice, self).__init__()
self.candidate_ops = candidate_ops
self.label = label
def forward(self, x):
return x
class InputChoice(nn.Module):
def __init__(self, n_chosen: int = 1, reduction: str = 'sum', label: str = None):
super(InputChoice, self).__init__()
self.n_chosen = n_chosen
self.reduction = reduction
self.label = label
def forward(self, candidate_inputs: List['Tensor']) -> 'Tensor':
# fake return
return torch.tensor(candidate_inputs)
class ValueChoice:
"""
The instance of this class can only be used as input argument,
when instantiating a pytorch module.
TODO: can also be used in training approach
"""
def __init__(self, candidate_values: List[Any]):
self.candidate_values = candidate_values
class Placeholder(nn.Module): class Placeholder(nn.Module):
def __init__(self, label, related_info): def __init__(self, label, related_info):
...@@ -45,8 +78,13 @@ class Module(nn.Module): ...@@ -45,8 +78,13 @@ class Module(nn.Module):
# TODO: users have to pass init's arguments to super init's arguments # TODO: users have to pass init's arguments to super init's arguments
global _records global _records
if _records is not None: if _records is not None:
# TODO: change tuple to dict assert not kwargs
_records[id(self)] = (args, kwargs) argname_list = list(inspect.signature(self.__class__).parameters.keys())
assert len(argname_list) == len(args), 'Error: {} not put input arguments in its super().__init__ function'.format(self.__class__)
full_args = {}
for i, arg_value in enumerate(args):
full_args[argname_list[i]] = args[i]
_records[id(self)] = full_args
#print('my module: ', id(self), args, kwargs) #print('my module: ', id(self), args, kwargs)
super(Module, self).__init__() super(Module, self).__init__()
...@@ -57,6 +95,13 @@ class Sequential(nn.Sequential): ...@@ -57,6 +95,13 @@ class Sequential(nn.Sequential):
_records[id(self)] = {} # no args need to be recorded _records[id(self)] = {} # no args need to be recorded
super(Sequential, self).__init__(*args) super(Sequential, self).__init__(*args)
class ModuleList(nn.ModuleList):
def __init__(self, *args):
global _records
if _records is not None:
_records[id(self)] = {} # no args need to be recorded
super(ModuleList, self).__init__(*args)
def wrap_module(original_class): def wrap_module(original_class):
orig_init = original_class.__init__ orig_init = original_class.__init__
argname_list = list(inspect.signature(original_class).parameters.keys()) argname_list = list(inspect.signature(original_class).parameters.keys())
...@@ -80,3 +125,7 @@ BatchNorm2d = wrap_module(nn.BatchNorm2d) ...@@ -80,3 +125,7 @@ BatchNorm2d = wrap_module(nn.BatchNorm2d)
ReLU = wrap_module(nn.ReLU) ReLU = wrap_module(nn.ReLU)
Dropout = wrap_module(nn.Dropout) Dropout = wrap_module(nn.Dropout)
Linear = wrap_module(nn.Linear) Linear = wrap_module(nn.Linear)
MaxPool2d = wrap_module(nn.MaxPool2d)
AvgPool2d = wrap_module(nn.AvgPool2d)
Identity = wrap_module(nn.Identity)
AdaptiveAvgPool2d = wrap_module(nn.AdaptiveAvgPool2d)
...@@ -105,6 +105,7 @@ class PyTorchOperation(Operation): ...@@ -105,6 +105,7 @@ class PyTorchOperation(Operation):
return None return None
def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str: def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str:
from .converter.op_types import Type
if self._to_class_name() is not None: if self._to_class_name() is not None:
return f'{output} = self.{field}({", ".join(inputs)})' return f'{output} = self.{field}({", ".join(inputs)})'
elif self.type.startswith('Function.'): elif self.type.startswith('Function.'):
...@@ -120,10 +121,34 @@ class PyTorchOperation(Operation): ...@@ -120,10 +121,34 @@ class PyTorchOperation(Operation):
return f'{output} = [{", ".join(inputs)}]' return f'{output} = [{", ".join(inputs)}]'
elif self.type == 'aten::mean': elif self.type == 'aten::mean':
return f'{output} = torch.mean({inputs[0]}, {", ".join(inputs[1:-1])}, out={inputs[-1]})' return f'{output} = torch.mean({inputs[0]}, {", ".join(inputs[1:-1])}, out={inputs[-1]})'
elif self.type == 'aten::__getitem__':
assert len(inputs) == 2
return f'{output} = {inputs[0]}[{inputs[1]}]'
elif self.type == 'aten::append':
assert len(inputs) == 2
return f'_, {output} = {inputs[0]}.append({inputs[1]}), {inputs[0]}'
elif self.type == 'aten::cat':
assert len(inputs) == 2
return f'{output} = torch.cat({inputs[0]}, dim={inputs[1]})'
elif self.type == 'aten::add':
assert len(inputs) == 2
return f'{output} = {inputs[0]} + {inputs[1]}'
elif self.type == Type.MergedSlice:
assert (len(inputs) - 1) % 4 == 0
slices = []
dim = int((len(inputs) - 1) / 4)
for i in range(dim):
slices.append(f'{inputs[i*4+2]}:{inputs[i*4+3]}:{inputs[i*4+4]}')
slice_str = ','.join(slices)
return f'{output} = {inputs[0]}[{slice_str}]'
elif self.type == 'aten::size': elif self.type == 'aten::size':
assert len(inputs) == 2
return f'{output} = {inputs[0]}.size({inputs[1]})' return f'{output} = {inputs[0]}.size({inputs[1]})'
elif self.type == 'aten::view': elif self.type == 'aten::view':
assert len(inputs) == 2
return f'{output} = {inputs[0]}.view({inputs[1]})' return f'{output} = {inputs[0]}.view({inputs[1]})'
elif self.type == 'aten::slice':
raise RuntimeError('not supposed to have aten::slice operation')
else: else:
raise RuntimeError(f'unsupported operation type: {self.type} ? {self._to_class_name()}') raise RuntimeError(f'unsupported operation type: {self.type} ? {self._to_class_name()}')
......
from .tpe_strategy import TPEStrategy
import abc
from typing import List
class BaseStrategy(abc.ABC):
@abc.abstractmethod
def run(self, base_model: 'Model', applied_mutators: List['Mutator'], trainer: 'BaseTrainer') -> None:
pass
import json
import logging
import random
import os
from .. import Model, submit_models, wait_models
from .. import Sampler
from .strategy import BaseStrategy
from ...algorithms.hpo.hyperopt_tuner.hyperopt_tuner import HyperoptTuner
_logger = logging.getLogger(__name__)
class TPESampler(Sampler):
def __init__(self, optimize_mode='minimize'):
self.tpe_tuner = HyperoptTuner('tpe', optimize_mode)
self.cur_sample = None
self.index = None
self.total_parameters = {}
def update_sample_space(self, sample_space):
search_space = {}
for i, each in enumerate(sample_space):
search_space[str(i)] = {'_type': 'choice', '_value': each}
self.tpe_tuner.update_search_space(search_space)
def generate_samples(self, model_id):
self.cur_sample = self.tpe_tuner.generate_parameters(model_id)
self.total_parameters[model_id] = self.cur_sample
self.index = 0
def receive_result(self, model_id, result):
self.tpe_tuner.receive_trial_result(model_id, self.total_parameters[model_id], result)
def choice(self, candidates, mutator, model, index):
chosen = self.cur_sample[str(self.index)]
self.index += 1
return chosen
class TPEStrategy(BaseStrategy):
def __init__(self):
self.tpe_sampler = TPESampler()
self.model_id = 0
def run(self, base_model, applied_mutators, trainer):
sample_space = []
new_model = base_model
for mutator in applied_mutators:
recorded_candidates, new_model = mutator.dry_run(new_model)
sample_space.extend(recorded_candidates)
self.tpe_sampler.update_sample_space(sample_space)
try:
_logger.info('stargety start...')
while True:
model = base_model
_logger.info('apply mutators...')
_logger.info('mutators: {}'.format(applied_mutators))
self.tpe_sampler.generate_samples(self.model_id)
for mutator in applied_mutators:
_logger.info('mutate model...')
mutator.bind_sampler(self.tpe_sampler)
model = mutator.apply(model)
# get and apply training approach
_logger.info('apply training approach...')
model.apply_trainer(trainer['modulename'], trainer['args'])
# run models
submit_models(model)
wait_models(model)
self.tpe_sampler.receive_result(self.model_id, model.metric)
self.model_id += 1
_logger.info('Strategy says:', model.metric)
except Exception as e:
_logger.error(logging.exception('message'))
import abc import abc
import inspect
from ..nn.pytorch import add_record
from typing import * from typing import *
...@@ -17,6 +19,23 @@ class BaseTrainer(abc.ABC): ...@@ -17,6 +19,23 @@ class BaseTrainer(abc.ABC):
Trainer has a ``fit`` function with no return value. Intermediate results and final results should be Trainer has a ``fit`` function with no return value. Intermediate results and final results should be
directly sent via ``nni.report_intermediate_result()`` and ``nni.report_final_result()`` functions. directly sent via ``nni.report_intermediate_result()`` and ``nni.report_final_result()`` functions.
""" """
def __init__(self, *args, **kwargs):
module = self.__class__.__module__
if module is None or module == str.__class__.__module__:
full_class_name = self.__class__.__name__
else:
full_class_name = module + '.' + self.__class__.__name__
assert not kwargs
argname_list = list(inspect.signature(self.__class__).parameters.keys())
assert len(argname_list) == len(args), 'Error: {} not put input arguments in its super().__init__ function'.format(self.__class__)
full_args = {}
for i, arg_value in enumerate(args):
if argname_list[i] == 'model':
assert i == 0
continue
full_args[argname_list[i]] = args[i]
add_record(id(self), {'modulename': full_class_name, 'args': full_args})
@abc.abstractmethod @abc.abstractmethod
def fit(self) -> None: def fit(self) -> None:
......
...@@ -78,6 +78,9 @@ class PyTorchImageClassificationTrainer(BaseTrainer): ...@@ -78,6 +78,9 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently, Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
only the key ``max_epochs`` is useful. only the key ``max_epochs`` is useful.
""" """
super(PyTorchImageClassificationTrainer, self).__init__(model,
dataset_cls, dataset_kwargs, dataloader_kwargs,
optimizer_cls, optimizer_kwargs, trainer_kwargs)
self._use_cuda = torch.cuda.is_available() self._use_cuda = torch.cuda.is_available()
self.model = model self.model = model
if self._use_cuda: if self._use_cuda:
......
import traceback
from .nn.pytorch import enable_record_args, get_records, disable_record_args
def import_(target: str, allow_none: bool = False) -> 'Any': def import_(target: str, allow_none: bool = False) -> 'Any':
if target is None: if target is None:
return None return None
path, identifier = target.rsplit('.', 1) path, identifier = target.rsplit('.', 1)
module = __import__(path, globals(), locals(), [identifier]) module = __import__(path, globals(), locals(), [identifier])
return getattr(module, identifier) return getattr(module, identifier)
class TraceClassArguments:
def __init__(self):
self.recorded_arguments = None
def __enter__(self):
enable_record_args()
return self
def __exit__(self, exc_type, exc_value, tb):
if exc_type is not None:
traceback.print_exception(exc_type, exc_value, tb)
# return False # uncomment to pass exception through
self.recorded_arguments = get_records()
disable_record_args()
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from datetime import datetime
from io import TextIOBase
import logging
import os
import sys
import time
log_level_map = {
'fatal': logging.FATAL,
'error': logging.ERROR,
'warning': logging.WARNING,
'info': logging.INFO,
'debug': logging.DEBUG
}
_time_format = '%m/%d/%Y, %I:%M:%S %p'
# FIXME
# This hotfix the bug that querying installed tuners with `package_utils` will activate dispatcher logger.
# This behavior depends on underlying implementation of `nnictl` and is likely to break in future.
_logger_initialized = False
class _LoggerFileWrapper(TextIOBase):
def __init__(self, logger_file):
self.file = logger_file
def write(self, s):
if s != '\n':
cur_time = datetime.now().strftime(_time_format)
self.file.write('[{}] PRINT '.format(cur_time) + s + '\n')
self.file.flush()
return len(s)
def init_logger(logger_file_path, log_level_name='info'):
"""Initialize root logger.
This will redirect anything from logging.getLogger() as well as stdout to specified file.
logger_file_path: path of logger file (path-like object).
"""
global _logger_initialized
if _logger_initialized:
return
_logger_initialized = True
if os.environ.get('NNI_PLATFORM') == 'unittest':
return # fixme: launching logic needs refactor
log_level = log_level_map.get(log_level_name, logging.INFO)
logger_file = open(logger_file_path, 'w')
fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logging.Formatter.converter = time.localtime
formatter = logging.Formatter(fmt, _time_format)
handler = logging.StreamHandler(logger_file)
handler.setFormatter(formatter)
root_logger = logging.getLogger()
root_logger.addHandler(handler)
root_logger.setLevel(log_level)
# these modules are too verbose
logging.getLogger('matplotlib').setLevel(log_level)
sys.stdout = _LoggerFileWrapper(logger_file)
def init_standalone_logger():
"""
Initialize root logger for standalone mode.
This will set NNI's log level to INFO and print its log to stdout.
"""
global _logger_initialized
if _logger_initialized:
return
_logger_initialized = True
fmt = '[%(asctime)s] %(levelname)s (%(name)s) %(message)s'
formatter = logging.Formatter(fmt, _time_format)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(formatter)
nni_logger = logging.getLogger('nni')
nni_logger.addHandler(handler)
nni_logger.setLevel(logging.INFO)
nni_logger.propagate = False
# Following line does not affect NNI loggers, but without this user's logger won't be able to
# print log even it's level is set to INFO, so we do it for user's convenience.
# If this causes any issue in future, remove it and use `logging.info` instead of
# `logging.getLogger('xxx')` in all examples.
logging.basicConfig()
_multi_thread = False _multi_thread = False
_multi_phase = False _multi_phase = False
......
from datetime import datetime
from io import TextIOBase
import logging
from logging import FileHandler, Formatter, Handler, StreamHandler
from pathlib import Path
import sys
from typing import Optional
from .env_vars import dispatcher_env_vars, trial_env_vars
def init_logger() -> None:
"""
This function will (and should only) get invoked on the first time of importing nni (no matter which submodule).
It will try to detect the running environment and setup logger accordingly.
The detection should work in most cases but for `nnictl` and `nni.experiment`.
They will be identified as "standalone" mode and must configure the logger by themselves.
"""
if dispatcher_env_vars.SDK_PROCESS == 'dispatcher':
_init_logger_dispatcher()
return
trial_platform = trial_env_vars.NNI_PLATFORM
if trial_platform == 'unittest':
return
if trial_platform:
_init_logger_trial()
return
_init_logger_standalone()
time_format = '%Y-%m-%d %H:%M:%S'
formatter = Formatter(
'[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s',
time_format
)
def _init_logger_dispatcher() -> None:
log_level_map = {
'fatal': logging.CRITICAL,
'error': logging.ERROR,
'warning': logging.WARNING,
'info': logging.INFO,
'debug': logging.DEBUG
}
log_path = _prepare_log_dir(dispatcher_env_vars.NNI_LOG_DIRECTORY) / 'dispatcher.log'
log_level = log_level_map.get(dispatcher_env_vars.NNI_LOG_LEVEL, logging.INFO)
_setup_root_logger(FileHandler(log_path), log_level)
def _init_logger_trial() -> None:
log_path = _prepare_log_dir(trial_env_vars.NNI_OUTPUT_DIR) / 'trial.log'
log_file = open(log_path, 'w')
_setup_root_logger(StreamHandler(log_file), logging.INFO)
sys.stdout = _LogFileWrapper(log_file)
def _init_logger_standalone() -> None:
_setup_nni_logger(StreamHandler(sys.stdout), logging.INFO)
# Following line does not affect NNI loggers, but without this user's logger won't
# print log even it's level is set to INFO, so we do it for user's convenience.
# If this causes any issue in future, remove it and use `logging.info()` instead of
# `logging.getLogger('xxx').info()` in all examples.
logging.basicConfig()
def _prepare_log_dir(path: Optional[str]) -> Path:
if path is None:
return Path()
ret = Path(path)
ret.mkdir(parents=True, exist_ok=True)
return ret
def _setup_root_logger(handler: Handler, level: int) -> None:
_setup_logger('', handler, level)
def _setup_nni_logger(handler: Handler, level: int) -> None:
_setup_logger('nni', handler, level)
def _setup_logger(name: str, handler: Handler, level: int) -> None:
handler.setFormatter(formatter)
logger = logging.getLogger(name)
logger.addHandler(handler)
logger.setLevel(level)
logger.propagate = False
class _LogFileWrapper(TextIOBase):
# wrap the logger file so that anything written to it will automatically get formatted
def __init__(self, log_file: TextIOBase):
self.file: TextIOBase = log_file
self.line_buffer: Optional[str] = None
self.line_start_time: Optional[datetime] = None
def write(self, s: str) -> int:
cur_time = datetime.now()
if self.line_buffer and (cur_time - self.line_start_time).total_seconds() > 0.1:
self.flush()
if self.line_buffer:
self.line_buffer += s
else:
self.line_buffer = s
self.line_start_time = cur_time
if '\n' not in s:
return len(s)
time_str = cur_time.strftime(time_format)
lines = self.line_buffer.split('\n')
for line in lines[:-1]:
self.file.write(f'[{time_str}] PRINT {line}\n')
self.file.flush()
self.line_buffer = lines[-1]
self.line_start_time = cur_time
return len(s)
def flush(self) -> None:
if self.line_buffer:
time_str = self.line_start_time.strftime(time_format)
self.file.write(f'[{time_str}] PRINT {self.line_buffer}\n')
self.file.flush()
self.line_buffer = None
...@@ -9,11 +9,9 @@ import json_tricks ...@@ -9,11 +9,9 @@ import json_tricks
from .common import multi_thread_enabled from .common import multi_thread_enabled
from .env_vars import dispatcher_env_vars from .env_vars import dispatcher_env_vars
from ..utils import init_dispatcher_logger
from ..recoverable import Recoverable from ..recoverable import Recoverable
from .protocol import CommandType, receive from .protocol import CommandType, receive
init_dispatcher_logger()
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
......
...@@ -7,7 +7,6 @@ import json ...@@ -7,7 +7,6 @@ import json
import time import time
import subprocess import subprocess
from ..common import init_logger
from ..env_vars import trial_env_vars from ..env_vars import trial_env_vars
from nni.utils import to_json from nni.utils import to_json
...@@ -21,9 +20,6 @@ if not os.path.exists(_outputdir): ...@@ -21,9 +20,6 @@ if not os.path.exists(_outputdir):
os.makedirs(_outputdir) os.makedirs(_outputdir)
_nni_platform = trial_env_vars.NNI_PLATFORM _nni_platform = trial_env_vars.NNI_PLATFORM
if _nni_platform == 'local':
_log_file_path = os.path.join(_outputdir, 'trial.log')
init_logger(_log_file_path)
_multiphase = trial_env_vars.MULTI_PHASE _multiphase = trial_env_vars.MULTI_PHASE
......
...@@ -4,8 +4,6 @@ ...@@ -4,8 +4,6 @@
import logging import logging
import json_tricks import json_tricks
from ..common import init_standalone_logger
__all__ = [ __all__ = [
'get_next_parameter', 'get_next_parameter',
'get_experiment_id', 'get_experiment_id',
...@@ -14,7 +12,6 @@ __all__ = [ ...@@ -14,7 +12,6 @@ __all__ = [
'send_metric', 'send_metric',
] ]
init_standalone_logger()
_logger = logging.getLogger('nni') _logger = logging.getLogger('nni')
......
...@@ -32,8 +32,7 @@ try: ...@@ -32,8 +32,7 @@ try:
_in_file = open(3, 'rb') _in_file = open(3, 'rb')
_out_file = open(4, 'wb') _out_file = open(4, 'wb')
except OSError: except OSError:
_msg = 'IPC pipeline not exists, maybe you are importing tuner/assessor from trial code?' pass
logging.getLogger(__name__).warning(_msg)
def send(command, data): def send(command, data):
......
...@@ -85,6 +85,7 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log ...@@ -85,6 +85,7 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log
log_header = LOG_HEADER % str(time_now) log_header = LOG_HEADER % str(time_now)
stdout_file.write(log_header) stdout_file.write(log_header)
stderr_file.write(log_header) stderr_file.write(log_header)
print('## [nnictl] cmds:', cmds)
if sys.platform == 'win32': if sys.platform == 'win32':
from subprocess import CREATE_NEW_PROCESS_GROUP from subprocess import CREATE_NEW_PROCESS_GROUP
if foreground: if foreground:
...@@ -387,6 +388,8 @@ def set_experiment(experiment_config, mode, port, config_file_name): ...@@ -387,6 +388,8 @@ def set_experiment(experiment_config, mode, port, config_file_name):
{'key': 'aml_config', 'value': experiment_config['amlConfig']}) {'key': 'aml_config', 'value': experiment_config['amlConfig']})
request_data['clusterMetaData'].append( request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']}) {'key': 'trial_config', 'value': experiment_config['trial']})
print('## experiment config:')
print(request_data)
response = rest_post(experiment_url(port), json.dumps(request_data), REST_TIME_OUT, show_error=True) response = rest_post(experiment_url(port), json.dumps(request_data), REST_TIME_OUT, show_error=True)
if check_response(response): if check_response(response):
return response return response
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment