Unverified Commit 0663218b authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #163 from Microsoft/master

merge master
parents 6c9360a5 cf983800
...@@ -31,7 +31,7 @@ import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; ...@@ -31,7 +31,7 @@ import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { LocalTrainingService } from '../local/localTrainingService'; import { LocalTrainingService } from '../local/localTrainingService';
// TODO: copy mockedTrail.py to local folder // TODO: copy mockedTrail.py to local folder
const localCodeDir: string = tmp.dirSync().name const localCodeDir: string = tmp.dirSync().name.split('\\').join('\\\\');
const mockedTrialPath: string = './training_service/test/mockedTrial.py' const mockedTrialPath: string = './training_service/test/mockedTrial.py'
fs.copyFileSync(mockedTrialPath, localCodeDir + '/mockedTrial.py') fs.copyFileSync(mockedTrialPath, localCodeDir + '/mockedTrial.py')
......
...@@ -95,3 +95,6 @@ class BatchTuner(Tuner): ...@@ -95,3 +95,6 @@ class BatchTuner(Tuner):
def receive_trial_result(self, parameter_id, parameters, value): def receive_trial_result(self, parameter_id, parameters, value):
pass pass
def import_data(self, data):
pass
...@@ -573,3 +573,35 @@ class BOHB(MsgDispatcherBase): ...@@ -573,3 +573,35 @@ class BOHB(MsgDispatcherBase):
def handle_add_customized_trial(self, data): def handle_add_customized_trial(self, data):
pass pass
def handle_import_data(self, data):
"""Import additional data for tuning
Parameters
----------
data:
a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
Raises
------
AssertionError
data doesn't have required key 'parameter' and 'value'
"""
_completed_num = 0
for trial_info in data:
logger.info("Importing data, current processing progress %s / %s" %(_completed_num), len(data))
_completed_num += 1
assert "parameter" in trial_info
_params = trial_info["parameter"]
assert "value" in trial_info
_value = trial_info['value']
if _KEY not in _params:
_params[_KEY] = self.max_budget
logger.info("Set \"TRIAL_BUDGET\" value to %s (max budget)" %self.max_budget)
if self.optimize_mode is OptimizeMode.Maximize:
reward = -_value
else:
reward = _value
_budget = _params[_KEY]
self.cg.new_result(loss=reward, budget=_budget, parameters=_params, update_model=True)
logger.info("Successfully import tuning data to BOHB advisor.")
...@@ -19,29 +19,13 @@ ...@@ -19,29 +19,13 @@
# ================================================================================================== # ==================================================================================================
from collections import namedtuple
from datetime import datetime from datetime import datetime
from io import TextIOBase from io import TextIOBase
import logging import logging
import os
import sys import sys
import time import time
log_level_map = {
def _load_env_args():
args = {
'platform': os.environ.get('NNI_PLATFORM'),
'trial_job_id': os.environ.get('NNI_TRIAL_JOB_ID'),
'log_dir': os.environ.get('NNI_LOG_DIRECTORY'),
'role': os.environ.get('NNI_ROLE'),
'log_level': os.environ.get('NNI_LOG_LEVEL')
}
return namedtuple('EnvArgs', args.keys())(**args)
env_args = _load_env_args()
'''Arguments passed from environment'''
logLevelMap = {
'fatal': logging.FATAL, 'fatal': logging.FATAL,
'error': logging.ERROR, 'error': logging.ERROR,
'warning': logging.WARNING, 'warning': logging.WARNING,
...@@ -49,7 +33,8 @@ logLevelMap = { ...@@ -49,7 +33,8 @@ logLevelMap = {
'debug': logging.DEBUG 'debug': logging.DEBUG
} }
_time_format = '%m/%d/%Y, %I:%M:%S %P' _time_format = '%m/%d/%Y, %I:%M:%S %p'
class _LoggerFileWrapper(TextIOBase): class _LoggerFileWrapper(TextIOBase):
def __init__(self, logger_file): def __init__(self, logger_file):
self.file = logger_file self.file = logger_file
...@@ -61,21 +46,12 @@ class _LoggerFileWrapper(TextIOBase): ...@@ -61,21 +46,12 @@ class _LoggerFileWrapper(TextIOBase):
self.file.flush() self.file.flush()
return len(s) return len(s)
def init_logger(logger_file_path): def init_logger(logger_file_path, log_level_name='info'):
"""Initialize root logger. """Initialize root logger.
This will redirect anything from logging.getLogger() as well as stdout to specified file. This will redirect anything from logging.getLogger() as well as stdout to specified file.
logger_file_path: path of logger file (path-like object). logger_file_path: path of logger file (path-like object).
""" """
if env_args.platform == 'unittest': log_level = log_level_map.get(log_level_name, logging.INFO)
logger_file_path = 'unittest.log'
elif env_args.log_dir is not None:
logger_file_path = os.path.join(env_args.log_dir, logger_file_path)
if env_args.log_level and logLevelMap.get(env_args.log_level):
log_level = logLevelMap[env_args.log_level]
else:
log_level = logging.INFO #default log level is INFO
logger_file = open(logger_file_path, 'w') logger_file = open(logger_file_path, 'w')
fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s' fmt = '[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logging.Formatter.converter = time.localtime logging.Formatter.converter = time.localtime
......
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import os
from collections import namedtuple
_trial_env_var_names = [
'NNI_PLATFORM',
'NNI_TRIAL_JOB_ID',
'NNI_SYS_DIR',
'NNI_OUTPUT_DIR',
'NNI_TRIAL_SEQ_ID',
'MULTI_PHASE'
]
_dispatcher_env_var_names = [
'NNI_MODE',
'NNI_CHECKPOINT_DIRECTORY',
'NNI_LOG_DIRECTORY',
'NNI_LOG_LEVEL',
'NNI_INCLUDE_INTERMEDIATE_RESULTS'
]
def _load_env_vars(env_var_names):
env_var_dict = {k: os.environ.get(k) for k in env_var_names}
return namedtuple('EnvVars', env_var_names)(**env_var_dict)
trial_env_vars = _load_env_vars(_trial_env_var_names)
dispatcher_env_vars = _load_env_vars(_dispatcher_env_var_names)
...@@ -34,7 +34,6 @@ from nni.tuner import Tuner ...@@ -34,7 +34,6 @@ from nni.tuner import Tuner
from nni.utils import extract_scalar_reward from nni.utils import extract_scalar_reward
from .. import parameter_expressions from .. import parameter_expressions
@unique @unique
class OptimizeMode(Enum): class OptimizeMode(Enum):
"""Optimize Mode class """Optimize Mode class
...@@ -299,3 +298,6 @@ class EvolutionTuner(Tuner): ...@@ -299,3 +298,6 @@ class EvolutionTuner(Tuner):
indiv = Individual(config=params, result=reward) indiv = Individual(config=params, result=reward)
self.population.append(indiv) self.population.append(indiv)
def import_data(self, data):
pass
...@@ -24,14 +24,17 @@ gridsearch_tuner.py including: ...@@ -24,14 +24,17 @@ gridsearch_tuner.py including:
import copy import copy
import numpy as np import numpy as np
import logging
import nni import nni
from nni.tuner import Tuner from nni.tuner import Tuner
from nni.utils import convert_dict2tuple
TYPE = '_type' TYPE = '_type'
CHOICE = 'choice' CHOICE = 'choice'
VALUE = '_value' VALUE = '_value'
logger = logging.getLogger('grid_search_AutoML')
class GridSearchTuner(Tuner): class GridSearchTuner(Tuner):
''' '''
...@@ -51,6 +54,7 @@ class GridSearchTuner(Tuner): ...@@ -51,6 +54,7 @@ class GridSearchTuner(Tuner):
def __init__(self): def __init__(self):
self.count = -1 self.count = -1
self.expanded_search_space = [] self.expanded_search_space = []
self.supplement_data = dict()
def json2paramater(self, ss_spec): def json2paramater(self, ss_spec):
''' '''
...@@ -135,9 +139,31 @@ class GridSearchTuner(Tuner): ...@@ -135,9 +139,31 @@ class GridSearchTuner(Tuner):
def generate_parameters(self, parameter_id): def generate_parameters(self, parameter_id):
self.count += 1 self.count += 1
if self.count > len(self.expanded_search_space)-1: while (self.count <= len(self.expanded_search_space)-1):
raise nni.NoMoreTrialError('no more parameters now.') _params_tuple = convert_dict2tuple(self.expanded_search_space[self.count])
if _params_tuple in self.supplement_data:
self.count += 1
else:
return self.expanded_search_space[self.count] return self.expanded_search_space[self.count]
raise nni.NoMoreTrialError('no more parameters now.')
def receive_trial_result(self, parameter_id, parameters, value): def receive_trial_result(self, parameter_id, parameters, value):
pass pass
def import_data(self, data):
"""Import additional data for tuning
Parameters
----------
data:
a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
_completed_num = 0
for trial_info in data:
logger.info("Importing data, current processing progress %s / %s" %(_completed_num), len(data))
_completed_num += 1
assert "parameter" in trial_info
_params = trial_info["parameter"]
_params_tuple = convert_dict2tuple(_params)
self.supplement_data[_params_tuple] = True
logger.info("Successfully import data to grid search tuner.")
...@@ -31,7 +31,6 @@ import json_tricks ...@@ -31,7 +31,6 @@ import json_tricks
from nni.protocol import CommandType, send from nni.protocol import CommandType, send
from nni.msg_dispatcher_base import MsgDispatcherBase from nni.msg_dispatcher_base import MsgDispatcherBase
from nni.common import init_logger
from nni.utils import extract_scalar_reward from nni.utils import extract_scalar_reward
from .. import parameter_expressions from .. import parameter_expressions
...@@ -420,3 +419,6 @@ class Hyperband(MsgDispatcherBase): ...@@ -420,3 +419,6 @@ class Hyperband(MsgDispatcherBase):
def handle_add_customized_trial(self, data): def handle_add_customized_trial(self, data):
pass pass
def handle_import_data(self, data):
pass
...@@ -172,6 +172,7 @@ class HyperoptTuner(Tuner): ...@@ -172,6 +172,7 @@ class HyperoptTuner(Tuner):
self.json = None self.json = None
self.total_data = {} self.total_data = {}
self.rval = None self.rval = None
self.supplement_data_num = 0
def _choose_tuner(self, algorithm_name): def _choose_tuner(self, algorithm_name):
""" """
...@@ -353,3 +354,27 @@ class HyperoptTuner(Tuner): ...@@ -353,3 +354,27 @@ class HyperoptTuner(Tuner):
# remove '_index' from json2parameter and save params-id # remove '_index' from json2parameter and save params-id
total_params = json2parameter(self.json, parameter) total_params = json2parameter(self.json, parameter)
return total_params return total_params
def import_data(self, data):
"""Import additional data for tuning
Parameters
----------
data:
a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
_completed_num = 0
for trial_info in data:
logger.info("Importing data, current processing progress %s / %s" %(_completed_num), len(data))
_completed_num += 1
if self.algorithm_name == 'random_search':
return
assert "parameter" in trial_info
_params = trial_info["parameter"]
assert "value" in trial_info
_value = trial_info['value']
self.supplement_data_num += 1
_parameter_id = '_'.join(["ImportData", str(self.supplement_data_num)])
self.total_data[_parameter_id] = _params
self.receive_trial_result(parameter_id=_parameter_id, parameters=_params, value=_value)
logger.info("Successfully import data to TPE/Anneal tuner.")
...@@ -96,7 +96,7 @@ class MetisTuner(Tuner): ...@@ -96,7 +96,7 @@ class MetisTuner(Tuner):
self.samples_x = [] self.samples_x = []
self.samples_y = [] self.samples_y = []
self.samples_y_aggregation = [] self.samples_y_aggregation = []
self.history_parameters = [] self.total_data = []
self.space = None self.space = None
self.no_resampling = no_resampling self.no_resampling = no_resampling
self.no_candidates = no_candidates self.no_candidates = no_candidates
...@@ -107,6 +107,7 @@ class MetisTuner(Tuner): ...@@ -107,6 +107,7 @@ class MetisTuner(Tuner):
self.exploration_probability = exploration_probability self.exploration_probability = exploration_probability
self.minimize_constraints_fun = None self.minimize_constraints_fun = None
self.minimize_starting_points = None self.minimize_starting_points = None
self.supplement_data_num = 0
def update_search_space(self, search_space): def update_search_space(self, search_space):
...@@ -392,15 +393,35 @@ class MetisTuner(Tuner): ...@@ -392,15 +393,35 @@ class MetisTuner(Tuner):
# ===== STEP 7: If current optimal hyperparameter occurs in the history or exploration probability is less than the threshold, take next config as exploration step ===== # ===== STEP 7: If current optimal hyperparameter occurs in the history or exploration probability is less than the threshold, take next config as exploration step =====
outputs = self._pack_output(lm_current['hyperparameter']) outputs = self._pack_output(lm_current['hyperparameter'])
ap = random.uniform(0, 1) ap = random.uniform(0, 1)
if outputs in self.history_parameters or ap<=self.exploration_probability: if outputs in self.total_data or ap<=self.exploration_probability:
if next_candidate is not None: if next_candidate is not None:
outputs = self._pack_output(next_candidate['hyperparameter']) outputs = self._pack_output(next_candidate['hyperparameter'])
else: else:
random_parameter = _rand_init(x_bounds, x_types, 1)[0] random_parameter = _rand_init(x_bounds, x_types, 1)[0]
outputs = self._pack_output(random_parameter) outputs = self._pack_output(random_parameter)
self.history_parameters.append(outputs) self.total_data.append(outputs)
return outputs return outputs
def import_data(self, data):
"""Import additional data for tuning
Parameters
----------
data:
a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
_completed_num = 0
for trial_info in data:
logger.info("Importing data, current processing progress %s / %s" %(_completed_num), len(data))
_completed_num += 1
assert "parameter" in trial_info
_params = trial_info["parameter"]
assert "value" in trial_info
_value = trial_info['value']
self.supplement_data_num += 1
_parameter_id = '_'.join(["ImportData", str(self.supplement_data_num)])
self.total_data.append(_params)
self.receive_trial_result(parameter_id=_parameter_id, parameters=_params, value=_value)
logger.info("Successfully import data to metis tuner.")
def _rand_with_constraints(x_bounds, x_types): def _rand_with_constraints(x_bounds, x_types):
outputs = None outputs = None
......
...@@ -27,6 +27,7 @@ from .protocol import CommandType, send ...@@ -27,6 +27,7 @@ from .protocol import CommandType, send
from .msg_dispatcher_base import MsgDispatcherBase from .msg_dispatcher_base import MsgDispatcherBase
from .assessor import AssessResult from .assessor import AssessResult
from .common import multi_thread_enabled from .common import multi_thread_enabled
from .env_vars import dispatcher_env_vars
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -108,6 +109,12 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -108,6 +109,12 @@ class MsgDispatcher(MsgDispatcherBase):
def handle_update_search_space(self, data): def handle_update_search_space(self, data):
self.tuner.update_search_space(data) self.tuner.update_search_space(data)
def handle_import_data(self, data):
"""Import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
self.tuner.import_data(data)
def handle_add_customized_trial(self, data): def handle_add_customized_trial(self, data):
# data: parameters # data: parameters
id_ = _create_parameter_id() id_ = _create_parameter_id()
...@@ -116,7 +123,7 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -116,7 +123,7 @@ class MsgDispatcher(MsgDispatcherBase):
def handle_report_metric_data(self, data): def handle_report_metric_data(self, data):
""" """
:param data: a dict received from nni_manager, which contains: data: a dict received from nni_manager, which contains:
- 'parameter_id': id of the trial - 'parameter_id': id of the trial
- 'value': metric value reported by nni.report_final_result() - 'value': metric value reported by nni.report_final_result()
- 'type': report type, support {'FINAL', 'PERIODICAL'} - 'type': report type, support {'FINAL', 'PERIODICAL'}
...@@ -134,9 +141,9 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -134,9 +141,9 @@ class MsgDispatcher(MsgDispatcherBase):
def handle_trial_end(self, data): def handle_trial_end(self, data):
""" """
data: it has three keys: trial_job_id, event, hyper_params data: it has three keys: trial_job_id, event, hyper_params
trial_job_id: the id generated by training service - trial_job_id: the id generated by training service
event: the job's state - event: the job's state
hyper_params: the hyperparameters generated and returned by tuner - hyper_params: the hyperparameters generated and returned by tuner
""" """
trial_job_id = data['trial_job_id'] trial_job_id = data['trial_job_id']
_ended_trials.add(trial_job_id) _ended_trials.add(trial_job_id)
...@@ -190,8 +197,8 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -190,8 +197,8 @@ class MsgDispatcher(MsgDispatcherBase):
_logger.debug('BAD, kill %s', trial_job_id) _logger.debug('BAD, kill %s', trial_job_id)
send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id)) send(CommandType.KillTrialJob, json_tricks.dumps(trial_job_id))
# notify tuner # notify tuner
_logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]', os.environ.get('NNI_INCLUDE_INTERMEDIATE_RESULTS')) _logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]', dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS)
if os.environ.get('NNI_INCLUDE_INTERMEDIATE_RESULTS') == 'true': if dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS == 'true':
self._earlystop_notify_tuner(data) self._earlystop_notify_tuner(data)
else: else:
_logger.debug('GOOD') _logger.debug('GOOD')
......
...@@ -26,11 +26,14 @@ from multiprocessing.dummy import Pool as ThreadPool ...@@ -26,11 +26,14 @@ from multiprocessing.dummy import Pool as ThreadPool
from queue import Queue, Empty from queue import Queue, Empty
import json_tricks import json_tricks
from .common import init_logger, multi_thread_enabled from .common import multi_thread_enabled
from .env_vars import dispatcher_env_vars
from .utils import init_dispatcher_logger
from .recoverable import Recoverable from .recoverable import Recoverable
from .protocol import CommandType, receive from .protocol import CommandType, receive
init_logger('dispatcher.log') init_dispatcher_logger()
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
QUEUE_LEN_WARNING_MARK = 20 QUEUE_LEN_WARNING_MARK = 20
...@@ -56,8 +59,7 @@ class MsgDispatcherBase(Recoverable): ...@@ -56,8 +59,7 @@ class MsgDispatcherBase(Recoverable):
This function will never return unless raise. This function will never return unless raise.
""" """
_logger.info('Start dispatcher') _logger.info('Start dispatcher')
mode = os.getenv('NNI_MODE') if dispatcher_env_vars.NNI_MODE == 'resume':
if mode == 'resume':
self.load_checkpoint() self.load_checkpoint()
while True: while True:
...@@ -142,6 +144,7 @@ class MsgDispatcherBase(Recoverable): ...@@ -142,6 +144,7 @@ class MsgDispatcherBase(Recoverable):
CommandType.Initialize: self.handle_initialize, CommandType.Initialize: self.handle_initialize,
CommandType.RequestTrialJobs: self.handle_request_trial_jobs, CommandType.RequestTrialJobs: self.handle_request_trial_jobs,
CommandType.UpdateSearchSpace: self.handle_update_search_space, CommandType.UpdateSearchSpace: self.handle_update_search_space,
CommandType.ImportData: self.handle_import_data,
CommandType.AddCustomizedTrialJob: self.handle_add_customized_trial, CommandType.AddCustomizedTrialJob: self.handle_add_customized_trial,
# Tunner/Assessor commands: # Tunner/Assessor commands:
...@@ -166,6 +169,9 @@ class MsgDispatcherBase(Recoverable): ...@@ -166,6 +169,9 @@ class MsgDispatcherBase(Recoverable):
def handle_update_search_space(self, data): def handle_update_search_space(self, data):
raise NotImplementedError('handle_update_search_space not implemented') raise NotImplementedError('handle_update_search_space not implemented')
def handle_import_data(self, data):
raise NotImplementedError('handle_import_data not implemented')
def handle_add_customized_trial(self, data): def handle_add_customized_trial(self, data):
raise NotImplementedError('handle_add_customized_trial not implemented') raise NotImplementedError('handle_add_customized_trial not implemented')
......
...@@ -112,6 +112,13 @@ class MultiPhaseMsgDispatcher(MsgDispatcherBase): ...@@ -112,6 +112,13 @@ class MultiPhaseMsgDispatcher(MsgDispatcherBase):
self.tuner.update_search_space(data) self.tuner.update_search_space(data)
return True return True
def handle_import_data(self, data):
"""import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
self.tuner.import_data(data)
return True
def handle_add_customized_trial(self, data): def handle_add_customized_trial(self, data):
# data: parameters # data: parameters
id_ = _create_parameter_id() id_ = _create_parameter_id()
...@@ -154,6 +161,9 @@ class MultiPhaseMsgDispatcher(MsgDispatcherBase): ...@@ -154,6 +161,9 @@ class MultiPhaseMsgDispatcher(MsgDispatcherBase):
self.tuner.trial_end(json_tricks.loads(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED', trial_job_id) self.tuner.trial_end(json_tricks.loads(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED', trial_job_id)
return True return True
def handle_import_data(self, data):
pass
def _handle_intermediate_metric_data(self, data): def _handle_intermediate_metric_data(self, data):
if data['type'] != 'PERIODICAL': if data['type'] != 'PERIODICAL':
return True return True
......
...@@ -76,6 +76,12 @@ class MultiPhaseTuner(Recoverable): ...@@ -76,6 +76,12 @@ class MultiPhaseTuner(Recoverable):
""" """
raise NotImplementedError('Tuner: update_search_space not implemented') raise NotImplementedError('Tuner: update_search_space not implemented')
def import_data(self, data):
"""Import additional data for tuning
data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
pass
def load_checkpoint(self): def load_checkpoint(self):
"""Load the checkpoint of tuner. """Load the checkpoint of tuner.
path: checkpoint directory for tuner path: checkpoint directory for tuner
...@@ -95,3 +101,6 @@ class MultiPhaseTuner(Recoverable): ...@@ -95,3 +101,6 @@ class MultiPhaseTuner(Recoverable):
def _on_error(self): def _on_error(self):
pass pass
def import_data(self, data):
pass
...@@ -307,3 +307,7 @@ class NetworkMorphismTuner(Tuner): ...@@ -307,3 +307,7 @@ class NetworkMorphismTuner(Tuner):
if item["model_id"] == model_id: if item["model_id"] == model_id:
return item["metric_value"] return item["metric_value"]
return None return None
def import_data(self, data):
pass
...@@ -21,13 +21,13 @@ ...@@ -21,13 +21,13 @@
# pylint: disable=wildcard-import # pylint: disable=wildcard-import
from ..common import env_args from ..env_vars import trial_env_vars
if env_args.platform is None: if trial_env_vars.NNI_PLATFORM is None:
from .standalone import * from .standalone import *
elif env_args.platform == 'unittest': elif trial_env_vars.NNI_PLATFORM == 'unittest':
from .test import * from .test import *
elif env_args.platform in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller'): elif trial_env_vars.NNI_PLATFORM in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller'):
from .local import * from .local import *
else: else:
raise RuntimeError('Unknown platform %s' % env_args.platform) raise RuntimeError('Unknown platform %s' % trial_env_vars.NNI_PLATFORM)
...@@ -19,34 +19,36 @@ ...@@ -19,34 +19,36 @@
# ================================================================================================== # ==================================================================================================
import os import os
import sys
import json import json
import time import time
import json_tricks
import subprocess import subprocess
import json_tricks
from ..common import init_logger, env_args from ..common import init_logger
from ..env_vars import trial_env_vars
_sysdir = os.environ['NNI_SYS_DIR'] _sysdir = trial_env_vars.NNI_SYS_DIR
if not os.path.exists(os.path.join(_sysdir, '.nni')): if not os.path.exists(os.path.join(_sysdir, '.nni')):
os.makedirs(os.path.join(_sysdir, '.nni')) os.makedirs(os.path.join(_sysdir, '.nni'))
_metric_file = open(os.path.join(_sysdir, '.nni', 'metrics'), 'wb') _metric_file = open(os.path.join(_sysdir, '.nni', 'metrics'), 'wb')
_outputdir = os.environ['NNI_OUTPUT_DIR'] _outputdir = trial_env_vars.NNI_OUTPUT_DIR
if not os.path.exists(_outputdir): if not os.path.exists(_outputdir):
os.makedirs(_outputdir) os.makedirs(_outputdir)
_nni_platform = os.environ['NNI_PLATFORM'] _nni_platform = trial_env_vars.NNI_PLATFORM
if _nni_platform == 'local': if _nni_platform == 'local':
_log_file_path = os.path.join(_outputdir, 'trial.log') _log_file_path = os.path.join(_outputdir, 'trial.log')
init_logger(_log_file_path) init_logger(_log_file_path)
_multiphase = os.environ.get('MULTI_PHASE') _multiphase = trial_env_vars.MULTI_PHASE
_param_index = 0 _param_index = 0
def request_next_parameter(): def request_next_parameter():
metric = json_tricks.dumps({ metric = json_tricks.dumps({
'trial_job_id': env_args.trial_job_id, 'trial_job_id': trial_env_vars.NNI_TRIAL_JOB_ID,
'type': 'REQUEST_PARAMETER', 'type': 'REQUEST_PARAMETER',
'sequence': 0, 'sequence': 0,
'parameter_index': _param_index 'parameter_index': _param_index
...@@ -86,7 +88,11 @@ def send_metric(string): ...@@ -86,7 +88,11 @@ def send_metric(string):
assert len(data) < 1000000, 'Metric too long' assert len(data) < 1000000, 'Metric too long'
_metric_file.write(b'ME%06d%b' % (len(data), data)) _metric_file.write(b'ME%06d%b' % (len(data), data))
_metric_file.flush() _metric_file.flush()
if sys.platform == "win32":
file = open(_metric_file.name)
file.close()
else:
subprocess.run(['touch', _metric_file.name], check = True) subprocess.run(['touch', _metric_file.name], check = True)
def get_sequence_id(): def get_sequence_id():
return os.environ['NNI_TRIAL_SEQ_ID'] return trial_env_vars.NNI_TRIAL_SEQ_ID
\ No newline at end of file
...@@ -30,6 +30,7 @@ class CommandType(Enum): ...@@ -30,6 +30,7 @@ class CommandType(Enum):
RequestTrialJobs = b'GE' RequestTrialJobs = b'GE'
ReportMetricData = b'ME' ReportMetricData = b'ME'
UpdateSearchSpace = b'SS' UpdateSearchSpace = b'SS'
ImportData = b'FD'
AddCustomizedTrialJob = b'AD' AddCustomizedTrialJob = b'AD'
TrialEnd = b'EN' TrialEnd = b'EN'
Terminate = b'TE' Terminate = b'TE'
......
...@@ -24,7 +24,7 @@ class Recoverable: ...@@ -24,7 +24,7 @@ class Recoverable:
def load_checkpoint(self): def load_checkpoint(self):
pass pass
def save_checkpont(self): def save_checkpoint(self):
pass pass
def get_checkpoint_path(self): def get_checkpoint_path(self):
......
...@@ -261,3 +261,6 @@ class SMACTuner(Tuner): ...@@ -261,3 +261,6 @@ class SMACTuner(Tuner):
params.append(self.convert_loguniform_categorical(challenger.get_dictionary())) params.append(self.convert_loguniform_categorical(challenger.get_dictionary()))
cnt += 1 cnt += 1
return params return params
def import_data(self, data):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment