Commit 693cf20f authored by xuehui's avatar xuehui
Browse files

add batch_tuner

parent 9f7cce63
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
'''
batch_tuner.py including:
class BatchTuner
'''
import copy
from enum import Enum, unique
import random
import numpy as np
from nni.tuner import Tuner
from . import parameter_expressions
TYPE = '_type'
CHOICE = 'choice'
VALUE = '_value'
class BatchTuner(Tuner):
'''
BatchTuner is tuner will running all the configure that user want to run batchly.
The search space only be accepted like:
{
'combine_params': { '_type': 'choice',
'_value': '[{...}, {...}, {...}]',
}
}
'''
def __init__(self):
self.count = -1
self.values = []
def is_valid(self, search_space)
'''
Check the search space is valid: only contains 'choice' type
'''
if not len(search_space) == 1:
raise RuntimeException('BatchTuner only supprt one combined-paramreters key.')
for param in search_space:
param_type = param[TYPE]
if param_type is not CHOICE:
raise RuntimeException('BatchTuner only supprt one combined-paramreters type is choice.')
else:
if isinstance(param[VALUE], list):
return param[VALUE]
raise RuntimeException('The combined-paramreters value in BatchTuner is not a list.')
return None
def update_search_space(self, search_space):
self.values = is_valid(search_space)
def generate_parameters(self, parameter_id):
count +=1
if count>len(self.value)-1:
return None
return self.values[count]
def receive_trial_result(self, parameter_id, parameters, reward):
pass
\ No newline at end of file
...@@ -89,7 +89,12 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -89,7 +89,12 @@ class MsgDispatcher(MsgDispatcherBase):
# data: number or trial jobs # data: number or trial jobs
ids = [_create_parameter_id() for _ in range(data)] ids = [_create_parameter_id() for _ in range(data)]
params_list = self.tuner.generate_multiple_parameters(ids) params_list = self.tuner.generate_multiple_parameters(ids)
assert len(ids) == len(params_list) #assert len(ids) == len(params_list)
# when parameters is None.
if len(params_list) == 0:
send(CommandType.NoMoreTrialJobs, _pack_parameter(ids[0], ''))
else:
for i, _ in enumerate(ids): for i, _ in enumerate(ids):
send(CommandType.NewTrialJob, _pack_parameter(ids[i], params_list[i])) send(CommandType.NewTrialJob, _pack_parameter(ids[i], params_list[i]))
return True return True
......
...@@ -42,7 +42,12 @@ class Tuner(Recoverable): ...@@ -42,7 +42,12 @@ class Tuner(Recoverable):
User code must override either this function or 'generate_parameters()'. User code must override either this function or 'generate_parameters()'.
parameter_id_list: list of int parameter_id_list: list of int
""" """
return [self.generate_parameters(parameter_id) for parameter_id in parameter_id_list] result = []
for parameter_id in parameter_id_list:
temp = self.generate_parameters(parameter_id)
if temp:
result.append(temp)
return result
def receive_trial_result(self, parameter_id, parameters, reward): def receive_trial_result(self, parameter_id, parameters, reward):
"""Invoked when a trial reports its final result. Must override. """Invoked when a trial reports its final result. Must override.
......
...@@ -90,7 +90,8 @@ def parse_tuner_content(experiment_config): ...@@ -90,7 +90,8 @@ def parse_tuner_content(experiment_config):
tuner_class_name_dict = {'TPE': 'HyperoptTuner',\ tuner_class_name_dict = {'TPE': 'HyperoptTuner',\
'Random': 'HyperoptTuner',\ 'Random': 'HyperoptTuner',\
'Anneal': 'HyperoptTuner',\ 'Anneal': 'HyperoptTuner',\
'Evolution': 'EvolutionTuner'} 'Evolution': 'EvolutionTuner',\
'BatchTuning': 'BatchTuner'}
tuner_algorithm_name_dict = {'TPE': 'tpe',\ tuner_algorithm_name_dict = {'TPE': 'tpe',\
'Random': 'random_search',\ 'Random': 'random_search',\
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment