Unverified Commit c785655e authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #207 from microsoft/master

merge master
parents 9fae194a d6b61e2f
import tensorflow as tf
from src.utils import DEFINE_boolean
from src.utils import DEFINE_float
from src.utils import DEFINE_integer
from src.utils import DEFINE_string
flags = tf.app.flags
FLAGS = flags.FLAGS
DEFINE_boolean("reset_output_dir", False, "Delete output_dir if exists.")
DEFINE_string("data_path", "", "")
DEFINE_string("output_dir", "", "")
DEFINE_string("data_format", "NHWC", "'NHWC' or 'NCWH'")
DEFINE_string("search_for", None, "Must be [macro|micro]")
DEFINE_integer("train_data_size", 45000, "")
DEFINE_integer("batch_size", 32, "")
DEFINE_integer("num_epochs", 300, "")
DEFINE_integer("child_lr_dec_every", 100, "")
DEFINE_integer("child_num_layers", 5, "")
DEFINE_integer("child_num_cells", 5, "")
DEFINE_integer("child_filter_size", 5, "")
DEFINE_integer("child_out_filters", 48, "")
DEFINE_integer("child_out_filters_scale", 1, "")
DEFINE_integer("child_num_branches", 4, "")
DEFINE_integer("child_num_aggregate", None, "")
DEFINE_integer("child_num_replicas", 1, "")
DEFINE_integer("child_block_size", 3, "")
DEFINE_integer("child_lr_T_0", None, "for lr schedule")
DEFINE_integer("child_lr_T_mul", None, "for lr schedule")
DEFINE_integer("child_cutout_size", None, "CutOut size")
DEFINE_float("child_grad_bound", 5.0, "Gradient clipping")
DEFINE_float("child_lr", 0.1, "")
DEFINE_float("child_lr_dec_rate", 0.1, "")
DEFINE_float("child_keep_prob", 0.5, "")
DEFINE_float("child_drop_path_keep_prob", 1.0, "minimum drop_path_keep_prob")
DEFINE_float("child_l2_reg", 1e-4, "")
DEFINE_float("child_lr_max", None, "for lr schedule")
DEFINE_float("child_lr_min", None, "for lr schedule")
DEFINE_string("child_skip_pattern", None, "Must be ['dense', None]")
DEFINE_string("child_fixed_arc", None, "")
DEFINE_boolean("child_use_aux_heads", False, "Should we use an aux head")
DEFINE_boolean("child_sync_replicas", False, "To sync or not to sync.")
DEFINE_boolean("child_lr_cosine", False, "Use cosine lr schedule")
DEFINE_integer("log_every", 50, "How many steps to log")
DEFINE_integer("eval_every_epochs", 1, "How many epochs to eval")
import numpy as np
import tensorflow as tf
from tensorflow.python.training import moving_averages
def lstm(x, prev_c, prev_h, w):
ifog = tf.matmul(tf.concat([x, prev_h], axis=1), w)
i, f, o, g = tf.split(ifog, 4, axis=1)
i = tf.sigmoid(i)
f = tf.sigmoid(f)
o = tf.sigmoid(o)
g = tf.tanh(g)
next_c = i * g + f * prev_c
next_h = o * tf.tanh(next_c)
return next_c, next_h
def stack_lstm(x, prev_c, prev_h, w):
next_c, next_h = [], []
for layer_id, (_c, _h, _w) in enumerate(zip(prev_c, prev_h, w)):
inputs = x if layer_id == 0 else next_h[-1]
curr_c, curr_h = lstm(inputs, _c, _h, _w)
next_c.append(curr_c)
next_h.append(curr_h)
return next_c, next_h
def create_weight(name, shape, initializer=None, trainable=True, seed=None):
if initializer is None:
initializer = tf.contrib.keras.initializers.he_normal(seed=seed)
return tf.get_variable(name, shape, initializer=initializer, trainable=trainable)
def create_bias(name, shape, initializer=None):
if initializer is None:
initializer = tf.constant_initializer(0.0, dtype=tf.float32)
return tf.get_variable(name, shape, initializer=initializer)
def conv_op(inputs, filter_size, is_training, count, out_filters,
data_format, ch_mul=1, start_idx=None, separable=False):
"""
Args:
start_idx: where to start taking the output channels. if None, assuming
fixed_arc mode
count: how many output_channels to take.
"""
if data_format == "NHWC":
inp_c = inputs.get_shape()[3].value
elif data_format == "NCHW":
inp_c = inputs.get_shape()[1].value
with tf.variable_scope("inp_conv_1"):
w = create_weight("w", [1, 1, inp_c, out_filters])
x = tf.nn.conv2d(inputs, w, [1, 1, 1, 1],
"SAME", data_format=data_format)
x = batch_norm(x, is_training, data_format=data_format)
x = tf.nn.relu(x)
with tf.variable_scope("out_conv_{}".format(filter_size)):
if start_idx is None:
if separable:
w_depth = create_weight(
"w_depth", [filter_size, filter_size, out_filters, ch_mul])
w_point = create_weight(
"w_point", [1, 1, out_filters * ch_mul, count])
x = tf.nn.separable_conv2d(x, w_depth, w_point, strides=[1, 1, 1, 1],
padding="SAME", data_format=data_format)
x = batch_norm(
x, is_training, data_format=data_format)
else:
w = create_weight(
"w", [filter_size, filter_size, inp_c, count])
x = tf.nn.conv2d(
x, w, [1, 1, 1, 1], "SAME", data_format=data_format)
x = batch_norm(
x, is_training, data_format=data_format)
else:
if separable:
w_depth = create_weight(
"w_depth", [filter_size, filter_size, out_filters, ch_mul])
#test_depth = w_depth
w_point = create_weight(
"w_point", [out_filters, out_filters * ch_mul])
w_point = w_point[start_idx:start_idx+count, :]
w_point = tf.transpose(w_point, [1, 0])
w_point = tf.reshape(
w_point, [1, 1, out_filters * ch_mul, count])
x = tf.nn.separable_conv2d(x, w_depth, w_point, strides=[1, 1, 1, 1],
padding="SAME", data_format=data_format)
mask = tf.range(0, out_filters, dtype=tf.int32)
mask = tf.logical_and(
start_idx <= mask, mask < start_idx + count)
x = batch_norm_with_mask(
x, is_training, mask, out_filters, data_format=data_format)
else:
w = create_weight(
"w", [filter_size, filter_size, out_filters, out_filters])
w = tf.transpose(w, [3, 0, 1, 2])
w = w[start_idx:start_idx+count, :, :, :]
w = tf.transpose(w, [1, 2, 3, 0])
x = tf.nn.conv2d(
x, w, [1, 1, 1, 1], "SAME", data_format=data_format)
mask = tf.range(0, out_filters, dtype=tf.int32)
mask = tf.logical_and(
start_idx <= mask, mask < start_idx + count)
x = batch_norm_with_mask(
x, is_training, mask, out_filters, data_format=data_format)
x = tf.nn.relu(x)
return x
def pool_op(inputs, is_training, count, out_filters, avg_or_max, data_format, start_idx=None):
"""
Args:
start_idx: where to start taking the output channels. if None, assuming
fixed_arc mode
count: how many output_channels to take.
"""
if data_format == "NHWC":
inp_c = inputs.get_shape()[3].value
elif data_format == "NCHW":
inp_c = inputs.get_shape()[1].value
with tf.variable_scope("conv_1"):
w = create_weight("w", [1, 1, inp_c, out_filters])
x = tf.nn.conv2d(inputs, w, [1, 1, 1, 1],
"SAME", data_format=data_format)
x = batch_norm(x, is_training, data_format=data_format)
x = tf.nn.relu(x)
with tf.variable_scope("pool"):
if data_format == "NHWC":
actual_data_format = "channels_last"
elif data_format == "NCHW":
actual_data_format = "channels_first"
if avg_or_max == "avg":
x = tf.layers.average_pooling2d(
x, [3, 3], [1, 1], "SAME", data_format=actual_data_format)
elif avg_or_max == "max":
x = tf.layers.max_pooling2d(
x, [3, 3], [1, 1], "SAME", data_format=actual_data_format)
else:
raise ValueError("Unknown pool {}".format(avg_or_max))
if start_idx is not None:
if data_format == "NHWC":
x = x[:, :, :, start_idx: start_idx+count]
elif data_format == "NCHW":
x = x[:, start_idx: start_idx+count, :, :]
return x
def global_avg_pool(x, data_format="NHWC"):
if data_format == "NHWC":
x = tf.reduce_mean(x, [1, 2])
elif data_format == "NCHW":
x = tf.reduce_mean(x, [2, 3])
else:
raise NotImplementedError("Unknown data_format {}".format(data_format))
return x
def batch_norm(x, is_training, name="bn", decay=0.9, epsilon=1e-5,
data_format="NHWC"):
if data_format == "NHWC":
shape = [x.get_shape()[3]]
elif data_format == "NCHW":
shape = [x.get_shape()[1]]
else:
raise NotImplementedError("Unknown data_format {}".format(data_format))
with tf.variable_scope(name, reuse=None if is_training else True):
offset = tf.get_variable(
"offset", shape,
initializer=tf.constant_initializer(0.0, dtype=tf.float32))
scale = tf.get_variable(
"scale", shape,
initializer=tf.constant_initializer(1.0, dtype=tf.float32))
moving_mean = tf.get_variable(
"moving_mean", shape, trainable=False,
initializer=tf.constant_initializer(0.0, dtype=tf.float32))
moving_variance = tf.get_variable(
"moving_variance", shape, trainable=False,
initializer=tf.constant_initializer(1.0, dtype=tf.float32))
if is_training:
x, mean, variance = tf.nn.fused_batch_norm(
x, scale, offset, epsilon=epsilon, data_format=data_format,
is_training=True)
update_mean = moving_averages.assign_moving_average(
moving_mean, mean, decay)
update_variance = moving_averages.assign_moving_average(
moving_variance, variance, decay)
with tf.control_dependencies([update_mean, update_variance]):
x = tf.identity(x)
else:
x, _, _ = tf.nn.fused_batch_norm(x, scale, offset, mean=moving_mean,
variance=moving_variance,
epsilon=epsilon, data_format=data_format,
is_training=False)
return x
def batch_norm_with_mask(x, is_training, mask, num_channels, name="bn",
decay=0.9, epsilon=1e-3, data_format="NHWC"):
shape = [num_channels]
indices = tf.where(mask)
indices = tf.to_int32(indices)
indices = tf.reshape(indices, [-1])
with tf.variable_scope(name, reuse=None if is_training else True):
offset = tf.get_variable(
"offset", shape,
initializer=tf.constant_initializer(0.0, dtype=tf.float32))
scale = tf.get_variable(
"scale", shape,
initializer=tf.constant_initializer(1.0, dtype=tf.float32))
offset = tf.boolean_mask(offset, mask)
scale = tf.boolean_mask(scale, mask)
moving_mean = tf.get_variable(
"moving_mean", shape, trainable=False,
initializer=tf.constant_initializer(0.0, dtype=tf.float32))
moving_variance = tf.get_variable(
"moving_variance", shape, trainable=False,
initializer=tf.constant_initializer(1.0, dtype=tf.float32))
if is_training:
x, mean, variance = tf.nn.fused_batch_norm(
x, scale, offset, epsilon=epsilon, data_format=data_format,
is_training=True)
mean = (1.0 - decay) * (tf.boolean_mask(moving_mean, mask) - mean)
variance = (1.0 - decay) * \
(tf.boolean_mask(moving_variance, mask) - variance)
update_mean = tf.scatter_sub(
moving_mean, indices, mean, use_locking=True)
update_variance = tf.scatter_sub(
moving_variance, indices, variance, use_locking=True)
with tf.control_dependencies([update_mean, update_variance]):
x = tf.identity(x)
else:
masked_moving_mean = tf.boolean_mask(moving_mean, mask)
masked_moving_variance = tf.boolean_mask(moving_variance, mask)
x, _, _ = tf.nn.fused_batch_norm(x, scale, offset,
mean=masked_moving_mean,
variance=masked_moving_variance,
epsilon=epsilon, data_format=data_format,
is_training=False)
return x
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import numpy as np
import tensorflow as tf
user_flags = []
def DEFINE_string(name, default_value, doc_string):
tf.app.flags.DEFINE_string(name, default_value, doc_string)
global user_flags
user_flags.append(name)
def DEFINE_integer(name, default_value, doc_string):
tf.app.flags.DEFINE_integer(name, default_value, doc_string)
global user_flags
user_flags.append(name)
def DEFINE_float(name, default_value, doc_string):
tf.app.flags.DEFINE_float(name, default_value, doc_string)
global user_flags
user_flags.append(name)
def DEFINE_boolean(name, default_value, doc_string):
tf.app.flags.DEFINE_boolean(name, default_value, doc_string)
global user_flags
user_flags.append(name)
def print_user_flags(line_limit=80):
print("-" * 80)
global user_flags
FLAGS = tf.app.flags.FLAGS
for flag_name in sorted(user_flags):
value = "{}".format(getattr(FLAGS, flag_name))
log_string = flag_name
log_string += "." * (line_limit - len(flag_name) - len(value))
log_string += value
print(log_string)
def get_C(x, data_format):
"""
Args:
x: tensor of shape [N, H, W, C] or [N, C, H, W]
"""
if data_format == "NHWC":
return x.get_shape()[3].value
elif data_format == "NCHW":
return x.get_shape()[1].value
else:
raise ValueError(
"Unknown data_format '{0}'".format(data_format))
def get_HW(x, data_format):
"""
Args:
x: tensor of shape [N, H, W, C] or [N, C, H, W]
"""
return x.get_shape()[2].value
def get_strides(stride, data_format):
"""
Args:
x: tensor of shape [N, H, W, C] or [N, C, H, W]
"""
if data_format == "NHWC":
return [1, stride, stride, 1]
elif data_format == "NCHW":
return [1, 1, stride, stride]
else:
raise ValueError(
"Unknown data_format '{0}'".format(data_format))
class TextColors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
class Logger(object):
def __init__(self, output_file):
self.terminal = sys.stdout
self.log = open(output_file, "a")
def write(self, message):
self.terminal.write(message)
self.terminal.flush()
self.log.write(message)
self.log.flush()
def count_model_params(tf_variables):
"""
Args:
tf_variables: list of all model variables
"""
num_vars = 0
for var in tf_variables:
num_vars += np.prod([dim.value for dim in var.get_shape()])
return num_vars
def get_train_ops(
loss,
tf_variables,
train_step,
clip_mode=None,
grad_bound=None,
l2_reg=1e-4,
lr_warmup_val=None,
lr_warmup_steps=100,
lr_init=0.1,
lr_dec_start=0,
lr_dec_every=10000,
lr_dec_rate=0.1,
lr_dec_min=None,
lr_cosine=False,
lr_max=None,
lr_min=None,
lr_T_0=None,
lr_T_mul=None,
num_train_batches=None,
optim_algo=None,
sync_replicas=False,
num_aggregate=None,
num_replicas=None,
get_grad_norms=False,
moving_average=None):
"""
Args:
clip_mode: "global", "norm", or None.
moving_average: store the moving average of parameters
"""
if l2_reg > 0:
l2_losses = []
for var in tf_variables:
l2_losses.append(tf.reduce_sum(var ** 2))
l2_loss = tf.add_n(l2_losses)
loss += l2_reg * l2_loss
grads = tf.gradients(loss, tf_variables)
grad_norm = tf.global_norm(grads)
grad_norms = {}
for v, g in zip(tf_variables, grads):
if v is None or g is None:
continue
if isinstance(g, tf.IndexedSlices):
grad_norms[v.name] = tf.sqrt(tf.reduce_sum(g.values ** 2))
else:
grad_norms[v.name] = tf.sqrt(tf.reduce_sum(g ** 2))
if clip_mode is not None:
assert grad_bound is not None, "Need grad_bound to clip gradients."
if clip_mode == "global":
grads, _ = tf.clip_by_global_norm(grads, grad_bound)
elif clip_mode == "norm":
clipped = []
for g in grads:
if isinstance(g, tf.IndexedSlices):
c_g = tf.clip_by_norm(g.values, grad_bound)
c_g = tf.IndexedSlices(g.indices, c_g)
else:
c_g = tf.clip_by_norm(g, grad_bound)
clipped.append(g)
grads = clipped
else:
raise NotImplementedError("Unknown clip_mode {}".format(clip_mode))
if lr_cosine:
assert lr_max is not None, "Need lr_max to use lr_cosine"
assert lr_min is not None, "Need lr_min to use lr_cosine"
assert lr_T_0 is not None, "Need lr_T_0 to use lr_cosine"
assert lr_T_mul is not None, "Need lr_T_mul to use lr_cosine"
assert num_train_batches is not None, ("Need num_train_batches to use"
" lr_cosine")
curr_epoch = train_step // num_train_batches
last_reset = tf.Variable(0, dtype=tf.int32, trainable=False,
name="last_reset")
T_i = tf.Variable(lr_T_0, dtype=tf.int32, trainable=False, name="T_i")
T_curr = curr_epoch - last_reset
def _update():
update_last_reset = tf.assign(
last_reset, curr_epoch, use_locking=True)
update_T_i = tf.assign(T_i, T_i * lr_T_mul, use_locking=True)
with tf.control_dependencies([update_last_reset, update_T_i]):
rate = tf.to_float(T_curr) / tf.to_float(T_i) * 3.1415926
lr = lr_min + 0.5 * (lr_max - lr_min) * (1.0 + tf.cos(rate))
return lr
def _no_update():
rate = tf.to_float(T_curr) / tf.to_float(T_i) * 3.1415926
lr = lr_min + 0.5 * (lr_max - lr_min) * (1.0 + tf.cos(rate))
return lr
learning_rate = tf.cond(
tf.greater_equal(T_curr, T_i), _update, _no_update)
else:
learning_rate = tf.train.exponential_decay(
lr_init, tf.maximum(train_step - lr_dec_start, 0), lr_dec_every,
lr_dec_rate, staircase=True)
if lr_dec_min is not None:
learning_rate = tf.maximum(learning_rate, lr_dec_min)
if lr_warmup_val is not None:
learning_rate = tf.cond(tf.less(train_step, lr_warmup_steps),
lambda: lr_warmup_val, lambda: learning_rate)
if optim_algo == "momentum":
opt = tf.train.MomentumOptimizer(
learning_rate, 0.9, use_locking=True, use_nesterov=True)
elif optim_algo == "sgd":
opt = tf.train.GradientDescentOptimizer(
learning_rate, use_locking=True)
elif optim_algo == "adam":
opt = tf.train.AdamOptimizer(learning_rate, beta1=0.0, epsilon=1e-3,
use_locking=True)
else:
raise ValueError("Unknown optim_algo {}".format(optim_algo))
if sync_replicas:
assert num_aggregate is not None, "Need num_aggregate to sync."
assert num_replicas is not None, "Need num_replicas to sync."
opt = tf.train.SyncReplicasOptimizer(
opt,
replicas_to_aggregate=num_aggregate,
total_num_replicas=num_replicas,
use_locking=True)
if moving_average is not None:
opt = tf.contrib.opt.MovingAverageOptimizer(
opt, average_decay=moving_average)
train_op = opt.apply_gradients(
zip(grads, tf_variables), global_step=train_step)
if get_grad_norms:
return train_op, learning_rate, grad_norm, opt, grad_norms
else:
return train_op, learning_rate, grad_norm, opt
authorName: default
experimentName: example_customized_advisor
trialConcurrency: 4
maxExecDuration: 1h
maxTrialNum: 200
#choice: local, remote, pai
trainingServicePlatform: local
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
advisor:
codeDir: .
classFileName: dummy_advisor.py
className: DummyAdvisor
classArgs:
k: 3
trial:
command: python3 mnist_keras.py --epochs 100 --num_train 600 --num_test 100
codeDir: .
gpuNum: 0
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import logging
from collections import defaultdict
import json_tricks
import numpy as np
from nni import parameter_expressions as param
from nni.msg_dispatcher_base import MsgDispatcherBase
from nni.protocol import CommandType, send
from nni.utils import MetricType
logger = logging.getLogger('customized_advisor')
class DummyAdvisor(MsgDispatcherBase):
"""WARNING: Advisor API is subject to change in future releases.
This advisor creates a new trial when validation accuracy of any one of the trials just dropped.
The trial is killed if the validation accuracy doesn't improve for at least k last-reported metrics.
To demonstrate the high flexibility of writing advisors, we don't use tuners or the standard definition of
search space. This is just a demo to customize an advisor. It's not intended to make any sense.
"""
def __init__(self, k=3):
super(DummyAdvisor, self).__init__()
self.k = k
self.random_state = np.random.RandomState()
def handle_initialize(self, data):
logger.info("Advisor initialized: {}".format(data))
self.handle_update_search_space(data)
self.parameters_count = 0
self.parameter_best_metric = defaultdict(float)
self.parameter_cooldown = defaultdict(int)
send(CommandType.Initialized, '')
def _send_new_trial(self):
self.parameters_count += 1
new_trial = {
"parameter_id": self.parameters_count,
"parameters": {
"optimizer": param.choice(self.searchspace_json["optimizer"], self.random_state),
"learning_rate": param.loguniform(self.searchspace_json["learning_rate"][0],
self.searchspace_json["learning_rate"][1],
self.random_state)
},
"parameter_source": "algorithm"
}
logger.info("New trial sent: {}".format(new_trial))
send(CommandType.NewTrialJob, json_tricks.dumps(new_trial))
def handle_request_trial_jobs(self, data):
logger.info("Request trial jobs: {}".format(data))
for _ in range(data):
self._send_new_trial()
def handle_update_search_space(self, data):
logger.info("Search space update: {}".format(data))
self.searchspace_json = data
def handle_trial_end(self, data):
logger.info("Trial end: {}".format(data)) # do nothing
def handle_report_metric_data(self, data):
logger.info("Metric reported: {}".format(data))
if data['type'] == MetricType.REQUEST_PARAMETER:
raise ValueError("Request parameter not supported")
elif data["type"] == MetricType.PERIODICAL:
parameter_id = data["parameter_id"]
if data["value"] > self.parameter_best_metric[parameter_id]:
self.parameter_best_metric[parameter_id] = data["value"]
self.parameter_cooldown[parameter_id] = 0
else:
self.parameter_cooldown[parameter_id] += 1
logger.info("Accuracy dropped, cooldown {}, sending a new trial".format(
self.parameter_cooldown[parameter_id]))
self._send_new_trial()
if self.parameter_cooldown[parameter_id] >= self.k:
logger.info("Send kill signal to {}".format(data))
send(CommandType.KillTrialJob, json_tricks.dumps(data["trial_job_id"]))
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import argparse
import logging
import os
import keras
import numpy as np
from keras import backend as K
from keras.callbacks import TensorBoard
from keras.datasets import mnist
from keras.layers import Conv2D, Dense, Flatten, MaxPooling2D
from keras.models import Sequential
import nni
LOG = logging.getLogger('mnist_keras')
K.set_image_data_format('channels_last')
TENSORBOARD_DIR = os.environ['NNI_OUTPUT_DIR']
H, W = 28, 28
NUM_CLASSES = 10
def create_mnist_model(hyper_params, input_shape=(H, W, 1), num_classes=NUM_CLASSES):
"""
Create simple convolutional model
"""
layers = [
Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape),
Conv2D(64, (3, 3), activation='relu'),
MaxPooling2D(pool_size=(2, 2)),
Flatten(),
Dense(100, activation='relu'),
Dense(num_classes, activation='softmax')
]
model = Sequential(layers)
if hyper_params['optimizer'] == 'Adam':
optimizer = keras.optimizers.Adam(lr=hyper_params['learning_rate'])
else:
optimizer = keras.optimizers.SGD(lr=hyper_params['learning_rate'], momentum=0.9)
model.compile(loss=keras.losses.categorical_crossentropy, optimizer=optimizer, metrics=['accuracy'])
return model
def load_mnist_data(args):
"""
Load MNIST dataset
"""
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = (np.expand_dims(x_train, -1).astype(np.float) / 255.)[:args.num_train]
x_test = (np.expand_dims(x_test, -1).astype(np.float) / 255.)[:args.num_test]
y_train = keras.utils.to_categorical(y_train, NUM_CLASSES)[:args.num_train]
y_test = keras.utils.to_categorical(y_test, NUM_CLASSES)[:args.num_test]
LOG.debug('x_train shape: %s', (x_train.shape,))
LOG.debug('x_test shape: %s', (x_test.shape,))
return x_train, y_train, x_test, y_test
class SendMetrics(keras.callbacks.Callback):
"""
Keras callback to send metrics to NNI framework
"""
def on_epoch_end(self, epoch, logs={}):
"""
Run on end of each epoch
"""
LOG.debug(logs)
# Should this be val_acc or val_accuracy? Seems inconsistent behavior of Keras?
nni.report_intermediate_result(logs["val_accuracy"])
def train(args, params):
"""
Train model
"""
x_train, y_train, x_test, y_test = load_mnist_data(args)
model = create_mnist_model(params)
model.fit(x_train, y_train, batch_size=args.batch_size, epochs=args.epochs, verbose=1,
validation_data=(x_test, y_test), callbacks=[SendMetrics(), TensorBoard(log_dir=TENSORBOARD_DIR)])
_, acc = model.evaluate(x_test, y_test, verbose=0)
LOG.debug('Final result is: %d', acc)
nni.report_final_result(acc)
def generate_default_params():
"""
Generate default hyper parameters
"""
return {
'optimizer': 'Adam',
'learning_rate': 0.001
}
if __name__ == '__main__':
PARSER = argparse.ArgumentParser()
PARSER.add_argument("--batch_size", type=int, default=200, help="batch size", required=False)
PARSER.add_argument("--epochs", type=int, default=10, help="Train epochs", required=False)
PARSER.add_argument("--num_train", type=int, default=60000,
help="Number of train samples to be used, maximum 60000", required=False)
PARSER.add_argument("--num_test", type=int, default=10000, help="Number of test samples to be used, maximum 10000",
required=False)
ARGS, UNKNOWN = PARSER.parse_known_args()
# get parameters from tuner
RECEIVED_PARAMS = nni.get_next_parameter()
LOG.debug(RECEIVED_PARAMS)
PARAMS = generate_default_params()
PARAMS.update(RECEIVED_PARAMS)
# train
train(ARGS, PARAMS)
{
"README": "To demonstrate the flexibility, this search space does not follow the standard definition.",
"optimizer": ["Adam", "SGD"],
"learning_rate": [0.001, 0.1]
}
...@@ -2,13 +2,14 @@ import numpy as np ...@@ -2,13 +2,14 @@ import numpy as np
from nni.tuner import Tuner from nni.tuner import Tuner
def random_archi_generator(nas_ss, random_state): def random_archi_generator(nas_ss, random_state):
'''random '''random
''' '''
chosen_archi = {} chosen_archi = {}
print("zql: nas search space: ", nas_ss)
for block_name, block_value in nas_ss.items(): for block_name, block_value in nas_ss.items():
assert block_value['_type'] == "mutable_layer", "Random NAS Tuner only receives NAS search space whose _type is 'mutable_layer'" assert block_value['_type'] == "mutable_layer", \
"Random NAS Tuner only receives NAS search space whose _type is 'mutable_layer'"
block = block_value['_value'] block = block_value['_value']
tmp_block = {} tmp_block = {}
for layer_name, layer in block.items(): for layer_name, layer in block.items():
...@@ -19,13 +20,12 @@ def random_archi_generator(nas_ss, random_state): ...@@ -19,13 +20,12 @@ def random_archi_generator(nas_ss, random_state):
tmp_layer['chosen_layer'] = value[index] tmp_layer['chosen_layer'] = value[index]
elif key == 'optional_inputs': elif key == 'optional_inputs':
tmp_layer['chosen_inputs'] = [] tmp_layer['chosen_inputs'] = []
print("zql: optional_inputs", layer['optional_inputs'])
if layer['optional_inputs']: if layer['optional_inputs']:
if isinstance(layer['optional_input_size'], int): if isinstance(layer['optional_input_size'], int):
choice_num = layer['optional_input_size'] choice_num = layer['optional_input_size']
else: else:
choice_range = layer['optional_input_size'] choice_range = layer['optional_input_size']
choice_num = random_state.randint(choice_range[0], choice_range[1]+1) choice_num = random_state.randint(choice_range[0], choice_range[1] + 1)
for _ in range(choice_num): for _ in range(choice_num):
index = random_state.randint(len(layer['optional_inputs'])) index = random_state.randint(len(layer['optional_inputs']))
tmp_layer['chosen_inputs'].append(layer['optional_inputs'][index]) tmp_layer['chosen_inputs'].append(layer['optional_inputs'][index])
...@@ -37,6 +37,7 @@ def random_archi_generator(nas_ss, random_state): ...@@ -37,6 +37,7 @@ def random_archi_generator(nas_ss, random_state):
chosen_archi[block_name] = tmp_block chosen_archi[block_name] = tmp_block
return chosen_archi return chosen_archi
class RandomNASTuner(Tuner): class RandomNASTuner(Tuner):
'''RandomNASTuner '''RandomNASTuner
''' '''
......
...@@ -30,14 +30,13 @@ class ExperimentStartupInfo { ...@@ -30,14 +30,13 @@ class ExperimentStartupInfo {
private newExperiment: boolean = true; private newExperiment: boolean = true;
private basePort: number = -1; private basePort: number = -1;
private initialized: boolean = false; private initialized: boolean = false;
private initTrialSequenceID: number = 0;
private logDir: string = ''; private logDir: string = '';
private logLevel: string = ''; private logLevel: string = '';
private readonly: boolean = false;
public setStartupInfo(newExperiment: boolean, experimentId: string, basePort: number, logDir?: string, logLevel?: string): void { public setStartupInfo(newExperiment: boolean, experimentId: string, basePort: number, logDir?: string, logLevel?: string, readonly?: boolean): void {
assert(!this.initialized); assert(!this.initialized);
assert(experimentId.trim().length > 0); assert(experimentId.trim().length > 0);
this.newExperiment = newExperiment; this.newExperiment = newExperiment;
this.experimentId = experimentId; this.experimentId = experimentId;
this.basePort = basePort; this.basePort = basePort;
...@@ -52,6 +51,10 @@ class ExperimentStartupInfo { ...@@ -52,6 +51,10 @@ class ExperimentStartupInfo {
if (logLevel !== undefined && logLevel.length > 1) { if (logLevel !== undefined && logLevel.length > 1) {
this.logLevel = logLevel; this.logLevel = logLevel;
} }
if (readonly !== undefined) {
this.readonly = readonly;
}
} }
public getExperimentId(): string { public getExperimentId(): string {
...@@ -84,15 +87,10 @@ class ExperimentStartupInfo { ...@@ -84,15 +87,10 @@ class ExperimentStartupInfo {
return this.logLevel; return this.logLevel;
} }
public setInitTrialSequenceId(initSequenceId: number): void { public isReadonly(): boolean {
assert(this.initialized);
this.initTrialSequenceID = initSequenceId;
}
public getInitTrialSequenceId(): number {
assert(this.initialized); assert(this.initialized);
return this.initTrialSequenceID; return this.readonly;
} }
} }
...@@ -108,23 +106,19 @@ function isNewExperiment(): boolean { ...@@ -108,23 +106,19 @@ function isNewExperiment(): boolean {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).isNewExperiment(); return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).isNewExperiment();
} }
function setInitTrialSequenceId(initSequenceId: number): void {
component.get<ExperimentStartupInfo>(ExperimentStartupInfo).setInitTrialSequenceId(initSequenceId);
}
function getInitTrialSequenceId(): number {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).getInitTrialSequenceId();
}
function getExperimentStartupInfo(): ExperimentStartupInfo { function getExperimentStartupInfo(): ExperimentStartupInfo {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo); return component.get<ExperimentStartupInfo>(ExperimentStartupInfo);
} }
function setExperimentStartupInfo( function setExperimentStartupInfo(
newExperiment: boolean, experimentId: string, basePort: number, logDir?: string, logLevel?: string): void { newExperiment: boolean, experimentId: string, basePort: number, logDir?: string, logLevel?: string, readonly?: boolean): void {
component.get<ExperimentStartupInfo>(ExperimentStartupInfo) component.get<ExperimentStartupInfo>(ExperimentStartupInfo)
.setStartupInfo(newExperiment, experimentId, basePort, logDir, logLevel); .setStartupInfo(newExperiment, experimentId, basePort, logDir, logLevel, readonly);
}
function isReadonly(): boolean {
return component.get<ExperimentStartupInfo>(ExperimentStartupInfo).isReadonly();
} }
export { ExperimentStartupInfo, getBasePort, getExperimentId, isNewExperiment, getExperimentStartupInfo, export { ExperimentStartupInfo, getBasePort, getExperimentId, isNewExperiment, getExperimentStartupInfo,
setExperimentStartupInfo, setInitTrialSequenceId, getInitTrialSequenceId }; setExperimentStartupInfo, isReadonly };
...@@ -26,7 +26,7 @@ import { Writable } from 'stream'; ...@@ -26,7 +26,7 @@ import { Writable } from 'stream';
import { WritableStreamBuffer } from 'stream-buffers'; import { WritableStreamBuffer } from 'stream-buffers';
import { format } from 'util'; import { format } from 'util';
import * as component from '../common/component'; import * as component from '../common/component';
import { getExperimentStartupInfo } from './experimentStartupInfo'; import { getExperimentStartupInfo, isReadonly } from './experimentStartupInfo';
import { getLogDir } from './utils'; import { getLogDir } from './utils';
const FATAL: number = 1; const FATAL: number = 1;
...@@ -76,6 +76,7 @@ class Logger { ...@@ -76,6 +76,7 @@ class Logger {
private level: number = INFO; private level: number = INFO;
private bufferSerialEmitter: BufferSerialEmitter; private bufferSerialEmitter: BufferSerialEmitter;
private writable: Writable; private writable: Writable;
private readonly: boolean = false;
constructor(fileName?: string) { constructor(fileName?: string) {
let logFile: string | undefined = fileName; let logFile: string | undefined = fileName;
...@@ -95,6 +96,8 @@ class Logger { ...@@ -95,6 +96,8 @@ class Logger {
if (logLevel !== undefined) { if (logLevel !== undefined) {
this.level = logLevel; this.level = logLevel;
} }
this.readonly = isReadonly();
} }
public close() { public close() {
...@@ -135,7 +138,13 @@ class Logger { ...@@ -135,7 +138,13 @@ class Logger {
this.log('FATAL', param); this.log('FATAL', param);
} }
/**
* if the experiment is not in readonly mode, write log content to stream
* @param level log level
* @param param the params to be written
*/
private log(level: string, param: any[]): void { private log(level: string, param: any[]): void {
if (!this.readonly) {
const buffer: WritableStreamBuffer = new WritableStreamBuffer(); const buffer: WritableStreamBuffer = new WritableStreamBuffer();
buffer.write(`[${(new Date()).toLocaleString()}] ${level} `); buffer.write(`[${(new Date()).toLocaleString()}] ${level} `);
buffer.write(format(param)); buffer.write(format(param));
...@@ -143,6 +152,7 @@ class Logger { ...@@ -143,6 +152,7 @@ class Logger {
buffer.end(); buffer.end();
this.bufferSerialEmitter.feed(buffer.getContents()); this.bufferSerialEmitter.feed(buffer.getContents());
} }
}
} }
function getLogger(fileName?: string): Logger { function getLogger(fileName?: string): Logger {
......
...@@ -24,6 +24,10 @@ import { TrialJobStatus } from './trainingService'; ...@@ -24,6 +24,10 @@ import { TrialJobStatus } from './trainingService';
type ProfileUpdateType = 'TRIAL_CONCURRENCY' | 'MAX_EXEC_DURATION' | 'SEARCH_SPACE' | 'MAX_TRIAL_NUM'; type ProfileUpdateType = 'TRIAL_CONCURRENCY' | 'MAX_EXEC_DURATION' | 'SEARCH_SPACE' | 'MAX_TRIAL_NUM';
type ExperimentStatus = 'INITIALIZED' | 'RUNNING' | 'ERROR' | 'STOPPING' | 'STOPPED' | 'DONE' | 'NO_MORE_TRIAL' | 'TUNER_NO_MORE_TRIAL'; type ExperimentStatus = 'INITIALIZED' | 'RUNNING' | 'ERROR' | 'STOPPING' | 'STOPPED' | 'DONE' | 'NO_MORE_TRIAL' | 'TUNER_NO_MORE_TRIAL';
namespace ExperimentStartUpMode {
export const NEW = 'new';
export const RESUME = 'resume';
}
interface ExperimentParams { interface ExperimentParams {
authorName: string; authorName: string;
...@@ -45,8 +49,8 @@ interface ExperimentParams { ...@@ -45,8 +49,8 @@ interface ExperimentParams {
classArgs?: any; classArgs?: any;
classFileName?: string; classFileName?: string;
checkpointDir: string; checkpointDir: string;
gpuNum?: number;
includeIntermediateResults?: boolean; includeIntermediateResults?: boolean;
gpuIndices?: string;
}; };
assessor?: { assessor?: {
className: string; className: string;
...@@ -55,7 +59,6 @@ interface ExperimentParams { ...@@ -55,7 +59,6 @@ interface ExperimentParams {
classArgs?: any; classArgs?: any;
classFileName?: string; classFileName?: string;
checkpointDir: string; checkpointDir: string;
gpuNum?: number;
}; };
advisor?: { advisor?: {
className: string; className: string;
...@@ -64,7 +67,7 @@ interface ExperimentParams { ...@@ -64,7 +67,7 @@ interface ExperimentParams {
classArgs?: any; classArgs?: any;
classFileName?: string; classFileName?: string;
checkpointDir: string; checkpointDir: string;
gpuNum?: number; gpuIndices?: string;
}; };
clusterMetaData?: { clusterMetaData?: {
key: string; key: string;
...@@ -79,7 +82,7 @@ interface ExperimentProfile { ...@@ -79,7 +82,7 @@ interface ExperimentProfile {
logDir?: string; logDir?: string;
startTime?: number; startTime?: number;
endTime?: number; endTime?: number;
maxSequenceId: number; nextSequenceId: number;
revision: number; revision: number;
} }
...@@ -95,7 +98,7 @@ interface NNIManagerStatus { ...@@ -95,7 +98,7 @@ interface NNIManagerStatus {
abstract class Manager { abstract class Manager {
public abstract startExperiment(experimentParams: ExperimentParams): Promise<string>; public abstract startExperiment(experimentParams: ExperimentParams): Promise<string>;
public abstract resumeExperiment(): Promise<void>; public abstract resumeExperiment(readonly: boolean): Promise<void>;
public abstract stopExperiment(): Promise<void>; public abstract stopExperiment(): Promise<void>;
public abstract getExperimentProfile(): Promise<ExperimentProfile>; public abstract getExperimentProfile(): Promise<ExperimentProfile>;
public abstract updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void>; public abstract updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void>;
...@@ -111,8 +114,11 @@ abstract class Manager { ...@@ -111,8 +114,11 @@ abstract class Manager {
public abstract getClusterMetadata(key: string): Promise<string>; public abstract getClusterMetadata(key: string): Promise<string>;
public abstract getMetricData(trialJobId?: string, metricType?: MetricType): Promise<MetricDataRecord[]>; public abstract getMetricData(trialJobId?: string, metricType?: MetricType): Promise<MetricDataRecord[]>;
public abstract getMetricDataByRange(minSeqId: number, maxSeqId: number): Promise<MetricDataRecord[]>;
public abstract getLatestMetricData(): Promise<MetricDataRecord[]>;
public abstract getTrialJobStatistics(): Promise<TrialJobStatistics[]>; public abstract getTrialJobStatistics(): Promise<TrialJobStatistics[]>;
public abstract getStatus(): NNIManagerStatus; public abstract getStatus(): NNIManagerStatus;
} }
export { Manager, ExperimentParams, ExperimentProfile, TrialJobStatistics, ProfileUpdateType, NNIManagerStatus, ExperimentStatus }; export { Manager, ExperimentParams, ExperimentProfile, TrialJobStatistics, ProfileUpdateType, NNIManagerStatus, ExperimentStatus, ExperimentStartUpMode };
...@@ -23,20 +23,12 @@ ...@@ -23,20 +23,12 @@
* define TrialJobStatus * define TrialJobStatus
*/ */
type TrialJobStatus = 'UNKNOWN' | 'WAITING' | 'RUNNING' | 'SUCCEEDED' | 'FAILED' | 'USER_CANCELED' | 'SYS_CANCELED' | 'EARLY_STOPPED'; type TrialJobStatus = 'UNKNOWN' | 'WAITING' | 'RUNNING' | 'SUCCEEDED' | 'FAILED' | 'USER_CANCELED' | 'SYS_CANCELED' | 'EARLY_STOPPED';
type JobType = 'TRIAL' | 'HOST';
interface TrainingServiceMetadata { interface TrainingServiceMetadata {
readonly key: string; readonly key: string;
readonly value: string; readonly value: string;
} }
/**
* define JobApplicationForm
*/
interface JobApplicationForm {
readonly jobType: JobType;
}
interface HyperParameters { interface HyperParameters {
readonly value: string; readonly value: string;
readonly index: number; readonly index: number;
...@@ -45,18 +37,11 @@ interface HyperParameters { ...@@ -45,18 +37,11 @@ interface HyperParameters {
/** /**
* define TrialJobApplicationForm * define TrialJobApplicationForm
*/ */
interface TrialJobApplicationForm extends JobApplicationForm { interface TrialJobApplicationForm {
readonly sequenceId: number;
readonly hyperParameters: HyperParameters; readonly hyperParameters: HyperParameters;
} }
/**
* define HostJobApplicationForm
*/
interface HostJobApplicationForm extends JobApplicationForm {
readonly host: string;
readonly cmd: string;
}
/** /**
* define TrialJobDetail * define TrialJobDetail
*/ */
...@@ -69,8 +54,7 @@ interface TrialJobDetail { ...@@ -69,8 +54,7 @@ interface TrialJobDetail {
readonly tags?: string[]; readonly tags?: string[];
readonly url?: string; readonly url?: string;
readonly workingDirectory: string; readonly workingDirectory: string;
readonly form: JobApplicationForm; readonly form: TrialJobApplicationForm;
readonly sequenceId: number;
isEarlyStopped?: boolean; isEarlyStopped?: boolean;
} }
...@@ -112,8 +96,8 @@ abstract class TrainingService { ...@@ -112,8 +96,8 @@ abstract class TrainingService {
public abstract getTrialJob(trialJobId: string): Promise<TrialJobDetail>; public abstract getTrialJob(trialJobId: string): Promise<TrialJobDetail>;
public abstract addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void; public abstract addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void;
public abstract removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void; public abstract removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void;
public abstract submitTrialJob(form: JobApplicationForm): Promise<TrialJobDetail>; public abstract submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail>;
public abstract updateTrialJob(trialJobId: string, form: JobApplicationForm): Promise<TrialJobDetail>; public abstract updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail>;
public abstract get isMultiPhaseJobSupported(): boolean; public abstract get isMultiPhaseJobSupported(): boolean;
public abstract cancelTrialJob(trialJobId: string, isEarlyStopped?: boolean): Promise<void>; public abstract cancelTrialJob(trialJobId: string, isEarlyStopped?: boolean): Promise<void>;
public abstract setClusterMetadata(key: string, value: string): Promise<void>; public abstract setClusterMetadata(key: string, value: string): Promise<void>;
...@@ -135,5 +119,5 @@ class NNIManagerIpConfig { ...@@ -135,5 +119,5 @@ class NNIManagerIpConfig {
export { export {
TrainingService, TrainingServiceError, TrialJobStatus, TrialJobApplicationForm, TrainingService, TrainingServiceError, TrialJobStatus, TrialJobApplicationForm,
TrainingServiceMetadata, TrialJobDetail, TrialJobMetric, HyperParameters, TrainingServiceMetadata, TrialJobDetail, TrialJobMetric, HyperParameters,
HostJobApplicationForm, JobApplicationForm, JobType, NNIManagerIpConfig NNIManagerIpConfig
}; };
...@@ -445,10 +445,10 @@ function getTunerProc(command: string, stdio: StdioOptions, newCwd: string, newE ...@@ -445,10 +445,10 @@ function getTunerProc(command: string, stdio: StdioOptions, newCwd: string, newE
/** /**
* judge whether the process is alive * judge whether the process is alive
*/ */
async function isAlive(pid:any): Promise<boolean> { async function isAlive(pid: any): Promise<boolean> {
let deferred : Deferred<boolean> = new Deferred<boolean>(); let deferred : Deferred<boolean> = new Deferred<boolean>();
let alive: boolean = false; let alive: boolean = false;
if(process.platform ==='win32'){ if (process.platform === 'win32') {
try { try {
const str = cp.execSync(`powershell.exe Get-Process -Id ${pid} -ErrorAction SilentlyContinue`).toString(); const str = cp.execSync(`powershell.exe Get-Process -Id ${pid} -ErrorAction SilentlyContinue`).toString();
if (str) { if (str) {
...@@ -458,7 +458,7 @@ async function isAlive(pid:any): Promise<boolean> { ...@@ -458,7 +458,7 @@ async function isAlive(pid:any): Promise<boolean> {
catch (error) { catch (error) {
} }
} }
else{ else {
try { try {
await cpp.exec(`kill -0 ${pid}`); await cpp.exec(`kill -0 ${pid}`);
alive = true; alive = true;
...@@ -473,11 +473,11 @@ async function isAlive(pid:any): Promise<boolean> { ...@@ -473,11 +473,11 @@ async function isAlive(pid:any): Promise<boolean> {
/** /**
* kill process * kill process
*/ */
async function killPid(pid:any): Promise<void> { async function killPid(pid: any): Promise<void> {
let deferred : Deferred<void> = new Deferred<void>(); let deferred : Deferred<void> = new Deferred<void>();
try { try {
if (process.platform === "win32") { if (process.platform === "win32") {
await cpp.exec(`cmd /c taskkill /PID ${pid} /F`); await cpp.exec(`cmd.exe /c taskkill /PID ${pid} /F`);
} }
else{ else{
await cpp.exec(`kill -9 ${pid}`); await cpp.exec(`kill -9 ${pid}`);
......
...@@ -26,7 +26,7 @@ import { Deferred } from 'ts-deferred'; ...@@ -26,7 +26,7 @@ import { Deferred } from 'ts-deferred';
import * as component from '../common/component'; import * as component from '../common/component';
import { DataStore, MetricDataRecord, MetricType, TrialJobInfo } from '../common/datastore'; import { DataStore, MetricDataRecord, MetricType, TrialJobInfo } from '../common/datastore';
import { NNIError } from '../common/errors'; import { NNIError } from '../common/errors';
import { getExperimentId, setInitTrialSequenceId } from '../common/experimentStartupInfo'; import { getExperimentId } from '../common/experimentStartupInfo';
import { getLogger, Logger } from '../common/log'; import { getLogger, Logger } from '../common/log';
import { import {
ExperimentParams, ExperimentProfile, Manager, ExperimentStatus, ExperimentParams, ExperimentProfile, Manager, ExperimentStatus,
...@@ -59,6 +59,7 @@ class NNIManager implements Manager { ...@@ -59,6 +59,7 @@ class NNIManager implements Manager {
private waitingTrials: string[]; private waitingTrials: string[];
private trialJobs: Map<string, TrialJobDetail>; private trialJobs: Map<string, TrialJobDetail>;
private trialDataForTuner: string; private trialDataForTuner: string;
private readonly: boolean;
private trialJobMetricListener: (metric: TrialJobMetric) => void; private trialJobMetricListener: (metric: TrialJobMetric) => void;
...@@ -72,6 +73,7 @@ class NNIManager implements Manager { ...@@ -72,6 +73,7 @@ class NNIManager implements Manager {
this.waitingTrials = []; this.waitingTrials = [];
this.trialJobs = new Map<string, TrialJobDetail>(); this.trialJobs = new Map<string, TrialJobDetail>();
this.trialDataForTuner = ''; this.trialDataForTuner = '';
this.readonly = false;
this.log = getLogger(); this.log = getLogger();
this.dataStore = component.get(DataStore); this.dataStore = component.get(DataStore);
...@@ -88,6 +90,9 @@ class NNIManager implements Manager { ...@@ -88,6 +90,9 @@ class NNIManager implements Manager {
} }
public updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void> { public updateExperimentProfile(experimentProfile: ExperimentProfile, updateType: ProfileUpdateType): Promise<void> {
if (this.readonly) {
return Promise.reject(new Error('Error: can not update experiment profile in readonly mode!'));
}
switch (updateType) { switch (updateType) {
case 'TRIAL_CONCURRENCY': case 'TRIAL_CONCURRENCY':
this.updateTrialConcurrency(experimentProfile.params.trialConcurrency); this.updateTrialConcurrency(experimentProfile.params.trialConcurrency);
...@@ -109,6 +114,9 @@ class NNIManager implements Manager { ...@@ -109,6 +114,9 @@ class NNIManager implements Manager {
} }
public importData(data: string): Promise<void> { public importData(data: string): Promise<void> {
if (this.readonly) {
return Promise.reject(new Error('Error: can not import data in readonly mode!'));
}
if (this.dispatcher === undefined) { if (this.dispatcher === undefined) {
return Promise.reject( return Promise.reject(
new Error('tuner has not been setup') new Error('tuner has not been setup')
...@@ -124,6 +132,9 @@ class NNIManager implements Manager { ...@@ -124,6 +132,9 @@ class NNIManager implements Manager {
} }
public addCustomizedTrialJob(hyperParams: string): Promise<void> { public addCustomizedTrialJob(hyperParams: string): Promise<void> {
if (this.readonly) {
return Promise.reject(new Error('Error: can not add customized trial job in readonly mode!'));
}
if (this.currSubmittedTrialNum >= this.experimentProfile.params.maxTrialNum) { if (this.currSubmittedTrialNum >= this.experimentProfile.params.maxTrialNum) {
return Promise.reject( return Promise.reject(
new Error('reach maxTrialNum') new Error('reach maxTrialNum')
...@@ -136,6 +147,9 @@ class NNIManager implements Manager { ...@@ -136,6 +147,9 @@ class NNIManager implements Manager {
} }
public async cancelTrialJobByUser(trialJobId: string): Promise<void> { public async cancelTrialJobByUser(trialJobId: string): Promise<void> {
if (this.readonly) {
return Promise.reject(new Error('Error: can not cancel trial job in readonly mode!'));
}
this.log.info(`User cancelTrialJob: ${trialJobId}`); this.log.info(`User cancelTrialJob: ${trialJobId}`);
await this.trainingService.cancelTrialJob(trialJobId); await this.trainingService.cancelTrialJob(trialJobId);
await this.dataStore.storeTrialJobEvent('USER_TO_CANCEL', trialJobId, ''); await this.dataStore.storeTrialJobEvent('USER_TO_CANCEL', trialJobId, '');
...@@ -180,15 +194,17 @@ class NNIManager implements Manager { ...@@ -180,15 +194,17 @@ class NNIManager implements Manager {
return this.experimentProfile.id; return this.experimentProfile.id;
} }
public async resumeExperiment(): Promise<void> { public async resumeExperiment(readonly: boolean): Promise<void> {
this.log.info(`Resuming experiment: ${this.experimentProfile.id}`); this.log.info(`Resuming experiment: ${this.experimentProfile.id}`);
//Fetch back the experiment profile //Fetch back the experiment profile
const experimentId: string = getExperimentId(); const experimentId: string = getExperimentId();
this.experimentProfile = await this.dataStore.getExperimentProfile(experimentId); this.experimentProfile = await this.dataStore.getExperimentProfile(experimentId);
this.readonly = readonly;
if (readonly) {
return Promise.resolve();
}
const expParams: ExperimentParams = this.experimentProfile.params; const expParams: ExperimentParams = this.experimentProfile.params;
setInitTrialSequenceId(this.experimentProfile.maxSequenceId + 1);
// Set up multiphase config // Set up multiphase config
if (expParams.multiPhase && this.trainingService.isMultiPhaseJobSupported) { if (expParams.multiPhase && this.trainingService.isMultiPhaseJobSupported) {
this.trainingService.setClusterMetadata('multiPhase', expParams.multiPhase.toString()); this.trainingService.setClusterMetadata('multiPhase', expParams.multiPhase.toString());
...@@ -196,7 +212,7 @@ class NNIManager implements Manager { ...@@ -196,7 +212,7 @@ class NNIManager implements Manager {
// Set up versionCheck config // Set up versionCheck config
if (expParams.versionCheck !== undefined) { if (expParams.versionCheck !== undefined) {
this.trainingService.setClusterMetadata('versionCheck', expParams.versionCheck.toString()); this.trainingService.setClusterMetadata('version_check', expParams.versionCheck.toString());
} }
const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.advisor, const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.advisor,
...@@ -247,6 +263,9 @@ class NNIManager implements Manager { ...@@ -247,6 +263,9 @@ class NNIManager implements Manager {
} }
public async setClusterMetadata(key: string, value: string): Promise<void> { public async setClusterMetadata(key: string, value: string): Promise<void> {
if (this.readonly) {
return Promise.reject(new Error('Error: can not set cluster metadata in readonly mode!'));
}
this.log.info(`NNIManager setClusterMetadata, key: ${key}, value: ${value}`); this.log.info(`NNIManager setClusterMetadata, key: ${key}, value: ${value}`);
let timeoutId: NodeJS.Timer; let timeoutId: NodeJS.Timer;
// TO DO: move timeout value to constants file // TO DO: move timeout value to constants file
...@@ -281,6 +300,37 @@ class NNIManager implements Manager { ...@@ -281,6 +300,37 @@ class NNIManager implements Manager {
return this.dataStore.getMetricData(trialJobId, metricType); return this.dataStore.getMetricData(trialJobId, metricType);
} }
public async getMetricDataByRange(minSeqId: number, maxSeqId: number): Promise<MetricDataRecord[]> {
const trialJobs = await this.dataStore.listTrialJobs();
const targetTrials = trialJobs.filter(trial => (
// FIXME: can this be undefined?
trial.sequenceId !== undefined && minSeqId <= trial.sequenceId && trial.sequenceId <= maxSeqId
));
const targetTrialIds = new Set(targetTrials.map(trial => trial.id));
const allMetrics = await this.dataStore.getMetricData();
return allMetrics.filter(metric => targetTrialIds.has(metric.trialJobId));
}
public async getLatestMetricData(): Promise<MetricDataRecord[]> {
// FIXME: this can take a long time
const allMetrics: MetricDataRecord[] = await this.dataStore.getMetricData();
const finals: MetricDataRecord[] = [];
const latestIntermediates: Map<string, MetricDataRecord> = new Map<string, MetricDataRecord>();
for (const metric of allMetrics) {
if (metric.type !== 'PERIODICAL') {
finals.push(metric);
} else {
const old: MetricDataRecord | undefined = latestIntermediates.get(metric.trialJobId);
if (old === undefined || old.sequence <= metric.sequence) {
latestIntermediates.set(metric.trialJobId, metric);
}
}
}
return finals.concat(Array.from(latestIntermediates.values()));
// FIXME: unit test
}
public getExperimentProfile(): Promise<ExperimentProfile> { public getExperimentProfile(): Promise<ExperimentProfile> {
// TO DO: using Promise.resolve() // TO DO: using Promise.resolve()
const deferred: Deferred<ExperimentProfile> = new Deferred<ExperimentProfile>(); const deferred: Deferred<ExperimentProfile> = new Deferred<ExperimentProfile>();
...@@ -319,7 +369,8 @@ class NNIManager implements Manager { ...@@ -319,7 +369,8 @@ class NNIManager implements Manager {
NNI_CHECKPOINT_DIRECTORY: dataDirectory, NNI_CHECKPOINT_DIRECTORY: dataDirectory,
NNI_LOG_DIRECTORY: getLogDir(), NNI_LOG_DIRECTORY: getLogDir(),
NNI_LOG_LEVEL: getLogLevel(), NNI_LOG_LEVEL: getLogLevel(),
NNI_INCLUDE_INTERMEDIATE_RESULTS: includeIntermediateResultsEnv NNI_INCLUDE_INTERMEDIATE_RESULTS: includeIntermediateResultsEnv,
CUDA_VISIBLE_DEVICES: this.getGpuEnvvarValue()
}; };
let newEnv = Object.assign({}, process.env, nniEnv); let newEnv = Object.assign({}, process.env, nniEnv);
const tunerProc: ChildProcess = getTunerProc(command,stdio,newCwd,newEnv); const tunerProc: ChildProcess = getTunerProc(command,stdio,newCwd,newEnv);
...@@ -329,6 +380,22 @@ class NNIManager implements Manager { ...@@ -329,6 +380,22 @@ class NNIManager implements Manager {
return; return;
} }
private getGpuEnvvarValue(): string {
let cudaDevices: string | undefined;
if (this.experimentProfile.params.advisor !== undefined) {
cudaDevices = this.experimentProfile.params.advisor.gpuIndices;
} else if (this.experimentProfile.params.tuner !== undefined) {
cudaDevices = this.experimentProfile.params.tuner.gpuIndices;
}
if (cudaDevices === undefined) {
return '';
} else {
return cudaDevices;
}
}
private updateTrialConcurrency(trialConcurrency: number): void { private updateTrialConcurrency(trialConcurrency: number): void {
// we assume trialConcurrency >= 0, which is checked by restserver // we assume trialConcurrency >= 0, which is checked by restserver
this.trialConcurrencyChange += (trialConcurrency - this.experimentProfile.params.trialConcurrency); this.trialConcurrencyChange += (trialConcurrency - this.experimentProfile.params.trialConcurrency);
...@@ -436,11 +503,7 @@ class NNIManager implements Manager { ...@@ -436,11 +503,7 @@ class NNIManager implements Manager {
case 'EARLY_STOPPED': case 'EARLY_STOPPED':
this.trialJobs.delete(trialJobId); this.trialJobs.delete(trialJobId);
finishedTrialJobNum++; finishedTrialJobNum++;
if (trialJobDetail.form.jobType === 'TRIAL') { hyperParams = trialJobDetail.form.hyperParameters.value;
hyperParams = (<TrialJobApplicationForm>trialJobDetail.form).hyperParameters.value;
} else {
throw new Error('Error: jobType error, not TRIAL');
}
this.dispatcher.sendCommand(TRIAL_END, JSON.stringify({ this.dispatcher.sendCommand(TRIAL_END, JSON.stringify({
trial_job_id: trialJobDetail.id, trial_job_id: trialJobDetail.id,
event: trialJobDetail.status, event: trialJobDetail.status,
...@@ -453,11 +516,7 @@ class NNIManager implements Manager { ...@@ -453,11 +516,7 @@ class NNIManager implements Manager {
// TO DO: push this job to queue for retry // TO DO: push this job to queue for retry
this.trialJobs.delete(trialJobId); this.trialJobs.delete(trialJobId);
finishedTrialJobNum++; finishedTrialJobNum++;
if (trialJobDetail.form.jobType === 'TRIAL') { hyperParams = trialJobDetail.form.hyperParameters.value;
hyperParams = (<TrialJobApplicationForm>trialJobDetail.form).hyperParameters.value;
} else {
throw new Error('Error: jobType error, not TRIAL');
}
this.dispatcher.sendCommand(TRIAL_END, JSON.stringify({ this.dispatcher.sendCommand(TRIAL_END, JSON.stringify({
trial_job_id: trialJobDetail.id, trial_job_id: trialJobDetail.id,
event: trialJobDetail.status, event: trialJobDetail.status,
...@@ -556,7 +615,7 @@ class NNIManager implements Manager { ...@@ -556,7 +615,7 @@ class NNIManager implements Manager {
} }
this.currSubmittedTrialNum++; this.currSubmittedTrialNum++;
const trialJobAppForm: TrialJobApplicationForm = { const trialJobAppForm: TrialJobApplicationForm = {
jobType: 'TRIAL', sequenceId: this.experimentProfile.nextSequenceId++,
hyperParameters: { hyperParameters: {
value: hyperParams, value: hyperParams,
index: 0 index: 0
...@@ -564,7 +623,7 @@ class NNIManager implements Manager { ...@@ -564,7 +623,7 @@ class NNIManager implements Manager {
}; };
this.log.info(`submitTrialJob: form: ${JSON.stringify(trialJobAppForm)}`); this.log.info(`submitTrialJob: form: ${JSON.stringify(trialJobAppForm)}`);
const trialJobDetail: TrialJobDetail = await this.trainingService.submitTrialJob(trialJobAppForm); const trialJobDetail: TrialJobDetail = await this.trainingService.submitTrialJob(trialJobAppForm);
await this.storeMaxSequenceId(trialJobDetail.sequenceId); await this.storeExperimentProfile();
this.trialJobs.set(trialJobDetail.id, Object.assign({}, trialJobDetail)); this.trialJobs.set(trialJobDetail.id, Object.assign({}, trialJobDetail));
const trialJobDetailSnapshot: TrialJobDetail | undefined = this.trialJobs.get(trialJobDetail.id); const trialJobDetailSnapshot: TrialJobDetail | undefined = this.trialJobs.get(trialJobDetail.id);
if (trialJobDetailSnapshot != undefined) { if (trialJobDetailSnapshot != undefined) {
...@@ -683,7 +742,7 @@ class NNIManager implements Manager { ...@@ -683,7 +742,7 @@ class NNIManager implements Manager {
assert(tunerCommand.trial_job_id !== undefined); assert(tunerCommand.trial_job_id !== undefined);
const trialJobForm: TrialJobApplicationForm = { const trialJobForm: TrialJobApplicationForm = {
jobType: 'TRIAL', sequenceId: -1, // FIXME: multi-phase tuner should use sequence ID instead of trial job ID
hyperParameters: { hyperParameters: {
value: content, value: content,
index: tunerCommand.parameter_index index: tunerCommand.parameter_index
...@@ -691,8 +750,11 @@ class NNIManager implements Manager { ...@@ -691,8 +750,11 @@ class NNIManager implements Manager {
}; };
this.log.info(`updateTrialJob: job id: ${tunerCommand.trial_job_id}, form: ${JSON.stringify(trialJobForm)}`); this.log.info(`updateTrialJob: job id: ${tunerCommand.trial_job_id}, form: ${JSON.stringify(trialJobForm)}`);
await this.trainingService.updateTrialJob(tunerCommand.trial_job_id, trialJobForm); await this.trainingService.updateTrialJob(tunerCommand.trial_job_id, trialJobForm);
if (tunerCommand['parameters'] !== null) {
// parameters field is set as empty string if no more hyper parameter can be generated by tuner.
await this.dataStore.storeTrialJobEvent( await this.dataStore.storeTrialJobEvent(
'ADD_HYPERPARAMETER', tunerCommand.trial_job_id, content, undefined); 'ADD_HYPERPARAMETER', tunerCommand.trial_job_id, content, undefined);
}
break; break;
case NO_MORE_TRIAL_JOBS: case NO_MORE_TRIAL_JOBS:
if (!['ERROR', 'STOPPING', 'STOPPED'].includes(this.status.status)) { if (!['ERROR', 'STOPPING', 'STOPPED'].includes(this.status.status)) {
...@@ -734,7 +796,7 @@ class NNIManager implements Manager { ...@@ -734,7 +796,7 @@ class NNIManager implements Manager {
revision: 0, revision: 0,
execDuration: 0, execDuration: 0,
logDir: getExperimentRootDir(), logDir: getExperimentRootDir(),
maxSequenceId: 0, nextSequenceId: 0,
params: { params: {
authorName: '', authorName: '',
experimentName: '', experimentName: '',
...@@ -765,13 +827,6 @@ class NNIManager implements Manager { ...@@ -765,13 +827,6 @@ class NNIManager implements Manager {
return Promise.resolve(chkpDir); return Promise.resolve(chkpDir);
} }
private async storeMaxSequenceId(sequenceId: number): Promise<void> {
if (sequenceId > this.experimentProfile.maxSequenceId) {
this.experimentProfile.maxSequenceId = sequenceId;
await this.storeExperimentProfile();
}
}
} }
export { NNIManager }; export { NNIManager };
...@@ -54,7 +54,7 @@ create table ExperimentProfile ( ...@@ -54,7 +54,7 @@ create table ExperimentProfile (
startTime integer, startTime integer,
endTime integer, endTime integer,
logDir text, logDir text,
maxSequenceId integer, nextSequenceId integer,
revision integer); revision integer);
create index ExperimentProfile_id on ExperimentProfile(id); create index ExperimentProfile_id on ExperimentProfile(id);
`; `;
...@@ -67,7 +67,7 @@ function loadExperimentProfile(row: any): ExperimentProfile { ...@@ -67,7 +67,7 @@ function loadExperimentProfile(row: any): ExperimentProfile {
startTime: row.startTime === null ? undefined : row.startTime, startTime: row.startTime === null ? undefined : row.startTime,
endTime: row.endTime === null ? undefined : row.endTime, endTime: row.endTime === null ? undefined : row.endTime,
logDir: row.logDir === null ? undefined : row.logDir, logDir: row.logDir === null ? undefined : row.logDir,
maxSequenceId: row.maxSequenceId, nextSequenceId: row.nextSequenceId,
revision: row.revision revision: row.revision
}; };
} }
...@@ -144,7 +144,7 @@ class SqlDB implements Database { ...@@ -144,7 +144,7 @@ class SqlDB implements Database {
exp.startTime === undefined ? null : exp.startTime, exp.startTime === undefined ? null : exp.startTime,
exp.endTime === undefined ? null : exp.endTime, exp.endTime === undefined ? null : exp.endTime,
exp.logDir === undefined ? null : exp.logDir, exp.logDir === undefined ? null : exp.logDir,
exp.maxSequenceId, exp.nextSequenceId,
exp.revision exp.revision
]; ];
this.log.trace(`storeExperimentProfile: SQL: ${sql}, args: ${JSON.stringify(args)}`); this.log.trace(`storeExperimentProfile: SQL: ${sql}, args: ${JSON.stringify(args)}`);
...@@ -183,7 +183,7 @@ class SqlDB implements Database { ...@@ -183,7 +183,7 @@ class SqlDB implements Database {
event: TrialJobEvent, trialJobId: string, timestamp: number, hyperParameter?: string, jobDetail?: TrialJobDetail): Promise<void> { event: TrialJobEvent, trialJobId: string, timestamp: number, hyperParameter?: string, jobDetail?: TrialJobDetail): Promise<void> {
const sql: string = 'insert into TrialJobEvent values (?,?,?,?,?,?)'; const sql: string = 'insert into TrialJobEvent values (?,?,?,?,?,?)';
const logPath: string | undefined = jobDetail === undefined ? undefined : jobDetail.url; const logPath: string | undefined = jobDetail === undefined ? undefined : jobDetail.url;
const sequenceId: number | undefined = jobDetail === undefined ? undefined : jobDetail.sequenceId; const sequenceId: number | undefined = jobDetail === undefined ? undefined : jobDetail.form.sequenceId;
const args: any[] = [timestamp, trialJobId, event, hyperParameter, logPath, sequenceId]; const args: any[] = [timestamp, trialJobId, event, hyperParameter, logPath, sequenceId];
this.log.trace(`storeTrialJobEvent: SQL: ${sql}, args: ${JSON.stringify(args)}`); this.log.trace(`storeTrialJobEvent: SQL: ${sql}, args: ${JSON.stringify(args)}`);
......
...@@ -72,15 +72,14 @@ describe('Unit test for dataStore', () => { ...@@ -72,15 +72,14 @@ describe('Unit test for dataStore', () => {
}`, }`,
tuner: { tuner: {
className: 'testTuner', className: 'testTuner',
checkpointDir: '/tmp/cp', checkpointDir: '/tmp/cp'
gpuNum: 0
} }
}, },
id: 'exp123', id: 'exp123',
execDuration: 0, execDuration: 0,
startTime: Date.now(), startTime: Date.now(),
endTime: Date.now(), endTime: Date.now(),
maxSequenceId: 0, nextSequenceId: 0,
revision: 0 revision: 0
} }
const id: string = profile.id; const id: string = profile.id;
......
...@@ -41,9 +41,9 @@ class MockedTrainingService extends TrainingService { ...@@ -41,9 +41,9 @@ class MockedTrainingService extends TrainingService {
url: 'http://test', url: 'http://test',
workingDirectory: '/tmp/mocked', workingDirectory: '/tmp/mocked',
form: { form: {
jobType: 'TRIAL' sequenceId: 0,
hyperParameters: { value: '', index: 0 }
}, },
sequenceId: 0
}; };
public jobDetail2: TrialJobDetail = { public jobDetail2: TrialJobDetail = {
id: '3456', id: '3456',
...@@ -55,9 +55,9 @@ class MockedTrainingService extends TrainingService { ...@@ -55,9 +55,9 @@ class MockedTrainingService extends TrainingService {
url: 'http://test', url: 'http://test',
workingDirectory: '/tmp/mocked', workingDirectory: '/tmp/mocked',
form: { form: {
jobType: 'TRIAL' sequenceId: 1,
hyperParameters: { value: '', index: 1 }
}, },
sequenceId: 0
}; };
public listTrialJobs(): Promise<TrialJobDetail[]> { public listTrialJobs(): Promise<TrialJobDetail[]> {
......
...@@ -101,7 +101,7 @@ describe('Unit test for nnimanager', function () { ...@@ -101,7 +101,7 @@ describe('Unit test for nnimanager', function () {
params: updateExperimentParams, params: updateExperimentParams,
id: 'test', id: 'test',
execDuration: 0, execDuration: 0,
maxSequenceId: 0, nextSequenceId: 0,
revision: 0 revision: 0
} }
......
...@@ -40,8 +40,7 @@ const expParams1: ExperimentParams = { ...@@ -40,8 +40,7 @@ const expParams1: ExperimentParams = {
searchSpace: 'SS', searchSpace: 'SS',
tuner: { tuner: {
className: 'testTuner', className: 'testTuner',
checkpointDir: '/tmp', checkpointDir: '/tmp'
gpuNum: 0
} }
}; };
...@@ -64,10 +63,10 @@ const expParams2: ExperimentParams = { ...@@ -64,10 +63,10 @@ const expParams2: ExperimentParams = {
}; };
const profiles: ExperimentProfile[] = [ const profiles: ExperimentProfile[] = [
{ params: expParams1, id: '#1', execDuration: 0, logDir: '/log', startTime: Date.now(), endTime: undefined, maxSequenceId: 0, revision: 1,}, { params: expParams1, id: '#1', execDuration: 0, logDir: '/log', startTime: Date.now(), endTime: undefined, nextSequenceId: 0, revision: 1,},
{ params: expParams1, id: '#1', execDuration: 0, logDir: '/log', startTime: Date.now(), endTime: Date.now(), maxSequenceId: 0, revision: 2 }, { params: expParams1, id: '#1', execDuration: 0, logDir: '/log', startTime: Date.now(), endTime: Date.now(), nextSequenceId: 1, revision: 2 },
{ params: expParams2, id: '#2', execDuration: 0, logDir: '/log', startTime: Date.now(), endTime: Date.now(), maxSequenceId: 0, revision: 2 }, { params: expParams2, id: '#2', execDuration: 0, logDir: '/log', startTime: Date.now(), endTime: Date.now(), nextSequenceId: 0, revision: 2 },
{ params: expParams2, id: '#2', execDuration: 0, logDir: '/log', startTime: Date.now(), endTime: Date.now(), maxSequenceId: 0, revision: 3 } { params: expParams2, id: '#2', execDuration: 0, logDir: '/log', startTime: Date.now(), endTime: Date.now(), nextSequenceId: 2, revision: 3 }
]; ];
const events: TrialJobEventRecord[] = [ const events: TrialJobEventRecord[] = [
......
...@@ -26,7 +26,7 @@ import * as component from './common/component'; ...@@ -26,7 +26,7 @@ import * as component from './common/component';
import { Database, DataStore } from './common/datastore'; import { Database, DataStore } from './common/datastore';
import { setExperimentStartupInfo } from './common/experimentStartupInfo'; import { setExperimentStartupInfo } from './common/experimentStartupInfo';
import { getLogger, Logger, logLevelNameMap } from './common/log'; import { getLogger, Logger, logLevelNameMap } from './common/log';
import { Manager } from './common/manager'; import { Manager, ExperimentStartUpMode } from './common/manager';
import { TrainingService } from './common/trainingService'; import { TrainingService } from './common/trainingService';
import { getLogDir, mkDirP, parseArg, uniqueString } from './common/utils'; import { getLogDir, mkDirP, parseArg, uniqueString } from './common/utils';
import { NNIDataStore } from './core/nniDataStore'; import { NNIDataStore } from './core/nniDataStore';
...@@ -43,10 +43,10 @@ import { ...@@ -43,10 +43,10 @@ import {
function initStartupInfo( function initStartupInfo(
startExpMode: string, resumeExperimentId: string, basePort: number, startExpMode: string, resumeExperimentId: string, basePort: number,
logDirectory: string, experimentLogLevel: string): void { logDirectory: string, experimentLogLevel: string, readonly: boolean): void {
const createNew: boolean = (startExpMode === 'new'); const createNew: boolean = (startExpMode === ExperimentStartUpMode.NEW);
const expId: string = createNew ? uniqueString(8) : resumeExperimentId; const expId: string = createNew ? uniqueString(8) : resumeExperimentId;
setExperimentStartupInfo(createNew, expId, basePort, logDirectory, experimentLogLevel); setExperimentStartupInfo(createNew, expId, basePort, logDirectory, experimentLogLevel, readonly);
} }
async function initContainer(platformMode: string): Promise<void> { async function initContainer(platformMode: string): Promise<void> {
...@@ -108,15 +108,15 @@ if (!['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller'].includes(mode ...@@ -108,15 +108,15 @@ if (!['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller'].includes(mode
} }
const startMode: string = parseArg(['--start_mode', '-s']); const startMode: string = parseArg(['--start_mode', '-s']);
if (!['new', 'resume'].includes(startMode)) { if (![ExperimentStartUpMode.NEW, ExperimentStartUpMode.RESUME].includes(startMode)) {
console.log(`FATAL: unknown start_mode: ${startMode}`); console.log(`FATAL: unknown start_mode: ${startMode}`);
usage(); usage();
process.exit(1); process.exit(1);
} }
const experimentId: string = parseArg(['--experiment_id', '-id']); const experimentId: string = parseArg(['--experiment_id', '-id']);
if (startMode === 'resume' && experimentId.trim().length < 1) { if ((startMode === ExperimentStartUpMode.RESUME) && experimentId.trim().length < 1) {
console.log(`FATAL: cannot resume experiment, invalid experiment_id: ${experimentId}`); console.log(`FATAL: cannot resume the experiment, invalid experiment_id: ${experimentId}`);
usage(); usage();
process.exit(1); process.exit(1);
} }
...@@ -133,7 +133,15 @@ if (logLevel.length > 0 && !logLevelNameMap.has(logLevel)) { ...@@ -133,7 +133,15 @@ if (logLevel.length > 0 && !logLevelNameMap.has(logLevel)) {
console.log(`FATAL: invalid log_level: ${logLevel}`); console.log(`FATAL: invalid log_level: ${logLevel}`);
} }
initStartupInfo(startMode, experimentId, port, logDir, logLevel); const readonlyArg: string = parseArg(['--readonly', '-r']);
if (!('true' || 'false').includes(readonlyArg.toLowerCase())) {
console.log(`FATAL: readonly property should only be true or false`);
usage();
process.exit(1);
}
const readonly = readonlyArg.toLowerCase() == 'true' ? true : false;
initStartupInfo(startMode, experimentId, port, logDir, logLevel, readonly);
mkDirP(getLogDir()) mkDirP(getLogDir())
.then(async () => { .then(async () => {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment