Unverified Commit 3efc59ee authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

improve PBT tuner (#2357)

parent 7e35d32e
......@@ -74,18 +74,16 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probabi
top_hyper_parameters = top_trial_info.hyper_parameters
hyper_parameters = copy.deepcopy(top_hyper_parameters)
random_state = np.random.RandomState()
hyper_parameters['load_checkpoint_dir'] = hyper_parameters['save_checkpoint_dir']
hyper_parameters['save_checkpoint_dir'] = os.path.join(bot_checkpoint_dir, str(epoch))
for key in hyper_parameters.keys():
hyper_parameter = hyper_parameters[key]
if key == 'load_checkpoint_dir':
hyper_parameters[key] = hyper_parameters['save_checkpoint_dir']
continue
elif key == 'save_checkpoint_dir':
hyper_parameters[key] = os.path.join(bot_checkpoint_dir, str(epoch))
if key == 'load_checkpoint_dir' or key == 'save_checkpoint_dir':
continue
elif search_space[key]["_type"] == "choice":
choices = search_space[key]["_value"]
ub, uv = len(choices) - 1, choices.index(hyper_parameter["_value"]) + 1
lb, lv = 0, choices.index(hyper_parameter["_value"]) - 1
ub, uv = len(choices) - 1, choices.index(hyper_parameter) + 1
lb, lv = 0, choices.index(hyper_parameter) - 1
elif search_space[key]["_type"] == "randint":
lb, ub = search_space[key]["_value"][:2]
ub -= 1
......@@ -132,10 +130,11 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probabi
else:
logger.warning("Illegal type to perturb: %s", search_space[key]["_type"])
continue
if search_space[key]["_type"] == "choice":
idx = perturbation(search_space[key]["_type"], search_space[key]["_value"],
resample_probability, uv, ub, lv, lb, random_state)
hyper_parameters[key] = {'_index': idx, '_value': choices[idx]}
hyper_parameters[key] = choices[idx]
else:
hyper_parameters[key] = perturbation(search_space[key]["_type"], search_space[key]["_value"],
resample_probability, uv, ub, lv, lb, random_state)
......@@ -231,6 +230,7 @@ class PBTTuner(Tuner):
for i in range(self.population_size):
hyper_parameters = json2parameter(
self.searchspace_json, is_rand, self.random_state)
hyper_parameters = split_index(hyper_parameters)
checkpoint_dir = os.path.join(self.all_checkpoint_dir, str(i))
hyper_parameters['load_checkpoint_dir'] = os.path.join(checkpoint_dir, str(self.epoch))
hyper_parameters['save_checkpoint_dir'] = os.path.join(checkpoint_dir, str(self.epoch))
......@@ -294,38 +294,19 @@ class PBTTuner(Tuner):
trial_info.parameter_id = parameter_id
self.running[parameter_id] = trial_info
logger.info('Generate parameter : %s', trial_info.hyper_parameters)
return split_index(trial_info.hyper_parameters)
return trial_info.hyper_parameters
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
def _proceed_next_epoch(self):
"""
Receive trial's result. if the number of finished trials equals ``self.population_size``, start the next epoch to
train the model.
Parameters
----------
parameter_id : int
Unique identifier of used hyper-parameters, same with :meth:`generate_parameters`.
parameters : dict
Hyper-parameters generated by :meth:`generate_parameters`.
value : dict
Result from trial (the return value of :func:`nni.report_final_result`).
"""
logger.info('Get one trial result, id = %d, value = %s', parameter_id, value)
value = extract_scalar_reward(value)
if self.optimize_mode == OptimizeMode.Minimize:
value = -value
trial_info = self.running.pop(parameter_id, None)
trial_info.score = value
self.finished.append(trial_info)
self.finished_trials += 1
if self.finished_trials == self.population_size:
logger.info('Proceeding to next epoch')
self.epoch += 1
self.population = []
self.pos = -1
self.running = {}
#exploit and explore
self.finished = sorted(self.finished, key=lambda x: x.score, reverse=True)
reverse = True if self.optimize_mode == OptimizeMode.Maximize else False
self.finished = sorted(self.finished, key=lambda x: x.score, reverse=reverse)
cutoff = int(np.ceil(self.fraction * len(self.finished)))
tops = self.finished[:cutoff]
bottoms = self.finished[self.finished_trials - cutoff:]
......@@ -348,7 +329,117 @@ class PBTTuner(Tuner):
trial_info = self.population[self.pos]
trial_info.parameter_id = parameter_id
self.running[parameter_id] = trial_info
self.send_trial_callback(parameter_id, split_index(trial_info.hyper_parameters))
self.send_trial_callback(parameter_id, trial_info.hyper_parameters)
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""
Receive trial's result. if the number of finished trials equals ``self.population_size``, start the next epoch to
train the model.
Parameters
----------
parameter_id : int
Unique identifier of used hyper-parameters, same with :meth:`generate_parameters`.
parameters : dict
Hyper-parameters generated by :meth:`generate_parameters`.
value : dict
Result from trial (the return value of :func:`nni.report_final_result`).
"""
logger.info('Get one trial result, id = %d, value = %s', parameter_id, value)
value = extract_scalar_reward(value)
trial_info = self.running.pop(parameter_id, None)
trial_info.score = value
self.finished.append(trial_info)
self.finished_trials += 1
if self.finished_trials == self.population_size:
self._proceed_next_epoch()
def trial_end(self, parameter_id, success, **kwargs):
"""
Deal with trial failure
Parameters
----------
parameter_id : int
Unique identifier for hyper-parameters used by this trial.
success : bool
True if the trial successfully completed; False if failed or terminated.
**kwargs
Unstable parameters which should be ignored by normal users.
"""
if success:
return
if self.optimize_mode == OptimizeMode.Minimize:
value = float('inf')
else:
value = float('-inf')
trial_info = self.running.pop(parameter_id, None)
trial_info.score = value
self.finished.append(trial_info)
self.finished_trials += 1
if self.finished_trials == self.population_size:
self._proceed_next_epoch()
def import_data(self, data):
pass
"""
Parameters
----------
data : json obj
imported data records
Returns
-------
int
the start epoch number after data imported, only used for unittest
"""
if self.running:
logger.warning("Do not support importing data in the middle of experiment")
return
# the following is for experiment resume
_completed_num = 0
epoch_data_dict = {}
for trial_info in data:
logger.info("Process data record %s / %s", _completed_num, len(data))
_completed_num += 1
# simply validate data format
_params = trial_info["parameter"]
_value = trial_info['value']
# assign fake value for failed trials
if not _value:
logger.info("Useless trial data, value is %s, skip this trial data.", _value)
_value = float('inf') if self.optimize_mode == OptimizeMode.Minimize else float('-inf')
_value = extract_scalar_reward(_value)
if 'save_checkpoint_dir' not in _params:
logger.warning("Invalid data record: save_checkpoint_dir is missing, abandon data import.")
return
epoch_num = int(os.path.basename(_params['save_checkpoint_dir']))
if epoch_num not in epoch_data_dict:
epoch_data_dict[epoch_num] = []
epoch_data_dict[epoch_num].append((_params, _value))
if not epoch_data_dict:
logger.warning("No valid epochs, abandon data import.")
return
# figure out start epoch for resume
max_epoch_num = max(epoch_data_dict, key=int)
if len(epoch_data_dict[max_epoch_num]) < self.population_size:
max_epoch_num -= 1
# If there is no a single complete round, no data to import, start from scratch
if max_epoch_num < 0:
logger.warning("No completed epoch, abandon data import.")
return
assert len(epoch_data_dict[max_epoch_num]) == self.population_size
# check existence of trial save checkpoint dir
for params, _ in epoch_data_dict[max_epoch_num]:
if not os.path.isdir(params['save_checkpoint_dir']):
logger.warning("save_checkpoint_dir %s does not exist, data will not be resumed", params['save_checkpoint_dir'])
return
# resume data
self.epoch = max_epoch_num
self.finished_trials = self.population_size
for params, value in epoch_data_dict[max_epoch_num]:
checkpoint_dir = os.path.dirname(params['save_checkpoint_dir'])
self.finished.append(TrialInfo(checkpoint_dir=checkpoint_dir, hyper_parameters=params, score=value))
self._proceed_next_epoch()
logger.info("Successfully import data to PBT tuner, total data: %d, imported data: %d.", len(data), self.population_size)
logger.info("Start from epoch %d ...", self.epoch)
return self.epoch # return for test
......@@ -159,6 +159,62 @@ class BuiltinTunersTestCase(TestCase):
logger.info("Full supported search space: %s", full_supported_search_space)
self.search_space_test_one(tuner_factory, full_supported_search_space)
def import_data_test_for_pbt(self):
"""
test1: import data with complete epoch
test2: import data with incomplete epoch
"""
search_space = {
"choice_str": {
"_type": "choice",
"_value": ["cat", "dog", "elephant", "cow", "sheep", "panda"]
}
}
all_checkpoint_dir = os.path.expanduser("~/nni/checkpoint/test/")
population_size = 4
# ===import data at the beginning===
tuner = PBTTuner(
all_checkpoint_dir=all_checkpoint_dir,
population_size=population_size
)
self.assertIsInstance(tuner, Tuner)
tuner.update_search_space(search_space)
save_dirs = [os.path.join(all_checkpoint_dir, str(i), str(0)) for i in range(population_size)]
# create save checkpoint directory
for save_dir in save_dirs:
os.makedirs(save_dir, exist_ok=True)
# for simplicity, omit "load_checkpoint_dir"
data = [{"parameter": {"choice_str": "cat", "save_checkpoint_dir": save_dirs[0]}, "value": 1.1},
{"parameter": {"choice_str": "dog", "save_checkpoint_dir": save_dirs[1]}, "value": {"default": 1.2, "tmp": 2}},
{"parameter": {"choice_str": "cat", "save_checkpoint_dir": save_dirs[2]}, "value": 11},
{"parameter": {"choice_str": "cat", "save_checkpoint_dir": save_dirs[3]}, "value": 7}]
epoch = tuner.import_data(data)
self.assertEqual(epoch, 1)
logger.info("Imported data successfully at the beginning")
shutil.rmtree(all_checkpoint_dir)
# ===import another data at the beginning, test the case when there is an incompleted epoch===
tuner = PBTTuner(
all_checkpoint_dir=all_checkpoint_dir,
population_size=population_size
)
self.assertIsInstance(tuner, Tuner)
tuner.update_search_space(search_space)
for i in range(population_size - 1):
save_dirs.append(os.path.join(all_checkpoint_dir, str(i), str(1)))
for save_dir in save_dirs:
os.makedirs(save_dir, exist_ok=True)
data = [{"parameter": {"choice_str": "cat", "save_checkpoint_dir": save_dirs[0]}, "value": 1.1},
{"parameter": {"choice_str": "dog", "save_checkpoint_dir": save_dirs[1]}, "value": {"default": 1.2, "tmp": 2}},
{"parameter": {"choice_str": "cat", "save_checkpoint_dir": save_dirs[2]}, "value": 11},
{"parameter": {"choice_str": "cat", "save_checkpoint_dir": save_dirs[3]}, "value": 7},
{"parameter": {"choice_str": "cat", "save_checkpoint_dir": save_dirs[4]}, "value": 1.1},
{"parameter": {"choice_str": "dog", "save_checkpoint_dir": save_dirs[5]}, "value": {"default": 1.2, "tmp": 2}},
{"parameter": {"choice_str": "cat", "save_checkpoint_dir": save_dirs[6]}, "value": 11}]
epoch = tuner.import_data(data)
self.assertEqual(epoch, 1)
logger.info("Imported data successfully at the beginning with incomplete epoch")
shutil.rmtree(all_checkpoint_dir)
def import_data_test(self, tuner_factory, stype="choice_str"):
"""
import data at the beginning with number value and dict value
......@@ -297,6 +353,7 @@ class BuiltinTunersTestCase(TestCase):
all_checkpoint_dir=os.path.expanduser("~/nni/checkpoint/test/"),
population_size=100
))
self.import_data_test_for_pbt()
def tearDown(self):
file_list = glob.glob("smac3*") + ["param_config_space.pcs", "scenario.txt", "model_path"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment