Commit b40e3db7 authored by quzha's avatar quzha
Browse files

Merge branch 'master' of github.com:Microsoft/nni into dev-retiarii

parents efa4e31c 95f731e4
...@@ -46,7 +46,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=- ...@@ -46,7 +46,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-
Parameters Parameters
---------- ----------
brackets_id: int brackets_id: string
brackets id brackets id
brackets_curr_decay: brackets_curr_decay:
brackets curr decay brackets curr decay
...@@ -60,7 +60,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=- ...@@ -60,7 +60,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-
""" """
if increased_id == -1: if increased_id == -1:
increased_id = str(create_parameter_id()) increased_id = str(create_parameter_id())
params_id = '_'.join([str(brackets_id), params_id = '_'.join([brackets_id,
str(brackets_curr_decay), str(brackets_curr_decay),
increased_id]) increased_id])
return params_id return params_id
...@@ -108,6 +108,8 @@ class Bracket(): ...@@ -108,6 +108,8 @@ class Bracket():
Parameters Parameters
---------- ----------
bracket_id: string
The id of this bracket, usually be set as '{Hyperband index}-{SH iteration index}'
s: int s: int
The current SH iteration index. The current SH iteration index.
s_max: int s_max: int
...@@ -122,8 +124,9 @@ class Bracket(): ...@@ -122,8 +124,9 @@ class Bracket():
optimize mode, 'maximize' or 'minimize' optimize mode, 'maximize' or 'minimize'
""" """
def __init__(self, s, s_max, eta, R, optimize_mode): def __init__(self, bracket_id, s, s_max, eta, R, optimize_mode):
self.bracket_id = s self.bracket_id = bracket_id
self.s = s
self.s_max = s_max self.s_max = s_max
self.eta = eta self.eta = eta
self.n = math.ceil((s_max + 1) * (eta ** s) / (s + 1) - _epsilon) self.n = math.ceil((s_max + 1) * (eta ** s) / (s + 1) - _epsilon)
...@@ -147,7 +150,7 @@ class Bracket(): ...@@ -147,7 +150,7 @@ class Bracket():
def increase_i(self): def increase_i(self):
"""i means the ith round. Increase i by 1""" """i means the ith round. Increase i by 1"""
self.i += 1 self.i += 1
if self.i > self.bracket_id: if self.i > self.s:
self.no_more_trial = True self.no_more_trial = True
def set_config_perf(self, i, parameter_id, seq, value): def set_config_perf(self, i, parameter_id, seq, value):
...@@ -256,13 +259,14 @@ class HyperbandClassArgsValidator(ClassArgsValidator): ...@@ -256,13 +259,14 @@ class HyperbandClassArgsValidator(ClassArgsValidator):
def validate_class_args(self, **kwargs): def validate_class_args(self, **kwargs):
Schema({ Schema({
'optimize_mode': self.choices('optimize_mode', 'maximize', 'minimize'), 'optimize_mode': self.choices('optimize_mode', 'maximize', 'minimize'),
Optional('exec_mode'): self.choices('exec_mode', 'serial', 'parallelism'),
Optional('R'): int, Optional('R'): int,
Optional('eta'): int Optional('eta'): int
}).validate(kwargs) }).validate(kwargs)
class Hyperband(MsgDispatcherBase): class Hyperband(MsgDispatcherBase):
"""Hyperband inherit from MsgDispatcherBase rather than Tuner, because it integrates both tuner's functions and assessor's functions. """Hyperband inherit from MsgDispatcherBase rather than Tuner, because it integrates both tuner's functions and assessor's functions.
This is an implementation that could fully leverage available resources, i.e., high parallelism. This is an implementation that could fully leverage available resources or follow the algorithm process, i.e., high parallelism or serial.
A single execution of Hyperband takes a finite budget of (s_max + 1)B. A single execution of Hyperband takes a finite budget of (s_max + 1)B.
Parameters Parameters
...@@ -273,9 +277,11 @@ class Hyperband(MsgDispatcherBase): ...@@ -273,9 +277,11 @@ class Hyperband(MsgDispatcherBase):
the variable that controls the proportion of configurations discarded in each round of SuccessiveHalving the variable that controls the proportion of configurations discarded in each round of SuccessiveHalving
optimize_mode: str optimize_mode: str
optimize mode, 'maximize' or 'minimize' optimize mode, 'maximize' or 'minimize'
exec_mode: str
execution mode, 'serial' or 'parallelism'
""" """
def __init__(self, R=60, eta=3, optimize_mode='maximize'): def __init__(self, R=60, eta=3, optimize_mode='maximize', exec_mode='parallelism'):
"""B = (s_max + 1)R""" """B = (s_max + 1)R"""
super(Hyperband, self).__init__() super(Hyperband, self).__init__()
self.R = R self.R = R
...@@ -285,6 +291,9 @@ class Hyperband(MsgDispatcherBase): ...@@ -285,6 +291,9 @@ class Hyperband(MsgDispatcherBase):
self.completed_hyper_configs = [] # all the completed configs self.completed_hyper_configs = [] # all the completed configs
self.s_max = math.floor(math.log(self.R, self.eta) + _epsilon) self.s_max = math.floor(math.log(self.R, self.eta) + _epsilon)
self.curr_s = self.s_max self.curr_s = self.s_max
self.curr_hb = 0
self.exec_mode = exec_mode
self.curr_bracket_id = None
self.searchspace_json = None self.searchspace_json = None
self.random_state = None self.random_state = None
...@@ -316,25 +325,44 @@ class Hyperband(MsgDispatcherBase): ...@@ -316,25 +325,44 @@ class Hyperband(MsgDispatcherBase):
data: int data: int
number of trial jobs number of trial jobs
""" """
for _ in range(data): self.credit += data
ret = self._get_one_trial_job()
for _ in range(self.credit):
self._request_one_trial_job()
def _request_one_trial_job(self):
ret = self._get_one_trial_job()
if ret is not None:
send(CommandType.NewTrialJob, json_tricks.dumps(ret)) send(CommandType.NewTrialJob, json_tricks.dumps(ret))
self.credit -= 1
def _get_one_trial_job(self): def _get_one_trial_job(self):
"""get one trial job, i.e., one hyperparameter configuration.""" """get one trial job, i.e., one hyperparameter configuration."""
if not self.generated_hyper_configs: if not self.generated_hyper_configs:
if self.curr_s < 0: if self.exec_mode == 'parallelism' or \
self.curr_s = self.s_max (self.exec_mode == 'serial' and (self.curr_bracket_id is None or self.brackets[self.curr_bracket_id].is_completed())):
_logger.debug('create a new bracket, self.curr_s=%d', self.curr_s) if self.curr_s < 0:
self.brackets[self.curr_s] = Bracket(self.curr_s, self.s_max, self.eta, self.R, self.optimize_mode) self.curr_s = self.s_max
next_n, next_r = self.brackets[self.curr_s].get_n_r() self.curr_hb += 1
_logger.debug('new bracket, next_n=%d, next_r=%d', next_n, next_r) _logger.debug('create a new bracket, self.curr_hb=%d, self.curr_s=%d', self.curr_hb, self.curr_s)
assert self.searchspace_json is not None and self.random_state is not None self.curr_bracket_id = '{}-{}'.format(self.curr_hb, self.curr_s)
generated_hyper_configs = self.brackets[self.curr_s].get_hyperparameter_configurations(next_n, next_r, self.brackets[self.curr_bracket_id] = Bracket(self.curr_bracket_id, self.curr_s, self.s_max, self.eta, self.R, self.optimize_mode)
self.searchspace_json, next_n, next_r = self.brackets[self.curr_bracket_id].get_n_r()
self.random_state) _logger.debug('new bracket, next_n=%d, next_r=%d', next_n, next_r)
self.generated_hyper_configs = generated_hyper_configs.copy() assert self.searchspace_json is not None and self.random_state is not None
self.curr_s -= 1 generated_hyper_configs = self.brackets[self.curr_bracket_id].get_hyperparameter_configurations(next_n, next_r,
self.searchspace_json,
self.random_state)
self.generated_hyper_configs = generated_hyper_configs.copy()
self.curr_s -= 1
else:
ret = {
'parameter_id': '-1_0_0',
'parameter_source': 'algorithm',
'parameters': ''
}
send(CommandType.NoMoreTrialJobs, json_tricks.dumps(ret))
return None
assert self.generated_hyper_configs assert self.generated_hyper_configs
params = self.generated_hyper_configs.pop(0) params = self.generated_hyper_configs.pop(0)
...@@ -358,10 +386,12 @@ class Hyperband(MsgDispatcherBase): ...@@ -358,10 +386,12 @@ class Hyperband(MsgDispatcherBase):
parameter_id: parameter id of the finished config parameter_id: parameter id of the finished config
""" """
bracket_id, i, _ = parameter_id.split('_') bracket_id, i, _ = parameter_id.split('_')
hyper_configs = self.brackets[int(bracket_id)].inform_trial_end(int(i)) hyper_configs = self.brackets[bracket_id].inform_trial_end(int(i))
if hyper_configs is not None: if hyper_configs is not None:
_logger.debug('bracket %s next round %s, hyper_configs: %s', bracket_id, i, hyper_configs) _logger.debug('bracket %s next round %s, hyper_configs: %s', bracket_id, i, hyper_configs)
self.generated_hyper_configs = self.generated_hyper_configs + hyper_configs self.generated_hyper_configs = self.generated_hyper_configs + hyper_configs
for _ in range(self.credit):
self._request_one_trial_job()
def handle_trial_end(self, data): def handle_trial_end(self, data):
""" """
...@@ -392,6 +422,7 @@ class Hyperband(MsgDispatcherBase): ...@@ -392,6 +422,7 @@ class Hyperband(MsgDispatcherBase):
""" """
if 'value' in data: if 'value' in data:
data['value'] = json_tricks.loads(data['value']) data['value'] = json_tricks.loads(data['value'])
# multiphase? need to check
if data['type'] == MetricType.REQUEST_PARAMETER: if data['type'] == MetricType.REQUEST_PARAMETER:
assert multi_phase_enabled() assert multi_phase_enabled()
assert data['trial_job_id'] is not None assert data['trial_job_id'] is not None
...@@ -408,7 +439,6 @@ class Hyperband(MsgDispatcherBase): ...@@ -408,7 +439,6 @@ class Hyperband(MsgDispatcherBase):
else: else:
value = extract_scalar_reward(data['value']) value = extract_scalar_reward(data['value'])
bracket_id, i, _ = data['parameter_id'].split('_') bracket_id, i, _ = data['parameter_id'].split('_')
bracket_id = int(bracket_id)
# add <trial_job_id, parameter_id> to self.job_id_para_id_map here, # add <trial_job_id, parameter_id> to self.job_id_para_id_map here,
# because when the first parameter_id is created, trial_job_id is not known yet. # because when the first parameter_id is created, trial_job_id is not known yet.
......
...@@ -58,7 +58,7 @@ def accuracy(output, target, topk=(1,)): ...@@ -58,7 +58,7 @@ def accuracy(output, target, topk=(1,)):
res = [] res = []
for k in topk: for k in topk:
correct_k = correct[:k].view(-1).float().sum(0) correct_k = correct[:k].reshape(-1).float().sum(0)
res.append(correct_k.mul_(1.0 / batch_size)) res.append(correct_k.mul_(1.0 / batch_size))
return res return res
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .trainer import CreamSupernetTrainer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import torch
import logging
from copy import deepcopy
from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup
from .utils import accuracy, reduce_metrics
logger = logging.getLogger(__name__)
class CreamSupernetTrainer(Trainer):
"""
This trainer trains a supernet and output prioritized architectures that can be used for other tasks.
Parameters
----------
model : nn.Module
Model with mutables.
loss : callable
Called with logits and targets. Returns a loss tensor.
val_loss : callable
Called with logits and targets for validation only. Returns a loss tensor.
optimizer : Optimizer
Optimizer that optimizes the model.
num_epochs : int
Number of epochs of training.
train_loader : iterablez
Data loader of training. Raise ``StopIteration`` when one epoch is exhausted.
valid_loader : iterablez
Data loader of validation. Raise ``StopIteration`` when one epoch is exhausted.
mutator : Mutator
A mutator object that has been initialized with the model.
batch_size : int
Batch size.
log_frequency : int
Number of mini-batches to log metrics.
meta_sta_epoch : int
start epoch of using meta matching network to pick teacher architecture
update_iter : int
interval of updating meta matching networks
slices : int
batch size of mini training data in the process of training meta matching network
pool_size : int
board size
pick_method : basestring
how to pick teacher network
choice_num : int
number of operations in supernet
sta_num : int
layer number of each stage in supernet (5 stage in supernet)
acc_gap : int
maximum accuracy improvement to omit the limitation of flops
flops_dict : Dict
dictionary of each layer's operations in supernet
flops_fixed : int
flops of fixed part in supernet
local_rank : int
index of current rank
callbacks : list of Callback
Callbacks to plug into the trainer. See Callbacks.
"""
def __init__(self, model, loss, val_loss,
optimizer, num_epochs, train_loader, valid_loader,
mutator=None, batch_size=64, log_frequency=None,
meta_sta_epoch=20, update_iter=200, slices=2,
pool_size=10, pick_method='meta', choice_num=6,
sta_num=(4, 4, 4, 4, 4), acc_gap=5,
flops_dict=None, flops_fixed=0, local_rank=0, callbacks=None):
assert torch.cuda.is_available()
super(CreamSupernetTrainer, self).__init__(model, mutator, loss, None,
optimizer, num_epochs, None, None,
batch_size, None, None, log_frequency, callbacks)
self.model = model
self.loss = loss
self.val_loss = val_loss
self.train_loader = train_loader
self.valid_loader = valid_loader
self.log_frequency = log_frequency
self.batch_size = batch_size
self.optimizer = optimizer
self.model = model
self.loss = loss
self.num_epochs = num_epochs
self.meta_sta_epoch = meta_sta_epoch
self.update_iter = update_iter
self.slices = slices
self.pick_method = pick_method
self.pool_size = pool_size
self.local_rank = local_rank
self.choice_num = choice_num
self.sta_num = sta_num
self.acc_gap = acc_gap
self.flops_dict = flops_dict
self.flops_fixed = flops_fixed
self.current_student_arch = None
self.current_teacher_arch = None
self.main_proc = (local_rank == 0)
self.current_epoch = 0
self.prioritized_board = []
# size of prioritized board
def _board_size(self):
return len(self.prioritized_board)
# select teacher architecture according to the logit difference
def _select_teacher(self):
self._replace_mutator_cand(self.current_student_arch)
if self.pick_method == 'top1':
meta_value, teacher_cand = 0.5, sorted(
self.prioritized_board, reverse=True)[0][3]
elif self.pick_method == 'meta':
meta_value, cand_idx, teacher_cand = -1000000000, -1, None
for now_idx, item in enumerate(self.prioritized_board):
inputx = item[4]
output = torch.nn.functional.softmax(self.model(inputx), dim=1)
weight = self.model.module.forward_meta(output - item[5])
if weight > meta_value:
meta_value = weight
cand_idx = now_idx
teacher_cand = self.prioritized_board[cand_idx][3]
assert teacher_cand is not None
meta_value = torch.nn.functional.sigmoid(-weight)
else:
raise ValueError('Method Not supported')
return meta_value, teacher_cand
# check whether to update prioritized board
def _isUpdateBoard(self, prec1, flops):
if self.current_epoch <= self.meta_sta_epoch:
return False
if len(self.prioritized_board) < self.pool_size:
return True
if prec1 > self.prioritized_board[-1][1] + self.acc_gap:
return True
if prec1 > self.prioritized_board[-1][1] and flops < self.prioritized_board[-1][2]:
return True
return False
# update prioritized board
def _update_prioritized_board(self, inputs, teacher_output, outputs, prec1, flops):
if self._isUpdateBoard(prec1, flops):
val_prec1 = prec1
training_data = deepcopy(inputs[:self.slices].detach())
if len(self.prioritized_board) == 0:
features = deepcopy(outputs[:self.slices].detach())
else:
features = deepcopy(
teacher_output[:self.slices].detach())
self.prioritized_board.append(
(val_prec1,
prec1,
flops,
self.current_teacher_arch,
training_data,
torch.nn.functional.softmax(
features,
dim=1)))
self.prioritized_board = sorted(
self.prioritized_board, reverse=True)
if len(self.prioritized_board) > self.pool_size:
self.prioritized_board = sorted(
self.prioritized_board, reverse=True)
del self.prioritized_board[-1]
# only update student network weights
def _update_student_weights_only(self, grad_1):
for weight, grad_item in zip(
self.model.module.rand_parameters(self.current_student_arch), grad_1):
weight.grad = grad_item
torch.nn.utils.clip_grad_norm_(
self.model.module.rand_parameters(self.current_student_arch), 1)
self.optimizer.step()
for weight, grad_item in zip(
self.model.module.rand_parameters(self.current_student_arch), grad_1):
del weight.grad
# only update meta networks weights
def _update_meta_weights_only(self, teacher_cand, grad_teacher):
for weight, grad_item in zip(self.model.module.rand_parameters(
teacher_cand, self.pick_method == 'meta'), grad_teacher):
weight.grad = grad_item
# clip gradients
torch.nn.utils.clip_grad_norm_(
self.model.module.rand_parameters(
self.current_student_arch, self.pick_method == 'meta'), 1)
self.optimizer.step()
for weight, grad_item in zip(self.model.module.rand_parameters(
teacher_cand, self.pick_method == 'meta'), grad_teacher):
del weight.grad
# simulate sgd updating
def _simulate_sgd_update(self, w, g, optimizer):
return g * optimizer.param_groups[-1]['lr'] + w
# split training images into several slices
def _get_minibatch_input(self, input):
slice = self.slices
x = deepcopy(input[:slice].clone().detach())
return x
# calculate 1st gradient of student architectures
def _calculate_1st_gradient(self, kd_loss):
self.optimizer.zero_grad()
grad = torch.autograd.grad(
kd_loss,
self.model.module.rand_parameters(self.current_student_arch),
create_graph=True)
return grad
# calculate 2nd gradient of meta networks
def _calculate_2nd_gradient(self, validation_loss, teacher_cand, students_weight):
self.optimizer.zero_grad()
grad_student_val = torch.autograd.grad(
validation_loss,
self.model.module.rand_parameters(self.random_cand),
retain_graph=True)
grad_teacher = torch.autograd.grad(
students_weight[0],
self.model.module.rand_parameters(
teacher_cand,
self.pick_method == 'meta'),
grad_outputs=grad_student_val)
return grad_teacher
# forward training data
def _forward_training(self, x, meta_value):
self._replace_mutator_cand(self.current_student_arch)
output = self.model(x)
with torch.no_grad():
self._replace_mutator_cand(self.current_teacher_arch)
teacher_output = self.model(x)
soft_label = torch.nn.functional.softmax(teacher_output, dim=1)
kd_loss = meta_value * \
self._cross_entropy_loss_with_soft_target(output, soft_label)
return kd_loss
# calculate soft target loss
def _cross_entropy_loss_with_soft_target(self, pred, soft_target):
logsoftmax = torch.nn.LogSoftmax()
return torch.mean(torch.sum(- soft_target * logsoftmax(pred), 1))
# forward validation data
def _forward_validation(self, input, target):
slice = self.slices
x = input[slice:slice * 2].clone()
self._replace_mutator_cand(self.current_student_arch)
output_2 = self.model(x)
validation_loss = self.loss(output_2, target[slice:slice * 2])
return validation_loss
def _isUpdateMeta(self, batch_idx):
isUpdate = True
isUpdate &= (self.current_epoch > self.meta_sta_epoch)
isUpdate &= (batch_idx > 0)
isUpdate &= (batch_idx % self.update_iter == 0)
isUpdate &= (self._board_size() > 0)
return isUpdate
def _replace_mutator_cand(self, cand):
self.mutator._cache = cand
# update meta matching networks
def _run_update(self, input, target, batch_idx):
if self._isUpdateMeta(batch_idx):
x = self._get_minibatch_input(input)
meta_value, teacher_cand = self._select_teacher()
kd_loss = self._forward_training(x, meta_value)
# calculate 1st gradient
grad_1st = self._calculate_1st_gradient(kd_loss)
# simulate updated student weights
students_weight = [
self._simulate_sgd_update(
p, grad_item, self.optimizer) for p, grad_item in zip(
self.model.module.rand_parameters(self.current_student_arch), grad_1st)]
# update student weights
self._update_student_weights_only(grad_1st)
validation_loss = self._forward_validation(input, target)
# calculate 2nd gradient
grad_teacher = self._calculate_2nd_gradient(validation_loss, teacher_cand, students_weight)
# update meta matching networks
self._update_meta_weights_only(teacher_cand, grad_teacher)
# delete internal variants
del grad_teacher, grad_1st, x, validation_loss, kd_loss, students_weight
def _get_cand_flops(self, cand):
flops = 0
for block_id, block in enumerate(cand):
if block == 'LayerChoice1' or block_id == 'LayerChoice23':
continue
for idx, choice in enumerate(cand[block]):
flops += self.flops_dict[block_id][idx] * (1 if choice else 0)
return flops + self.flops_fixed
def train_one_epoch(self, epoch):
self.current_epoch = epoch
meters = AverageMeterGroup()
self.steps_per_epoch = len(self.train_loader)
for step, (input_data, target) in enumerate(self.train_loader):
self.mutator.reset()
self.current_student_arch = self.mutator._cache
input_data, target = input_data.cuda(), target.cuda()
# calculate flops of current architecture
cand_flops = self._get_cand_flops(self.mutator._cache)
# update meta matching network
self._run_update(input_data, target, step)
if self._board_size() > 0:
# select teacher architecture
meta_value, teacher_cand = self._select_teacher()
self.current_teacher_arch = teacher_cand
# forward supernet
if self._board_size() == 0 or epoch <= self.meta_sta_epoch:
self._replace_mutator_cand(self.current_student_arch)
output = self.model(input_data)
loss = self.loss(output, target)
kd_loss, teacher_output, teacher_cand = None, None, None
else:
self._replace_mutator_cand(self.current_student_arch)
output = self.model(input_data)
gt_loss = self.loss(output, target)
with torch.no_grad():
self._replace_mutator_cand(self.current_teacher_arch)
teacher_output = self.model(input_data).detach()
soft_label = torch.nn.functional.softmax(teacher_output, dim=1)
kd_loss = self._cross_entropy_loss_with_soft_target(output, soft_label)
loss = (meta_value * kd_loss + (2 - meta_value) * gt_loss) / 2
# update network
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
# update metrics
prec1, prec5 = accuracy(output, target, topk=(1, 5))
metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
metrics = reduce_metrics(metrics)
meters.update(metrics)
# update prioritized board
self._update_prioritized_board(input_data, teacher_output, output, metrics['prec1'], cand_flops)
if self.main_proc and (step % self.log_frequency == 0 or step + 1 == self.steps_per_epoch):
logger.info("Epoch [%d/%d] Step [%d/%d] %s", epoch + 1, self.num_epochs,
step + 1, len(self.train_loader), meters)
if self.main_proc and self.num_epochs == epoch + 1:
for idx, i in enumerate(self.best_children_pool):
logger.info("No.%s %s", idx, i[:4])
def validate_one_epoch(self, epoch):
self.model.eval()
meters = AverageMeterGroup()
with torch.no_grad():
for step, (x, y) in enumerate(self.valid_loader):
self.mutator.reset()
logits = self.model(x)
loss = self.val_loss(logits, y)
prec1, prec5 = self.accuracy(logits, y, topk=(1, 5))
metrics = {"prec1": prec1, "prec5": prec5, "loss": loss}
metrics = self.reduce_metrics(metrics, self.distributed)
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.valid_loader), meters)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import torch.distributed as dist
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(1.0 / batch_size))
return res
def reduce_metrics(metrics):
return {k: reduce_tensor(v).item() for k, v in metrics.items()}
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= float(os.environ["WORLD_SIZE"])
return rt
...@@ -210,5 +210,5 @@ class DartsTrainer(Trainer): ...@@ -210,5 +210,5 @@ class DartsTrainer(Trainer):
dalphas.append(torch.autograd.grad(loss, self.mutator.parameters())) dalphas.append(torch.autograd.grad(loss, self.mutator.parameters()))
dalpha_pos, dalpha_neg = dalphas # dalpha { L_trn(w+) }, # dalpha { L_trn(w-) } dalpha_pos, dalpha_neg = dalphas # dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
hessian = [(p - n) / 2. * eps for p, n in zip(dalpha_pos, dalpha_neg)] hessian = [(p - n) / (2. * eps) for p, n in zip(dalpha_pos, dalpha_neg)]
return hessian return hessian
...@@ -15,6 +15,7 @@ LIST_CONSTRUCT_KIND = 'prim::ListConstruct' ...@@ -15,6 +15,7 @@ LIST_CONSTRUCT_KIND = 'prim::ListConstruct'
LIST_UNPACK_KIND = 'prim::ListUnpack' LIST_UNPACK_KIND = 'prim::ListUnpack'
TUPLE_CONSTRUCT_KIND = 'prim::TupleConstruct' TUPLE_CONSTRUCT_KIND = 'prim::TupleConstruct'
TUPLE_UNPACK_KIND = 'prim::TupleUnpack' TUPLE_UNPACK_KIND = 'prim::TupleUnpack'
CONSTANT_KIND = 'prim::Constant'
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -68,9 +69,11 @@ class TorchGraph: ...@@ -68,9 +69,11 @@ class TorchGraph:
'Please provide model & dummy_input or the traced_model as inputs') 'Please provide model & dummy_input or the traced_model as inputs')
def _trace(self, model, dummy_input): def _trace(self, model, dummy_input):
with torch.onnx.set_training(model, False): training = model.training
self.trace = torch.jit.trace(model, dummy_input) model.eval()
torch._C._jit_pass_inline(self.trace.graph) self.trace = torch.jit.trace(model, dummy_input)
torch._C._jit_pass_inline(self.trace.graph)
model.train(training)
class TorchProtoGraph(TorchGraph): class TorchProtoGraph(TorchGraph):
...@@ -282,27 +285,35 @@ class TorchModuleGraph(TorchGraph): ...@@ -282,27 +285,35 @@ class TorchModuleGraph(TorchGraph):
self.global_count += 1 self.global_count += 1
op_type = node.kind() op_type = node.kind()
node_group = [node] node_group = [node]
inputs = list() inputs = set()
outputs = list() outputs = set()
node_queue = queue.Queue() node_queue = queue.Queue()
node_queue.put(node) node_queue.put(node)
while not node_queue.empty(): while not node_queue.empty():
curr_node = node_queue.get() curr_node = node_queue.get()
for _input in curr_node.inputs(): for _input in curr_node.inputs():
if _input.node().kind() == CONSTANT_KIND:
continue
input_name = _input.debugName() input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes: if input_name in output_to_node:
predecessor_node = output_to_node[input_name] for predecessor_node in output_to_node[input_name]:
if not self._is_key_func(predecessor_node): if predecessor_node in nodes:
node_group.append(predecessor_node) if not self._is_key_func(predecessor_node):
node_queue.put(predecessor_node) if predecessor_node not in node_group:
else: node_group.append(predecessor_node)
inputs.append(input_name) node_queue.put(predecessor_node)
else:
inputs.add(input_name)
else:
inputs.add(input_name)
else: else:
inputs.append(input_name) inputs.add(input_name)
for output in node.outputs(): for output in node.outputs():
outputs.append(output.debugName()) if output.node().kind() == CONSTANT_KIND:
continue
outputs.add(output.debugName())
nodepy = NodePyGroup(node_name, unique_name, module_type, op_type, nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
node_group, inputs=inputs, outputs=outputs, key_node=node) node_group, inputs=list(inputs), outputs=list(outputs), key_node=node)
return nodepy return nodepy
def _expand_module_node(self, node, node_name, unique_name, op_type, nodes, def _expand_module_node(self, node, node_name, unique_name, op_type, nodes,
...@@ -342,36 +353,46 @@ class TorchModuleGraph(TorchGraph): ...@@ -342,36 +353,46 @@ class TorchModuleGraph(TorchGraph):
if not op_type: if not op_type:
op_type = node.kind() op_type = node.kind()
node_group = [node] node_group = [node]
inputs = list() inputs = set()
outputs = list() outputs = set()
node_queue = queue.Queue() node_queue = queue.Queue()
node_queue.put(node) node_queue.put(node)
visited = {node} visited = {node}
while not node_queue.empty(): while not node_queue.empty():
curr_node = node_queue.get() curr_node = node_queue.get()
for _input in curr_node.inputs(): for _input in curr_node.inputs():
if _input.node().kind() == CONSTANT_KIND:
continue
input_name = _input.debugName() input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes: if input_name in output_to_node:
predecessor_node = output_to_node[input_name] for predecessor_node in output_to_node[input_name]:
if predecessor_node not in visited: if predecessor_node in nodes:
node_group.append(predecessor_node) if predecessor_node not in visited:
node_queue.put(predecessor_node) node_group.append(predecessor_node)
visited.add(predecessor_node) node_queue.put(predecessor_node)
visited.add(predecessor_node)
else:
inputs.add(input_name)
else: else:
inputs.append(input_name) inputs.add(input_name)
for _output in curr_node.outputs(): for _output in curr_node.outputs():
if _output.node().kind() == CONSTANT_KIND:
continue
output_name = _output.debugName() output_name = _output.debugName()
if output_name in input_to_node and input_to_node[output_name] in nodes: if output_name in input_to_node:
successor_node = input_to_node[output_name] for successor_node in input_to_node[output_name]:
if successor_node not in visited: if successor_node in nodes:
node_group.append(successor_node) if successor_node not in visited:
node_queue.put(successor_node) node_group.append(successor_node)
visited.add(successor_node) node_queue.put(successor_node)
visited.add(successor_node)
else:
outputs.add(output_name)
else: else:
outputs.append(output_name) outputs.add(output_name)
nodepy = NodePyGroup(node_name, unique_name, module_type, op_type, nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
node_group, inputs=inputs, outputs=outputs) node_group, inputs=list(inputs), outputs=list(outputs))
return nodepy return nodepy
def _extract_cat_info(self, node_group, cpp_node): def _extract_cat_info(self, node_group, cpp_node):
...@@ -544,7 +565,7 @@ class TorchModuleGraph(TorchGraph): ...@@ -544,7 +565,7 @@ class TorchModuleGraph(TorchGraph):
input_to_node[_input].append(node) input_to_node[_input].append(node)
for output in node.outputs: for output in node.outputs:
assert not output in output_to_node, \ assert not output in output_to_node, \
"One output cannot be generated by multiple nodes" "One output cannot be generated by multiple nodes %s" % output
output_to_node[output] = node output_to_node[output] = node
return name_to_node, input_to_node, output_to_node return name_to_node, input_to_node, output_to_node
...@@ -642,12 +663,22 @@ class TorchModuleGraph(TorchGraph): ...@@ -642,12 +663,22 @@ class TorchModuleGraph(TorchGraph):
omit_useless_nodes = True omit_useless_nodes = True
graph = self.trace.graph graph = self.trace.graph
_logger.debug(graph) _logger.debug(graph)
# build output mapping, from output debugName to its node # build input/output mapping, from input/output debugName to its node
output_to_node = {x.debugName(): n for n in graph.nodes() input_to_node = defaultdict(list)
for x in n.outputs()} output_to_node = defaultdict(list)
# build input mapping, from input debugName to its node for node in graph.nodes():
input_to_node = {x.debugName(): n for n in graph.nodes() if node.kind() == CONSTANT_KIND:
for x in n.inputs()} continue
for x in node.outputs():
if x.node().kind() == CONSTANT_KIND:
continue
output_to_node[x.debugName()].append(node)
assert len(output_to_node[x.debugName()]) <= 1, "One output cannot be generated by multiple nodes %s" % x.debugName()
for x in node.inputs():
if x.node().kind() == CONSTANT_KIND:
continue
input_to_node[x.debugName()].append(node)
# build module mapping, from module name to all nodes (as list) under this module scope # build module mapping, from module name to all nodes (as list) under this module scope
module_to_nodes = defaultdict(list) module_to_nodes = defaultdict(list)
# the mapping of function (non-module in forward) to nodes, key is scope name # the mapping of function (non-module in forward) to nodes, key is scope name
...@@ -668,6 +699,8 @@ class TorchModuleGraph(TorchGraph): ...@@ -668,6 +699,8 @@ class TorchModuleGraph(TorchGraph):
# associate module name with their trace graph nodes # associate module name with their trace graph nodes
for node in graph.nodes(): for node in graph.nodes():
if node.kind() == CONSTANT_KIND:
continue
module_name = self._get_module_name(node.scopeName()) module_name = self._get_module_name(node.scopeName())
if module_name in self.leaf_modules: if module_name in self.leaf_modules:
module_to_nodes[module_name].append(node) module_to_nodes[module_name].append(node)
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import functools
from collections import Counter
from prettytable import PrettyTable
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.compression.pytorch.compressor import PrunerModuleWrapper from nni.compression.pytorch.compressor import PrunerModuleWrapper
try:
from thop import profile
except Exception as e:
print('thop is not found, please install the python package: thop')
raise
__all__ = ['count_flops_params']
def count_flops_params(model: nn.Module, input_size, custom_ops=None, verbose=True):
"""
Count FLOPs and Params of the given model.
This function would identify the mask on the module
and take the pruned shape into consideration.
Note that, for sturctured pruning, we only identify
the remained filters according to its mask, which
not taking the pruned input channels into consideration,
so the calculated FLOPs will be larger than real number.
Parameters def _get_params(m):
--------- return sum([p.numel() for p in m.parameters()])
model : nn.Module
target model.
input_size: list, tuple
the input shape of data
custom_ops: dict
a mapping of (module: custom operation)
the custom operation will overwrite the default operation.
for reference, please see ``custom_mask_ops``.
Returns
-------
flops: float
total flops of the model
params:
total params of the model
"""
assert input_size is not None class ModelProfiler:
device = next(model.parameters()).device def __init__(self, custom_ops=None, mode='default'):
inputs = torch.randn(input_size).to(device) """
ModelProfiler is used to share state to hooks.
hook_module_list = [] Parameters
if custom_ops is None: ----------
custom_ops = {} custom_ops: dict
custom_mask_ops.update(custom_ops) a mapping of (module -> torch.nn.Module : custom operation)
prev_m = None the custom operation is a callback funtion to calculate
for m in model.modules(): the module flops, parameters and the weight shape, it will overwrite the default operation.
weight_mask = None for reference, please see ``self.ops``.
m_type = type(m) mode:
if m_type in custom_mask_ops: the mode of how to collect information. If the mode is set to `default`,
if isinstance(prev_m, PrunerModuleWrapper): only the information of convolution and linear will be collected.
weight_mask = prev_m.weight_mask If the mode is set to `full`, other operations will also be collected.
"""
m.register_buffer('weight_mask', weight_mask) self.ops = {
hook_module_list.append(m) nn.Conv1d: self._count_convNd,
prev_m = m nn.Conv2d: self._count_convNd,
nn.Conv3d: self._count_convNd,
nn.Linear: self._count_linear
}
self._count_bias = False
if mode == 'full':
self.ops.update({
nn.ConvTranspose1d: self._count_convNd,
nn.ConvTranspose2d: self._count_convNd,
nn.ConvTranspose3d: self._count_convNd,
nn.BatchNorm1d: self._count_bn,
nn.BatchNorm2d: self._count_bn,
nn.BatchNorm3d: self._count_bn,
nn.LeakyReLU: self._count_relu,
nn.AvgPool1d: self._count_avgpool,
nn.AvgPool2d: self._count_avgpool,
nn.AvgPool3d: self._count_avgpool,
nn.AdaptiveAvgPool1d: self._count_adap_avgpool,
nn.AdaptiveAvgPool2d: self._count_adap_avgpool,
nn.AdaptiveAvgPool3d: self._count_adap_avgpool,
nn.Upsample: self._count_upsample,
nn.UpsamplingBilinear2d: self._count_upsample,
nn.UpsamplingNearest2d: self._count_upsample
})
self._count_bias = True
flops, params = profile(model, inputs=(inputs, ), custom_ops=custom_mask_ops, verbose=verbose) if custom_ops is not None:
self.ops.update(custom_ops)
self.mode = mode
self.results = []
for m in hook_module_list: def _push_result(self, result):
m._buffers.pop("weight_mask") self.results.append(result)
# Remove registerd buffer on the model, and fixed following issue:
# https://github.com/Lyken17/pytorch-OpCounter/issues/96
for m in model.modules():
if 'total_ops' in m._buffers:
m._buffers.pop("total_ops")
if 'total_params' in m._buffers:
m._buffers.pop("total_params")
return flops, params def _get_result(self, m, flops):
# assume weight is called `weight`, otherwise it's not applicable
# if user customize the operation, the callback function should
# return the dict result, inluding calculated flops, params and weight_shape.
def count_convNd_mask(m, x, y): result = {
""" 'flops': flops,
The forward hook to count FLOPs and Parameters of convolution operation. 'params': _get_params(m),
Parameters 'weight_shape': tuple(m.weight.size()) if hasattr(m, 'weight') else 0,
---------- }
m : torch.nn.Module return result
convolution module to calculate the FLOPs and Parameters
x : torch.Tensor def _count_convNd(self, m, x, y):
input data cin = m.in_channels
y : torch.Tensor kernel_ops = m.weight.size()[2] * m.weight.size()[3]
output data output_size = torch.zeros(y.size()[2:]).numel()
""" cout = y.size()[1]
output_channel = y.size()[1]
output_size = torch.zeros(y.size()[2:]).numel() if hasattr(m, 'weight_mask'):
kernel_size = torch.zeros(m.weight.size()[2:]).numel() cout = m.weight_mask.sum() // (cin * kernel_ops)
total_ops = cout * output_size * kernel_ops * cin // m.groups # cout x oW x oH
if self._count_bias:
bias_flops = 1 if m.bias is not None else 0
total_ops += cout * output_size * bias_flops
return self._get_result(m, total_ops)
def _count_linear(self, m, x, y):
out_features = m.out_features
if hasattr(m, 'weight_mask'):
out_features = m.weight_mask.sum() // m.in_features
total_ops = out_features * m.in_features
if self._count_bias:
bias_flops = 1 if m.bias is not None else 0
total_ops += out_features * bias_flops
return self._get_result(m, total_ops)
def _count_bn(self, m, x, y):
total_ops = 2 * x[0].numel()
return self._get_result(m, total_ops)
def _count_relu(self, m, x, y):
total_ops = x[0].numel()
return self._get_result(m, total_ops)
bias_flops = 1 if m.bias is not None else 0 def _count_avgpool(self, m, x, y):
total_ops = y.numel()
return self._get_result(m, total_ops)
if m.weight_mask is not None: def _count_adap_avgpool(self, m, x, y):
output_channel = m.weight_mask.sum() // (m.in_channels * kernel_size) kernel = torch.Tensor([*(x[0].shape[2:])]) // torch.Tensor(list((m.output_size,))).squeeze()
total_add = int(torch.prod(kernel))
total_div = 1
kernel_ops = total_add + total_div
num_elements = y.numel()
total_ops = kernel_ops * num_elements
total_ops = output_channel * output_size * (m.in_channels // m.groups * kernel_size + bias_flops) return self._get_result(m, total_ops)
m.total_ops += torch.DoubleTensor([int(total_ops)]) def _count_upsample(self, m, x, y):
if m.mode == 'linear':
total_ops = y.nelement() * 5 # 2 muls + 3 add
elif m.mode == 'bilinear':
# https://en.wikipedia.org/wiki/Bilinear_interpolation
total_ops = y.nelement() * 11 # 6 muls + 5 adds
elif m.mode == 'bicubic':
# https://en.wikipedia.org/wiki/Bicubic_interpolation
# Product matrix [4x4] x [4x4] x [4x4]
ops_solve_A = 224 # 128 muls + 96 adds
ops_solve_p = 35 # 16 muls + 12 adds + 4 muls + 3 adds
total_ops = y.nelement() * (ops_solve_A + ops_solve_p)
elif m.mode == 'trilinear':
# https://en.wikipedia.org/wiki/Trilinear_interpolation
# can viewed as 2 bilinear + 1 linear
total_ops = y.nelement() * (13 * 2 + 5)
else:
total_ops = 0
return self._get_result(m, total_ops)
def count_linear_mask(m, x, y): def count_module(self, m, x, y, name):
# assume x is tuple of single tensor
result = self.ops[type(m)](m, x, y)
total_result = {
'name': name,
'input_size': tuple(x[0].size()),
'output_size': tuple(y.size()),
'module_type': type(m).__name__,
**result
}
self._push_result(total_result)
def sum_flops(self):
return sum([s['flops'] for s in self.results])
def sum_params(self):
return sum({s['name']: s['params'] for s in self.results}.values())
def format_results(self):
table = PrettyTable()
name_counter = Counter([s['name'] for s in self.results])
has_multi_use = any(map(lambda v: v > 1, name_counter.values()))
name_counter = Counter() # clear the counter to count from 0
headers = [
'Index',
'Name',
'Type',
'Weight Shape',
'FLOPs',
'#Params',
]
if has_multi_use:
headers.append('#Call')
table.field_names = headers
for i, result in enumerate(self.results):
row_values = [
i,
result['name'],
result['module_type'],
str(result['weight_shape']),
result['flops'],
result['params'],
]
name_counter[result['name']] += 1
if has_multi_use:
row_values.append(name_counter[result['name']])
table.add_row(row_values)
return table
def count_flops_params(model, x, custom_ops=None, verbose=True, mode='default'):
""" """
The forward hook to count FLOPs and Parameters of linear transformation. Count FLOPs and Params of the given model. This function would
identify the mask on the module and take the pruned shape into consideration.
Note that, for sturctured pruning, we only identify the remained filters
according to its mask, and do not take the pruned input channels into consideration,
so the calculated FLOPs will be larger than real number.
Parameters Parameters
---------- ---------
m : torch.nn.Module model : nn.Module
linear to calculate the FLOPs and Parameters Target model.
x : torch.Tensor x : tuple or tensor
input data The input shape of data (a tuple), a tensor or a tuple of tensor as input data.
y : torch.Tensor custom_ops : dict
output data A mapping of (module -> torch.nn.Module : custom operation)
the custom operation is a callback funtion to calculate
the module flops and parameters, it will overwrite the default operation.
for reference, please see ``ops`` in ``ModelProfiler``.
verbose : bool
If False, mute detail information about modules. Default is True.
mode : str
the mode of how to collect information. If the mode is set to ``default``,
only the information of convolution and linear will be collected.
If the mode is set to ``full``, other operations will also be collected.
Returns
-------
tuple of int, int and dict
Representing total FLOPs, total parameters, and a detailed list of results respectively.
The list of results are a list of dict, each of which contains (name, module_type, weight_shape,
flops, params, input_size, output_size) as its keys.
""" """
output_channel = y.numel()
bias_flops = 1 if m.bias is not None else 0 assert isinstance(x, tuple) or isinstance(x, torch.Tensor)
assert mode in ['default', 'full']
original_device = next(model.parameters()).device
training = model.training
if isinstance(x, tuple) and all(isinstance(t, int) for t in x):
x = (torch.zeros(x).to(original_device), )
elif torch.is_tensor(x):
x = (x.to(original_device), )
else:
x = (t.to(original_device) for t in x)
handler_collection = []
profiler = ModelProfiler(custom_ops, mode)
prev_m = None
for name, m in model.named_modules():
# dealing with weight mask here
if isinstance(prev_m, PrunerModuleWrapper):
# weight mask is set to weight mask of its parent (wrapper)
weight_mask = prev_m.weight_mask
m.weight_mask = weight_mask
prev_m = m
if type(m) in profiler.ops:
# if a leaf node
_handler = m.register_forward_hook(functools.partial(profiler.count_module, name=name))
handler_collection.append(_handler)
model.eval()
if m.weight_mask is not None: with torch.no_grad():
output_channel = m.weight_mask.sum() // m.in_features model(*x)
total_ops = output_channel * (m.in_features + bias_flops) # restore origin status
for name, m in model.named_modules():
if hasattr(m, 'weight_mask'):
delattr(m, 'weight_mask')
m.total_ops += torch.DoubleTensor([int(total_ops)]) model.train(training).to(original_device)
for handler in handler_collection:
handler.remove()
if verbose:
# get detail information
print(profiler.format_results())
print(f'FLOPs total: {profiler.sum_flops()}')
print(f'#Params total: {profiler.sum_params()}')
custom_mask_ops = { return profiler.sum_flops(), profiler.sum_params(), profiler.results
nn.Conv1d: count_convNd_mask, \ No newline at end of file
nn.Conv2d: count_convNd_mask,
nn.Conv3d: count_convNd_mask,
nn.Linear: count_linear_mask,
}
...@@ -36,9 +36,11 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None): ...@@ -36,9 +36,11 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
# this traced model. # this traced model.
if traced is None: if traced is None:
assert model is not None and dummy_input is not None assert model is not None and dummy_input is not None
with torch.onnx.set_training(model, False): training = model.training
# We need to trace the model in this way, else it will have problems model.eval()
traced = torch.jit.trace(model, dummy_input) # We need to trace the model in eval mode
traced = torch.jit.trace(model, dummy_input)
model.train(training)
fix_group_mask = GroupMaskConflict(masks, model, dummy_input, traced) fix_group_mask = GroupMaskConflict(masks, model, dummy_input, traced)
masks = fix_group_mask.fix_mask() masks = fix_group_mask.fix_mask()
......
...@@ -100,10 +100,7 @@ class TrialResult: ...@@ -100,10 +100,7 @@ class TrialResult:
self.value = None self.value = None
self.trialJobId = None self.trialJobId = None
for key in json_obj.keys(): for key in json_obj.keys():
if key == 'id': setattr(self, key, json_obj[key])
setattr(self, 'trialJobId', json_obj[key])
elif hasattr(self, key):
setattr(self, key, json_obj[key])
self.value = json.loads(self.value) self.value = json.loads(self.value)
def __repr__(self): def __repr__(self):
...@@ -220,10 +217,7 @@ class TrialJob: ...@@ -220,10 +217,7 @@ class TrialJob:
self.finalMetricData = None self.finalMetricData = None
self.stderrPath = None self.stderrPath = None
for key in json_obj.keys(): for key in json_obj.keys():
if key == 'id': setattr(self, key, json_obj[key])
setattr(self, 'trialJobId', json_obj[key])
elif hasattr(self, key):
setattr(self, key, json_obj[key])
if self.hyperParameters: if self.hyperParameters:
self.hyperParameters = [TrialHyperParameters(json.loads(e)) for e in self.hyperParameters] self.hyperParameters = [TrialHyperParameters(json.loads(e)) for e in self.hyperParameters]
if self.finalMetricData: if self.finalMetricData:
......
...@@ -39,7 +39,7 @@ def _sort_history(history): ...@@ -39,7 +39,7 @@ def _sort_history(history):
# Tuner global variables # Tuner global variables
_next_parameter_id = 0 _next_parameter_id = 0
_trial_params = {} _trial_params = {}
'''key: trial job ID; value: parameters''' '''key: parameter ID; value: parameters'''
_customized_parameter_ids = set() _customized_parameter_ids = set()
...@@ -114,7 +114,7 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -114,7 +114,7 @@ class MsgDispatcher(MsgDispatcherBase):
data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value' data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value'
""" """
for entry in data: for entry in data:
entry['value'] = entry['value'] if type(entry['value']) is str else json_tricks.dumps(entry['value']) entry['value'] = entry['value'] if type(entry['value']) is str else json_tricks.dumps(entry['value'])
entry['value'] = json_tricks.loads(entry['value']) entry['value'] = json_tricks.loads(entry['value'])
self.tuner.import_data(data) self.tuner.import_data(data)
...@@ -182,8 +182,11 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -182,8 +182,11 @@ class MsgDispatcher(MsgDispatcherBase):
customized = True customized = True
else: else:
customized = False customized = False
if id_ in _trial_params:
self.tuner.receive_trial_result(id_, _trial_params[id_], value, customized=customized, self.tuner.receive_trial_result(id_, _trial_params[id_], value, customized=customized,
trial_job_id=data.get('trial_job_id')) trial_job_id=data.get('trial_job_id'))
else:
_logger.warning('Find unknown job parameter id %s, maybe something goes wrong.', _trial_params[id_])
def _handle_intermediate_metric_data(self, data): def _handle_intermediate_metric_data(self, data):
"""Call assessor to process intermediate results """Call assessor to process intermediate results
......
...@@ -9,7 +9,7 @@ if trial_env_vars.NNI_PLATFORM is None: ...@@ -9,7 +9,7 @@ if trial_env_vars.NNI_PLATFORM is None:
from .standalone import * from .standalone import *
elif trial_env_vars.NNI_PLATFORM == 'unittest': elif trial_env_vars.NNI_PLATFORM == 'unittest':
from .test import * from .test import *
elif trial_env_vars.NNI_PLATFORM in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml'): elif trial_env_vars.NNI_PLATFORM in ('adl', 'local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml'):
from .local import * from .local import *
else: else:
raise RuntimeError('Unknown platform %s' % trial_env_vars.NNI_PLATFORM) raise RuntimeError('Unknown platform %s' % trial_env_vars.NNI_PLATFORM)
...@@ -5,11 +5,14 @@ import os ...@@ -5,11 +5,14 @@ import os
import sys import sys
import json import json
import tempfile import tempfile
import time
import socket import socket
import string import string
import random import random
import ruamel.yaml as yaml import ruamel.yaml as yaml
import psutil import psutil
import filelock
import glob
from colorama import Fore from colorama import Fore
from .constants import ERROR_INFO, NORMAL_INFO, WARNING_INFO from .constants import ERROR_INFO, NORMAL_INFO, WARNING_INFO
...@@ -95,3 +98,36 @@ def generate_temp_dir(): ...@@ -95,3 +98,36 @@ def generate_temp_dir():
temp_dir = generate_folder_name() temp_dir = generate_folder_name()
os.makedirs(temp_dir) os.makedirs(temp_dir)
return temp_dir return temp_dir
class SimplePreemptiveLock(filelock.SoftFileLock):
'''this is a lock support check lock expiration, if you do not need check expiration, you can use SoftFileLock'''
def __init__(self, lock_file, stale=-1):
super(__class__, self).__init__(lock_file, timeout=-1)
self._lock_file_name = '{}.{}'.format(self._lock_file, os.getpid())
self._stale = stale
def _acquire(self):
open_mode = os.O_WRONLY | os.O_CREAT | os.O_EXCL | os.O_TRUNC
try:
lock_file_names = glob.glob(self._lock_file + '.*')
for file_name in lock_file_names:
if os.path.exists(file_name) and (self._stale < 0 or time.time() - os.stat(file_name).st_mtime < self._stale):
return None
fd = os.open(self._lock_file_name, open_mode)
except (IOError, OSError):
pass
else:
self._lock_file_fd = fd
return None
def _release(self):
os.close(self._lock_file_fd)
self._lock_file_fd = None
try:
os.remove(self._lock_file_name)
except OSError:
pass
return None
def get_file_lock(path: string, stale=-1):
return SimplePreemptiveLock(path + '.lock', stale=-1)
...@@ -124,7 +124,7 @@ common_schema = { ...@@ -124,7 +124,7 @@ common_schema = {
Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$', error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')), Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$', error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')),
Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999), Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999),
'trainingServicePlatform': setChoice( 'trainingServicePlatform': setChoice(
'trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml'), 'trainingServicePlatform', 'adl', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml'),
Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'), Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'),
Optional('multiPhase'): setType('multiPhase', bool), Optional('multiPhase'): setType('multiPhase', bool),
Optional('multiThread'): setType('multiThread', bool), Optional('multiThread'): setType('multiThread', bool),
...@@ -262,6 +262,30 @@ aml_config_schema = { ...@@ -262,6 +262,30 @@ aml_config_schema = {
} }
} }
adl_trial_schema = {
'trial':{
'codeDir': setType('codeDir', str),
'command': setType('command', str),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
'image': setType('image', str),
Optional('imagePullSecrets'): [{
'name': setType('name', str)
}],
Optional('nfs'): {
'server': setType('server', str),
'path': setType('path', str),
'containerMountPath': setType('containerMountPath', str)
},
Optional('adaptive'): setType('adaptive', bool),
Optional('checkpoint'): {
'storageClass': setType('storageClass', str),
'storageSize': setType('storageSize', str)
},
Optional('cpuNum'): setNumberRange('cpuNum', int, 0, 99999),
Optional('memorySize'): setType('memorySize', str)
}
}
kubeflow_trial_schema = { kubeflow_trial_schema = {
'trial': { 'trial': {
'codeDir': setPathCheck('codeDir'), 'codeDir': setPathCheck('codeDir'),
...@@ -404,6 +428,7 @@ machine_list_schema = { ...@@ -404,6 +428,7 @@ machine_list_schema = {
} }
training_service_schema_dict = { training_service_schema_dict = {
'adl': Schema({**common_schema, **adl_trial_schema}),
'local': Schema({**common_schema, **common_trial_schema}), 'local': Schema({**common_schema, **common_trial_schema}),
'remote': Schema({**common_schema, **common_trial_schema, **machine_list_schema, **remote_config_schema}), 'remote': Schema({**common_schema, **common_trial_schema, **machine_list_schema, **remote_config_schema}),
'pai': Schema({**common_schema, **pai_trial_schema, **pai_config_schema}), 'pai': Schema({**common_schema, **pai_trial_schema, **pai_config_schema}),
......
...@@ -4,8 +4,10 @@ ...@@ -4,8 +4,10 @@
import os import os
import json import json
import shutil import shutil
import time
from .constants import NNICTL_HOME_DIR from .constants import NNICTL_HOME_DIR
from .command_utils import print_error from .command_utils import print_error
from .common_utils import get_file_lock
class Config: class Config:
'''a util class to load and save config''' '''a util class to load and save config'''
...@@ -34,7 +36,7 @@ class Config: ...@@ -34,7 +36,7 @@ class Config:
if self.config: if self.config:
try: try:
with open(self.config_file, 'w') as file: with open(self.config_file, 'w') as file:
json.dump(self.config, file) json.dump(self.config, file, indent=4)
except IOError as error: except IOError as error:
print('Error:', error) print('Error:', error)
return return
...@@ -54,39 +56,53 @@ class Experiments: ...@@ -54,39 +56,53 @@ class Experiments:
def __init__(self, home_dir=NNICTL_HOME_DIR): def __init__(self, home_dir=NNICTL_HOME_DIR):
os.makedirs(home_dir, exist_ok=True) os.makedirs(home_dir, exist_ok=True)
self.experiment_file = os.path.join(home_dir, '.experiment') self.experiment_file = os.path.join(home_dir, '.experiment')
self.experiments = self.read_file() self.lock = get_file_lock(self.experiment_file, stale=2)
with self.lock:
self.experiments = self.read_file()
def add_experiment(self, expId, port, startTime, file_name, platform, experiment_name, endTime='N/A', status='INITIALIZED'): def add_experiment(self, expId, port, startTime, platform, experiment_name, endTime='N/A', status='INITIALIZED',
'''set {key:value} paris to self.experiment''' tag=[], pid=None, webuiUrl=[], logDir=[]):
self.experiments[expId] = {} '''set {key:value} pairs to self.experiment'''
self.experiments[expId]['port'] = port with self.lock:
self.experiments[expId]['startTime'] = startTime self.experiments = self.read_file()
self.experiments[expId]['endTime'] = endTime self.experiments[expId] = {}
self.experiments[expId]['status'] = status self.experiments[expId]['id'] = expId
self.experiments[expId]['fileName'] = file_name self.experiments[expId]['port'] = port
self.experiments[expId]['platform'] = platform self.experiments[expId]['startTime'] = startTime
self.experiments[expId]['experimentName'] = experiment_name self.experiments[expId]['endTime'] = endTime
self.write_file() self.experiments[expId]['status'] = status
self.experiments[expId]['platform'] = platform
self.experiments[expId]['experimentName'] = experiment_name
self.experiments[expId]['tag'] = tag
self.experiments[expId]['pid'] = pid
self.experiments[expId]['webuiUrl'] = webuiUrl
self.experiments[expId]['logDir'] = logDir
self.write_file()
def update_experiment(self, expId, key, value): def update_experiment(self, expId, key, value):
'''Update experiment''' '''Update experiment'''
if expId not in self.experiments: with self.lock:
return False self.experiments = self.read_file()
self.experiments[expId][key] = value if expId not in self.experiments:
self.write_file() return False
return True self.experiments[expId][key] = value
self.write_file()
return True
def remove_experiment(self, expId): def remove_experiment(self, expId):
'''remove an experiment by id''' '''remove an experiment by id'''
if expId in self.experiments: with self.lock:
fileName = self.experiments.pop(expId).get('fileName') self.experiments = self.read_file()
if fileName: if expId in self.experiments:
logPath = os.path.join(NNICTL_HOME_DIR, fileName) self.experiments.pop(expId)
try: fileName = expId
shutil.rmtree(logPath) if fileName:
except FileNotFoundError: logPath = os.path.join(NNICTL_HOME_DIR, fileName)
print_error('{0} does not exist.'.format(logPath)) try:
self.write_file() shutil.rmtree(logPath)
except FileNotFoundError:
print_error('{0} does not exist.'.format(logPath))
self.write_file()
def get_all_experiments(self): def get_all_experiments(self):
'''return all of experiments''' '''return all of experiments'''
...@@ -96,7 +112,7 @@ class Experiments: ...@@ -96,7 +112,7 @@ class Experiments:
'''save config to local file''' '''save config to local file'''
try: try:
with open(self.experiment_file, 'w') as file: with open(self.experiment_file, 'w') as file:
json.dump(self.experiments, file) json.dump(self.experiments, file, indent=4)
except IOError as error: except IOError as error:
print('Error:', error) print('Error:', error)
return '' return ''
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import os import os
from colorama import Fore from colorama import Fore
NNICTL_HOME_DIR = os.path.join(os.path.expanduser('~'), '.local', 'nnictl') NNICTL_HOME_DIR = os.path.join(os.path.expanduser('~'), 'nni-experiments')
NNI_HOME_DIR = os.path.join(os.path.expanduser('~'), 'nni-experiments') NNI_HOME_DIR = os.path.join(os.path.expanduser('~'), 'nni-experiments')
...@@ -64,21 +64,21 @@ TRIAL_MONITOR_TAIL = '---------------------------------------------------------- ...@@ -64,21 +64,21 @@ TRIAL_MONITOR_TAIL = '----------------------------------------------------------
INSTALLABLE_PACKAGE_META = { INSTALLABLE_PACKAGE_META = {
'SMAC': { 'SMAC': {
'type': 'tuner', 'type': 'tuner',
'class_name': 'nni.smac_tuner.smac_tuner.SMACTuner', 'class_name': 'nni.algorithms.hpo.smac_tuner.smac_tuner.SMACTuner',
'code_sub_dir': 'smac_tuner', 'code_sub_dir': 'smac_tuner',
'class_args_validator': 'nni.smac_tuner.smac_tuner.SMACClassArgsValidator' 'class_args_validator': 'nni.algorithms.hpo.smac_tuner.smac_tuner.SMACClassArgsValidator'
}, },
'BOHB': { 'BOHB': {
'type': 'advisor', 'type': 'advisor',
'class_name': 'nni.bohb_advisor.bohb_advisor.BOHB', 'class_name': 'nni.algorithms.hpo.bohb_advisor.bohb_advisor.BOHB',
'code_sub_dir': 'bohb_advisor', 'code_sub_dir': 'bohb_advisor',
'class_args_validator': 'nni.bohb_advisor.bohb_advisor.BOHBClassArgsValidator' 'class_args_validator': 'nni.algorithms.hpo.bohb_advisor.bohb_advisor.BOHBClassArgsValidator'
}, },
'PPOTuner': { 'PPOTuner': {
'type': 'tuner', 'type': 'tuner',
'class_name': 'nni.ppo_tuner.ppo_tuner.PPOTuner', 'class_name': 'nni.algorithms.hpo.ppo_tuner.ppo_tuner.PPOTuner',
'code_sub_dir': 'ppo_tuner', 'code_sub_dir': 'ppo_tuner',
'class_args_validator': 'nni.ppo_tuner.ppo_tuner.PPOClassArgsValidator' 'class_args_validator': 'nni.algorithms.hpo.ppo_tuner.ppo_tuner.PPOClassArgsValidator'
} }
} }
......
...@@ -23,10 +23,11 @@ from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SU ...@@ -23,10 +23,11 @@ from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SU
from .command_utils import check_output_command, kill_command from .command_utils import check_output_command, kill_command
from .nnictl_utils import update_experiment from .nnictl_utils import update_experiment
def get_log_path(config_file_name): def get_log_path(experiment_id):
'''generate stdout and stderr log path''' '''generate stdout and stderr log path'''
stdout_full_path = os.path.join(NNICTL_HOME_DIR, config_file_name, 'stdout') os.makedirs(os.path.join(NNICTL_HOME_DIR, experiment_id, 'log'), exist_ok=True)
stderr_full_path = os.path.join(NNICTL_HOME_DIR, config_file_name, 'stderr') stdout_full_path = os.path.join(NNICTL_HOME_DIR, experiment_id, 'log', 'nnictl_stdout.log')
stderr_full_path = os.path.join(NNICTL_HOME_DIR, experiment_id, 'log', 'nnictl_stderr.log')
return stdout_full_path, stderr_full_path return stdout_full_path, stderr_full_path
def print_log_content(config_file_name): def print_log_content(config_file_name):
...@@ -38,7 +39,7 @@ def print_log_content(config_file_name): ...@@ -38,7 +39,7 @@ def print_log_content(config_file_name):
print_normal(' Stderr:') print_normal(' Stderr:')
print(check_output_command(stderr_full_path)) print(check_output_command(stderr_full_path))
def start_rest_server(port, platform, mode, config_file_name, foreground=False, experiment_id=None, log_dir=None, log_level=None): def start_rest_server(port, platform, mode, experiment_id, foreground=False, log_dir=None, log_level=None):
'''Run nni manager process''' '''Run nni manager process'''
if detect_port(port): if detect_port(port):
print_error('Port %s is used by another process, please reset the port!\n' \ print_error('Port %s is used by another process, please reset the port!\n' \
...@@ -63,7 +64,8 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False, ...@@ -63,7 +64,8 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False,
node_command = os.path.join(entry_dir, 'node.exe') node_command = os.path.join(entry_dir, 'node.exe')
else: else:
node_command = os.path.join(entry_dir, 'node') node_command = os.path.join(entry_dir, 'node')
cmds = [node_command, '--max-old-space-size=4096', entry_file, '--port', str(port), '--mode', platform] cmds = [node_command, '--max-old-space-size=4096', entry_file, '--port', str(port), '--mode', platform, \
'--experiment_id', experiment_id]
if mode == 'view': if mode == 'view':
cmds += ['--start_mode', 'resume'] cmds += ['--start_mode', 'resume']
cmds += ['--readonly', 'true'] cmds += ['--readonly', 'true']
...@@ -73,13 +75,12 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False, ...@@ -73,13 +75,12 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False,
cmds += ['--log_dir', log_dir] cmds += ['--log_dir', log_dir]
if log_level is not None: if log_level is not None:
cmds += ['--log_level', log_level] cmds += ['--log_level', log_level]
if mode in ['resume', 'view']:
cmds += ['--experiment_id', experiment_id]
if foreground: if foreground:
cmds += ['--foreground', 'true'] cmds += ['--foreground', 'true']
stdout_full_path, stderr_full_path = get_log_path(config_file_name) stdout_full_path, stderr_full_path = get_log_path(experiment_id)
with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file: with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) start_time = time.time()
time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
#add time information in the header of log files #add time information in the header of log files
log_header = LOG_HEADER % str(time_now) log_header = LOG_HEADER % str(time_now)
stdout_file.write(log_header) stdout_file.write(log_header)
...@@ -95,7 +96,7 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False, ...@@ -95,7 +96,7 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False,
process = Popen(cmds, cwd=entry_dir, stdout=PIPE, stderr=PIPE) process = Popen(cmds, cwd=entry_dir, stdout=PIPE, stderr=PIPE)
else: else:
process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file) process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file)
return process, str(time_now) return process, int(start_time * 1000)
def set_trial_config(experiment_config, port, config_file_name): def set_trial_config(experiment_config, port, config_file_name):
'''set trial configuration''' '''set trial configuration'''
...@@ -136,6 +137,14 @@ def set_local_config(experiment_config, port, config_file_name): ...@@ -136,6 +137,14 @@ def set_local_config(experiment_config, port, config_file_name):
return set_trial_config(experiment_config, port, config_file_name), None return set_trial_config(experiment_config, port, config_file_name), None
def set_adl_config(experiment_config, port, config_file_name):
'''set adl configuration'''
result, message = setNNIManagerIp(experiment_config, port, config_file_name)
if not result:
return result, message
#set trial_config
return set_trial_config(experiment_config, port, config_file_name), None
def set_remote_config(experiment_config, port, config_file_name): def set_remote_config(experiment_config, port, config_file_name):
'''Call setClusterMetadata to pass trial''' '''Call setClusterMetadata to pass trial'''
#set machine_list #set machine_list
...@@ -393,7 +402,9 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res ...@@ -393,7 +402,9 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
'''call set_cluster_metadata for specific platform''' '''call set_cluster_metadata for specific platform'''
print_normal('Setting {0} config...'.format(platform)) print_normal('Setting {0} config...'.format(platform))
config_result, err_msg = None, None config_result, err_msg = None, None
if platform == 'local': if platform == 'adl':
config_result, err_msg = set_adl_config(experiment_config, port, config_file_name)
elif platform == 'local':
config_result, err_msg = set_local_config(experiment_config, port, config_file_name) config_result, err_msg = set_local_config(experiment_config, port, config_file_name)
elif platform == 'remote': elif platform == 'remote':
config_result, err_msg = set_remote_config(experiment_config, port, config_file_name) config_result, err_msg = set_remote_config(experiment_config, port, config_file_name)
...@@ -422,9 +433,9 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res ...@@ -422,9 +433,9 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
raise Exception(ERROR_INFO % 'Rest server stopped!') raise Exception(ERROR_INFO % 'Rest server stopped!')
exit(1) exit(1)
def launch_experiment(args, experiment_config, mode, config_file_name, experiment_id=None): def launch_experiment(args, experiment_config, mode, experiment_id):
'''follow steps to start rest server and start experiment''' '''follow steps to start rest server and start experiment'''
nni_config = Config(config_file_name) nni_config = Config(experiment_id)
# check packages for tuner # check packages for tuner
package_name, module_name = None, None package_name, module_name = None, None
if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'): if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'):
...@@ -435,15 +446,15 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -435,15 +446,15 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
module_name, _ = get_builtin_module_class_name('advisors', package_name) module_name, _ = get_builtin_module_class_name('advisors', package_name)
if package_name and module_name: if package_name and module_name:
try: try:
stdout_full_path, stderr_full_path = get_log_path(config_file_name) stdout_full_path, stderr_full_path = get_log_path(experiment_id)
with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file: with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
check_call([sys.executable, '-c', 'import %s'%(module_name)], stdout=stdout_file, stderr=stderr_file) check_call([sys.executable, '-c', 'import %s'%(module_name)], stdout=stdout_file, stderr=stderr_file)
except CalledProcessError: except CalledProcessError:
print_error('some errors happen when import package %s.' %(package_name)) print_error('some errors happen when import package %s.' %(package_name))
print_log_content(config_file_name) print_log_content(experiment_id)
if package_name in INSTALLABLE_PACKAGE_META: if package_name in INSTALLABLE_PACKAGE_META:
print_error('If %s is not installed, it should be installed through '\ print_error('If %s is not installed, it should be installed through '\
'\'nnictl package install --name %s\''%(package_name, package_name)) '\'nnictl package install --name %s\'' % (package_name, package_name))
exit(1) exit(1)
log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None
log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None
...@@ -455,7 +466,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -455,7 +466,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
log_level = 'debug' log_level = 'debug'
# start rest server # start rest server
rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \ rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \
mode, config_file_name, foreground, experiment_id, log_dir, log_level) mode, experiment_id, foreground, log_dir, log_level)
nni_config.set_config('restServerPid', rest_process.pid) nni_config.set_config('restServerPid', rest_process.pid)
# Deal with annotation # Deal with annotation
if experiment_config.get('useAnnotation'): if experiment_config.get('useAnnotation'):
...@@ -481,7 +492,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -481,7 +492,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
print_normal('Successfully started Restful server!') print_normal('Successfully started Restful server!')
else: else:
print_error('Restful server start failed!') print_error('Restful server start failed!')
print_log_content(config_file_name) print_log_content(experiment_id)
try: try:
kill_command(rest_process.pid) kill_command(rest_process.pid)
except Exception: except Exception:
...@@ -490,21 +501,25 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -490,21 +501,25 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
if mode != 'view': if mode != 'view':
# set platform configuration # set platform configuration
set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port,\ set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port,\
config_file_name, rest_process) experiment_id, rest_process)
# start a new experiment # start a new experiment
print_normal('Starting experiment...') print_normal('Starting experiment...')
# save experiment information
nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time,
experiment_config['trainingServicePlatform'],
experiment_config['experimentName'], pid=rest_process.pid, logDir=log_dir)
# set debug configuration # set debug configuration
if mode != 'view' and experiment_config.get('debug') is None: if mode != 'view' and experiment_config.get('debug') is None:
experiment_config['debug'] = args.debug experiment_config['debug'] = args.debug
response = set_experiment(experiment_config, mode, args.port, config_file_name) response = set_experiment(experiment_config, mode, args.port, experiment_id)
if response: if response:
if experiment_id is None: if experiment_id is None:
experiment_id = json.loads(response.text).get('experiment_id') experiment_id = json.loads(response.text).get('experiment_id')
nni_config.set_config('experimentId', experiment_id)
else: else:
print_error('Start experiment failed!') print_error('Start experiment failed!')
print_log_content(config_file_name) print_log_content(experiment_id)
try: try:
kill_command(rest_process.pid) kill_command(rest_process.pid)
except Exception: except Exception:
...@@ -516,12 +531,6 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -516,12 +531,6 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
web_ui_url_list = get_local_urls(args.port) web_ui_url_list = get_local_urls(args.port)
nni_config.set_config('webuiUrl', web_ui_url_list) nni_config.set_config('webuiUrl', web_ui_url_list)
# save experiment information
nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time, config_file_name,
experiment_config['trainingServicePlatform'],
experiment_config['experimentName'])
print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list))) print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list)))
if mode != 'view' and args.foreground: if mode != 'view' and args.foreground:
try: try:
...@@ -534,8 +543,9 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -534,8 +543,9 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
def create_experiment(args): def create_experiment(args):
'''start a new experiment''' '''start a new experiment'''
config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8)) experiment_id = ''.join(random.sample(string.ascii_letters + string.digits, 8))
nni_config = Config(config_file_name) nni_config = Config(experiment_id)
nni_config.set_config('experimentId', experiment_id)
config_path = os.path.abspath(args.config) config_path = os.path.abspath(args.config)
if not os.path.exists(config_path): if not os.path.exists(config_path):
print_error('Please set correct config path!') print_error('Please set correct config path!')
...@@ -550,9 +560,9 @@ def create_experiment(args): ...@@ -550,9 +560,9 @@ def create_experiment(args):
nni_config.set_config('experimentConfig', experiment_config) nni_config.set_config('experimentConfig', experiment_config)
nni_config.set_config('restServerPort', args.port) nni_config.set_config('restServerPort', args.port)
try: try:
launch_experiment(args, experiment_config, 'new', config_file_name) launch_experiment(args, experiment_config, 'new', experiment_id)
except Exception as exception: except Exception as exception:
nni_config = Config(config_file_name) nni_config = Config(experiment_id)
restServerPid = nni_config.get_config('restServerPid') restServerPid = nni_config.get_config('restServerPid')
if restServerPid: if restServerPid:
kill_command(restServerPid) kill_command(restServerPid)
...@@ -579,17 +589,13 @@ def manage_stopped_experiment(args, mode): ...@@ -579,17 +589,13 @@ def manage_stopped_experiment(args, mode):
exit(1) exit(1)
experiment_id = args.id experiment_id = args.id
print_normal('{0} experiment {1}...'.format(mode, experiment_id)) print_normal('{0} experiment {1}...'.format(mode, experiment_id))
nni_config = Config(experiment_dict[experiment_id]['fileName']) nni_config = Config(experiment_id)
experiment_config = nni_config.get_config('experimentConfig') experiment_config = nni_config.get_config('experimentConfig')
experiment_id = nni_config.get_config('experimentId') nni_config.set_config('restServerPort', args.port)
new_config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8))
new_nni_config = Config(new_config_file_name)
new_nni_config.set_config('experimentConfig', experiment_config)
new_nni_config.set_config('restServerPort', args.port)
try: try:
launch_experiment(args, experiment_config, mode, new_config_file_name, experiment_id) launch_experiment(args, experiment_config, mode, experiment_id)
except Exception as exception: except Exception as exception:
nni_config = Config(new_config_file_name) nni_config = Config(experiment_id)
restServerPid = nni_config.get_config('restServerPid') restServerPid = nni_config.get_config('restServerPid')
if restServerPid: if restServerPid:
kill_command(restServerPid) kill_command(restServerPid)
......
...@@ -32,6 +32,8 @@ def parse_time(time): ...@@ -32,6 +32,8 @@ def parse_time(time):
def parse_path(experiment_config, config_path): def parse_path(experiment_config, config_path):
'''Parse path in config file''' '''Parse path in config file'''
expand_path(experiment_config, 'searchSpacePath') expand_path(experiment_config, 'searchSpacePath')
if experiment_config.get('logDir'):
expand_path(experiment_config, 'logDir')
if experiment_config.get('trial'): if experiment_config.get('trial'):
expand_path(experiment_config['trial'], 'codeDir') expand_path(experiment_config['trial'], 'codeDir')
if experiment_config['trial'].get('authFile'): if experiment_config['trial'].get('authFile'):
...@@ -65,6 +67,8 @@ def parse_path(experiment_config, config_path): ...@@ -65,6 +67,8 @@ def parse_path(experiment_config, config_path):
root_path = os.path.dirname(config_path) root_path = os.path.dirname(config_path)
if experiment_config.get('searchSpacePath'): if experiment_config.get('searchSpacePath'):
parse_relative_path(root_path, experiment_config, 'searchSpacePath') parse_relative_path(root_path, experiment_config, 'searchSpacePath')
if experiment_config.get('logDir'):
parse_relative_path(root_path, experiment_config, 'logDir')
if experiment_config.get('trial'): if experiment_config.get('trial'):
parse_relative_path(root_path, experiment_config['trial'], 'codeDir') parse_relative_path(root_path, experiment_config['trial'], 'codeDir')
if experiment_config['trial'].get('authFile'): if experiment_config['trial'].get('authFile'):
......
...@@ -10,6 +10,7 @@ import re ...@@ -10,6 +10,7 @@ import re
import shutil import shutil
import subprocess import subprocess
from functools import cmp_to_key from functools import cmp_to_key
import traceback
from datetime import datetime, timezone from datetime import datetime, timezone
from subprocess import Popen from subprocess import Popen
from pyhdfs import HdfsClient from pyhdfs import HdfsClient
...@@ -21,6 +22,7 @@ from .config_utils import Config, Experiments ...@@ -21,6 +22,7 @@ from .config_utils import Config, Experiments
from .constants import NNICTL_HOME_DIR, NNI_HOME_DIR, EXPERIMENT_INFORMATION_FORMAT, EXPERIMENT_DETAIL_FORMAT, \ from .constants import NNICTL_HOME_DIR, NNI_HOME_DIR, EXPERIMENT_INFORMATION_FORMAT, EXPERIMENT_DETAIL_FORMAT, \
EXPERIMENT_MONITOR_INFO, TRIAL_MONITOR_HEAD, TRIAL_MONITOR_CONTENT, TRIAL_MONITOR_TAIL, REST_TIME_OUT EXPERIMENT_MONITOR_INFO, TRIAL_MONITOR_HEAD, TRIAL_MONITOR_CONTENT, TRIAL_MONITOR_TAIL, REST_TIME_OUT
from .common_utils import print_normal, print_error, print_warning, detect_process, get_yml_content, generate_temp_dir from .common_utils import print_normal, print_error, print_warning, detect_process, get_yml_content, generate_temp_dir
from .common_utils import print_green
from .command_utils import check_output_command, kill_command from .command_utils import check_output_command, kill_command
from .ssh_utils import create_ssh_sftp_client, remove_remote_directory from .ssh_utils import create_ssh_sftp_client, remove_remote_directory
...@@ -28,7 +30,7 @@ def get_experiment_time(port): ...@@ -28,7 +30,7 @@ def get_experiment_time(port):
'''get the startTime and endTime of an experiment''' '''get the startTime and endTime of an experiment'''
response = rest_get(experiment_url(port), REST_TIME_OUT) response = rest_get(experiment_url(port), REST_TIME_OUT)
if response and check_response(response): if response and check_response(response):
content = convert_time_stamp_to_date(json.loads(response.text)) content = json.loads(response.text)
return content.get('startTime'), content.get('endTime') return content.get('startTime'), content.get('endTime')
return None, None return None, None
...@@ -48,20 +50,11 @@ def update_experiment(): ...@@ -48,20 +50,11 @@ def update_experiment():
for key in experiment_dict.keys(): for key in experiment_dict.keys():
if isinstance(experiment_dict[key], dict): if isinstance(experiment_dict[key], dict):
if experiment_dict[key].get('status') != 'STOPPED': if experiment_dict[key].get('status') != 'STOPPED':
nni_config = Config(experiment_dict[key]['fileName']) nni_config = Config(key)
rest_pid = nni_config.get_config('restServerPid') rest_pid = nni_config.get_config('restServerPid')
if not detect_process(rest_pid): if not detect_process(rest_pid):
experiment_config.update_experiment(key, 'status', 'STOPPED') experiment_config.update_experiment(key, 'status', 'STOPPED')
continue continue
rest_port = nni_config.get_config('restServerPort')
startTime, endTime = get_experiment_time(rest_port)
if startTime:
experiment_config.update_experiment(key, 'startTime', startTime)
if endTime:
experiment_config.update_experiment(key, 'endTime', endTime)
status = get_experiment_status(rest_port)
if status:
experiment_config.update_experiment(key, 'status', status)
def check_experiment_id(args, update=True): def check_experiment_id(args, update=True):
'''check if the id is valid '''check if the id is valid
...@@ -182,9 +175,7 @@ def get_config_filename(args): ...@@ -182,9 +175,7 @@ def get_config_filename(args):
if experiment_id is None: if experiment_id is None:
print_error('Please set correct experiment id.') print_error('Please set correct experiment id.')
exit(1) exit(1)
experiment_config = Experiments() return experiment_id
experiment_dict = experiment_config.get_all_experiments()
return experiment_dict[experiment_id]['fileName']
def get_experiment_port(args): def get_experiment_port(args):
'''get the port of experiment''' '''get the port of experiment'''
...@@ -226,11 +217,9 @@ def stop_experiment(args): ...@@ -226,11 +217,9 @@ def stop_experiment(args):
exit(1) exit(1)
experiment_id_list = parse_ids(args) experiment_id_list = parse_ids(args)
if experiment_id_list: if experiment_id_list:
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
for experiment_id in experiment_id_list: for experiment_id in experiment_id_list:
print_normal('Stopping experiment %s' % experiment_id) print_normal('Stopping experiment %s' % experiment_id)
nni_config = Config(experiment_dict[experiment_id]['fileName']) nni_config = Config(experiment_id)
rest_pid = nni_config.get_config('restServerPid') rest_pid = nni_config.get_config('restServerPid')
if rest_pid: if rest_pid:
kill_command(rest_pid) kill_command(rest_pid)
...@@ -243,9 +232,6 @@ def stop_experiment(args): ...@@ -243,9 +232,6 @@ def stop_experiment(args):
print_error(exception) print_error(exception)
nni_config.set_config('tensorboardPidList', []) nni_config.set_config('tensorboardPidList', [])
print_normal('Stop experiment success.') print_normal('Stop experiment success.')
experiment_config.update_experiment(experiment_id, 'status', 'STOPPED')
time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
experiment_config.update_experiment(experiment_id, 'endTime', str(time_now))
def trial_ls(args): def trial_ls(args):
'''List trial''' '''List trial'''
...@@ -372,6 +358,40 @@ def log_stderr(args): ...@@ -372,6 +358,40 @@ def log_stderr(args):
'''get stderr log''' '''get stderr log'''
log_internal(args, 'stderr') log_internal(args, 'stderr')
def log_trial_adl_helper(args, experiment_id):
# adljob_id format should be consistent to the one in "adlTrainingService.ts":
# const adlJobName: string = `nni-exp-${this.experimentId}-trial-${trialJobId}`.toLowerCase();
adlJobName = "nni-exp-{}-trial-{}".format(experiment_id, args.trial_id).lower()
print_warning('Note that no log will show when trial is pending or done (succeeded or failed). '
'You can retry the command.')
print_green('>>> Trial log streaming:')
try:
subprocess.run(
[
"kubectl", "logs",
"-l", "adaptdl/job=%s" % adlJobName,
"-f" # Follow the stream
], # TODO: support remaining argument, uncomment the lines in nnictl.py
) # TODO: emulate tee behaviors, not necessary tho.
except KeyboardInterrupt:
pass
except Exception:
print_error('Error! Please check kubectl:')
traceback.print_exc()
exit(1)
finally:
print_green('<<< [adlJobName:%s]' % adlJobName)
nni_manager_collection_path = os.path.expanduser('~/nni-experiments/%s/trials/%s/stdout_log_collection.log' %
(experiment_id, args.trial_id))
print_green('>>> (Optional) How to persist the complete trial log locally:')
print(
'Please ensure `logCollection: http` '
'exists in the experiment configuration yaml. '
'After trial done, you can check it from the file below: \n %s'
% nni_manager_collection_path
)
def log_trial(args): def log_trial(args):
''''get trial log path''' ''''get trial log path'''
trial_id_path_dict = {} trial_id_path_dict = {}
...@@ -388,16 +408,24 @@ def log_trial(args): ...@@ -388,16 +408,24 @@ def log_trial(args):
if response and check_response(response): if response and check_response(response):
content = json.loads(response.text) content = json.loads(response.text)
for trial in content: for trial in content:
trial_id_list.append(trial.get('id')) trial_id_list.append(trial.get('trialJobId'))
if trial.get('logPath'): if trial.get('logPath'):
trial_id_path_dict[trial.get('id')] = trial['logPath'] trial_id_path_dict[trial.get('trialJobId')] = trial['logPath']
else: else:
print_error('Restful server is not running...') print_error('Restful server is not running...')
exit(1) exit(1)
is_adl = nni_config.get_config('experimentConfig').get('trainingServicePlatform') == 'adl'
if is_adl and not args.trial_id:
print_error('Trial ID is required to retrieve the log for adl. Please specify it with "--trial_id".')
exit(1)
if args.trial_id: if args.trial_id:
if args.trial_id not in trial_id_list: if args.trial_id not in trial_id_list:
print_error('Trial id {0} not correct, please check your command!'.format(args.trial_id)) print_error('Trial id {0} not correct, please check your command!'.format(args.trial_id))
exit(1) exit(1)
if is_adl:
log_trial_adl_helper(args, nni_config.get_config('experimentId'))
# adl has its own way to log trial, and it thus returns right after the helper returns
return
if trial_id_path_dict.get(args.trial_id): if trial_id_path_dict.get(args.trial_id):
print_normal('id:' + args.trial_id + ' path:' + trial_id_path_dict[args.trial_id]) print_normal('id:' + args.trial_id + ' path:' + trial_id_path_dict[args.trial_id])
else: else:
...@@ -429,7 +457,7 @@ def webui_nas(args): ...@@ -429,7 +457,7 @@ def webui_nas(args):
if sys.platform == 'win32': if sys.platform == 'win32':
node_command = os.path.join(entry_dir, 'node.exe') node_command = os.path.join(entry_dir, 'node.exe')
else: else:
node_command = 'node' node_command = os.path.join(entry_dir, 'node')
cmds = [node_command, '--max-old-space-size=4096', entry_file, '--port', str(args.port), '--logdir', args.logdir] cmds = [node_command, '--max-old-space-size=4096', entry_file, '--port', str(args.port), '--logdir', args.logdir]
subprocess.run(cmds, cwd=entry_dir) subprocess.run(cmds, cwd=entry_dir)
except KeyboardInterrupt: except KeyboardInterrupt:
...@@ -509,7 +537,7 @@ def experiment_clean(args): ...@@ -509,7 +537,7 @@ def experiment_clean(args):
else: else:
break break
for experiment_id in experiment_id_list: for experiment_id in experiment_id_list:
nni_config = Config(experiment_dict[experiment_id]['fileName']) nni_config = Config(experiment_id)
platform = nni_config.get_config('experimentConfig').get('trainingServicePlatform') platform = nni_config.get_config('experimentConfig').get('trainingServicePlatform')
experiment_id = nni_config.get_config('experimentId') experiment_id = nni_config.get_config('experimentId')
if platform == 'remote': if platform == 'remote':
...@@ -624,18 +652,15 @@ def experiment_list(args): ...@@ -624,18 +652,15 @@ def experiment_list(args):
experiment_dict[key]['status'], experiment_dict[key]['status'],
experiment_dict[key]['port'], experiment_dict[key]['port'],
experiment_dict[key].get('platform'), experiment_dict[key].get('platform'),
experiment_dict[key]['startTime'], time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'],
experiment_dict[key]['endTime']) time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['endTime'] / 1000)) if isinstance(experiment_dict[key]['endTime'], int) else experiment_dict[key]['endTime'])
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information) print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
return experiment_id_list return experiment_id_list
def get_time_interval(time1, time2): def get_time_interval(time1, time2):
'''get the interval of two times''' '''get the interval of two times'''
try: try:
#convert time to timestamp seconds = int((time2 - time1) / 1000)
time1 = time.mktime(time.strptime(time1, '%Y/%m/%d %H:%M:%S'))
time2 = time.mktime(time.strptime(time2, '%Y/%m/%d %H:%M:%S'))
seconds = (datetime.fromtimestamp(time2) - datetime.fromtimestamp(time1)).seconds
#convert seconds to day:hour:minute:second #convert seconds to day:hour:minute:second
days = seconds / 86400 days = seconds / 86400
seconds %= 86400 seconds %= 86400
...@@ -664,8 +689,8 @@ def show_experiment_info(): ...@@ -664,8 +689,8 @@ def show_experiment_info():
return return
for key in experiment_id_list: for key in experiment_id_list:
print(EXPERIMENT_MONITOR_INFO % (key, experiment_dict[key]['status'], experiment_dict[key]['port'], \ print(EXPERIMENT_MONITOR_INFO % (key, experiment_dict[key]['status'], experiment_dict[key]['port'], \
experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], \ experiment_dict[key].get('platform'), time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'], \
get_time_interval(experiment_dict[key]['startTime'], experiment_dict[key]['endTime']))) get_time_interval(experiment_dict[key]['startTime'], experiment_dict[key]['endTime'])))
print(TRIAL_MONITOR_HEAD) print(TRIAL_MONITOR_HEAD)
running, response = check_rest_server_quick(experiment_dict[key]['port']) running, response = check_rest_server_quick(experiment_dict[key]['port'])
if running: if running:
...@@ -674,7 +699,7 @@ def show_experiment_info(): ...@@ -674,7 +699,7 @@ def show_experiment_info():
content = json.loads(response.text) content = json.loads(response.text)
for index, value in enumerate(content): for index, value in enumerate(content):
content[index] = convert_time_stamp_to_date(value) content[index] = convert_time_stamp_to_date(value)
print(TRIAL_MONITOR_CONTENT % (content[index].get('id'), content[index].get('startTime'), \ print(TRIAL_MONITOR_CONTENT % (content[index].get('trialJobId'), content[index].get('startTime'), \
content[index].get('endTime'), content[index].get('status'))) content[index].get('endTime'), content[index].get('status')))
print(TRIAL_MONITOR_TAIL) print(TRIAL_MONITOR_TAIL)
...@@ -747,7 +772,7 @@ def export_trials_data(args): ...@@ -747,7 +772,7 @@ def export_trials_data(args):
return return
intermediate_results = groupby_trial_id(json.loads(intermediate_results_response.text)) intermediate_results = groupby_trial_id(json.loads(intermediate_results_response.text))
for record in content: for record in content:
record['intermediate'] = intermediate_results[record['id']] record['intermediate'] = intermediate_results[record['trialJobId']]
if args.type == 'json': if args.type == 'json':
with open(args.path, 'w') as file: with open(args.path, 'w') as file:
file.write(json.dumps(content)) file.write(json.dumps(content))
...@@ -759,9 +784,9 @@ def export_trials_data(args): ...@@ -759,9 +784,9 @@ def export_trials_data(args):
formated_record['intermediate'] = '[' + ','.join(record['intermediate']) + ']' formated_record['intermediate'] = '[' + ','.join(record['intermediate']) + ']'
record_value = json.loads(record['value']) record_value = json.loads(record['value'])
if not isinstance(record_value, (float, int)): if not isinstance(record_value, (float, int)):
formated_record.update({**record['parameter'], **record_value, **{'id': record['id']}}) formated_record.update({**record['parameter'], **record_value, **{'trialJobId': record['trialJobId']}})
else: else:
formated_record.update({**record['parameter'], **{'reward': record_value, 'id': record['id']}}) formated_record.update({**record['parameter'], **{'reward': record_value, 'trialJobId': record['trialJobId']}})
trial_records.append(formated_record) trial_records.append(formated_record)
if not trial_records: if not trial_records:
print_error('No trial results collected! Please check your trial log...') print_error('No trial results collected! Please check your trial log...')
...@@ -806,7 +831,7 @@ def save_experiment(args): ...@@ -806,7 +831,7 @@ def save_experiment(args):
print_error('Can only save stopped experiment!') print_error('Can only save stopped experiment!')
exit(1) exit(1)
print_normal('Saving...') print_normal('Saving...')
nni_config = Config(experiment_dict[args.id]['fileName']) nni_config = Config(args.id)
logDir = os.path.join(NNI_HOME_DIR, args.id) logDir = os.path.join(NNI_HOME_DIR, args.id)
if nni_config.get_config('logDir'): if nni_config.get_config('logDir'):
logDir = os.path.join(nni_config.get_config('logDir'), args.id) logDir = os.path.join(nni_config.get_config('logDir'), args.id)
...@@ -829,8 +854,8 @@ def save_experiment(args): ...@@ -829,8 +854,8 @@ def save_experiment(args):
except IOError: except IOError:
print_error('Write file to %s failed!' % os.path.join(temp_nnictl_dir, '.experiment')) print_error('Write file to %s failed!' % os.path.join(temp_nnictl_dir, '.experiment'))
exit(1) exit(1)
nnictl_config_dir = os.path.join(NNICTL_HOME_DIR, experiment_dict[args.id]['fileName']) nnictl_config_dir = os.path.join(NNICTL_HOME_DIR, args.id)
shutil.copytree(nnictl_config_dir, os.path.join(temp_nnictl_dir, experiment_dict[args.id]['fileName'])) shutil.copytree(nnictl_config_dir, os.path.join(temp_nnictl_dir, args.id))
# Step3. Copy code dir # Step3. Copy code dir
if args.saveCodeDir: if args.saveCodeDir:
...@@ -903,20 +928,20 @@ def load_experiment(args): ...@@ -903,20 +928,20 @@ def load_experiment(args):
print_error('Invalid: experiment id already exist!') print_error('Invalid: experiment id already exist!')
shutil.rmtree(temp_root_dir) shutil.rmtree(temp_root_dir)
exit(1) exit(1)
if not os.path.exists(os.path.join(nnictl_temp_dir, experiment_metadata.get('fileName'))): if not os.path.exists(os.path.join(nnictl_temp_dir, experiment_id)):
print_error('Invalid: experiment metadata does not exist!') print_error('Invalid: experiment metadata does not exist!')
shutil.rmtree(temp_root_dir) shutil.rmtree(temp_root_dir)
exit(1) exit(1)
# Step2. Copy nnictl metadata # Step2. Copy nnictl metadata
src_path = os.path.join(nnictl_temp_dir, experiment_metadata.get('fileName')) src_path = os.path.join(nnictl_temp_dir, experiment_id)
dest_path = os.path.join(NNICTL_HOME_DIR, experiment_metadata.get('fileName')) dest_path = os.path.join(NNICTL_HOME_DIR, experiment_id)
if os.path.exists(dest_path): if os.path.exists(dest_path):
shutil.rmtree(dest_path) shutil.rmtree(dest_path)
shutil.copytree(src_path, dest_path) shutil.copytree(src_path, dest_path)
# Step3. Copy experiment data # Step3. Copy experiment data
nni_config = Config(experiment_metadata.get('fileName')) nni_config = Config(experiment_id)
nnictl_exp_config = nni_config.get_config('experimentConfig') nnictl_exp_config = nni_config.get_config('experimentConfig')
if args.logDir: if args.logDir:
logDir = args.logDir logDir = args.logDir
...@@ -983,13 +1008,15 @@ def load_experiment(args): ...@@ -983,13 +1008,15 @@ def load_experiment(args):
experiment_config.add_experiment(experiment_id, experiment_config.add_experiment(experiment_id,
experiment_metadata.get('port'), experiment_metadata.get('port'),
experiment_metadata.get('startTime'), experiment_metadata.get('startTime'),
experiment_metadata.get('fileName'),
experiment_metadata.get('platform'), experiment_metadata.get('platform'),
experiment_metadata.get('experimentName'), experiment_metadata.get('experimentName'),
experiment_metadata.get('endTime'), experiment_metadata.get('endTime'),
experiment_metadata.get('status')) experiment_metadata.get('status'),
experiment_metadata.get('tag'),
experiment_metadata.get('pid'),
experiment_metadata.get('webUrl'),
experiment_metadata.get('logDir'))
print_normal('Load experiment %s succsss!' % experiment_id) print_normal('Load experiment %s succsss!' % experiment_id)
# Step6. Cleanup temp data # Step6. Cleanup temp data
shutil.rmtree(temp_root_dir) shutil.rmtree(temp_root_dir)
...@@ -10,8 +10,8 @@ from .rest_utils import rest_get, check_rest_server_quick, check_response ...@@ -10,8 +10,8 @@ from .rest_utils import rest_get, check_rest_server_quick, check_response
from .config_utils import Config, Experiments from .config_utils import Config, Experiments
from .url_utils import trial_jobs_url, get_local_urls from .url_utils import trial_jobs_url, get_local_urls
from .constants import REST_TIME_OUT from .constants import REST_TIME_OUT
from .common_utils import print_normal, print_error, print_green, detect_process, detect_port, check_tensorboard_version from .common_utils import print_normal, print_warning, print_error, print_green, detect_process, detect_port, check_tensorboard_version
from .nnictl_utils import check_experiment_id, check_experiment_id from .nnictl_utils import check_experiment_id
from .ssh_utils import create_ssh_sftp_client, copy_remote_directory_to_local from .ssh_utils import create_ssh_sftp_client, copy_remote_directory_to_local
def parse_log_path(args, trial_content): def parse_log_path(args, trial_content):
...@@ -19,7 +19,7 @@ def parse_log_path(args, trial_content): ...@@ -19,7 +19,7 @@ def parse_log_path(args, trial_content):
path_list = [] path_list = []
host_list = [] host_list = []
for trial in trial_content: for trial in trial_content:
if args.trial_id and args.trial_id != 'all' and trial.get('id') != args.trial_id: if args.trial_id and args.trial_id != 'all' and trial.get('trialJobId') != args.trial_id:
continue continue
pattern = r'(?P<head>.+)://(?P<host>.+):(?P<path>.*)' pattern = r'(?P<head>.+)://(?P<host>.+):(?P<path>.*)'
match = re.search(pattern, trial['logPath']) match = re.search(pattern, trial['logPath'])
...@@ -40,7 +40,7 @@ def copy_data_from_remote(args, nni_config, trial_content, path_list, host_list, ...@@ -40,7 +40,7 @@ def copy_data_from_remote(args, nni_config, trial_content, path_list, host_list,
machine_dict[machine['ip']] = {'port': machine['port'], 'passwd': machine['passwd'], 'username': machine['username'], machine_dict[machine['ip']] = {'port': machine['port'], 'passwd': machine['passwd'], 'username': machine['username'],
'sshKeyPath': machine.get('sshKeyPath'), 'passphrase': machine.get('passphrase')} 'sshKeyPath': machine.get('sshKeyPath'), 'passphrase': machine.get('passphrase')}
for index, host in enumerate(host_list): for index, host in enumerate(host_list):
local_path = os.path.join(temp_nni_path, trial_content[index].get('id')) local_path = os.path.join(temp_nni_path, trial_content[index].get('trialJobId'))
local_path_list.append(local_path) local_path_list.append(local_path)
print_normal('Copying log data from %s to %s' % (host + ':' + path_list[index], local_path)) print_normal('Copying log data from %s to %s' % (host + ':' + path_list[index], local_path))
sftp = create_ssh_sftp_client(host, machine_dict[host]['port'], machine_dict[host]['username'], machine_dict[host]['passwd'], sftp = create_ssh_sftp_client(host, machine_dict[host]['port'], machine_dict[host]['username'], machine_dict[host]['passwd'],
...@@ -95,8 +95,7 @@ def stop_tensorboard(args): ...@@ -95,8 +95,7 @@ def stop_tensorboard(args):
experiment_id = check_experiment_id(args) experiment_id = check_experiment_id(args)
experiment_config = Experiments() experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiment_dict = experiment_config.get_all_experiments()
config_file_name = experiment_dict[experiment_id]['fileName'] nni_config = Config(experiment_id)
nni_config = Config(config_file_name)
tensorboard_pid_list = nni_config.get_config('tensorboardPidList') tensorboard_pid_list = nni_config.get_config('tensorboardPidList')
if tensorboard_pid_list: if tensorboard_pid_list:
for tensorboard_pid in tensorboard_pid_list: for tensorboard_pid in tensorboard_pid_list:
...@@ -110,14 +109,36 @@ def stop_tensorboard(args): ...@@ -110,14 +109,36 @@ def stop_tensorboard(args):
else: else:
print_error('No tensorboard configuration!') print_error('No tensorboard configuration!')
def adl_tensorboard_helper(args):
'''start tensorboard on adl'''
import subprocess
if args.trial_id is not None:
print_warning('Tensorboard on adl platform will show all trials. No trial ids needed.')
cmd = "kubectl port-forward --address 0.0.0.0 deployment/{} {}:{}".format(
"adaptdl-tensorboard" + "-" + args.id.lower(),
args.port,
6006
)
print_green('Tensorboard is accessible at 0.0.0.0:{port} or localhost:{port}'.format(port=args.port))
subprocess.run(args=cmd, shell=True)
def start_tensorboard(args): def start_tensorboard(args):
'''start tensorboard''' '''start tensorboard'''
experiment_id = check_experiment_id(args) experiment_id = check_experiment_id(args)
if not experiment_id:
return
if args.id is None:
args.id = experiment_id
experiment_config = Experiments() experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiment_dict = experiment_config.get_all_experiments()
if experiment_dict[args.id]["status"] == "STOPPED":
print_error("Experiment {} is stopped...".format(args.id))
return
config_file_name = experiment_dict[experiment_id]['fileName'] config_file_name = experiment_dict[experiment_id]['fileName']
nni_config = Config(config_file_name) nni_config = Config(args.id)
if nni_config.get_config('experimentConfig').get('trainingServicePlatform') == 'adl':
adl_tensorboard_helper(args)
return
rest_port = nni_config.get_config('restServerPort') rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid') rest_pid = nni_config.get_config('restServerPid')
if not detect_process(rest_pid): if not detect_process(rest_pid):
...@@ -144,4 +165,4 @@ def start_tensorboard(args): ...@@ -144,4 +165,4 @@ def start_tensorboard(args):
os.makedirs(temp_nni_path, exist_ok=True) os.makedirs(temp_nni_path, exist_ok=True)
path_list = get_path_list(args, nni_config, trial_content, temp_nni_path) path_list = get_path_list(args, nni_config, trial_content, temp_nni_path)
start_tensorboard_process(args, nni_config, path_list, temp_nni_path) start_tensorboard_process(args, nni_config, path_list, temp_nni_path)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment