Unverified Commit d2c57770 authored by RayMeng8's avatar RayMeng8 Committed by GitHub
Browse files

Add supported data types for PBT tuner (#2271)

parent c61700f3
...@@ -155,8 +155,8 @@ def get_params(): ...@@ -155,8 +155,8 @@ def get_params():
help='learning rate (default: 0.01)') help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M', parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)') help='SGD momentum (default: 0.5)')
parser.add_argument('--epochs', type=int, default=10, metavar='N', parser.add_argument('--epochs', type=int, default=1, metavar='N',
help='number of epochs to train (default: 10)') help='number of epochs to train (default: 1)')
parser.add_argument('--seed', type=int, default=1, metavar='S', parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)') help='random seed (default: 1)')
parser.add_argument('--no_cuda', action='store_true', default=False, parser.add_argument('--no_cuda', action='store_true', default=False,
......
...@@ -4,9 +4,11 @@ ...@@ -4,9 +4,11 @@
import copy import copy
import logging import logging
import os import os
import random
import numpy as np import numpy as np
import nni import nni
import nni.parameter_expressions
from nni.tuner import Tuner from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward, split_index, json2parameter, json2space from nni.utils import OptimizeMode, extract_scalar_reward, split_index, json2parameter, json2space
...@@ -14,7 +16,42 @@ from nni.utils import OptimizeMode, extract_scalar_reward, split_index, json2par ...@@ -14,7 +16,42 @@ from nni.utils import OptimizeMode, extract_scalar_reward, split_index, json2par
logger = logging.getLogger('pbt_tuner_AutoML') logger = logging.getLogger('pbt_tuner_AutoML')
def exploit_and_explore(bot_trial_info, top_trial_info, factors, epoch, search_space): def perturbation(hyperparameter_type, value, resample_probablity, uv, ub, lv, lb, random_state):
"""
Perturbation for hyperparameters
Parameters
----------
hyperparameter_type : str
type of hyperparameter
value : list
parameters for sampling hyperparameter
resample_probability : float
probability for resampling
uv : float/int
upper value after perturbation
ub : float/int
upper bound
lv : float/int
lower value after perturbation
lb : float/int
lower bound
random_state : RandomState
random state
"""
if random.random() < resample_probablity:
if hyperparameter_type == "choice":
return value.index(nni.parameter_expressions.choice(value, random_state))
else:
return getattr(nni.parameter_expressions, hyperparameter_type)(*(value + [random_state]))
else:
if random.random() > 0.5:
return min(uv, ub)
else:
return max(lv, lb)
def exploit_and_explore(bot_trial_info, top_trial_info, factor, resample_probability, epoch, search_space):
""" """
Replace checkpoint of bot_trial with top, and perturb hyperparameters Replace checkpoint of bot_trial with top, and perturb hyperparameters
...@@ -24,8 +61,10 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factors, epoch, search_s ...@@ -24,8 +61,10 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factors, epoch, search_s
bottom model whose parameters should be replaced bottom model whose parameters should be replaced
top_trial_info : TrialInfo top_trial_info : TrialInfo
better model better model
factors : float factor : float
factors for perturbation factor for perturbation
resample_probability : float
probability for resampling
epoch : int epoch : int
step of PBTTuner step of PBTTuner
search_space : dict search_space : dict
...@@ -34,21 +73,72 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factors, epoch, search_s ...@@ -34,21 +73,72 @@ def exploit_and_explore(bot_trial_info, top_trial_info, factors, epoch, search_s
bot_checkpoint_dir = bot_trial_info.checkpoint_dir bot_checkpoint_dir = bot_trial_info.checkpoint_dir
top_hyper_parameters = top_trial_info.hyper_parameters top_hyper_parameters = top_trial_info.hyper_parameters
hyper_parameters = copy.deepcopy(top_hyper_parameters) hyper_parameters = copy.deepcopy(top_hyper_parameters)
# TODO think about different type of hyperparameters for 1.perturbation 2.within search space random_state = np.random.RandomState()
for key in hyper_parameters.keys(): for key in hyper_parameters.keys():
hyper_parameter = hyper_parameters[key]
if key == 'load_checkpoint_dir': if key == 'load_checkpoint_dir':
hyper_parameters[key] = hyper_parameters['save_checkpoint_dir'] hyper_parameters[key] = hyper_parameters['save_checkpoint_dir']
continue
elif key == 'save_checkpoint_dir': elif key == 'save_checkpoint_dir':
hyper_parameters[key] = os.path.join(bot_checkpoint_dir, str(epoch)) hyper_parameters[key] = os.path.join(bot_checkpoint_dir, str(epoch))
elif isinstance(hyper_parameters[key], float): continue
perturb = np.random.choice(factors) elif search_space[key]["_type"] == "choice":
val = hyper_parameters[key] * perturb choices = search_space[key]["_value"]
ub, uv = len(choices) - 1, choices.index(hyper_parameter["_value"]) + 1
lb, lv = 0, choices.index(hyper_parameter["_value"]) - 1
elif search_space[key]["_type"] == "randint":
lb, ub = search_space[key]["_value"][:2] lb, ub = search_space[key]["_value"][:2]
if search_space[key]["_type"] in ("uniform", "normal"): ub -= 1
val = np.clip(val, lb, ub).item() uv = hyper_parameter + 1
hyper_parameters[key] = val lv = hyper_parameter - 1
elif search_space[key]["_type"] == "uniform":
lb, ub = search_space[key]["_value"][:2]
perturb = (ub - lb) * factor
uv = hyper_parameter + perturb
lv = hyper_parameter - perturb
elif search_space[key]["_type"] == "quniform":
lb, ub, q = search_space[key]["_value"][:3]
multi = round(hyper_parameter / q)
uv = (multi + 1) * q
lv = (multi - 1) * q
elif search_space[key]["_type"] == "loguniform":
lb, ub = search_space[key]["_value"][:2]
perturb = (np.log(ub) - np.log(lb)) * factor
uv = np.exp(min(np.log(hyper_parameter) + perturb, np.log(ub)))
lv = np.exp(max(np.log(hyper_parameter) - perturb, np.log(lb)))
elif search_space[key]["_type"] == "qloguniform":
lb, ub, q = search_space[key]["_value"][:3]
multi = round(hyper_parameter / q)
uv = (multi + 1) * q
lv = (multi - 1) * q
elif search_space[key]["_type"] == "normal":
sigma = search_space[key]["_value"][1]
perturb = sigma * factor
uv = ub = hyper_parameter + perturb
lv = lb = hyper_parameter - perturb
elif search_space[key]["_type"] == "qnormal":
q = search_space[key]["_value"][2]
uv = ub = hyper_parameter + q
lv = lb = hyper_parameter - q
elif search_space[key]["_type"] == "lognormal":
sigma = search_space[key]["_value"][1]
perturb = sigma * factor
uv = ub = np.exp(np.log(hyper_parameter) + perturb)
lv = lb = np.exp(np.log(hyper_parameter) - perturb)
elif search_space[key]["_type"] == "qlognormal":
q = search_space[key]["_value"][2]
uv = ub = hyper_parameter + q
lv, lb = hyper_parameter - q, 1E-10
else: else:
logger.warning("Illegal type to perturb: %s", search_space[key]["_type"])
continue continue
if search_space[key]["_type"] == "choice":
idx = perturbation(search_space[key]["_type"], search_space[key]["_value"],
resample_probability, uv, ub, lv, lb, random_state)
hyper_parameters[key] = {'_index': idx, '_value': choices[idx]}
else:
hyper_parameters[key] = perturbation(search_space[key]["_type"], search_space[key]["_value"],
resample_probability, uv, ub, lv, lb, random_state)
bot_trial_info.hyper_parameters = hyper_parameters bot_trial_info.hyper_parameters = hyper_parameters
bot_trial_info.clean_id() bot_trial_info.clean_id()
...@@ -70,7 +160,8 @@ class TrialInfo: ...@@ -70,7 +160,8 @@ class TrialInfo:
class PBTTuner(Tuner): class PBTTuner(Tuner):
def __init__(self, optimize_mode="maximize", all_checkpoint_dir=None, population_size=10, factors=(1.2, 0.8), fraction=0.2): def __init__(self, optimize_mode="maximize", all_checkpoint_dir=None, population_size=10, factor=0.2,
resample_probability=0.25, fraction=0.2):
""" """
Initialization Initialization
...@@ -82,8 +173,10 @@ class PBTTuner(Tuner): ...@@ -82,8 +173,10 @@ class PBTTuner(Tuner):
directory to store training model checkpoint directory to store training model checkpoint
population_size : int population_size : int
number of trials for each epoch number of trials for each epoch
factors : tuple factor : float
factors for perturbation factor for perturbation
resample_probability : float
probability for resampling
fraction : float fraction : float
fraction for selecting bottom and top trials fraction for selecting bottom and top trials
""" """
...@@ -93,7 +186,8 @@ class PBTTuner(Tuner): ...@@ -93,7 +186,8 @@ class PBTTuner(Tuner):
logger.info("Checkpoint dir is set to %s by default.", all_checkpoint_dir) logger.info("Checkpoint dir is set to %s by default.", all_checkpoint_dir)
self.all_checkpoint_dir = all_checkpoint_dir self.all_checkpoint_dir = all_checkpoint_dir
self.population_size = population_size self.population_size = population_size
self.factors = factors self.factor = factor
self.resample_probability = resample_probability
self.fraction = fraction self.fraction = fraction
# defined in trial code # defined in trial code
#self.perturbation_interval = perturbation_interval #self.perturbation_interval = perturbation_interval
...@@ -237,7 +331,7 @@ class PBTTuner(Tuner): ...@@ -237,7 +331,7 @@ class PBTTuner(Tuner):
bottoms = self.finished[self.finished_trials - cutoff:] bottoms = self.finished[self.finished_trials - cutoff:]
for bottom in bottoms: for bottom in bottoms:
top = np.random.choice(tops) top = np.random.choice(tops)
exploit_and_explore(bottom, top, self.factors, self.epoch, self.searchspace_json) exploit_and_explore(bottom, top, self.factor, self.resample_probability, self.epoch, self.searchspace_json)
for trial in self.finished: for trial in self.finished:
if trial not in bottoms: if trial not in bottoms:
trial.clean_id() trial.clean_id()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment