Unverified Commit c577553d authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #166 from Microsoft/master

merge master
parents 50326948 c7cc8db3
...@@ -164,6 +164,11 @@ class GridSearchTuner(Tuner): ...@@ -164,6 +164,11 @@ class GridSearchTuner(Tuner):
_completed_num += 1 _completed_num += 1
assert "parameter" in trial_info assert "parameter" in trial_info
_params = trial_info["parameter"] _params = trial_info["parameter"]
assert "value" in trial_info
_value = trial_info['value']
if not _value:
logger.info("Useless trial data, value is %s, skip this trial data." %_value)
continue
_params_tuple = convert_dict2tuple(_params) _params_tuple = convert_dict2tuple(_params)
self.supplement_data[_params_tuple] = True self.supplement_data[_params_tuple] = True
logger.info("Successfully import data to grid search tuner.") logger.info("Successfully import data to grid search tuner.")
...@@ -139,19 +139,50 @@ def json2vals(in_x, vals, out_y, name=ROOT): ...@@ -139,19 +139,50 @@ def json2vals(in_x, vals, out_y, name=ROOT):
for i, temp in enumerate(in_x): for i, temp in enumerate(in_x):
json2vals(temp, vals[i], out_y, name + '[%d]' % i) json2vals(temp, vals[i], out_y, name + '[%d]' % i)
def _add_index(in_x, parameter):
"""
change parameters in NNI format to parameters in hyperopt format(This function also support nested dict.).
For example, receive parameters like:
{'dropout_rate': 0.8, 'conv_size': 3, 'hidden_size': 512}
Will change to format in hyperopt, like:
{'dropout_rate': 0.8, 'conv_size': {'_index': 1, '_value': 3}, 'hidden_size': {'_index': 1, '_value': 512}}
"""
if TYPE not in in_x: # if at the top level
out_y = dict()
for key, value in parameter.items():
out_y[key] = _add_index(in_x[key], value)
return out_y
elif isinstance(in_x, dict):
value_type = in_x[TYPE]
value_format = in_x[VALUE]
if value_type == "choice":
choice_name = parameter[0] if isinstance(parameter, list) else parameter
for pos, item in enumerate(value_format): # here value_format is a list
if isinstance(item, list): # this format is ["choice_key", format_dict]
choice_key = item[0]
choice_value_format = item[1]
if choice_key == choice_name:
return {INDEX: pos, VALUE: [choice_name, _add_index(choice_value_format, parameter[1])]}
elif choice_name == item:
return {INDEX: pos, VALUE: item}
else:
return parameter
def _split_index(params): def _split_index(params):
""" """
Delete index infromation from params Delete index infromation from params
""" """
result = {} if isinstance(params, list):
for key in params: return [params[0], _split_index(params[1])]
if isinstance(params[key], dict): elif isinstance(params, dict):
value = params[key][VALUE] if INDEX in params.keys():
else: return _split_index(params[VALUE])
value = params[key] result = dict()
result[key] = value for key in params:
return result result[key] = _split_index(params[key])
return result
else:
return params
class HyperoptTuner(Tuner): class HyperoptTuner(Tuner):
...@@ -373,8 +404,11 @@ class HyperoptTuner(Tuner): ...@@ -373,8 +404,11 @@ class HyperoptTuner(Tuner):
_params = trial_info["parameter"] _params = trial_info["parameter"]
assert "value" in trial_info assert "value" in trial_info
_value = trial_info['value'] _value = trial_info['value']
if not _value:
logger.info("Useless trial data, value is %s, skip this trial data." %_value)
continue
self.supplement_data_num += 1 self.supplement_data_num += 1
_parameter_id = '_'.join(["ImportData", str(self.supplement_data_num)]) _parameter_id = '_'.join(["ImportData", str(self.supplement_data_num)])
self.total_data[_parameter_id] = _params self.total_data[_parameter_id] = _add_index(in_x=self.json, parameter=_params)
self.receive_trial_result(parameter_id=_parameter_id, parameters=_params, value=_value) self.receive_trial_result(parameter_id=_parameter_id, parameters=_params, value=_value)
logger.info("Successfully import data to TPE/Anneal tuner.") logger.info("Successfully import data to TPE/Anneal tuner.")
...@@ -65,7 +65,7 @@ class MetisTuner(Tuner): ...@@ -65,7 +65,7 @@ class MetisTuner(Tuner):
https://www.microsoft.com/en-us/research/publication/metis-robustly-tuning-tail-latencies-cloud-systems/ https://www.microsoft.com/en-us/research/publication/metis-robustly-tuning-tail-latencies-cloud-systems/
""" """
def __init__(self, optimize_mode="maximize", no_resampling=True, no_candidates=True, def __init__(self, optimize_mode="maximize", no_resampling=True, no_candidates=False,
selection_num_starting_points=600, cold_start_num=10, exploration_probability=0.9): selection_num_starting_points=600, cold_start_num=10, exploration_probability=0.9):
""" """
Parameters Parameters
...@@ -417,6 +417,9 @@ class MetisTuner(Tuner): ...@@ -417,6 +417,9 @@ class MetisTuner(Tuner):
_params = trial_info["parameter"] _params = trial_info["parameter"]
assert "value" in trial_info assert "value" in trial_info
_value = trial_info['value'] _value = trial_info['value']
if not _value:
logger.info("Useless trial data, value is %s, skip this trial data." %_value)
continue
self.supplement_data_num += 1 self.supplement_data_num += 1
_parameter_id = '_'.join(["ImportData", str(self.supplement_data_num)]) _parameter_id = '_'.join(["ImportData", str(self.supplement_data_num)])
self.total_data.append(_params) self.total_data.append(_params)
......
...@@ -77,6 +77,8 @@ class MsgDispatcherBase(Recoverable): ...@@ -77,6 +77,8 @@ class MsgDispatcherBase(Recoverable):
break break
else: else:
self.enqueue_command(command, data) self.enqueue_command(command, data)
if self.worker_exceptions:
break
_logger.info('Dispatcher exiting...') _logger.info('Dispatcher exiting...')
self.stopping = True self.stopping = True
......
...@@ -30,13 +30,15 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState> ...@@ -30,13 +30,15 @@ class DefaultPoint extends React.Component<DefaultPointProps, DefaultPointState>
const accSource: Array<DetailAccurPoint> = []; const accSource: Array<DetailAccurPoint> = [];
Object.keys(showSource).map(item => { Object.keys(showSource).map(item => {
const temp = showSource[item]; const temp = showSource[item];
if (temp.status === 'SUCCEEDED' && temp.acc.default !== undefined) { if (temp.status === 'SUCCEEDED' && temp.acc !== undefined) {
const searchSpace = temp.description.parameters; if (temp.acc.default !== undefined) {
accSource.push({ const searchSpace = temp.description.parameters;
acc: temp.acc.default, accSource.push({
index: temp.sequenceId, acc: temp.acc.default,
searchSpace: JSON.stringify(searchSpace) index: temp.sequenceId,
}); searchSpace: JSON.stringify(searchSpace)
});
}
} }
}); });
const resultList: Array<number | string>[] = []; const resultList: Array<number | string>[] = [];
......
...@@ -32,9 +32,12 @@ def get_yml_content(file_path): ...@@ -32,9 +32,12 @@ def get_yml_content(file_path):
try: try:
with open(file_path, 'r') as file: with open(file_path, 'r') as file:
return yaml.load(file, Loader=yaml.Loader) return yaml.load(file, Loader=yaml.Loader)
except TypeError as err: except yaml.scanner.ScannerError as err:
print('Error: ', err) print_error('yaml file format error!')
return None exit(1)
except Exception as exception:
print_error(exception)
exit(1)
def get_json_content(file_path): def get_json_content(file_path):
'''Load json file content''' '''Load json file content'''
...@@ -42,7 +45,7 @@ def get_json_content(file_path): ...@@ -42,7 +45,7 @@ def get_json_content(file_path):
with open(file_path, 'r') as file: with open(file_path, 'r') as file:
return json.load(file) return json.load(file)
except TypeError as err: except TypeError as err:
print('Error: ', err) print_error('json file format error!')
return None return None
def print_error(content): def print_error(content):
......
...@@ -113,8 +113,11 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None ...@@ -113,8 +113,11 @@ def start_rest_server(port, platform, mode, config_file_name, experiment_id=None
entry_dir = get_nni_installation_path() entry_dir = get_nni_installation_path()
entry_file = os.path.join(entry_dir, 'main.js') entry_file = os.path.join(entry_dir, 'main.js')
cmds = ['node', entry_file, '--port', str(port), '--mode', platform, '--start_mode', mode] node_command = 'node'
if sys.platform == 'win32':
node_command = os.path.join(entry_dir[:-3], 'Scripts', 'node.exe')
cmds = [node_command, entry_file, '--port', str(port), '--mode', platform, '--start_mode', mode]
if log_dir is not None: if log_dir is not None:
cmds += ['--log_dir', log_dir] cmds += ['--log_dir', log_dir]
if log_level is not None: if log_level is not None:
......
...@@ -136,7 +136,7 @@ def import_data(args): ...@@ -136,7 +136,7 @@ def import_data(args):
args.port = get_experiment_port(args) args.port = get_experiment_port(args)
if args.port is not None: if args.port is not None:
if import_data_to_restful_server(args, content): if import_data_to_restful_server(args, content):
print_normal('Import data success!') pass
else: else:
print_error('Import data failed!') print_error('Import data failed!')
......
...@@ -25,15 +25,20 @@ import time ...@@ -25,15 +25,20 @@ import time
from xml.dom import minidom from xml.dom import minidom
def check_ready_to_run(): def check_ready_to_run():
#TODO check process in windows
if sys.platform == 'win32': if sys.platform == 'win32':
return True pgrep_output = subprocess.check_output('wmic process where "CommandLine like \'%nni_gpu_tool.gpu_metrics_collector%\' and name like \'%python%\'" get processId')
pgrep_output =subprocess.check_output('pgrep -fx \'python3 -m nni_gpu_tool.gpu_metrics_collector\'', shell=True) pidList = pgrep_output.decode("utf-8").strip().split()
pidList = [] pidList.pop(0) # remove the key word 'ProcessId'
for pid in pgrep_output.splitlines(): pidList = list(map(int, pidList))
pidList.append(int(pid)) pidList.remove(os.getpid())
pidList.remove(os.getpid()) return len(pidList) == 0
return len(pidList) == 0 else:
pgrep_output =subprocess.check_output('pgrep -fx \'python3 -m nni_gpu_tool.gpu_metrics_collector\'', shell=True)
pidList = []
for pid in pgrep_output.splitlines():
pidList.append(int(pid))
pidList.remove(os.getpid())
return len(pidList) == 0
def main(argv): def main(argv):
metrics_output_dir = os.environ['METRIC_OUTPUT_DIR'] metrics_output_dir = os.environ['METRIC_OUTPUT_DIR']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment