Unverified Commit 4ac1c3c5 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

Fix smac import data (#1093)

fix smac import data
parent d135d184
......@@ -22,11 +22,7 @@ batch_tuner.py including:
class BatchTuner
"""
import copy
from enum import Enum, unique
import random
import numpy as np
import logging
import nni
from nni.tuner import Tuner
......@@ -35,6 +31,7 @@ TYPE = '_type'
CHOICE = 'choice'
VALUE = '_value'
logger = logging.getLogger('batch_tuner_AutoML')
class BatchTuner(Tuner):
"""
......@@ -46,7 +43,7 @@ class BatchTuner(Tuner):
}
}
"""
def __init__(self):
self.count = -1
self.values = []
......@@ -54,14 +51,14 @@ class BatchTuner(Tuner):
def is_valid(self, search_space):
"""
Check the search space is valid: only contains 'choice' type
Parameters
----------
search_space : dict
"""
if not len(search_space) == 1:
raise RuntimeError('BatchTuner only supprt one combined-paramreters key.')
for param in search_space:
param_type = search_space[param][TYPE]
if not param_type == CHOICE:
......@@ -73,8 +70,8 @@ class BatchTuner(Tuner):
return None
def update_search_space(self, search_space):
"""Update the search space
"""Update the search space
Parameters
----------
search_space : dict
......@@ -88,8 +85,8 @@ class BatchTuner(Tuner):
----------
parameter_id : int
"""
self.count +=1
if self.count>len(self.values)-1:
self.count += 1
if self.count > len(self.values) - 1:
raise nni.NoMoreTrialError('no more parameters now.')
return self.values[self.count]
......@@ -97,4 +94,31 @@ class BatchTuner(Tuner):
pass
def import_data(self, data):
pass
"""Import additional data for tuning
Parameters
----------
data:
a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
if len(self.values) == 0:
logger.info("Search space has not been initialized, skip this data import")
return
self.values = self.values[(self.count+1):]
self.count = -1
_completed_num = 0
for trial_info in data:
logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data))
# simply validate data format
assert "parameter" in trial_info
_params = trial_info["parameter"]
assert "value" in trial_info
_value = trial_info['value']
if not _value:
logger.info("Useless trial data, value is %s, skip this trial data.", _value)
continue
_completed_num += 1
if _params in self.values:
self.values.remove(_params)
logger.info("Successfully import data to batch tuner, total data: %d, imported data: %d.", len(data), _completed_num)
......@@ -21,21 +21,21 @@
smac_tuner.py
"""
from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward
import sys
import logging
import numpy as np
import json_tricks
from enum import Enum, unique
from .convert_ss_to_scenario import generate_scenario
from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward
from smac.utils.io.cmd_reader import CMDReader
from smac.scenario.scenario import Scenario
from smac.facade.smac_facade import SMAC
from smac.facade.roar_facade import ROAR
from smac.facade.epils_facade import EPILS
from ConfigSpaceNNI import Configuration
from .convert_ss_to_scenario import generate_scenario
class SMACTuner(Tuner):
......@@ -57,6 +57,7 @@ class SMACTuner(Tuner):
self.update_ss_done = False
self.loguniform_key = set()
self.categorical_dict = {}
self.cs = None
def _main_cli(self):
"""Main function of SMAC for CLI interface
......@@ -66,7 +67,7 @@ class SMACTuner(Tuner):
instance
optimizer
"""
self.logger.info("SMAC call: %s" % (" ".join(sys.argv)))
self.logger.info("SMAC call: %s", " ".join(sys.argv))
cmd_reader = CMDReader()
args, _ = cmd_reader.read_cmd()
......@@ -95,6 +96,7 @@ class SMACTuner(Tuner):
# Create scenario-object
scen = Scenario(args.scenario_file, [])
self.cs = scen.cs
if args.mode == "SMAC":
optimizer = SMAC(
......@@ -258,4 +260,45 @@ class SMACTuner(Tuner):
return params
def import_data(self, data):
pass
"""Import additional data for tuning
Parameters
----------
data:
a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
"""
_completed_num = 0
for trial_info in data:
self.logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data))
# simply validate data format
assert "parameter" in trial_info
_params = trial_info["parameter"]
assert "value" in trial_info
_value = trial_info['value']
if not _value:
self.logger.info("Useless trial data, value is %s, skip this trial data.", _value)
continue
# convert the keys in loguniform and categorical types
valid_entry = True
for key, value in _params.items():
if key in self.loguniform_key:
_params[key] = np.log(value)
elif key in self.categorical_dict:
if value in self.categorical_dict[key]:
_params[key] = self.categorical_dict[key].index(value)
else:
self.logger.info("The value %s of key %s is not in search space.", str(value), key)
valid_entry = False
break
if not valid_entry:
continue
# start import this data entry
_completed_num += 1
config = Configuration(self.cs, values=_params)
if self.optimize_mode is OptimizeMode.Maximize:
_value = -_value
if self.first_one:
self.smbo_solver.nni_smac_receive_first_run(config, _value)
self.first_one = False
else:
self.smbo_solver.nni_smac_receive_runs(config, _value)
self.logger.info("Successfully import data to smac tuner, total data: %d, imported data: %d.", len(data), _completed_num)
......@@ -86,12 +86,13 @@ TUNERS_SUPPORTING_IMPORT_DATA = {
'Anneal',
'GridSearch',
'MetisTuner',
'BOHB'
'BOHB',
'SMAC',
'BatchTuner'
}
TUNERS_NO_NEED_TO_IMPORT_DATA = {
'Random',
'Batch_tuner',
'Hyperband'
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment