Unverified Commit 3efc59ee authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

improve PBT tuner (#2357)

parent 7e35d32e
...@@ -74,18 +74,16 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probabi ...@@ -74,18 +74,16 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probabi
top_hyper_parameters = top_trial_info.hyper_parameters top_hyper_parameters = top_trial_info.hyper_parameters
hyper_parameters = copy.deepcopy(top_hyper_parameters) hyper_parameters = copy.deepcopy(top_hyper_parameters)
random_state = np.random.RandomState() random_state = np.random.RandomState()
hyper_parameters['load_checkpoint_dir'] = hyper_parameters['save_checkpoint_dir']
hyper_parameters['save_checkpoint_dir'] = os.path.join(bot_checkpoint_dir, str(epoch))
for key in hyper_parameters.keys(): for key in hyper_parameters.keys():
hyper_parameter = hyper_parameters[key] hyper_parameter = hyper_parameters[key]
if key == 'load_checkpoint_dir': if key == 'load_checkpoint_dir' or key == 'save_checkpoint_dir':
hyper_parameters[key] = hyper_parameters['save_checkpoint_dir']
continue
elif key == 'save_checkpoint_dir':
hyper_parameters[key] = os.path.join(bot_checkpoint_dir, str(epoch))
continue continue
elif search_space[key]["_type"] == "choice": elif search_space[key]["_type"] == "choice":
choices = search_space[key]["_value"] choices = search_space[key]["_value"]
ub, uv = len(choices) - 1, choices.index(hyper_parameter["_value"]) + 1 ub, uv = len(choices) - 1, choices.index(hyper_parameter) + 1
lb, lv = 0, choices.index(hyper_parameter["_value"]) - 1 lb, lv = 0, choices.index(hyper_parameter) - 1
elif search_space[key]["_type"] == "randint": elif search_space[key]["_type"] == "randint":
lb, ub = search_space[key]["_value"][:2] lb, ub = search_space[key]["_value"][:2]
ub -= 1 ub -= 1
...@@ -132,10 +130,11 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probabi ...@@ -132,10 +130,11 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probabi
else: else:
logger.warning("Illegal type to perturb: %s", search_space[key]["_type"]) logger.warning("Illegal type to perturb: %s", search_space[key]["_type"])
continue continue
if search_space[key]["_type"] == "choice": if search_space[key]["_type"] == "choice":
idx = perturbation(search_space[key]["_type"], search_space[key]["_value"], idx = perturbation(search_space[key]["_type"], search_space[key]["_value"],
resample_probability, uv, ub, lv, lb, random_state) resample_probability, uv, ub, lv, lb, random_state)
hyper_parameters[key] = {'_index': idx, '_value': choices[idx]} hyper_parameters[key] = choices[idx]
else: else:
hyper_parameters[key] = perturbation(search_space[key]["_type"], search_space[key]["_value"], hyper_parameters[key] = perturbation(search_space[key]["_type"], search_space[key]["_value"],
resample_probability, uv, ub, lv, lb, random_state) resample_probability, uv, ub, lv, lb, random_state)
...@@ -231,6 +230,7 @@ class PBTTuner(Tuner): ...@@ -231,6 +230,7 @@ class PBTTuner(Tuner):
for i in range(self.population_size): for i in range(self.population_size):
hyper_parameters = json2parameter( hyper_parameters = json2parameter(
self.searchspace_json, is_rand, self.random_state) self.searchspace_json, is_rand, self.random_state)
hyper_parameters = split_index(hyper_parameters)
checkpoint_dir = os.path.join(self.all_checkpoint_dir, str(i)) checkpoint_dir = os.path.join(self.all_checkpoint_dir, str(i))
hyper_parameters['load_checkpoint_dir'] = os.path.join(checkpoint_dir, str(self.epoch)) hyper_parameters['load_checkpoint_dir'] = os.path.join(checkpoint_dir, str(self.epoch))
hyper_parameters['save_checkpoint_dir'] = os.path.join(checkpoint_dir, str(self.epoch)) hyper_parameters['save_checkpoint_dir'] = os.path.join(checkpoint_dir, str(self.epoch))
...@@ -294,7 +294,42 @@ class PBTTuner(Tuner): ...@@ -294,7 +294,42 @@ class PBTTuner(Tuner):
trial_info.parameter_id = parameter_id trial_info.parameter_id = parameter_id
self.running[parameter_id] = trial_info self.running[parameter_id] = trial_info
logger.info('Generate parameter : %s', trial_info.hyper_parameters) logger.info('Generate parameter : %s', trial_info.hyper_parameters)
return split_index(trial_info.hyper_parameters) return trial_info.hyper_parameters
def _proceed_next_epoch(self):
"""
"""
logger.info('Proceeding to next epoch')
self.epoch += 1
self.population = []
self.pos = -1
self.running = {}
#exploit and explore
reverse = True if self.optimize_mode == OptimizeMode.Maximize else False
self.finished = sorted(self.finished, key=lambda x: x.score, reverse=reverse)
cutoff = int(np.ceil(self.fraction * len(self.finished)))
tops = self.finished[:cutoff]
bottoms = self.finished[self.finished_trials - cutoff:]
for bottom in bottoms:
top = np.random.choice(tops)
exploit_and_explore(bottom, top, self.factor, self.resample_probability, self.epoch, self.searchspace_json)
for trial in self.finished:
if trial not in bottoms:
trial.clean_id()
trial.hyper_parameters['load_checkpoint_dir'] = trial.hyper_parameters['save_checkpoint_dir']
trial.hyper_parameters['save_checkpoint_dir'] = os.path.join(trial.checkpoint_dir, str(self.epoch))
self.finished_trials = 0
for _ in range(self.population_size):
trial_info = self.finished.pop()
self.population.append(trial_info)
while self.credit > 0 and self.pos + 1 < len(self.population):
self.credit -= 1
self.pos += 1
parameter_id = self.param_ids.pop()
trial_info = self.population[self.pos]
trial_info.parameter_id = parameter_id
self.running[parameter_id] = trial_info
self.send_trial_callback(parameter_id, trial_info.hyper_parameters)
def receive_trial_result(self, parameter_id, parameters, value, **kwargs): def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
""" """
...@@ -312,43 +347,99 @@ class PBTTuner(Tuner): ...@@ -312,43 +347,99 @@ class PBTTuner(Tuner):
""" """
logger.info('Get one trial result, id = %d, value = %s', parameter_id, value) logger.info('Get one trial result, id = %d, value = %s', parameter_id, value)
value = extract_scalar_reward(value) value = extract_scalar_reward(value)
trial_info = self.running.pop(parameter_id, None)
trial_info.score = value
self.finished.append(trial_info)
self.finished_trials += 1
if self.finished_trials == self.population_size:
self._proceed_next_epoch()
def trial_end(self, parameter_id, success, **kwargs):
"""
Deal with trial failure
Parameters
----------
parameter_id : int
Unique identifier for hyper-parameters used by this trial.
success : bool
True if the trial successfully completed; False if failed or terminated.
**kwargs
Unstable parameters which should be ignored by normal users.
"""
if success:
return
if self.optimize_mode == OptimizeMode.Minimize: if self.optimize_mode == OptimizeMode.Minimize:
value = -value value = float('inf')
else:
value = float('-inf')
trial_info = self.running.pop(parameter_id, None) trial_info = self.running.pop(parameter_id, None)
trial_info.score = value trial_info.score = value
self.finished.append(trial_info) self.finished.append(trial_info)
self.finished_trials += 1 self.finished_trials += 1
if self.finished_trials == self.population_size: if self.finished_trials == self.population_size:
logger.info('Proceeding to next epoch') self._proceed_next_epoch()
self.epoch += 1
self.population = []
self.pos = -1
self.running = {}
#exploit and explore
self.finished = sorted(self.finished, key=lambda x: x.score, reverse=True)
cutoff = int(np.ceil(self.fraction * len(self.finished)))
tops = self.finished[:cutoff]
bottoms = self.finished[self.finished_trials - cutoff:]
for bottom in bottoms:
top = np.random.choice(tops)
exploit_and_explore(bottom, top, self.factor, self.resample_probability, self.epoch, self.searchspace_json)
for trial in self.finished:
if trial not in bottoms:
trial.clean_id()
trial.hyper_parameters['load_checkpoint_dir'] = trial.hyper_parameters['save_checkpoint_dir']
trial.hyper_parameters['save_checkpoint_dir'] = os.path.join(trial.checkpoint_dir, str(self.epoch))
self.finished_trials = 0
for _ in range(self.population_size):
trial_info = self.finished.pop()
self.population.append(trial_info)
while self.credit > 0 and self.pos + 1 < len(self.population):
self.credit -= 1
self.pos += 1
parameter_id = self.param_ids.pop()
trial_info = self.population[self.pos]
trial_info.parameter_id = parameter_id
self.running[parameter_id] = trial_info
self.send_trial_callback(parameter_id, split_index(trial_info.hyper_parameters))
def import_data(self, data): def import_data(self, data):
pass """
Parameters
----------
data : json obj
imported data records
Returns
-------
int
the start epoch number after data imported, only used for unittest
"""
if self.running:
logger.warning("Do not support importing data in the middle of experiment")
return
# the following is for experiment resume
_completed_num = 0
epoch_data_dict = {}
for trial_info in data:
logger.info("Process data record %s / %s", _completed_num, len(data))
_completed_num += 1
# simply validate data format
_params = trial_info["parameter"]
_value = trial_info['value']
# assign fake value for failed trials
if not _value:
logger.info("Useless trial data, value is %s, skip this trial data.", _value)
_value = float('inf') if self.optimize_mode == OptimizeMode.Minimize else float('-inf')
_value = extract_scalar_reward(_value)
if 'save_checkpoint_dir' not in _params:
logger.warning("Invalid data record: save_checkpoint_dir is missing, abandon data import.")
return
epoch_num = int(os.path.basename(_params['save_checkpoint_dir']))
if epoch_num not in epoch_data_dict:
epoch_data_dict[epoch_num] = []
epoch_data_dict[epoch_num].append((_params, _value))
if not epoch_data_dict:
logger.warning("No valid epochs, abandon data import.")
return
# figure out start epoch for resume
max_epoch_num = max(epoch_data_dict, key=int)
if len(epoch_data_dict[max_epoch_num]) < self.population_size:
max_epoch_num -= 1
# If there is no a single complete round, no data to import, start from scratch
if max_epoch_num < 0:
logger.warning("No completed epoch, abandon data import.")
return
assert len(epoch_data_dict[max_epoch_num]) == self.population_size
# check existence of trial save checkpoint dir
for params, _ in epoch_data_dict[max_epoch_num]:
if not os.path.isdir(params['save_checkpoint_dir']):
logger.warning("save_checkpoint_dir %s does not exist, data will not be resumed", params['save_checkpoint_dir'])
return
# resume data
self.epoch = max_epoch_num
self.finished_trials = self.population_size
for params, value in epoch_data_dict[max_epoch_num]:
checkpoint_dir = os.path.dirname(params['save_checkpoint_dir'])
self.finished.append(TrialInfo(checkpoint_dir=checkpoint_dir, hyper_parameters=params, score=value))
self._proceed_next_epoch()
logger.info("Successfully import data to PBT tuner, total data: %d, imported data: %d.", len(data), self.population_size)
logger.info("Start from epoch %d ...", self.epoch)
return self.epoch # return for test
...@@ -159,6 +159,62 @@ class BuiltinTunersTestCase(TestCase): ...@@ -159,6 +159,62 @@ class BuiltinTunersTestCase(TestCase):
logger.info("Full supported search space: %s", full_supported_search_space) logger.info("Full supported search space: %s", full_supported_search_space)
self.search_space_test_one(tuner_factory, full_supported_search_space) self.search_space_test_one(tuner_factory, full_supported_search_space)
def import_data_test_for_pbt(self):
"""
test1: import data with complete epoch
test2: import data with incomplete epoch
"""
search_space = {
"choice_str": {
"_type": "choice",
"_value": ["cat", "dog", "elephant", "cow", "sheep", "panda"]
}
}
all_checkpoint_dir = os.path.expanduser("~/nni/checkpoint/test/")
population_size = 4
# ===import data at the beginning===
tuner = PBTTuner(
all_checkpoint_dir=all_checkpoint_dir,
population_size=population_size
)
self.assertIsInstance(tuner, Tuner)
tuner.update_search_space(search_space)
save_dirs = [os.path.join(all_checkpoint_dir, str(i), str(0)) for i in range(population_size)]
# create save checkpoint directory
for save_dir in save_dirs:
os.makedirs(save_dir, exist_ok=True)
# for simplicity, omit "load_checkpoint_dir"
data = [{"parameter": {"choice_str": "cat", "save_checkpoint_dir": save_dirs[0]}, "value": 1.1},
{"parameter": {"choice_str": "dog", "save_checkpoint_dir": save_dirs[1]}, "value": {"default": 1.2, "tmp": 2}},
{"parameter": {"choice_str": "cat", "save_checkpoint_dir": save_dirs[2]}, "value": 11},
{"parameter": {"choice_str": "cat", "save_checkpoint_dir": save_dirs[3]}, "value": 7}]
epoch = tuner.import_data(data)
self.assertEqual(epoch, 1)
logger.info("Imported data successfully at the beginning")
shutil.rmtree(all_checkpoint_dir)
# ===import another data at the beginning, test the case when there is an incompleted epoch===
tuner = PBTTuner(
all_checkpoint_dir=all_checkpoint_dir,
population_size=population_size
)
self.assertIsInstance(tuner, Tuner)
tuner.update_search_space(search_space)
for i in range(population_size - 1):
save_dirs.append(os.path.join(all_checkpoint_dir, str(i), str(1)))
for save_dir in save_dirs:
os.makedirs(save_dir, exist_ok=True)
data = [{"parameter": {"choice_str": "cat", "save_checkpoint_dir": save_dirs[0]}, "value": 1.1},
{"parameter": {"choice_str": "dog", "save_checkpoint_dir": save_dirs[1]}, "value": {"default": 1.2, "tmp": 2}},
{"parameter": {"choice_str": "cat", "save_checkpoint_dir": save_dirs[2]}, "value": 11},
{"parameter": {"choice_str": "cat", "save_checkpoint_dir": save_dirs[3]}, "value": 7},
{"parameter": {"choice_str": "cat", "save_checkpoint_dir": save_dirs[4]}, "value": 1.1},
{"parameter": {"choice_str": "dog", "save_checkpoint_dir": save_dirs[5]}, "value": {"default": 1.2, "tmp": 2}},
{"parameter": {"choice_str": "cat", "save_checkpoint_dir": save_dirs[6]}, "value": 11}]
epoch = tuner.import_data(data)
self.assertEqual(epoch, 1)
logger.info("Imported data successfully at the beginning with incomplete epoch")
shutil.rmtree(all_checkpoint_dir)
def import_data_test(self, tuner_factory, stype="choice_str"): def import_data_test(self, tuner_factory, stype="choice_str"):
""" """
import data at the beginning with number value and dict value import data at the beginning with number value and dict value
...@@ -297,6 +353,7 @@ class BuiltinTunersTestCase(TestCase): ...@@ -297,6 +353,7 @@ class BuiltinTunersTestCase(TestCase):
all_checkpoint_dir=os.path.expanduser("~/nni/checkpoint/test/"), all_checkpoint_dir=os.path.expanduser("~/nni/checkpoint/test/"),
population_size=100 population_size=100
)) ))
self.import_data_test_for_pbt()
def tearDown(self): def tearDown(self):
file_list = glob.glob("smac3*") + ["param_config_space.pcs", "scenario.txt", "model_path"] file_list = glob.glob("smac3*") + ["param_config_space.pcs", "scenario.txt", "model_path"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment