Unverified Commit b6894c1e authored by Erik Fäßler's avatar Erik Fäßler Committed by GitHub
Browse files

Pass ConfigSpace definition file directly to BOHB (#4153)

parent f1bfdd80
......@@ -4,7 +4,6 @@
'''
bohb_advisor.py
'''
import sys
import math
import logging
......@@ -12,6 +11,7 @@ import json_tricks
from schema import Schema, Optional
import ConfigSpace as CS
import ConfigSpace.hyperparameters as CSH
from ConfigSpace.read_and_write import pcs_new
from nni import ClassArgsValidator
from nni.runtime.protocol import CommandType, send
......@@ -244,6 +244,7 @@ class BOHBClassArgsValidator(ClassArgsValidator):
Optional('random_fraction'): self.range('random_fraction', float, 0, 9999),
Optional('bandwidth_factor'): self.range('bandwidth_factor', float, 0, 9999),
Optional('min_bandwidth'): self.range('min_bandwidth', float, 0, 9999),
Optional('config_space'): self.path('config_space')
}).validate(kwargs)
class BOHB(MsgDispatcherBase):
......@@ -297,7 +298,8 @@ class BOHB(MsgDispatcherBase):
num_samples=64,
random_fraction=1/3,
bandwidth_factor=3,
min_bandwidth=1e-3):
min_bandwidth=1e-3,
config_space=None):
super(BOHB, self).__init__()
self.optimize_mode = OptimizeMode(optimize_mode)
self.min_budget = min_budget
......@@ -309,6 +311,7 @@ class BOHB(MsgDispatcherBase):
self.random_fraction = random_fraction
self.bandwidth_factor = bandwidth_factor
self.min_bandwidth = min_bandwidth
self.config_space = config_space
# all the configs waiting for run
self.generated_hyper_configs = []
......@@ -468,6 +471,14 @@ class BOHB(MsgDispatcherBase):
search space of this experiment
"""
search_space = data
cs = None
logger.debug(f'Received data: {data}')
if self.config_space:
logger.info(f'Got a ConfigSpace file path, parsing the search space directly from {self.config_space}. '
'The NNI search space is ignored.')
with open(self.config_space, 'r') as fh:
cs = pcs_new.read(fh)
else:
cs = CS.ConfigurationSpace()
for var in search_space:
_type = str(search_space[var]["_type"])
......
......@@ -94,6 +94,10 @@ def parse_path(experiment_config, config_path):
parse_relative_path(root_path, experiment_config['assessor'], 'codeDir')
if experiment_config.get('advisor'):
parse_relative_path(root_path, experiment_config['advisor'], 'codeDir')
# for BOHB when delivering a ConfigSpace file directly
if experiment_config.get('advisor').get('classArgs') and experiment_config.get('advisor').get('classArgs').get('config_space'):
parse_relative_path(root_path, experiment_config.get('advisor').get('classArgs'), 'config_space')
if experiment_config.get('machineList'):
for index in range(len(experiment_config['machineList'])):
parse_relative_path(root_path, experiment_config['machineList'][index], 'sshKeyPath')
......
......@@ -4,6 +4,7 @@
import copy
import functools
from enum import Enum, unique
from pathlib import Path
import json_tricks
from schema import And
......@@ -305,3 +306,9 @@ class ClassArgsValidator(object):
And(keyType, error='%s should be %s type!' % (key, keyType.__name__)),
And(lambda n: start <= n <= end, error='%s should be in range of (%s, %s)!' % (key, start, end))
)
def path(self, key):
return And(
And(str, error='%s should be a string!' % key),
And(lambda p: Path(p).exists(), error='%s path does not exist!' % (key))
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment