Unverified Commit b6894c1e authored by Erik Fäßler's avatar Erik Fäßler Committed by GitHub
Browse files

Pass ConfigSpace definition file directly to BOHB (#4153)

parent f1bfdd80
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
''' '''
bohb_advisor.py bohb_advisor.py
''' '''
import sys import sys
import math import math
import logging import logging
...@@ -12,6 +11,7 @@ import json_tricks ...@@ -12,6 +11,7 @@ import json_tricks
from schema import Schema, Optional from schema import Schema, Optional
import ConfigSpace as CS import ConfigSpace as CS
import ConfigSpace.hyperparameters as CSH import ConfigSpace.hyperparameters as CSH
from ConfigSpace.read_and_write import pcs_new
from nni import ClassArgsValidator from nni import ClassArgsValidator
from nni.runtime.protocol import CommandType, send from nni.runtime.protocol import CommandType, send
...@@ -244,6 +244,7 @@ class BOHBClassArgsValidator(ClassArgsValidator): ...@@ -244,6 +244,7 @@ class BOHBClassArgsValidator(ClassArgsValidator):
Optional('random_fraction'): self.range('random_fraction', float, 0, 9999), Optional('random_fraction'): self.range('random_fraction', float, 0, 9999),
Optional('bandwidth_factor'): self.range('bandwidth_factor', float, 0, 9999), Optional('bandwidth_factor'): self.range('bandwidth_factor', float, 0, 9999),
Optional('min_bandwidth'): self.range('min_bandwidth', float, 0, 9999), Optional('min_bandwidth'): self.range('min_bandwidth', float, 0, 9999),
Optional('config_space'): self.path('config_space')
}).validate(kwargs) }).validate(kwargs)
class BOHB(MsgDispatcherBase): class BOHB(MsgDispatcherBase):
...@@ -297,7 +298,8 @@ class BOHB(MsgDispatcherBase): ...@@ -297,7 +298,8 @@ class BOHB(MsgDispatcherBase):
num_samples=64, num_samples=64,
random_fraction=1/3, random_fraction=1/3,
bandwidth_factor=3, bandwidth_factor=3,
min_bandwidth=1e-3): min_bandwidth=1e-3,
config_space=None):
super(BOHB, self).__init__() super(BOHB, self).__init__()
self.optimize_mode = OptimizeMode(optimize_mode) self.optimize_mode = OptimizeMode(optimize_mode)
self.min_budget = min_budget self.min_budget = min_budget
...@@ -309,6 +311,7 @@ class BOHB(MsgDispatcherBase): ...@@ -309,6 +311,7 @@ class BOHB(MsgDispatcherBase):
self.random_fraction = random_fraction self.random_fraction = random_fraction
self.bandwidth_factor = bandwidth_factor self.bandwidth_factor = bandwidth_factor
self.min_bandwidth = min_bandwidth self.min_bandwidth = min_bandwidth
self.config_space = config_space
# all the configs waiting for run # all the configs waiting for run
self.generated_hyper_configs = [] self.generated_hyper_configs = []
...@@ -468,48 +471,56 @@ class BOHB(MsgDispatcherBase): ...@@ -468,48 +471,56 @@ class BOHB(MsgDispatcherBase):
search space of this experiment search space of this experiment
""" """
search_space = data search_space = data
cs = CS.ConfigurationSpace() cs = None
for var in search_space: logger.debug(f'Received data: {data}')
_type = str(search_space[var]["_type"]) if self.config_space:
if _type == 'choice': logger.info(f'Got a ConfigSpace file path, parsing the search space directly from {self.config_space}. '
cs.add_hyperparameter(CSH.CategoricalHyperparameter( 'The NNI search space is ignored.')
var, choices=search_space[var]["_value"])) with open(self.config_space, 'r') as fh:
elif _type == 'randint': cs = pcs_new.read(fh)
cs.add_hyperparameter(CSH.UniformIntegerHyperparameter( else:
var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1] - 1)) cs = CS.ConfigurationSpace()
elif _type == 'uniform': for var in search_space:
cs.add_hyperparameter(CSH.UniformFloatHyperparameter( _type = str(search_space[var]["_type"])
var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1])) if _type == 'choice':
elif _type == 'quniform': cs.add_hyperparameter(CSH.CategoricalHyperparameter(
cs.add_hyperparameter(CSH.UniformFloatHyperparameter( var, choices=search_space[var]["_value"]))
var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1], elif _type == 'randint':
q=search_space[var]["_value"][2])) cs.add_hyperparameter(CSH.UniformIntegerHyperparameter(
elif _type == 'loguniform': var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1] - 1))
cs.add_hyperparameter(CSH.UniformFloatHyperparameter( elif _type == 'uniform':
var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1], cs.add_hyperparameter(CSH.UniformFloatHyperparameter(
log=True)) var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1]))
elif _type == 'qloguniform': elif _type == 'quniform':
cs.add_hyperparameter(CSH.UniformFloatHyperparameter( cs.add_hyperparameter(CSH.UniformFloatHyperparameter(
var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1], var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1],
q=search_space[var]["_value"][2], log=True)) q=search_space[var]["_value"][2]))
elif _type == 'normal': elif _type == 'loguniform':
cs.add_hyperparameter(CSH.NormalFloatHyperparameter( cs.add_hyperparameter(CSH.UniformFloatHyperparameter(
var, mu=search_space[var]["_value"][1], sigma=search_space[var]["_value"][2])) var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1],
elif _type == 'qnormal': log=True))
cs.add_hyperparameter(CSH.NormalFloatHyperparameter( elif _type == 'qloguniform':
var, mu=search_space[var]["_value"][1], sigma=search_space[var]["_value"][2], cs.add_hyperparameter(CSH.UniformFloatHyperparameter(
q=search_space[var]["_value"][3])) var, lower=search_space[var]["_value"][0], upper=search_space[var]["_value"][1],
elif _type == 'lognormal': q=search_space[var]["_value"][2], log=True))
cs.add_hyperparameter(CSH.NormalFloatHyperparameter( elif _type == 'normal':
var, mu=search_space[var]["_value"][1], sigma=search_space[var]["_value"][2], cs.add_hyperparameter(CSH.NormalFloatHyperparameter(
log=True)) var, mu=search_space[var]["_value"][1], sigma=search_space[var]["_value"][2]))
elif _type == 'qlognormal': elif _type == 'qnormal':
cs.add_hyperparameter(CSH.NormalFloatHyperparameter( cs.add_hyperparameter(CSH.NormalFloatHyperparameter(
var, mu=search_space[var]["_value"][1], sigma=search_space[var]["_value"][2], var, mu=search_space[var]["_value"][1], sigma=search_space[var]["_value"][2],
q=search_space[var]["_value"][3], log=True)) q=search_space[var]["_value"][3]))
else: elif _type == 'lognormal':
raise ValueError( cs.add_hyperparameter(CSH.NormalFloatHyperparameter(
'unrecognized type in search_space, type is {}'.format(_type)) var, mu=search_space[var]["_value"][1], sigma=search_space[var]["_value"][2],
log=True))
elif _type == 'qlognormal':
cs.add_hyperparameter(CSH.NormalFloatHyperparameter(
var, mu=search_space[var]["_value"][1], sigma=search_space[var]["_value"][2],
q=search_space[var]["_value"][3], log=True))
else:
raise ValueError(
'unrecognized type in search_space, type is {}'.format(_type))
self.search_space = cs self.search_space = cs
......
...@@ -94,6 +94,10 @@ def parse_path(experiment_config, config_path): ...@@ -94,6 +94,10 @@ def parse_path(experiment_config, config_path):
parse_relative_path(root_path, experiment_config['assessor'], 'codeDir') parse_relative_path(root_path, experiment_config['assessor'], 'codeDir')
if experiment_config.get('advisor'): if experiment_config.get('advisor'):
parse_relative_path(root_path, experiment_config['advisor'], 'codeDir') parse_relative_path(root_path, experiment_config['advisor'], 'codeDir')
# for BOHB when delivering a ConfigSpace file directly
if experiment_config.get('advisor').get('classArgs') and experiment_config.get('advisor').get('classArgs').get('config_space'):
parse_relative_path(root_path, experiment_config.get('advisor').get('classArgs'), 'config_space')
if experiment_config.get('machineList'): if experiment_config.get('machineList'):
for index in range(len(experiment_config['machineList'])): for index in range(len(experiment_config['machineList'])):
parse_relative_path(root_path, experiment_config['machineList'][index], 'sshKeyPath') parse_relative_path(root_path, experiment_config['machineList'][index], 'sshKeyPath')
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import copy import copy
import functools import functools
from enum import Enum, unique from enum import Enum, unique
from pathlib import Path
import json_tricks import json_tricks
from schema import And from schema import And
...@@ -305,3 +306,9 @@ class ClassArgsValidator(object): ...@@ -305,3 +306,9 @@ class ClassArgsValidator(object):
And(keyType, error='%s should be %s type!' % (key, keyType.__name__)), And(keyType, error='%s should be %s type!' % (key, keyType.__name__)),
And(lambda n: start <= n <= end, error='%s should be in range of (%s, %s)!' % (key, start, end)) And(lambda n: start <= n <= end, error='%s should be in range of (%s, %s)!' % (key, start, end))
) )
def path(self, key):
return And(
And(str, error='%s should be a string!' % key),
And(lambda p: Path(p).exists(), error='%s path does not exist!' % (key))
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment