Commit 252f36f8 authored by Deshui Yu's avatar Deshui Yu
Browse files

NNI dogfood version 1

parent 781cea26
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import argparse
import logging
from customer_tuner import CustomerTuner, OptimizeMode
logger = logging.getLogger('nni.ga_customer_tuner')
logger.debug('START')
def main():
parser = argparse.ArgumentParser(description='parse command line parameters.')
parser.add_argument('--optimize_mode', type=str, default='maximize',
help='Select optimize mode for Tuner: minimize or maximize.')
FLAGS, unparsed = parser.parse_known_args()
if FLAGS.optimize_mode not in [ mode.value for mode in OptimizeMode ]:
raise AttributeError('Unsupported optimize mode "%s"' % FLAGS.optimize_mode)
tuner = CustomerTuner(FLAGS.optimize_mode)
tuner.run()
try:
main()
except Exception as e:
logger.exception(e)
raise
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
from graph import *
import copy
import json
import logging
import random
import numpy as np
from nni.tuner import Tuner
logger = logging.getLogger('ga_customer_tuner')
@unique
class OptimizeMode(Enum):
Minimize = 'minimize'
Maximize = 'maximize'
def init_population(population_size=32):
population = []
graph = Graph(4,
input=[Layer(LayerType.input.value, output=[4, 5], size='x'), Layer(LayerType.input.value, output=[4, 5], size='y')],
output=[Layer(LayerType.output.value, input=[4], size='x'), Layer(LayerType.output.value, input=[5], size='y')],
hide=[Layer(LayerType.attention.value, input=[0, 1], output=[2]), Layer(LayerType.attention.value, input=[1, 0], output=[3])])
for _ in range(population_size):
g = copy.deepcopy(graph)
for _ in range(1):
g.mutation()
population.append(Individual(g, result=None))
return population
class Individual(object):
def __init__(self, config=None, info=None, result=None, save_dir=None):
self.config = config
self.result = result
self.info = info
self.restore_dir = None
self.save_dir = save_dir
def __str__(self):
return "info: " + str(self.info) + ", config :" + str(self.config) + ", result: " + str(self.result)
def mutation(self, config=None, info=None, save_dir=None):
self.result = None
self.config = config
self.config.mutation()
self.restore_dir = self.save_dir
self.save_dir = save_dir
self.info = info
class CustomerTuner(Tuner):
def __init__(self, optimize_mode, population_size = 32):
self.optimize_mode = OptimizeMode(optimize_mode)
self.population = init_population(population_size)
assert len(self.population) == population_size
logger.debug('init population done.')
return
def generate_parameters(self, parameter_id):
"""Returns a set of trial graph config, as a serializable object.
parameter_id : int
"""
if len(self.population) <= 0:
logger.debug("the len of poplution lower than zero.")
raise Exception('The population is empty')
pos = -1
for i in range(len(self.population)):
if self.population[i].result == None:
pos = i
break
if pos != -1:
indiv = copy.deepcopy(self.population[pos])
self.population.pop(pos)
temp = json.loads(graph_dumps(indiv.config))
else:
random.shuffle(self.population)
if self.population[0].result > self.population[1].result:
self.population[0] = self.population[1]
indiv = copy.deepcopy(self.population[0])
self.population.pop(1)
indiv.mutation()
graph = indiv.config
temp = json.loads(graph_dumps(graph))
logger.debug('generate_parameter return value is:')
logger.debug(temp)
return temp
def receive_trial_result(self, parameter_id, parameters, reward):
'''
Record an observation of the objective function
parameter_id : int
parameters : dict of parameters
reward : reward of one trial
'''
if self.optimize_mode is OptimizeMode.Minimize:
reward = -reward
logger.debug('receive trial result is:\n')
logger.debug(str(parameters))
logger.debug(str(reward))
indiv = graph_loads(parameters)
indiv.result = reward
self.population.append(indiv)
return
def update_search_space(self, data):
pass
if __name__ =='__main__':
tuner = CustomerTuner(OptimizeMode.Maximize)
config = tuner.generate_parameter(0)
with open('./data.json', 'w') as outfile:
json.dump(config, outfile)
tuner.receive_trial_result(0, config, 0.99)
# -*- coding: utf-8 -*-
import copy
import json
import random
from enum import Enum, unique
@unique
class LayerType(Enum):
attention = 0
self_attention = 1
rnn = 2
input = 3
output = 4
class Layer(object):
def __init__(self, type, input=None, output=None, size=None):
self.input = input if input is not None else []
self.output = output if output is not None else []
self.type = type
self.is_delete = False
self.size = size
if type == LayerType.attention.value:
self.input_size = 2
self.output_size = 1
elif type == LayerType.rnn.value:
self.input_size = 1
self.output_size = 1
elif type == LayerType.self_attention.value:
self.input_size = 1
self.output_size = 1
elif type == LayerType.input.value:
self.input_size = 0
self.output_size = 1
elif type == LayerType.output.value:
self.input_size = 1
self.output_size = 0
else:
print(type)
def set_size(self, id, size):
if self.type == LayerType.attention.value:
if self.input[0] == id:
self.size = size
if self.type == LayerType.rnn.value:
self.size = size
if self.type == LayerType.self_attention.value:
self.size = size
if self.type == LayerType.output.value:
if self.size != size:
return False
return True
def clear_size(self):
if self.type == LayerType.attention.value or LayerType.rnn.value or LayerType.self_attention.value:
self.size = None
def __str__(self):
return 'input:' + str(self.input) + ' output:' + str(self.output) + ' type:' + str(
self.type) + ' is_delete:' + str(self.is_delete) + ' size:' + str(self.size)
def graph_dumps(graph):
return json.dumps(graph, default=lambda obj: obj.__dict__)
def graph_loads(js):
layers = []
for layer in js['layers']:
p = Layer(layer['type'],layer['input'],layer['output'],layer['size'])
p.is_delete = layer['is_delete']
layers.append(p)
graph = Graph(js['max_layer_num'],[], [], [])
graph.layers = layers
return graph
class Graph(object):
def __init__(self, max_layer_num, input, output, hide):
self.layers = []
self.max_layer_num = max_layer_num
for layer in input:
self.layers.append(layer)
for layer in output:
self.layers.append(layer)
if hide is not None:
for layer in hide:
self.layers.append(layer)
assert self.is_legal()
def is_topology(self, layers=None):
if layers == None:
layers = self.layers
layers_nodle = []
xx = []
for i in range(len(layers)):
if layers[i].is_delete == False:
layers_nodle.append(i)
while True:
flag_break = True
layers_toremove = []
for layer1 in layers_nodle:
flag_arrive = True
for layer2 in layers[layer1].input:
if layer2 in layers_nodle:
flag_arrive = False
if flag_arrive == True:
for layer2 in layers[layer1].output:
if layers[layer2].set_size(layer1, layers[layer1].size) == False: # Size is error
return False
layers_toremove.append(layer1)
xx.append(layer1)
flag_break = False
for layer in layers_toremove:
layers_nodle.remove(layer)
xx.append('|')
if flag_break == True:
break
if len(layers_nodle) > 0: # There is loop in graph || some layers can't to arrive
return False
return xx
def layer_num(self, layers=None):
if layers == None:
layers = self.layers
layer_num = 0
for layer in layers:
if layer.is_delete == False and layer.type != LayerType.input.value and layer.type != LayerType.output.value:
layer_num += 1
return layer_num
def is_legal(self, layers=None):
if layers == None:
layers = self.layers
for layer in layers:
if layer.is_delete == False:
if len(layer.input) != layer.input_size:
return False
if len(layer.output) < layer.output_size:
return False
# layer_num <= max_layer_num
if self.layer_num(layers) > self.max_layer_num:
return False
if self.is_topology(layers) == False: # There is loop in graph || some layers can't to arrive
return False
return True
def mutation(self, only_add=False):
types = []
if self.layer_num() < self.max_layer_num:
types.append(0)
types.append(1)
if self.layer_num() > 0:
types.append(2)
types.append(3)
# 0 : add a layer , delete a edge
# 1 : add a layer , change a edge
# 2 : delete a layer, delete a edge
# 3 : delete a layer, change a edge
type = random.choice(types)
layer_type = random.choice([LayerType.attention.value, LayerType.self_attention.value, LayerType.rnn.value])
layers = copy.deepcopy(self.layers)
cnt_try = 0
while True:
layers_in = []
layers_out = []
layers_del = []
for layer1 in range(len(layers)):
layer = layers[layer1]
if layer.is_delete == False:
if layer.type != LayerType.output.value:
layers_in.append(layer1)
if layer.type != LayerType.input.value:
layers_out.append(layer1)
if layer.type != LayerType.output.value and layer.type != LayerType.input.value:
layers_del.append(layer1)
if type <= 1:
new_id = len(layers)
out = random.choice(layers_out)
input = []
output = [out]
pos = random.randint(0, len(layers[out].input) - 1)
last_in = layers[out].input[pos]
layers[out].input[pos] = new_id
if type == 0:
layers[last_in].output.remove(out)
if type == 1:
layers[last_in].output.remove(out)
layers[last_in].output.append(new_id)
input = [last_in]
lay = Layer(type=layer_type, input=input, output=output)
while len(input) < lay.input_size:
layer1 = random.choice(layers_in)
input.append(layer1)
layers[layer1].output.append(new_id)
lay.input = input
layers.append(lay)
else:
layer1 = random.choice(layers_del)
for layer2 in layers[layer1].output:
layers[layer2].input.remove(layer1)
if type == 2:
v2 = random.choice(layers_in)
else:
v2 = random.choice(layers[layer1].input)
layers[layer2].input.append(v2)
layers[v2].output.append(layer2)
for layer2 in layers[layer1].input:
layers[layer2].output.remove(layer1)
layers[layer1].is_delete = True
if self.is_legal(layers):
self.layers = layers
break
else:
layers = copy.deepcopy(self.layers)
cnt_try += 1
def __str__(self):
info = ""
for id, layer in enumerate(self.layers):
if layer.is_delete == False:
info += 'id:%d ' % id + str(layer) + '\n'
return info
if __name__ == '__main__':
graph = Graph(10,
input=[Layer(LayerType.input.value, output=[4, 5], size='x'), Layer(LayerType.input.value, output=[4, 5], size='y')],
output=[Layer(LayerType.output.value, input=[4], size='x'), Layer(LayerType.output.value, input=[5], size='y')],
hide=[Layer(LayerType.attention.value, input=[0, 1], output=[2]), Layer(LayerType.attention.value, input=[1, 0], output=[3])])
s = graph_dumps(graph)
g = graph_loads(json.loads(s))
print(g)
print(s)
s = '''{"count":2,"array":[{"input":%s,"output":{"output":0.7}}]}'''%s
print(len(s))
print(s)
\ No newline at end of file
# Usage:
# pylint --rcfile=PATH_TO_THIS_FILE PACKAGE_NAME
# or
# pylint --rcfile=PATH_TO_THIS_FILE SOURCE_FILE.py
[SETTINGS]
max-line-length=140
max-args=5
max-locals=15
max-statements=50
max-attributes=7
const-naming-style=any
disable=duplicate-code,
super-init-not-called
#Build result
dist/
#node modules
node_modules/
/**
* Copyright (c) Microsoft Corporation
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict';
import * as ioc from 'typescript-ioc';
// tslint:disable-next-line:no-any
const Inject: (...args: any[]) => any = ioc.Inject;
const Singleton: (target: Function) => void = ioc.Singleton;
const Container = ioc.Container;
const Provides = ioc.Provides;
function get<T>(source: Function): T {
return ioc.Container.get(source) as T;
}
export { Provides, Container, Inject, Singleton, get };
/**
* Copyright (c) Microsoft Corporation
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict';
import { ExperimentProfile, TrialJobStatistics } from './manager';
import { TrialJobDetail, TrialJobStatus } from './trainingService';
type TrialJobEvent = TrialJobStatus | 'USER_TO_CANCEL' | 'ADD_CUSTOMIZED';
type MetricType = 'PERIODICAL' | 'FINAL' | 'CUSTOM';
interface ExperimentProfileRecord {
readonly timestamp: Date;
readonly experimentId: number;
readonly revision: number;
readonly data: ExperimentProfile;
}
interface TrialJobEventRecord {
readonly timestamp: Date;
readonly trialJobId: string;
readonly event: TrialJobEvent;
readonly data?: string;
readonly logPath?: string;
}
interface MetricData {
readonly parameter_id: string;
readonly trial_job_id: string;
readonly type: MetricType;
readonly sequence: number;
readonly value: any;
}
interface MetricDataRecord {
readonly timestamp: Date;
readonly trialJobId: string;
readonly parameterId: string;
readonly type: MetricType;
readonly sequence: number;
readonly data: any;
}
interface TrialJobInfo {
id: string;
status: TrialJobStatus;
startTime?: Date;
endTime?: Date;
hyperParameters?: string;
logPath?: string;
finalMetricData?: string;
stderrPath?: string;
}
abstract class DataStore {
public abstract init(): Promise<void>;
public abstract close(): Promise<void>;
public abstract storeExperimentProfile(experimentProfile: ExperimentProfile): Promise<void>;
public abstract getExperimentProfile(experimentId: string): Promise<ExperimentProfile>;
public abstract storeTrialJobEvent(event: TrialJobEvent, trialJobId: string, data?: string, logPath?: string): Promise<void>;
public abstract getTrialJobStatistics(): Promise<TrialJobStatistics[]>;
public abstract listTrialJobs(status?: TrialJobStatus): Promise<TrialJobInfo[]>;
public abstract getTrialJob(trialJobId: string): Promise<TrialJobInfo>;
public abstract storeMetricData(trialJobId: string, data: string): Promise<void>;
public abstract getMetricData(trialJobId: string, metricType: MetricType): Promise<MetricDataRecord[]>;
}
abstract class Database {
public abstract init(createNew: boolean, dbDir: string): Promise<void>;
public abstract close(): Promise<void>;
public abstract storeExperimentProfile(experimentProfile: ExperimentProfile): Promise<void>;
public abstract queryExperimentProfile(experimentId: string, revision?: number): Promise<ExperimentProfile[]>;
public abstract queryLatestExperimentProfile(experimentId: string): Promise<ExperimentProfile>;
public abstract storeTrialJobEvent(event: TrialJobEvent, trialJobId: string, data?: string, logPath?: string): Promise<void>;
public abstract queryTrialJobEvent(trialJobId?: string, event?: TrialJobEvent): Promise<TrialJobEventRecord[]>;
public abstract storeMetricData(trialJobId: string, data: string): Promise<void>;
public abstract queryMetricData(trialJobId?: string, type?: MetricType): Promise<MetricDataRecord[]>;
}
export {
DataStore, Database, TrialJobEvent, MetricType, MetricData, TrialJobInfo,
ExperimentProfileRecord, TrialJobEventRecord, MetricDataRecord
}
/**
* Copyright (c) Microsoft Corporation
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict';
export namespace NNIErrorNames {
export const NOT_FOUND: string = 'NOT_FOUND';
export const INVALID_JOB_DETAIL: string = 'NO_VALID_JOB_DETAIL_FOUND';
export const RESOURCE_NOT_AVAILABLE: string = 'RESOURCE_NOT_AVAILABLE';
}
export class NNIError extends Error {
constructor (name: string, message: string) {
super(message);
this.name = name;
}
}
export class MethodNotImplementedError extends Error {
constructor() {
super('Method not implemented.');
}
}
/**
* Copyright (c) Microsoft Corporation
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict';
import * as assert from 'assert';
import * as component from '../common/component';
@component.Singleton
class ExperimentStartupInfo {
private experimentId: string = '';
private newExperiment: boolean = true;
private initialized: boolean = false;
public setStartupInfo(newExperiment: boolean, experimentId: string): void {
assert(!this.initialized);
assert(experimentId.trim().length > 0);
this.newExperiment = newExperiment;
this.experimentId = experimentId;
this.initialized = true;
}
public getExperimentId(): string {
assert(this.initialized);
return this.experimentId;
}
public isNewExperiment(): boolean {
assert(this.initialized);
return this.newExperiment;
}
}
function getExperimentId(): string {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getExperimentId();
}
function isNewExperiment(): boolean {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).isNewExperiment();
}
function setExperimentStartupInfo(newExperiment: boolean, experimentId: string): void {
component.get<ExperimentStartupInfo>(ExperimentStartupInfo).setStartupInfo(newExperiment, experimentId);
}
export { ExperimentStartupInfo, getExperimentId, isNewExperiment, setExperimentStartupInfo };
/**
* Copyright (c) Microsoft Corporation
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict';
/* tslint:disable:no-any */
import * as fs from 'fs';
import * as path from 'path';
import { Writable } from 'stream';
import { WritableStreamBuffer } from 'stream-buffers';
import { format } from 'util';
import * as component from '../common/component';
import { getLogDir } from './utils';
const CRITICAL: number = 1;
const ERROR: number = 2;
const WARNING: number = 3;
const INFO: number = 4;
const DEBUG: number = 5;
class BufferSerialEmitter {
private buffer: Buffer;
private emitting: boolean;
private writable: Writable;
constructor(writable: Writable) {
this.buffer = new Buffer(0);
this.emitting = false;
this.writable = writable;
}
public feed(buffer: Buffer): void {
this.buffer = Buffer.concat([this.buffer, buffer]);
if (!this.emitting) {
this.emit();
}
}
private emit(): void {
this.emitting = true;
this.writable.write(this.buffer, () => {
if (this.buffer.length === 0) {
this.emitting = false;
} else {
this.emit();
}
});
this.buffer = new Buffer(0);
}
}
@component.Singleton
class Logger {
private DEFAULT_LOGFILE: string = path.join(getLogDir(), 'nnimanager.log');
private level: number = DEBUG;
private bufferSerialEmitter: BufferSerialEmitter;
constructor(fileName?: string) {
let logFile: string | undefined = fileName;
if (logFile === undefined) {
logFile = this.DEFAULT_LOGFILE;
}
this.bufferSerialEmitter = new BufferSerialEmitter(fs.createWriteStream(logFile, {
flags: 'a+',
encoding: 'utf8',
autoClose: true
}));
}
public debug(...param: any[]): void {
if (this.level >= DEBUG) {
this.log('DEBUG', param);
}
}
public info(...param: any[]): void {
if (this.level >= INFO) {
this.log('INFO', param);
}
}
public warning(...param: any[]): void {
if (this.level >= WARNING) {
this.log('WARNING', param);
}
}
public error(...param: any[]): void {
if (this.level >= ERROR) {
this.log('ERROR', param);
}
}
public critical(...param: any[]): void {
this.log('CRITICAL', param);
}
private log(level: string, param: any[]): void {
const buffer: WritableStreamBuffer = new WritableStreamBuffer();
buffer.write(`[${(new Date()).toISOString()}] ${level} `);
buffer.write(format.apply(null, param));
buffer.write('\n');
buffer.end();
this.bufferSerialEmitter.feed(buffer.getContents());
}
}
function getLogger(fileName?: string): Logger {
component.Container.bind(Logger).provider({
get: (): Logger => new Logger(fileName)
});
return component.get(Logger);
}
export { Logger, getLogger };
/**
* Copyright (c) Microsoft Corporation
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict';
import { MetricDataRecord, MetricType, TrialJobInfo } from './datastore';
import { TrialJobStatus } from './trainingService';
type ProfileUpdateType = 'TRIAL_CONCURRENCY' | 'MAX_EXEC_DURATION' | 'SEARCH_SPACE';
interface ExperimentParams {
authorName: string;
experimentName: string;
trialConcurrency: number;
maxExecDuration: number; //seconds
maxTrialNum: number;
searchSpace: string;
tuner: {
tunerCommand: string;
tunerCwd: string;
tunerCheckpointDirectory: string;
tunerGpuNum?: number;
};
assessor?: {
assessorCommand: string;
assessorCwd: string;
assessorCheckpointDirectory: string;
assessorGpuNum?: number;
};
clusterMetaData?: {
key: string;
value: string;
}[];
}
interface ExperimentProfile {
params: ExperimentParams;
id: string;
execDuration: number;
startTime?: Date;
endTime?: Date;
revision: number;
}
interface TrialJobStatistics {
trialJobStatus: TrialJobStatus;
trialJobNumber: number;
}
abstract class Manager {
public abstract startExperiment(experimentParams: ExperimentParams): Promise<string>;
public abstract resumeExperiment(): Promise<void>;
public abstract stopExperiment(): Promise<void>;
public abstract getExperimentProfile(): Promise<ExperimentProfile>;
public abstract updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void>;
public abstract addCustomizedTrialJob(hyperParams: string): Promise<void>;
public abstract cancelTrialJobByUser(trialJobId: string): Promise<void>;
public abstract listTrialJobs(status?: TrialJobStatus): Promise<TrialJobInfo[]>;
public abstract getTrialJob(trialJobId: string): Promise<TrialJobInfo>;
public abstract setClusterMetadata(key: string, value: string): Promise<void>;
public abstract getClusterMetadata(key: string): Promise<string>;
public abstract getMetricData(trialJobId: string, metricType: MetricType): Promise<MetricDataRecord[]>;
public abstract getTrialJobStatistics(): Promise<TrialJobStatistics[]>;
}
export { Manager, ExperimentParams, ExperimentProfile, TrialJobStatistics, ProfileUpdateType };
/**
* Copyright (c) Microsoft Corporation
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict';
import * as rx from 'rx';
import * as component from '../common/component';
@component.Singleton
class ObservableTimer {
private observableSource: rx.Observable<number>;
constructor() {
// TODO: move 100 and 1000 into constants class
this.observableSource = rx.Observable.timer(100, 1000).takeWhile(() => true);
}
public subscribe(onNext?: (value: any) => void, onError?: (exception: any) => void, onCompleted?: () => void): Rx.IDisposable {
return this.observableSource.subscribe(onNext, onError, onCompleted);
}
public unsubscribe( subscription : Rx.IDisposable) {
if(typeof subscription !== undefined) {
subscription.dispose();
}
}
}
export { ObservableTimer };
/**
* Copyright (c) Microsoft Corporation
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict';
/**
* define TrialJobStatus
*/
type TrialJobStatus = 'UNKNOWN' | 'WAITING' | 'RUNNING' | 'SUCCEEDED' | 'FAILED' | 'USER_CANCELED' | 'SYS_CANCELED';
type JobType = 'TRIAL' | 'HOST';
interface TrainingServiceMetadata {
readonly key: string;
readonly value: string;
}
/**
* define JobApplicationForm
*/
interface JobApplicationForm {
readonly jobType: JobType;
}
/**
* define TrialJobApplicationForm
*/
interface TrialJobApplicationForm extends JobApplicationForm {
readonly hyperParameters: string;
}
/**
* define HostJobApplicationForm
*/
interface HostJobApplicationForm extends JobApplicationForm {
readonly host: string;
readonly cmd: string;
}
/**
* define TrialJobDetail
*/
interface TrialJobDetail {
readonly id: string;
readonly status: TrialJobStatus;
readonly submitTime: Date;
readonly startTime?: Date;
readonly endTime?: Date;
readonly tags?: string[];
readonly url?: string;
readonly workingDirectory: string;
readonly form: JobApplicationForm;
}
interface HostJobDetail {
readonly id: string;
readonly status: string;
}
/**
* define TrialJobMetric
*/
interface TrialJobMetric {
readonly id: string;
readonly data: string;
}
/**
* define TrainingServiceError
*/
class TrainingServiceError extends Error {
private errCode: number;
constructor(errorCode: number, errorMessage: string) {
super(errorMessage);
this.errCode = errorCode;
}
get errorCode(): number {
return this.errCode;
}
}
/**
* define TrainingService
*/
abstract class TrainingService {
public abstract listTrialJobs(): Promise<TrialJobDetail[]>;
public abstract getTrialJob(trialJobId: string): Promise<TrialJobDetail>;
public abstract addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void;
public abstract removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void;
public abstract submitTrialJob(form: JobApplicationForm): Promise<TrialJobDetail>;
public abstract cancelTrialJob(trialJobId: string): Promise<void>;
public abstract setClusterMetadata(key: string, value: string): Promise<void>;
public abstract getClusterMetadata(key: string): Promise<string>;
public abstract cleanUp(): Promise<void>;
public abstract run(): Promise<void>;
}
export {
TrainingService, TrainingServiceError, TrialJobStatus, TrialJobApplicationForm,
TrainingServiceMetadata, TrialJobDetail, TrialJobMetric,
HostJobApplicationForm, JobApplicationForm, JobType
};
/**
* Copyright (c) Microsoft Corporation
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict';
import { randomBytes } from 'crypto';
import * as fs from 'fs';
import * as os from 'os';
import * as path from 'path';
import { Deferred } from 'ts-deferred';
import { Container } from 'typescript-ioc';
import * as util from 'util';
import { Database, DataStore } from './datastore';
import { ExperimentStartupInfo, setExperimentStartupInfo, getExperimentId } from './experimentStartupInfo';
import { Manager } from './manager';
import { TrainingService } from './trainingService';
function getExperimentRootDir(): string{
return path.join(os.homedir(), 'nni', 'experiments', getExperimentId());
}
function getLogDir(): string{
return path.join(getExperimentRootDir(), 'log');
}
function getDefaultDatabaseDir(): string {
return path.join(getExperimentRootDir(), 'db');
}
function mkDirP(dirPath: string): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
fs.exists(dirPath, (exists: boolean) => {
if (exists) {
deferred.resolve();
} else {
const parent: string = path.dirname(dirPath);
mkDirP(parent).then(() => {
fs.mkdir(dirPath, (err: Error) => {
if (err) {
deferred.reject(err);
} else {
deferred.resolve();
}
});
}).catch((err: Error) => {
deferred.reject(err);
});
}
});
return deferred.promise;
}
function mkDirPSync(dirPath: string): void {
if (fs.existsSync(dirPath)) {
return;
}
mkDirPSync(path.dirname(dirPath));
fs.mkdirSync(dirPath);
}
const delay: (ms: number) => Promise<void> = util.promisify(setTimeout);
/**
* Convert index to character
* @param index index
* @returns a mapping character
*/
function charMap(index: number): number {
if (index < 26) {
return index + 97;
} else if (index < 52) {
return index - 26 + 65;
} else {
return index - 52 + 48;
}
}
/**
* Generate a unique string by length
* @param len length of string
* @returns a unique string
*/
function uniqueString(len: number): string {
if (len === 0) {
return '';
}
const byteLength: number = Math.ceil((Math.log2(52) + Math.log2(62) * (len - 1)) / 8);
let num: number = randomBytes(byteLength).reduce((a: number, b: number) => a * 256 + b, 0);
const codes: number[] = [];
codes.push(charMap(num % 52));
num = Math.floor(num / 52);
for (let i: number = 1; i < len; i++) {
codes.push(charMap(num % 62));
num = Math.floor(num / 62);
}
return String.fromCharCode(...codes);
}
function parseArg(names: string[]): string {
if (process.argv.length >= 4) {
for (let i: number = 2; i < process.argv.length - 1; i++) {
if (names.includes(process.argv[i])) {
return process.argv[i + 1];
}
}
}
return '';
}
/**
* Initialize a pseudo experiment environment for unit test.
* Must be paired with `cleanupUnitTest()`.
*/
function prepareUnitTest(): void {
Container.snapshot(ExperimentStartupInfo);
Container.snapshot(Database);
Container.snapshot(DataStore);
Container.snapshot(TrainingService);
Container.snapshot(Manager);
setExperimentStartupInfo(true, 'unittest');
mkDirPSync(getLogDir());
const sqliteFile: string = path.join(getDefaultDatabaseDir(), 'nni.sqlite');
try {
fs.unlinkSync(sqliteFile);
} catch (err) {
// file not exists, good
}
}
/**
* Clean up unit test pseudo experiment.
* Must be paired with `prepareUnitTest()`.
*/
function cleanupUnitTest(): void {
Container.restore(Manager);
Container.restore(TrainingService);
Container.restore(DataStore);
Container.restore(Database);
Container.restore(ExperimentStartupInfo);
}
export { getLogDir, getExperimentRootDir, getDefaultDatabaseDir, mkDirP, delay, prepareUnitTest,
parseArg, cleanupUnitTest, uniqueString };
/**
* Copyright (c) Microsoft Corporation
* All rights reserved.
* MIT License
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
const INITIALIZE = 'IN';
const REQUEST_TRIAL_JOBS = 'GE';
const REPORT_METRIC_DATA = 'ME';
const UPDATE_SEARCH_SPACE = 'SS';
const ADD_CUSTOMIZED_TRIAL_JOB = 'AD';
const TRIAL_END = 'EN';
const TERMINATE = 'TE';
const NEW_TRIAL_JOB = 'TR';
const NO_MORE_TRIAL_JOBS = 'NO';
const KILL_TRIAL_JOB = 'KI';
const TUNER_COMMANDS: Set<string> = new Set([
INITIALIZE,
REQUEST_TRIAL_JOBS,
REPORT_METRIC_DATA,
UPDATE_SEARCH_SPACE,
ADD_CUSTOMIZED_TRIAL_JOB,
TERMINATE,
NEW_TRIAL_JOB,
NO_MORE_TRIAL_JOBS
]);
const ASSESSOR_COMMANDS: Set<string> = new Set([
INITIALIZE,
REPORT_METRIC_DATA,
TRIAL_END,
TERMINATE,
KILL_TRIAL_JOB
]);
export {
INITIALIZE,
REQUEST_TRIAL_JOBS,
REPORT_METRIC_DATA,
UPDATE_SEARCH_SPACE,
ADD_CUSTOMIZED_TRIAL_JOB,
TRIAL_END,
TERMINATE,
NEW_TRIAL_JOB,
NO_MORE_TRIAL_JOBS,
KILL_TRIAL_JOB,
TUNER_COMMANDS,
ASSESSOR_COMMANDS
};
/**
* Copyright (c) Microsoft Corporation
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict';
import * as assert from 'assert';
import { ChildProcess } from 'child_process';
import { EventEmitter } from 'events';
import { Readable, Writable } from 'stream';
import { getLogger, Logger } from '../common/log';
import * as CommandType from './commands';
const ipcOutgoingFd: number = 3;
const ipcIncomingFd: number = 4;
/**
* Encode a command
* @param commandType a command type defined in 'core/commands'
* @param content payload of the command
* @returns binary command data
*/
function encodeCommand(commandType: string, content: string): Buffer {
const contentBuffer: Buffer = Buffer.from(content);
if (contentBuffer.length >= 1_000_000) {
throw new RangeError('Command too long');
}
const contentLengthBuffer: Buffer = Buffer.from(contentBuffer.length.toString().padStart(6, '0'));
return Buffer.concat([Buffer.from(commandType), contentLengthBuffer, contentBuffer]);
}
/**
* Decode a command
* @param Buffer binary incoming data
* @returns a tuple of (success, commandType, content, remain)
* success: true if the buffer contains at least one complete command; otherwise false
* remain: remaining data after the first command
*/
function decodeCommand(data: Buffer): [boolean, string, string, Buffer] {
if (data.length < 8) {
return [false, '', '', data];
}
const commandType: string = data.slice(0, 2).toString();
const contentLength: number = parseInt(data.slice(2, 8).toString(), 10);
if (data.length < contentLength + 8) {
return [false, '', '', data];
}
const content: string = data.slice(8, contentLength + 8).toString();
const remain: Buffer = data.slice(contentLength + 8);
return [true, commandType, content, remain];
}
class IpcInterface {
private acceptCommandTypes: Set<string>;
private outgoingStream: Writable;
private incomingStream: Readable;
private eventEmitter: EventEmitter;
private readBuffer: Buffer;
private logger: Logger = getLogger();
/**
* Construct a IPC proxy
* @param proc the process to wrap
* @param acceptCommandTypes set of accepted commands for this process
*/
constructor(proc: ChildProcess, acceptCommandTypes: Set<string>) {
this.acceptCommandTypes = acceptCommandTypes;
this.outgoingStream = <Writable>proc.stdio[ipcOutgoingFd];
this.incomingStream = <Readable>proc.stdio[ipcIncomingFd];
this.eventEmitter = new EventEmitter();
this.readBuffer = Buffer.alloc(0);
this.incomingStream.on('data', (data: Buffer) => { this.receive(data); });
}
/**
* Send a command to process
* @param commandType: a command type defined in 'core/commands'
* @param content: payload of command
*/
public sendCommand(commandType: string, content: string = ''): void {
assert.ok(this.acceptCommandTypes.has(commandType));
const data: Buffer = encodeCommand(commandType, content);
if (!this.outgoingStream.write(data)) {
//this.logger.warning('Commands jammed in buffer!');
}
}
/**
* Add a command listener
* @param listener the listener callback
*/
public onCommand(listener: (commandType: string, content: string) => void): void {
this.eventEmitter.on('command', listener);
}
/**
* Deal with incoming data from process
* Invoke listeners for each complete command received, save incomplete command to buffer
* @param data binary incoming data
*/
private receive(data: Buffer): void {
this.readBuffer = Buffer.concat([this.readBuffer, data]);
while (this.readBuffer.length > 0) {
const [success, commandType, content, remain] = decodeCommand(this.readBuffer);
if (!success) {
break;
}
assert.ok(this.acceptCommandTypes.has(commandType));
this.eventEmitter.emit('command', commandType, content);
this.readBuffer = remain;
}
}
}
/**
* Create IPC proxy for tuner process
* @param process_ the tuner process
*/
function createTunerInterface(process: ChildProcess): IpcInterface {
return new IpcInterface(process, CommandType.TUNER_COMMANDS);
}
/**
* Create IPC proxy for assessor process
* @param process_ the assessor process
*/
function createAssessorInterface(process: ChildProcess): IpcInterface {
return new IpcInterface(process, CommandType.ASSESSOR_COMMANDS);
}
export { IpcInterface, createTunerInterface, createAssessorInterface };
/**
* Copyright (c) Microsoft Corporation
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict';
import * as assert from 'assert';
import { Deferred } from 'ts-deferred';
import * as component from '../common/component';
import { Database, DataStore, MetricData, MetricDataRecord, MetricType,
TrialJobEvent, TrialJobEventRecord, TrialJobInfo } from '../common/datastore';
import { isNewExperiment } from '../common/experimentStartupInfo';
import { getLogger, Logger } from '../common/log';
import { ExperimentProfile, TrialJobStatistics } from '../common/manager';
import { TrialJobStatus } from '../common/trainingService';
import { getDefaultDatabaseDir, mkDirP } from '../common/utils';
class NNIDataStore implements DataStore {
private db: Database = component.get(Database);
private log: Logger = getLogger();
private initTask!: Deferred<void>;
public init(): Promise<void> {
if (this.initTask !== undefined) {
return this.initTask.promise;
}
this.initTask = new Deferred<void>();
// TODO support specify database dir
const databaseDir: string = getDefaultDatabaseDir();
if(isNewExperiment()) {
mkDirP(databaseDir).then(() => {
this.db.init(true, databaseDir).then(() => {
this.initTask.resolve();
}).catch((err: Error) => {
this.initTask.reject(err);
});
}).catch((err: Error) => {
this.initTask.reject(err);
});
} else {
this.db.init(false, databaseDir).then(() => {
this.initTask.resolve();
}).catch((err: Error) => {
this.initTask.reject(err);
});
}
return this.initTask.promise;
}
public async close(): Promise<void> {
await this.db.close();
}
public async storeExperimentProfile(experimentProfile: ExperimentProfile): Promise<void> {
await this.db.storeExperimentProfile(experimentProfile);
}
public getExperimentProfile(experimentId: string): Promise<ExperimentProfile> {
return this.db.queryLatestExperimentProfile(experimentId);
}
public storeTrialJobEvent(event: TrialJobEvent, trialJobId: string, data?: string, logPath?: string): Promise<void> {
this.log.debug(`storeTrialJobEvent: event: ${event}, data: ${data}, logpath: ${logPath}`);
return this.db.storeTrialJobEvent(event, trialJobId, data, logPath);
}
public async getTrialJobStatistics(): Promise<any[]> {
const result: TrialJobStatistics[] = [];
const jobs: TrialJobInfo[] = await this.listTrialJobs();
const map: Map<TrialJobStatus, number> = new Map();
jobs.forEach((value: TrialJobInfo) => {
let n: number|undefined = map.get(value.status);
if (!n) {
n = 0;
}
map.set(value.status, n + 1);
});
map.forEach((value: number, key: TrialJobStatus) => {
const statistics: TrialJobStatistics = {
trialJobStatus: key,
trialJobNumber: value
};
result.push(statistics);
});
return result;
}
public listTrialJobs(status?: TrialJobStatus): Promise<TrialJobInfo[]> {
return this.queryTrialJobs(status);
}
public async getTrialJob(trialJobId: string): Promise<TrialJobInfo> {
const trialJobs = await this.queryTrialJobs(undefined, trialJobId);
return trialJobs[0];
}
public async storeMetricData(trialJobId: string, data: string): Promise<void> {
const metrics = JSON.parse(data) as MetricData;
assert(trialJobId === metrics.trial_job_id);
await this.db.storeMetricData(trialJobId, JSON.stringify({
trialJobId: metrics.trial_job_id,
parameterId: metrics.parameter_id,
type: metrics.type,
sequence: metrics.sequence,
data: metrics.value,
timestamp: new Date()
}));
}
public getMetricData(trialJobId: string, metricType: MetricType): Promise<MetricDataRecord[]> {
return this.db.queryMetricData(trialJobId, metricType);
}
private async queryTrialJobs(status?: TrialJobStatus, trialJobId?: string): Promise<TrialJobInfo[]> {
const result: TrialJobInfo[]= [];
const trialJobEvents: TrialJobEventRecord[] = await this.db.queryTrialJobEvent(trialJobId);
if (trialJobEvents === undefined) {
return result;
}
const map: Map<string, TrialJobInfo> = this.getTrialJobsByReplayEvents(trialJobEvents);
for (let key of map.keys()) {
const jobInfo = map.get(key);
if (jobInfo === undefined) {
continue;
}
if (!(status !== undefined && jobInfo.status !== status)) {
if (jobInfo.status === 'SUCCEEDED') {
jobInfo.finalMetricData = await this.getFinalMetricData(jobInfo.id);
}
result.push(jobInfo);
}
}
return result;
}
private async getFinalMetricData(trialJobId: string): Promise<any> {
const metrics: MetricDataRecord[] = await this.getMetricData(trialJobId, 'FINAL');
assert(metrics.length <= 1);
if (metrics.length === 1) {
return metrics[0];
} else {
return undefined;
}
}
private getJobStatusByLatestEvent(event: TrialJobEvent): TrialJobStatus {
switch (event) {
case 'USER_TO_CANCEL':
return 'USER_CANCELED';
case 'ADD_CUSTOMIZED':
return 'WAITING';
default:
}
return <TrialJobStatus>event;
}
private getTrialJobsByReplayEvents(trialJobEvents: TrialJobEventRecord[]): Map<string, TrialJobInfo> {
const map: Map<string, TrialJobInfo> = new Map();
// assume data is stored by time ASC order
for (const record of trialJobEvents) {
let jobInfo: TrialJobInfo | undefined;
if (map.has(record.trialJobId)) {
jobInfo = map.get(record.trialJobId);
} else {
jobInfo = {
id: record.trialJobId,
status: this.getJobStatusByLatestEvent(record.event)
};
}
if (!jobInfo) {
throw new Error('Empty JobInfo');
}
switch (record.event) {
case 'RUNNING':
if (record.timestamp !== undefined) {
jobInfo.startTime = record.timestamp;
}
case 'WAITING':
if (record.logPath !== undefined) {
jobInfo.logPath = record.logPath;
}
break;
case 'SUCCEEDED':
case 'FAILED':
case 'USER_CANCELED':
case 'SYS_CANCELED':
if (record.logPath !== undefined) {
jobInfo.logPath = record.logPath;
}
jobInfo.endTime = record.timestamp;
default:
}
jobInfo.status = this.getJobStatusByLatestEvent(record.event);
if (record.data !== undefined && record.data.trim().length > 0) {
jobInfo.hyperParameters = record.data;
}
map.set(record.trialJobId, jobInfo);
}
return map;
}
}
export { NNIDataStore };
/**
* Copyright (c) Microsoft Corporation
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict';
import * as assert from 'assert';
import * as cpp from 'child-process-promise';
import { ChildProcess, spawn } from 'child_process';
import { Deferred } from 'ts-deferred';
import * as component from '../common/component';
import { DataStore, MetricDataRecord, MetricType, TrialJobInfo } from '../common/datastore';
import { getExperimentId } from '../common/experimentStartupInfo';
import { getLogger, Logger } from '../common/log';
import {
ExperimentParams, ExperimentProfile, Manager,
ProfileUpdateType, TrialJobStatistics
} from '../common/manager';
import {
TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, TrialJobStatus
} from '../common/trainingService';
import { delay , getLogDir} from '../common/utils';
import {
ADD_CUSTOMIZED_TRIAL_JOB, KILL_TRIAL_JOB, NEW_TRIAL_JOB, NO_MORE_TRIAL_JOBS, REPORT_METRIC_DATA,
REQUEST_TRIAL_JOBS, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE
} from './commands';
import { createAssessorInterface, createTunerInterface, IpcInterface } from './ipcInterface';
import { TrialJobMaintainerEvent, TrialJobs } from './trialJobs';
/**
* NNIManager
*/
class NNIManager implements Manager {
private trainingService: TrainingService;
private tuner: IpcInterface | undefined;
private assessor: IpcInterface | undefined;
private trialJobsMaintainer: TrialJobs | undefined;
private currSubmittedTrialNum: number; // need to be recovered
private trialConcurrencyReduction: number;
private customizedTrials: string[]; // need to be recovered
private log: Logger;
private dataStore: DataStore;
private experimentProfile: ExperimentProfile;
// TO DO: could use struct here
private tunerPid: number;
private assessorPid: number;
constructor() {
this.currSubmittedTrialNum = 0;
this.trialConcurrencyReduction = 0;
this.customizedTrials = [];
const experimentId: string = getExperimentId();
this.trainingService = component.get(TrainingService);
assert(this.trainingService);
this.tunerPid = 0;
this.assessorPid = 0;
this.log = getLogger();
this.dataStore = component.get(DataStore);
this.experimentProfile = {
id: experimentId,
revision: 0,
execDuration: 0,
params: {
authorName: '',
experimentName: '',
trialConcurrency: 0,
maxExecDuration: 0, // unit: second
maxTrialNum: 0, // maxTrialNum includes all the submitted trial jobs
searchSpace: '',
tuner: {
tunerCommand: '',
tunerCwd: '',
tunerCheckpointDirectory: ''
}
}
};
}
public updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void> {
// TO DO: remove this line, and let rest server do data type validation
experimentProfile.startTime = new Date(<string><any>experimentProfile.startTime);
switch (updateType) {
case 'TRIAL_CONCURRENCY':
this.updateTrialConcurrency(experimentProfile.params.trialConcurrency);
break;
case 'MAX_EXEC_DURATION':
this.updateMaxExecDuration(experimentProfile.params.maxExecDuration);
break;
case 'SEARCH_SPACE':
this.updateSearchSpace(experimentProfile.params.searchSpace);
break;
default:
throw new Error('Error: unrecognized updateType');
}
return this.storeExperimentProfile();
}
public addCustomizedTrialJob(hyperParams: string): Promise<void> {
if (this.currSubmittedTrialNum >= this.experimentProfile.params.maxTrialNum) {
return Promise.reject(
new Error('reach maxTrialNum')
);
}
this.customizedTrials.push(hyperParams);
// trial id has not been generated yet, thus use '' instead
return this.dataStore.storeTrialJobEvent('ADD_CUSTOMIZED', '', hyperParams);
}
public async cancelTrialJobByUser(trialJobId: string): Promise<void> {
await this.trainingService.cancelTrialJob(trialJobId);
await this.dataStore.storeTrialJobEvent('USER_TO_CANCEL', trialJobId, '');
}
public async startExperiment(expParams: ExperimentParams): Promise<string> {
this.log.debug(`Starting experiment: ${this.experimentProfile.id}`);
this.experimentProfile.params = expParams;
await this.storeExperimentProfile();
this.log.debug('Setup tuner...');
this.setupTuner(
expParams.tuner.tunerCommand,
expParams.tuner.tunerCwd,
'start',
expParams.tuner.tunerCheckpointDirectory);
if (expParams.assessor !== undefined) {
this.log.debug('Setup assessor...');
this.setupAssessor(
expParams.assessor.assessorCommand,
expParams.assessor.assessorCwd,
'start',
expParams.assessor.assessorCheckpointDirectory
);
}
this.experimentProfile.startTime = new Date();
await this.storeExperimentProfile();
this.run().catch(err => {
this.log.error(err.stack);
});
return this.experimentProfile.id;
}
public async resumeExperiment(): Promise<void> {
//Fetch back the experiment profile
const experimentId: string = getExperimentId();
this.experimentProfile = await this.dataStore.getExperimentProfile(experimentId);
const expParams: ExperimentParams = this.experimentProfile.params;
this.setupTuner(
expParams.tuner.tunerCommand,
expParams.tuner.tunerCwd,
'resume',
expParams.tuner.tunerCheckpointDirectory);
if (expParams.assessor !== undefined) {
this.setupAssessor(
expParams.assessor.assessorCommand,
expParams.assessor.assessorCwd,
'resume',
expParams.assessor.assessorCheckpointDirectory
);
}
const allTrialJobs: TrialJobInfo[] = await this.dataStore.listTrialJobs();
// Resume currSubmittedTrialNum
this.currSubmittedTrialNum = allTrialJobs.length;
// Check the final status for WAITING and RUNNING jobs
await Promise.all(allTrialJobs
.filter((job: TrialJobInfo) => job.status === 'WAITING' || job.status === 'RUNNING')
.map((job: TrialJobInfo) => this.dataStore.storeTrialJobEvent('FAILED', job.id)));
// TO DO: update database record for resume event
this.run().catch(console.error);
}
public getTrialJob(trialJobId: string): Promise<TrialJobDetail> {
return Promise.resolve(
this.trainingService.getTrialJob(trialJobId)
);
}
public async setClusterMetadata(key: string, value: string): Promise<void> {
let timeoutId: NodeJS.Timer;
// TO DO: move timeout value to constants file
const delay1: Promise<{}> = new Promise((resolve: Function, reject: Function): void => {
timeoutId = setTimeout(
() => { reject(new Error('TrainingService setClusterMetadata timeout.')); },
10000);
});
await Promise.race([delay1, this.trainingService.setClusterMetadata(key, value)]).finally(() => {
clearTimeout(timeoutId);
});
}
public getClusterMetadata(key: string): Promise<string> {
return Promise.resolve(
this.trainingService.getClusterMetadata(key)
);
}
public async getTrialJobStatistics(): Promise<TrialJobStatistics[]> {
return this.dataStore.getTrialJobStatistics();
}
public stopExperiment(): Promise<void> {
if (this.trialJobsMaintainer !== undefined) {
this.trialJobsMaintainer.setStopLoop();
return Promise.resolve();
} else {
return Promise.reject(new Error('Error: undefined trialJobsMaintainer'));
}
}
public async getMetricData(trialJobId: string, metricType: MetricType): Promise<MetricDataRecord[]> {
return this.dataStore.getMetricData(trialJobId, metricType);
}
public getExperimentProfile(): Promise<ExperimentProfile> {
// TO DO: using Promise.resolve()
const deferred: Deferred<ExperimentProfile> = new Deferred<ExperimentProfile>();
deferred.resolve(this.experimentProfile);
return deferred.promise;
}
public async listTrialJobs(status?: TrialJobStatus): Promise<TrialJobInfo[]> {
return this.dataStore.listTrialJobs(status);
}
private setupTuner(command: string, cwd: string, mode: 'start' | 'resume', dataDirectory: string): void {
if (this.tuner !== undefined) {
return;
}
const stdio: (string | NodeJS.WriteStream)[] = ['ignore', process.stdout, process.stderr, 'pipe', 'pipe'];
let newCwd: string;
if (cwd === undefined || cwd === '') {
newCwd = getLogDir();
} else {
newCwd = cwd;
}
// TO DO: add CUDA_VISIBLE_DEVICES
const tunerProc: ChildProcess = spawn(command, [], {
stdio,
cwd: newCwd,
env: {
NNI_MODE: mode,
NNI_CHECKPOINT_DIRECTORY: dataDirectory,
NNI_LOG_DIRECTORY: getLogDir()
},
shell: true
});
this.tunerPid = tunerProc.pid;
this.tuner = createTunerInterface(tunerProc);
return;
}
private setupAssessor(command: string, cwd: string, mode: 'start' | 'resume', dataDirectory: string): void {
if (this.assessor !== undefined) {
return;
}
const stdio: (string | NodeJS.WriteStream)[] = ['ignore', process.stdout, process.stderr, 'pipe', 'pipe'];
let newCwd: string;
if (cwd === undefined || cwd === '') {
newCwd = getLogDir();
} else {
newCwd = cwd;
}
// TO DO: add CUDA_VISIBLE_DEVICES
const assessorProc: ChildProcess = spawn(command, [], {
stdio,
cwd: newCwd,
env: {
NNI_MODE: mode,
NNI_CHECKPOINT_DIRECTORY: dataDirectory,
NNI_LOG_DIRECTORY: getLogDir()
},
shell: true
});
this.assessorPid = assessorProc.pid;
this.assessor = createAssessorInterface(assessorProc);
return;
}
private updateTrialConcurrency(trialConcurrency: number): void {
// TO DO: this method can only be called after startExperiment/resumeExperiment
if (trialConcurrency > this.experimentProfile.params.trialConcurrency) {
if (this.tuner === undefined) {
throw new Error('Error: tuner has to be initialized');
}
this.tuner.sendCommand(
REQUEST_TRIAL_JOBS,
String(trialConcurrency - this.experimentProfile.params.trialConcurrency)
);
} else {
// we assume trialConcurrency >= 0, which is checked by restserver
this.trialConcurrencyReduction += (this.experimentProfile.params.trialConcurrency - trialConcurrency);
}
this.experimentProfile.params.trialConcurrency = trialConcurrency;
return;
}
private updateMaxExecDuration(duration: number): void {
if (this.trialJobsMaintainer !== undefined) {
this.trialJobsMaintainer.updateMaxExecDuration(duration);
}
this.experimentProfile.params.maxExecDuration = duration;
return;
}
private updateSearchSpace(searchSpace: string): void {
if (this.tuner === undefined) {
throw new Error('Error: tuner has not been setup');
}
this.tuner.sendCommand(UPDATE_SEARCH_SPACE, searchSpace);
this.experimentProfile.params.searchSpace = searchSpace;
return;
}
private async experimentDoneCleanUp(): Promise<void> {
if (this.tuner === undefined) {
throw new Error('Error: tuner has not been setup');
}
this.tuner.sendCommand(TERMINATE);
if (this.assessor !== undefined) {
this.assessor.sendCommand(TERMINATE);
}
let tunerAlive: boolean = true;
let assessorAlive: boolean = true;
// gracefully terminate tuner and assessor here, wait at most 30 seconds.
for (let i: number = 0; i < 30; i++) {
if (!tunerAlive && !assessorAlive) { break; }
try {
await cpp.exec(`kill -0 ${this.tunerPid}`);
} catch (error) { tunerAlive = false; }
if (this.assessor !== undefined) {
try {
await cpp.exec(`kill -0 ${this.assessorPid}`);
} catch (error) { assessorAlive = false; }
} else {
assessorAlive = false;
}
await delay(1000);
}
try {
await cpp.exec(`kill ${this.tunerPid}`);
if (this.assessorPid !== undefined) {
await cpp.exec(`kill ${this.assessorPid}`);
}
} catch (error) {
// this.tunerPid does not exist, do nothing here
}
const trialJobList: TrialJobDetail[] = await this.trainingService.listTrialJobs();
// TO DO: to promise all
for (const trialJob of trialJobList) {
if (trialJob.status === 'RUNNING' ||
trialJob.status === 'WAITING') {
try {
await this.trainingService.cancelTrialJob(trialJob.id);
} catch (error) {
// pid does not exist, do nothing here
}
}
}
await this.trainingService.cleanUp();
this.experimentProfile.endTime = new Date();
await this.storeExperimentProfile();
}
private async periodicallyUpdateExecDuration(): Promise<void> {
const startTime: Date = new Date();
const execDuration: number = this.experimentProfile.execDuration;
for (; ;) {
await delay(1000 * 60 * 10); // 10 minutes
this.experimentProfile.execDuration = execDuration + (Date.now() - startTime.getTime()) / 1000;
await this.storeExperimentProfile();
}
}
private storeExperimentProfile(): Promise<void> {
this.experimentProfile.revision += 1;
return this.dataStore.storeExperimentProfile(this.experimentProfile);
}
private runInternal(): Promise<void> {
// TO DO: cannot run this method more than once in one NNIManager instance
if (this.tuner === undefined) {
throw new Error('Error: tuner has not been setup');
}
this.trainingService.addTrialJobMetricListener(async (metric: TrialJobMetric) => {
await this.dataStore.storeMetricData(metric.id, metric.data);
if (this.tuner === undefined) {
throw new Error('Error: tuner has not been setup');
}
this.tuner.sendCommand(REPORT_METRIC_DATA, metric.data);
if (this.assessor !== undefined) {
try {
this.assessor.sendCommand(REPORT_METRIC_DATA, metric.data);
} catch (error) {
this.log.critical(`ASSESSOR ERROR: ${error.message}`);
this.log.critical(`ASSESSOR ERROR: ${error.stack}`);
}
}
});
this.trialJobsMaintainer = new TrialJobs(
this.trainingService,
this.experimentProfile.execDuration,
this.experimentProfile.params.maxExecDuration);
this.trialJobsMaintainer.on(async (event: TrialJobMaintainerEvent, trialJobDetail: TrialJobDetail) => {
if (trialJobDetail !== undefined) {
this.log.debug(`Job event: ${event}, id: ${trialJobDetail.id}`);
} else {
this.log.debug(`Job event: ${event}`);
}
if (this.tuner === undefined) {
throw new Error('Error: tuner has not been setup');
}
switch (event) {
case 'SUCCEEDED':
case 'FAILED':
case 'USER_CANCELED':
case 'SYS_CANCELED':
if (this.trialConcurrencyReduction > 0) {
this.trialConcurrencyReduction--;
} else {
if (this.currSubmittedTrialNum < this.experimentProfile.params.maxTrialNum) {
if (this.customizedTrials.length > 0) {
const hyperParams: string | undefined = this.customizedTrials.shift();
this.tuner.sendCommand(ADD_CUSTOMIZED_TRIAL_JOB, hyperParams);
} else {
this.tuner.sendCommand(REQUEST_TRIAL_JOBS, '1');
}
}
}
if (this.assessor !== undefined) {
this.assessor.sendCommand(TRIAL_END, JSON.stringify({trial_job_id: trialJobDetail.id, event: event}));
}
await this.dataStore.storeTrialJobEvent(event, trialJobDetail.id, undefined, trialJobDetail.url);
break;
case 'RUNNING':
await this.dataStore.storeTrialJobEvent(event, trialJobDetail.id, undefined, trialJobDetail.url);
break;
case 'EXPERIMENT_DONE':
this.log.info('Experiment done, cleaning up...');
await this.experimentDoneCleanUp();
this.log.info('Experiment done.');
break;
default:
throw new Error('Error: unrecognized event from trialJobsMaintainer');
}
});
// TO DO: we should send INITIALIZE command to tuner if user's tuner needs to run init method in tuner
// TO DO: we should send INITIALIZE command to assessor if user's tuner needs to run init method in tuner
this.log.debug(`Send tuner command: update search space: ${this.experimentProfile.params.searchSpace}`)
this.tuner.sendCommand(UPDATE_SEARCH_SPACE, this.experimentProfile.params.searchSpace);
if (this.trialConcurrencyReduction !== 0) {
return Promise.reject(new Error('Error: cannot modify trialConcurrency before startExperiment'));
}
this.log.debug(`Send tuner command: ${this.experimentProfile.params.trialConcurrency}`)
this.tuner.sendCommand(REQUEST_TRIAL_JOBS, String(this.experimentProfile.params.trialConcurrency));
this.tuner.onCommand(async (commandType: string, content: string) => {
this.log.info(`Command from tuner: ${commandType}, ${content}`);
if (this.trialJobsMaintainer === undefined) {
throw new Error('Error: trialJobsMaintainer not initialized');
}
switch (commandType) {
case NEW_TRIAL_JOB:
if (this.currSubmittedTrialNum < this.experimentProfile.params.maxTrialNum) {
this.currSubmittedTrialNum++;
const trialJobAppForm: TrialJobApplicationForm = {
jobType: 'TRIAL',
hyperParameters: content
};
const trialJobDetail: TrialJobDetail = await this.trainingService.submitTrialJob(trialJobAppForm);
this.trialJobsMaintainer.setTrialJob(trialJobDetail.id, Object.assign({}, trialJobDetail));
// TO DO: to uncomment
//assert(trialJobDetail.status === 'WAITING');
await this.dataStore.storeTrialJobEvent(trialJobDetail.status, trialJobDetail.id, content, trialJobDetail.url);
if (this.currSubmittedTrialNum === this.experimentProfile.params.maxTrialNum) {
this.trialJobsMaintainer.setNoMoreTrials();
}
}
break;
case NO_MORE_TRIAL_JOBS:
this.trialJobsMaintainer.setNoMoreTrials();
break;
default:
throw new Error('Error: unsupported command type from tuner');
}
});
if (this.assessor !== undefined) {
this.assessor.onCommand(async (commandType: string, content: string) => {
if (commandType === KILL_TRIAL_JOB) {
await this.trainingService.cancelTrialJob(JSON.parse(content));
} else {
throw new Error('Error: unsupported command type from assessor');
}
});
}
return this.trialJobsMaintainer.run();
}
private async run(): Promise<void> {
await Promise.all([
this.periodicallyUpdateExecDuration(),
this.trainingService.run(),
this.runInternal()]);
}
}
export { NNIManager };
/**
* Copyright (c) Microsoft Corporation
* All rights reserved.
*
* MIT License
*
* Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
* documentation files (the "Software"), to deal in the Software without restriction, including without limitation
* the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
* to permit persons to whom the Software is furnished to do so, subject to the following conditions:
* The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
* BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
'use strict';
import * as assert from 'assert';
import * as fs from 'fs';
import * as path from 'path';
import * as sqlite3 from 'sqlite3';
import { Deferred } from 'ts-deferred';
import {
Database,
MetricDataRecord,
MetricType,
TrialJobEvent,
TrialJobEventRecord
} from '../common/datastore';
import { ExperimentProfile } from '../common/manager';
/* tslint:disable:no-any */
const createTables: string = `
create table TrialJobEvent (timestamp integer, trialJobId text, event text, data text, logPath text);
create index TrialJobEvent_trialJobId on TrialJobEvent(trialJobId);
create index TrialJobEvent_event on TrialJobEvent(event);
create table MetricData (timestamp integer, trialJobId text, parameterId text, type text, sequence integer, data text);
create index MetricData_trialJobId on MetricData(trialJobId);
create index MetricData_type on MetricData(type);
create table ExperimentProfile (
params text,
id text,
execDuration integer,
startTime integer,
endTime integer,
revision integer);
create index ExperimentProfile_id on ExperimentProfile(id);
`;
function loadExperimentProfile(row: any): ExperimentProfile {
return {
params: JSON.parse(row.params),
id: row.id,
execDuration: row.execDuration,
startTime: row.startTime === null ? undefined : new Date(row.startTime),
endTime: row.endTime === null ? undefined : new Date(row.endTime),
revision: row.revision
};
}
function loadTrialJobEvent(row: any): TrialJobEventRecord {
return {
timestamp: new Date(row.timestamp),
trialJobId: row.trialJobId,
event: row.event,
data: row.data === null ? undefined : row.data,
logPath: row.logPath === null ? undefined : row.logPath
};
}
function loadMetricData(row: any): MetricDataRecord {
return {
timestamp: new Date(row.timestamp),
trialJobId: row.trialJobId,
parameterId: row.parameterId,
type: row.type,
sequence: row.sequence,
data: row.data
};
}
class SqlDB implements Database {
private db!: sqlite3.Database;
private initTask!: Deferred<void>;
public init(createNew: boolean, dbDir: string): Promise<void> {
if (this.initTask !== undefined) {
return this.initTask.promise;
}
this.initTask = new Deferred<void>();
assert(fs.existsSync(dbDir));
// tslint:disable-next-line:no-bitwise
const mode: number = createNew ? (sqlite3.OPEN_CREATE | sqlite3.OPEN_READWRITE) : sqlite3.OPEN_READWRITE;
const dbFileName: string = path.join(dbDir, 'nni.sqlite');
this.db = new sqlite3.Database(dbFileName, mode, (err: Error | null): void => {
if (err) {
this.resolve(this.initTask, err);
} else {
if (createNew) {
this.db.exec(createTables, (error: Error | null) => {
this.resolve(this.initTask, err);
});
} else {
this.initTask.resolve();
}
}
});
return this.initTask.promise;
}
public close(): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
this.db.close((err: Error | null) => { this.resolve(deferred, err); });
return deferred.promise;
}
public storeExperimentProfile(exp: ExperimentProfile): Promise<void> {
const sql: string = 'insert into ExperimentProfile values (?,?,?,?,?,?)';
const args: any[] = [
JSON.stringify(exp.params),
exp.id,
exp.execDuration,
exp.startTime === undefined ? null : exp.startTime.getTime(),
exp.endTime === undefined ? null : exp.endTime.getTime(),
exp.revision
];
const deferred: Deferred<void> = new Deferred<void>();
this.db.run(sql, args, (err: Error | null) => { this.resolve(deferred, err); });
return deferred.promise;
}
public queryExperimentProfile(experimentId: string, revision?: number): Promise<ExperimentProfile[]> {
let sql: string = '';
let args: any[] = [];
if (revision === undefined) {
sql = 'select * from ExperimentProfile where id=? order by revision DESC';
args = [experimentId];
} else {
sql = 'select * from ExperimentProfile where id=? and revision=?';
args = [experimentId, revision];
}
const deferred: Deferred<ExperimentProfile[]> = new Deferred<ExperimentProfile[]>();
this.db.all(sql, args, (err: Error | null, rows: any[]) => {
this.resolve(deferred, err, rows, loadExperimentProfile);
});
return deferred.promise;
}
public async queryLatestExperimentProfile(experimentId: string): Promise<ExperimentProfile> {
const profiles: ExperimentProfile[] = await this.queryExperimentProfile(experimentId);
return profiles[0];
}
public storeTrialJobEvent(event: TrialJobEvent, trialJobId: string, data?: string, logPath?: string): Promise<void> {
const sql: string = 'insert into TrialJobEvent values (?,?,?,?,?)';
const args: any[] = [Date.now(), trialJobId, event, data, logPath];
const deferred: Deferred<void> = new Deferred<void>();
this.db.run(sql, args, (err: Error | null) => { this.resolve(deferred, err); });
return deferred.promise;
}
public queryTrialJobEvent(trialJobId?: string, event?: TrialJobEvent): Promise<TrialJobEventRecord[]> {
let sql: string = '';
let args: any[] | undefined;
if (trialJobId === undefined && event === undefined) {
sql = 'select * from TrialJobEvent';
} else if (trialJobId === undefined) {
sql = 'select * from TrialJobEvent where event=?';
args = [event];
} else if (event === undefined) {
sql = 'select * from TrialJobEvent where trialJobId=?';
args = [trialJobId];
} else {
sql = 'select * from TrialJobEvent where trialJobId=? and event=?';
args = [trialJobId, event];
}
const deferred: Deferred<TrialJobEventRecord[]> = new Deferred<TrialJobEventRecord[]>();
this.db.all(sql, args, (err: Error | null, rows: any[]) => {
this.resolve(deferred, err, rows, loadTrialJobEvent);
});
return deferred.promise;
}
public storeMetricData(trialJobId: string, data: string): Promise<void> {
const sql: string = 'insert into MetricData values (?,?,?,?,?,?)';
const json: MetricDataRecord = JSON.parse(data);
const args: any[] = [Date.now(), json.trialJobId, json.parameterId, json.type, json.sequence, JSON.stringify(json.data)];
const deferred: Deferred<void> = new Deferred<void>();
this.db.run(sql, args, (err: Error | null) => { this.resolve(deferred, err); });
return deferred.promise;
}
public queryMetricData(trialJobId?: string, metricType?: MetricType): Promise<MetricDataRecord[]> {
let sql: string = '';
let args: any[] | undefined;
if (metricType === undefined && trialJobId === undefined) {
sql = 'select * from MetricData';
} else if (trialJobId === undefined) {
sql = 'select * from MetricData where type=?';
args = [metricType];
} else if (metricType === undefined) {
sql = 'select * from MetricData where trialJobId=?';
args = [trialJobId];
} else {
sql = 'select * from MetricData where trialJobId=? and type=?';
args = [trialJobId, metricType];
}
const deferred: Deferred<MetricDataRecord[]> = new Deferred<MetricDataRecord[]>();
this.db.all(sql, args, (err: Error | null, rows: any[]) => {
this.resolve(deferred, err, rows, loadMetricData);
});
return deferred.promise;
}
private resolve<T>(
deferred: Deferred<T[]> | Deferred<void>,
error: Error | null,
rows?: any[],
rowLoader?: (row: any) => T
): void {
if (error !== null) {
deferred.reject(error);
return;
}
if (rowLoader === undefined) {
(<Deferred<void>>deferred).resolve();
} else {
const data: T[] = [];
for (const row of (<any[]>rows)) {
data.push(rowLoader(row));
}
(<Deferred<T[]>>deferred).resolve(data);
}
}
}
export { SqlDB };
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment