"...layers/attention_layers_fp8/BaseAttentionFP8Layer.h" did not exist on "720fc533da804ac3f46ee938864403e51fcd9fa7"
Unverified Commit 6c3148c7 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #239 from microsoft/master

merge master
parents 0fb78620 a2e524d3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import sys
import os
import logging
import pickle
import shutil
import random
import math
import time
import datetime
import argparse
import distutils.util
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
import torch.nn.functional as Func
from model import Model
from nni.nas.pytorch.fixed import apply_fixed_architecture
from dataloader import read_data_sst
logger = logging.getLogger("nni.textnas")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--reset_output_dir",
type=distutils.util.strtobool,
default=True,
help="Whether to clean the output dir if existed. (default: %(default)s)")
parser.add_argument(
"--child_fixed_arc",
type=str,
required=True,
help="Architecture json file. (default: %(default)s)")
parser.add_argument(
"--data_path",
type=str,
default="data",
help="Directory containing the dataset and embedding file. (default: %(default)s)")
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="The output directory. (default: %(default)s)")
parser.add_argument(
"--child_lr_decay_scheme",
type=str,
default="cosine",
help="Learning rate annealing strategy, only 'cosine' supported. (default: %(default)s)")
parser.add_argument(
"--batch_size",
type=int,
default=128,
help="Number of samples each batch for training. (default: %(default)s)")
parser.add_argument(
"--eval_batch_size",
type=int,
default=128,
help="Number of samples each batch for evaluation. (default: %(default)s)")
parser.add_argument(
"--class_num",
type=int,
default=5,
help="The number of categories. (default: %(default)s)")
parser.add_argument(
"--global_seed",
type=int,
default=1234,
help="Seed for reproduction. (default: %(default)s)")
parser.add_argument(
"--max_input_length",
type=int,
default=64,
help="The maximum length of the sentence. (default: %(default)s)")
parser.add_argument(
"--num_epochs",
type=int,
default=10,
help="The number of training epochs. (default: %(default)s)")
parser.add_argument(
"--child_num_layers",
type=int,
default=24,
help="The layer number of the architecture. (default: %(default)s)")
parser.add_argument(
"--child_out_filters",
type=int,
default=256,
help="The dimension of hidden states. (default: %(default)s)")
parser.add_argument(
"--child_out_filters_scale",
type=int,
default=1,
help="The scale of hidden state dimension. (default: %(default)s)")
parser.add_argument(
"--child_lr_T_0",
type=int,
default=10,
help="The length of one cycle. (default: %(default)s)")
parser.add_argument(
"--child_lr_T_mul",
type=int,
default=2,
help="The multiplication factor per cycle. (default: %(default)s)")
parser.add_argument(
"--min_count",
type=int,
default=1,
help="The threshold to cut off low frequent words. (default: %(default)s)")
parser.add_argument(
"--train_ratio",
type=float,
default=1.0,
help="The sample ratio for the training set. (default: %(default)s)")
parser.add_argument(
"--valid_ratio",
type=float,
default=1.0,
help="The sample ratio for the dev set. (default: %(default)s)")
parser.add_argument(
"--child_grad_bound",
type=float,
default=5.0,
help="The threshold for gradient clipping. (default: %(default)s)")
parser.add_argument(
"--child_lr",
type=float,
default=0.02,
help="The initial learning rate. (default: %(default)s)")
parser.add_argument(
"--cnn_keep_prob",
type=float,
default=0.8,
help="Keep prob for cnn layer. (default: %(default)s)")
parser.add_argument(
"--final_output_keep_prob",
type=float,
default=1.0,
help="Keep prob for the last output layer. (default: %(default)s)")
parser.add_argument(
"--lstm_out_keep_prob",
type=float,
default=0.8,
help="Keep prob for the RNN layer. (default: %(default)s)")
parser.add_argument(
"--embed_keep_prob",
type=float,
default=0.8,
help="Keep prob for the embedding layer. (default: %(default)s)")
parser.add_argument(
"--attention_keep_prob",
type=float,
default=0.8,
help="Keep prob for the self-attention layer. (default: %(default)s)")
parser.add_argument(
"--child_l2_reg",
type=float,
default=3e-6,
help="Weight decay factor. (default: %(default)s)")
parser.add_argument(
"--child_lr_max",
type=float,
default=0.002,
help="The max learning rate. (default: %(default)s)")
parser.add_argument(
"--child_lr_min",
type=float,
default=0.001,
help="The min learning rate. (default: %(default)s)")
parser.add_argument(
"--child_optim_algo",
type=str,
default="adam",
help="Optimization algorithm. (default: %(default)s)")
parser.add_argument(
"--checkpoint_dir",
type=str,
default="best_checkpoint",
help="Path for saved checkpoints. (default: %(default)s)")
parser.add_argument(
"--output_type",
type=str,
default="avg",
help="Opertor type for the time steps reduction. (default: %(default)s)")
parser.add_argument(
"--multi_path",
type=distutils.util.strtobool,
default=False,
help="Search for multiple path in the architecture. (default: %(default)s)")
parser.add_argument(
"--is_binary",
type=distutils.util.strtobool,
default=False,
help="Binary label for sst dataset. (default: %(default)s)")
parser.add_argument(
"--is_cuda",
type=distutils.util.strtobool,
default=True,
help="Specify the device type. (default: %(default)s)")
parser.add_argument(
"--is_mask",
type=distutils.util.strtobool,
default=True,
help="Apply mask. (default: %(default)s)")
parser.add_argument(
"--fixed_seed",
type=distutils.util.strtobool,
default=True,
help="Fix the seed. (default: %(default)s)")
parser.add_argument(
"--load_checkpoint",
type=distutils.util.strtobool,
default=False,
help="Wether to load checkpoint. (default: %(default)s)")
parser.add_argument(
"--log_every",
type=int,
default=50,
help="How many steps to log. (default: %(default)s)")
parser.add_argument(
"--eval_every_epochs",
type=int,
default=1,
help="How many epochs to eval. (default: %(default)s)")
global FLAGS
FLAGS = parser.parse_args()
def set_random_seed(seed):
logger.info("set random seed for data reading: {}".format(seed))
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
if FLAGS.is_cuda:
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def get_model(embedding, num_layers):
logger.info("num layers: {0}".format(num_layers))
assert FLAGS.child_fixed_arc is not None, "Architecture should be provided."
child_model = Model(
embedding=embedding,
hidden_units=FLAGS.child_out_filters_scale * FLAGS.child_out_filters,
num_layers=num_layers,
num_classes=FLAGS.class_num,
choose_from_k=5 if FLAGS.multi_path else 1,
lstm_keep_prob=FLAGS.lstm_out_keep_prob,
cnn_keep_prob=FLAGS.cnn_keep_prob,
att_keep_prob=FLAGS.attention_keep_prob,
att_mask=FLAGS.is_mask,
embed_keep_prob=FLAGS.embed_keep_prob,
final_output_keep_prob=FLAGS.final_output_keep_prob,
global_pool=FLAGS.output_type)
apply_fixed_architecture(child_model, FLAGS.child_fixed_arc)
return child_model
def eval_once(child_model, device, eval_set, criterion, valid_dataloader=None, test_dataloader=None):
if eval_set == "test":
assert test_dataloader is not None
dataloader = test_dataloader
elif eval_set == "valid":
assert valid_dataloader is not None
dataloader = valid_dataloader
else:
raise NotImplementedError("Unknown eval_set '{}'".format(eval_set))
tot_acc = 0
tot = 0
losses = []
with torch.no_grad(): # save memory
for batch in dataloader:
(sent_ids, mask), labels = batch
sent_ids = sent_ids.to(device, non_blocking=True)
mask = mask.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
logits = child_model((sent_ids, mask)) # run
loss = criterion(logits, labels.long())
loss = loss.mean()
preds = logits.argmax(dim=1).long()
acc = torch.eq(preds, labels.long()).long().sum().item()
losses.append(loss)
tot_acc += acc
tot += len(labels)
losses = torch.tensor(losses)
loss = losses.mean()
if tot > 0:
final_acc = float(tot_acc) / tot
else:
final_acc = 0
logger.info("Error in calculating final_acc")
return final_acc, loss
def print_user_flags(FLAGS, line_limit=80):
log_strings = "\n" + "-" * line_limit + "\n"
for flag_name in sorted(vars(FLAGS)):
value = "{}".format(getattr(FLAGS, flag_name))
log_string = flag_name
log_string += "." * (line_limit - len(flag_name) - len(value))
log_string += value
log_strings = log_strings + log_string
log_strings = log_strings + "\n"
log_strings += "-" * line_limit
logger.info(log_strings)
def count_model_params(trainable_params):
num_vars = 0
for var in trainable_params:
num_vars += np.prod([dim for dim in var.size()])
return num_vars
def update_lr(
optimizer,
epoch,
l2_reg=1e-4,
lr_warmup_val=None,
lr_init=0.1,
lr_decay_scheme="cosine",
lr_max=0.002,
lr_min=0.000000001,
lr_T_0=4,
lr_T_mul=1,
sync_replicas=False,
num_aggregate=None,
num_replicas=None):
if lr_decay_scheme == "cosine":
assert lr_max is not None, "Need lr_max to use lr_cosine"
assert lr_min is not None, "Need lr_min to use lr_cosine"
assert lr_T_0 is not None, "Need lr_T_0 to use lr_cosine"
assert lr_T_mul is not None, "Need lr_T_mul to use lr_cosine"
T_i = lr_T_0
t_epoch = epoch
last_reset = 0
while True:
t_epoch -= T_i
if t_epoch < 0:
break
last_reset += T_i
T_i *= lr_T_mul
T_curr = epoch - last_reset
def _update():
rate = T_curr / T_i * 3.1415926
lr = lr_min + 0.5 * (lr_max - lr_min) * (1.0 + math.cos(rate))
return lr
learning_rate = _update()
else:
raise ValueError("Unknown learning rate decay scheme {}".format(lr_decay_scheme))
#update lr in optimizer
for params_group in optimizer.param_groups:
params_group['lr'] = learning_rate
return learning_rate
def train(device, data_path, output_dir, num_layers):
logger.info("Build dataloader")
train_dataset, valid_dataset, test_dataset, embedding = \
read_data_sst(data_path,
FLAGS.max_input_length,
FLAGS.min_count,
train_ratio=FLAGS.train_ratio,
valid_ratio=FLAGS.valid_ratio,
is_binary=FLAGS.is_binary)
train_dataloader = DataLoader(train_dataset, batch_size=FLAGS.batch_size, shuffle=True, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.eval_batch_size, pin_memory=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=FLAGS.eval_batch_size, pin_memory=True)
logger.info("Build model")
child_model = get_model(embedding, num_layers)
logger.info("Finish build model")
#for name, var in child_model.named_parameters():
# logger.info(name, var.size(), var.requires_grad) # output all params
num_vars = count_model_params(child_model.parameters())
logger.info("Model has {} params".format(num_vars))
for m in child_model.modules(): # initializer
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.xavier_uniform_(m.weight)
criterion = nn.CrossEntropyLoss()
# get optimizer
if FLAGS.child_optim_algo == "adam":
optimizer = optim.Adam(child_model.parameters(), eps=1e-3, weight_decay=FLAGS.child_l2_reg) # with L2
else:
raise ValueError("Unknown optim_algo {}".format(FLAGS.child_optim_algo))
child_model.to(device)
criterion.to(device)
logger.info("Start training")
start_time = time.time()
step = 0
# save path
model_save_path = os.path.join(FLAGS.output_dir, "model.pth")
best_model_save_path = os.path.join(FLAGS.output_dir, "best_model.pth")
best_acc = 0
start_epoch = 0
if FLAGS.load_checkpoint:
if os.path.isfile(model_save_path):
checkpoint = torch.load(model_save_path, map_location = torch.device('cpu'))
step = checkpoint['step']
start_epoch = checkpoint['epoch']
child_model.load_state_dict(checkpoint['child_model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
for epoch in range(start_epoch, FLAGS.num_epochs):
lr = update_lr(optimizer,
epoch,
l2_reg=FLAGS.child_l2_reg,
lr_warmup_val=None,
lr_init=FLAGS.child_lr,
lr_decay_scheme=FLAGS.child_lr_decay_scheme,
lr_max=FLAGS.child_lr_max,
lr_min=FLAGS.child_lr_min,
lr_T_0=FLAGS.child_lr_T_0,
lr_T_mul=FLAGS.child_lr_T_mul)
child_model.train()
for batch in train_dataloader:
(sent_ids, mask), labels = batch
sent_ids = sent_ids.to(device, non_blocking=True)
mask = mask.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
step += 1
logits = child_model((sent_ids, mask)) # run
loss = criterion(logits, labels.long())
loss = loss.mean()
preds = logits.argmax(dim=1).long()
acc = torch.eq(preds, labels.long()).long().sum().item()
optimizer.zero_grad()
loss.backward()
grad_norm = 0
trainable_params = child_model.parameters()
assert FLAGS.child_grad_bound is not None, "Need grad_bound to clip gradients."
# compute the gradient norm value
grad_norm = nn.utils.clip_grad_norm_(trainable_params, 99999999)
for param in trainable_params:
nn.utils.clip_grad_norm_(param, FLAGS.child_grad_bound) # clip grad
optimizer.step()
if step % FLAGS.log_every == 0:
curr_time = time.time()
log_string = ""
log_string += "epoch={:<6d}".format(epoch)
log_string += "ch_step={:<6d}".format(step)
log_string += " loss={:<8.6f}".format(loss)
log_string += " lr={:<8.4f}".format(lr)
log_string += " |g|={:<8.4f}".format(grad_norm)
log_string += " tr_acc={:<3d}/{:>3d}".format(acc, logits.size()[0])
log_string += " mins={:<10.2f}".format(float(curr_time - start_time) / 60)
logger.info(log_string)
epoch += 1
save_state = {
'step' : step,
'epoch' : epoch,
'child_model_state_dict' : child_model.state_dict(),
'optimizer_state_dict' : optimizer.state_dict()}
torch.save(save_state, model_save_path)
child_model.eval()
logger.info("Epoch {}: Eval".format(epoch))
eval_acc, eval_loss = eval_once(child_model, device, "test", criterion, test_dataloader=test_dataloader)
logger.info("ch_step={} {}_accuracy={:<6.4f} {}_loss={:<6.4f}".format(step, "test", eval_acc, "test", eval_loss))
if eval_acc > best_acc:
best_acc = eval_acc
logger.info("Save best model")
save_state = {
'step' : step,
'epoch' : epoch,
'child_model_state_dict' : child_model.state_dict(),
'optimizer_state_dict' : optimizer.state_dict()}
torch.save(save_state, best_model_save_path)
return eval_acc
def main():
parse_args()
if not os.path.isdir(FLAGS.output_dir):
logger.info("Path {} does not exist. Creating.".format(FLAGS.output_dir))
os.makedirs(FLAGS.output_dir)
elif FLAGS.reset_output_dir:
logger.info("Path {} exists. Remove and remake.".format(FLAGS.output_dir))
shutil.rmtree(FLAGS.output_dir, ignore_errors=True)
os.makedirs(FLAGS.output_dir)
print_user_flags(FLAGS)
if FLAGS.fixed_seed:
set_random_seed(FLAGS.global_seed)
device = torch.device("cuda" if FLAGS.is_cuda else "cpu")
train(device, FLAGS.data_path, FLAGS.output_dir, FLAGS.child_num_layers)
if __name__ == "__main__":
main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
export PYTHONPATH="$(pwd)"
export CUDA_VISIBLE_DEVICES=0
python -u retrain.py \
--train_ratio=1.0 \
--valid_ratio=1.0 \
--min_count=1 \
--is_mask=True \
--is_binary=True \
--child_lr_decay_scheme="cosine" \
--data_path="data" \
--class_num=2 \
--child_optim_algo="adam" \
--output_dir="output_sst2" \
--global_seed=1234 \
--max_input_length=64 \
--batch_size=128 \
--eval_batch_size=128 \
--num_epochs=10 \
--log_every=50 \
--eval_every_epochs=1 \
--child_num_layers=24 \
--child_out_filters=256 \
--child_l2_reg=1e-6 \
--cnn_keep_prob=0.8 \
--final_output_keep_prob=1.0 \
--embed_keep_prob=0.8 \
--lstm_out_keep_prob=0.8 \
--attention_keep_prob=0.8 \
--child_lr=0.02 \
--child_lr_max=0.002 \
--child_lr_min=5e-6 \
--child_lr_T_0=10 \
--child_lr_T_mul=2 \
--multi_path=True \
--child_fixed_arc="./checkpoints/architecture_00.json" \
--fixed_seed=True \
"$@"
authorName: default
experimentName: example_mnist_pbt_tuner_pytorch
trialConcurrency: 3
maxExecDuration: 2h
maxTrialNum: 100
#choice: local, remote, pai
trainingServicePlatform: local
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
tuner:
# codeDir: ~/nni/src/sdk/pynni/nni/pbt_tuner
# classFileName: pbt_tuner.py
# className: PBTTuner
builtinTunerName: PBTTuner
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
trial:
command: python3 mnist.py
codeDir: .
gpuNum: 1
import argparse
import logging
import os
import nni
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
logger = logging.getLogger('mnist_pbt_tuner_pytorch_AutoML')
class Net(nn.Module):
def __init__(self, hidden_size):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, hidden_size)
self.fc2 = nn.Linear(hidden_size, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args['log_interval'] == 0:
logger.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, reduction='sum').item()
# get the index of the max log-probability
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
logger.info('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset), accuracy))
return accuracy
def save_checkpoint(model, checkpoint_path):
torch.save(model.state_dict(), checkpoint_path)
def load_checkpoint(checkpoint_path):
model_state_dict = torch.load(checkpoint_path)
return model_state_dict
def main(args):
use_cuda = not args['no_cuda'] and torch.cuda.is_available()
torch.manual_seed(args['seed'])
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
data_dir = os.path.join(args['data_dir'], nni.get_trial_id())
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dir, train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args['batch_size'], shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(data_dir, train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=1000, shuffle=True, **kwargs)
hidden_size = args['hidden_size']
model = Net(hidden_size=hidden_size).to(device)
save_checkpoint_dir = args['save_checkpoint_dir']
save_checkpoint_path = os.path.join(save_checkpoint_dir, 'model.pth')
load_checkpoint_path = os.path.join(args['load_checkpoint_dir'], 'model.pth')
if os.path.isfile(load_checkpoint_path):
model_state_dict = load_checkpoint(load_checkpoint_path)
logger.info("test : " + load_checkpoint_path)
logger.info(type(model_state_dict))
model.load_state_dict(model_state_dict)
optimizer = optim.SGD(model.parameters(), lr=args['lr'],
momentum=args['momentum'])
#epoch is perturbation interval
for epoch in range(1, args['epochs'] + 1):
train(args, model, device, train_loader, optimizer, epoch)
test_acc = test(args, model, device, test_loader)
if epoch < args['epochs']:
# report intermediate result
nni.report_intermediate_result(test_acc)
logger.debug('test accuracy %g', test_acc)
logger.debug('Pipe send intermediate result done.')
else:
# report final result
nni.report_final_result(test_acc)
logger.debug('Final result is %g', test_acc)
logger.debug('Send final result done.')
if not os.path.exists(save_checkpoint_dir):
os.makedirs(save_checkpoint_dir)
save_checkpoint(model, save_checkpoint_path)
def get_params():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument("--data_dir", type=str,
default='./tmp/pytorch/mnist/input_data', help="data directory")
parser.add_argument('--batch_size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument("--hidden_size", type=int, default=512, metavar='N',
help='hidden layer size (default: 512)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--no_cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--log_interval', type=int, default=1000, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save_checkpoint_dir', type=str,
help='where to save checkpoint of this trial')
parser.add_argument('--load_checkpoint_dir', type=str,
help='where to load the model')
args, _ = parser.parse_known_args()
return args
if __name__ == '__main__':
try:
# get parameters form tuner
tuner_params = nni.get_next_parameter()
logger.debug(tuner_params)
params = vars(get_params())
params.update(tuner_params)
main(params)
except Exception as exception:
logger.exception(exception)
raise
\ No newline at end of file
{
"batch_size": {"_type":"choice", "_value": [16, 32, 64, 128]},
"hidden_size":{"_type":"choice","_value":[128, 256, 512, 1024]},
"lr":{"_type":"choice","_value":[0.0001, 0.001, 0.01, 0.1]},
"momentum":{"_type":"uniform","_value":[0, 1]}
}
......@@ -15,14 +15,6 @@ import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
# Temporary patch this example until the MNIST dataset download issue get resolved
# https://github.com/pytorch/vision/issues/1938
import urllib
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)
logger = logging.getLogger('mnist_AutoML')
......@@ -48,6 +40,8 @@ class Net(nn.Module):
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if (args['batch_num'] is not None) and batch_idx >= args['batch_num']:
break
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
......@@ -119,12 +113,11 @@ def main(args):
train(args, model, device, train_loader, optimizer, epoch)
test_acc = test(args, model, device, test_loader)
if epoch < args['epochs']:
# report intermediate result
nni.report_intermediate_result(test_acc)
logger.debug('test accuracy %g', test_acc)
logger.debug('Pipe send intermediate result done.')
else:
# report final result
nni.report_final_result(test_acc)
logger.debug('Final result is %g', test_acc)
......@@ -138,6 +131,7 @@ def get_params():
default='/tmp/pytorch/mnist/input_data', help="data directory")
parser.add_argument('--batch_size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument("--batch_num", type=int, default=None)
parser.add_argument("--hidden_size", type=int, default=512, metavar='N',
help='hidden layer size (default: 512)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
......@@ -165,6 +159,7 @@ if __name__ == '__main__':
logger.debug(tuner_params)
params = vars(get_params())
params.update(tuner_params)
print(params)
main(params)
except Exception as exception:
logger.exception(exception)
......
......@@ -1784,11 +1784,6 @@ abab@^2.0.0:
resolved "https://registry.yarnpkg.com/abab/-/abab-2.0.3.tgz#623e2075e02eb2d3f2475e49f99c91846467907a"
integrity sha512-tsFzPpcttalNjFBCFMqsKYQcWxxen1pgJR56by//QwvJc4/OUS3kPOOttx2tSIfjsylB0pYu7f5D3K1RCxUnUg==
abbrev@1:
version "1.1.1"
resolved "https://registry.yarnpkg.com/abbrev/-/abbrev-1.1.1.tgz#f8f2c887ad10bf67f634f005b6987fed3179aac8"
integrity sha512-nne9/IiQ/hzIhY6pdDnbBtz7DjPTKrY00P/zvPSm5pOFkl6xuGrGnXn/VtTNNfNtAfZ9/1RtehkszU9qcTii0Q==
accepts@~1.3.4, accepts@~1.3.5, accepts@~1.3.7:
version "1.3.7"
resolved "https://registry.yarnpkg.com/accepts/-/accepts-1.3.7.tgz#531bc726517a3b2b41f850021c6cc15eaab507cd"
......@@ -1949,19 +1944,11 @@ anymatch@~3.1.1:
normalize-path "^3.0.0"
picomatch "^2.0.4"
aproba@^1.0.3, aproba@^1.1.1:
aproba@^1.1.1:
version "1.2.0"
resolved "https://registry.yarnpkg.com/aproba/-/aproba-1.2.0.tgz#6802e6264efd18c790a1b0d517f0f2627bf2c94a"
integrity sha512-Y9J6ZjXtoYh8RnXVCMOU/ttDmk1aBjunq9vO0ta5x85WDQiQfUF9sIPBITdbiiIVcBo03Hi3jMxigBtsddlXRw==
are-we-there-yet@~1.1.2:
version "1.1.5"
resolved "https://registry.yarnpkg.com/are-we-there-yet/-/are-we-there-yet-1.1.5.tgz#4b35c2944f062a8bfcda66410760350fe9ddfc21"
integrity sha512-5hYdAkZlcG8tOLujVDTgCT+uPX0VnpAH28gWsLfzpXYm7wP6mp5Q/gYyR7YQ0cKVJcXJnl3j2kpBan13PtQf6w==
dependencies:
delegates "^1.0.0"
readable-stream "^2.0.6"
argparse@^1.0.7:
version "1.0.10"
resolved "https://registry.yarnpkg.com/argparse/-/argparse-1.0.10.tgz#bcd6791ea5ae09725e17e5ad988134cd40b3d911"
......@@ -3029,11 +3016,6 @@ console-browserify@^1.1.0:
resolved "https://registry.yarnpkg.com/console-browserify/-/console-browserify-1.2.0.tgz#67063cef57ceb6cf4993a2ab3a55840ae8c49336"
integrity sha512-ZMkYO/LkF17QvCPqM0gxw8yUzigAOZOSWSHg91FH6orS7vcEj5dVZTidN2fQ14yBSdg97RqhSNwLUXInd52OTA==
console-control-strings@^1.0.0, console-control-strings@~1.1.0:
version "1.1.0"
resolved "https://registry.yarnpkg.com/console-control-strings/-/console-control-strings-1.1.0.tgz#3d7cf4464db6446ea644bf4b39507f9851008e8e"
integrity sha1-PXz0Rk22RG6mRL9LOVB/mFEAjo4=
constants-browserify@^1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/constants-browserify/-/constants-browserify-1.0.0.tgz#c20b96d8c617748aaf1c16021760cd27fcb8cb75"
......@@ -3504,7 +3486,7 @@ debug@2.6.9, debug@^2.2.0, debug@^2.3.3, debug@^2.6.0, debug@^2.6.9:
dependencies:
ms "2.0.0"
debug@^3.0.0, debug@^3.1.1, debug@^3.2.5, debug@^3.2.6:
debug@^3.0.0, debug@^3.1.1, debug@^3.2.5:
version "3.2.6"
resolved "https://registry.yarnpkg.com/debug/-/debug-3.2.6.tgz#e83d17de16d8a7efb7717edbe5fb10135eee629b"
integrity sha512-mel+jf7nrtEl5Pn1Qx46zARXKDpBbvzezse7p7LqINmdoIk8PYP5SySaxEmYv6TZ0JyEKA1hsCId6DIhgITtWQ==
......@@ -3540,11 +3522,6 @@ deep-equal@^1.0.1:
object-keys "^1.1.1"
regexp.prototype.flags "^1.2.0"
deep-extend@^0.6.0:
version "0.6.0"
resolved "https://registry.yarnpkg.com/deep-extend/-/deep-extend-0.6.0.tgz#c4fa7c95404a17a9c3e8ca7e1537312b736330ac"
integrity sha512-LOHxIOaPYdHlJRtCQfDIVZtfw/ufM8+rVj649RIHzcm/vGwQRXFt6OPqIFWsm2XEMrNIEtWR64sY1LEKD2vAOA==
deep-is@~0.1.3:
version "0.1.3"
resolved "https://registry.yarnpkg.com/deep-is/-/deep-is-0.1.3.tgz#b369d6fb5dbc13eecf524f91b070feedc357cf34"
......@@ -3605,11 +3582,6 @@ delayed-stream@~1.0.0:
resolved "https://registry.yarnpkg.com/delayed-stream/-/delayed-stream-1.0.0.tgz#df3ae199acadfb7d440aaae0b29e2272b24ec619"
integrity sha1-3zrhmayt+31ECqrgsp4icrJOxhk=
delegates@^1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/delegates/-/delegates-1.0.0.tgz#84c6e159b81904fdca59a0ef44cd870d31250f9a"
integrity sha1-hMbhWbgZBP3KWaDvRM2HDTElD5o=
depd@~1.1.2:
version "1.1.2"
resolved "https://registry.yarnpkg.com/depd/-/depd-1.1.2.tgz#9bcd52e14c097763e749b274c4346ed2e560b5a9"
......@@ -3628,11 +3600,6 @@ destroy@~1.0.4:
resolved "https://registry.yarnpkg.com/destroy/-/destroy-1.0.4.tgz#978857442c44749e4206613e37946205826abd80"
integrity sha1-l4hXRCxEdJ5CBmE+N5RiBYJqvYA=
detect-libc@^1.0.2:
version "1.0.3"
resolved "https://registry.yarnpkg.com/detect-libc/-/detect-libc-1.0.3.tgz#fa137c4bd698edf55cd5cd02ac559f91a4c4ba9b"
integrity sha1-+hN8S9aY7fVc1c0CrFWfkaTEups=
detect-newline@^2.1.0:
version "2.1.0"
resolved "https://registry.yarnpkg.com/detect-newline/-/detect-newline-2.1.0.tgz#f41f1c10be4b00e87b5f13da680759f2c5bfd3e2"
......@@ -4678,13 +4645,6 @@ fs-extra@^8.1.0:
jsonfile "^4.0.0"
universalify "^0.1.0"
fs-minipass@^1.2.5:
version "1.2.7"
resolved "https://registry.yarnpkg.com/fs-minipass/-/fs-minipass-1.2.7.tgz#ccff8570841e7fe4265693da88936c55aed7f7c7"
integrity sha512-GWSSJGFy4e9GUeCcbIkED+bgAoFyj7XF1mV8rma3QW4NIqX9Kyx79N/PF61H5udOV3aY1IaMLs6pGbH71nlCTA==
dependencies:
minipass "^2.6.0"
fs-minipass@^2.0.0:
version "2.1.0"
resolved "https://registry.yarnpkg.com/fs-minipass/-/fs-minipass-2.1.0.tgz#7f5036fdbf12c63c169190cbe4199c852271f9fb"
......@@ -4730,20 +4690,6 @@ functional-red-black-tree@^1.0.1:
resolved "https://registry.yarnpkg.com/functional-red-black-tree/-/functional-red-black-tree-1.0.1.tgz#1b0ab3bd553b2a0d6399d29c0e3ea0b252078327"
integrity sha1-GwqzvVU7Kg1jmdKcDj6gslIHgyc=
gauge@~2.7.3:
version "2.7.4"
resolved "https://registry.yarnpkg.com/gauge/-/gauge-2.7.4.tgz#2c03405c7538c39d7eb37b317022e325fb018bf7"
integrity sha1-LANAXHU4w51+s3sxcCLjJfsBi/c=
dependencies:
aproba "^1.0.3"
console-control-strings "^1.0.0"
has-unicode "^2.0.0"
object-assign "^4.1.0"
signal-exit "^3.0.0"
string-width "^1.0.1"
strip-ansi "^3.0.1"
wide-align "^1.1.0"
gensync@^1.0.0-beta.1:
version "1.0.0-beta.1"
resolved "https://registry.yarnpkg.com/gensync/-/gensync-1.0.0-beta.1.tgz#58f4361ff987e5ff6e1e7a210827aa371eaac269"
......@@ -4937,11 +4883,6 @@ has-symbols@^1.0.0, has-symbols@^1.0.1:
resolved "https://registry.yarnpkg.com/has-symbols/-/has-symbols-1.0.1.tgz#9f5214758a44196c406d9bd76cebf81ec2dd31e8"
integrity sha512-PLcsoqu++dmEIZB+6totNFKq/7Do+Z0u4oT0zKOJNl3lYK6vGwwu2hjHs+68OEZbTjiUE9bgOABXbP/GvrS0Kg==
has-unicode@^2.0.0:
version "2.0.1"
resolved "https://registry.yarnpkg.com/has-unicode/-/has-unicode-2.0.1.tgz#e0e6fe6a28cf51138855e086d1691e771de2a8b9"
integrity sha1-4Ob+aijPUROIVeCG0Wkedx3iqLk=
has-value@^0.3.1:
version "0.3.1"
resolved "https://registry.yarnpkg.com/has-value/-/has-value-0.3.1.tgz#7b1f58bada62ca827ec0a2078025654845995e1f"
......@@ -5191,7 +5132,7 @@ hyphenate-style-name@^1.0.3:
resolved "https://registry.yarnpkg.com/hyphenate-style-name/-/hyphenate-style-name-1.0.3.tgz#097bb7fa0b8f1a9cf0bd5c734cf95899981a9b48"
integrity sha512-EcuixamT82oplpoJ2XU4pDtKGWQ7b00CD9f1ug9IaQ3p1bkHMiKCZ9ut9QDI6qsa6cpUuB+A/I+zLtdNK4n2DQ==
iconv-lite@0.4.24, iconv-lite@^0.4.24, iconv-lite@^0.4.4:
iconv-lite@0.4.24, iconv-lite@^0.4.24:
version "0.4.24"
resolved "https://registry.yarnpkg.com/iconv-lite/-/iconv-lite-0.4.24.tgz#2022b4b25fbddc21d2f524974a474aafe733908b"
integrity sha512-v3MXnZAcvnywkTUEZomIActle7RXXeedOR31wwl7VlyoXO4Qi9arvSenNQWne1TcRwhCL1HwLI21bEqdpj8/rA==
......@@ -5222,13 +5163,6 @@ iferr@^0.1.5:
resolved "https://registry.yarnpkg.com/iferr/-/iferr-0.1.5.tgz#c60eed69e6d8fdb6b3104a1fcbca1c192dc5b501"
integrity sha1-xg7taebY/bazEEofy8ocGS3FtQE=
ignore-walk@^3.0.1:
version "3.0.3"
resolved "https://registry.yarnpkg.com/ignore-walk/-/ignore-walk-3.0.3.tgz#017e2447184bfeade7c238e4aefdd1e8f95b1e37"
integrity sha512-m7o6xuOaT1aqheYHKf8W6J5pYH85ZI9w077erOzLje3JsB1gkafkAhHHY19dqjulgIZHFm32Cp5uNZgcQqdJKw==
dependencies:
minimatch "^3.0.4"
ignore@^3.3.5:
version "3.3.10"
resolved "https://registry.yarnpkg.com/ignore/-/ignore-3.3.10.tgz#0a97fb876986e8081c631160f8f9f389157f0043"
......@@ -5325,7 +5259,7 @@ inherits@2.0.3:
resolved "https://registry.yarnpkg.com/inherits/-/inherits-2.0.3.tgz#633c2c83e3da42a502f52466022480f4208261de"
integrity sha1-Yzwsg+PaQqUC9SRmAiSA9CCCYd4=
ini@^1.3.5, ini@~1.3.0:
ini@^1.3.5:
version "1.3.5"
resolved "https://registry.yarnpkg.com/ini/-/ini-1.3.5.tgz#eee25f56db1c9ec6085e0c22778083f596abf927"
integrity sha512-RZY5huIKCMRWDUqZlEi72f/lmXKMvuszcMBduliQ3nnWbx9X/ZBQO7DijMEYS9EhHBb2qacRUMtC7svLwe0lcw==
......@@ -6888,14 +6822,6 @@ minipass-pipeline@^1.2.2:
dependencies:
minipass "^3.0.0"
minipass@^2.6.0, minipass@^2.8.6, minipass@^2.9.0:
version "2.9.0"
resolved "https://registry.yarnpkg.com/minipass/-/minipass-2.9.0.tgz#e713762e7d3e32fed803115cf93e04bca9fcc9a6"
integrity sha512-wxfUjg9WebH+CUDX/CdbRlh5SmfZiy/hpkxaRI16Y9W56Pa75sWgd/rvFilSgrauD9NyFymP/+JFV3KwzIsJeg==
dependencies:
safe-buffer "^5.1.2"
yallist "^3.0.0"
minipass@^3.0.0, minipass@^3.1.1:
version "3.1.1"
resolved "https://registry.yarnpkg.com/minipass/-/minipass-3.1.1.tgz#7607ce778472a185ad6d89082aa2070f79cedcd5"
......@@ -6903,13 +6829,6 @@ minipass@^3.0.0, minipass@^3.1.1:
dependencies:
yallist "^4.0.0"
minizlib@^1.2.1:
version "1.3.3"
resolved "https://registry.yarnpkg.com/minizlib/-/minizlib-1.3.3.tgz#2290de96818a34c29551c8a8d301216bd65a861d"
integrity sha512-6ZYMOEnmVsdCeTJVE0W9ZD+pVnE8h9Hma/iOwwRDsdQoePpoX56/8B6z3P9VNwppJuBKNRuFDRNRqRWexT9G9Q==
dependencies:
minipass "^2.9.0"
mississippi@^3.0.0:
version "3.0.0"
resolved "https://registry.yarnpkg.com/mississippi/-/mississippi-3.0.0.tgz#ea0a3291f97e0b5e8776b363d5f0a12d94c67022"
......@@ -6942,7 +6861,7 @@ mixin-object@^2.0.1:
for-in "^0.1.3"
is-extendable "^0.1.1"
mkdirp@0.5.1, mkdirp@^0.5.0, mkdirp@^0.5.1, mkdirp@~0.5.1:
mkdirp@0.5.1, mkdirp@^0.5.1, mkdirp@~0.5.1:
version "0.5.1"
resolved "https://registry.yarnpkg.com/mkdirp/-/mkdirp-0.5.1.tgz#30057438eac6cf7f8c4767f38648d6697d75c903"
integrity sha1-MAV0OOrGz3+MR2fzhkjWaX11yQM=
......@@ -7021,15 +6940,6 @@ natural-compare@^1.4.0:
resolved "https://registry.yarnpkg.com/natural-compare/-/natural-compare-1.4.0.tgz#4abebfeed7541f2c27acfb29bdbbd15c8d5ba4f7"
integrity sha1-Sr6/7tdUHywnrPspvbvRXI1bpPc=
needle@^2.2.1:
version "2.3.2"
resolved "https://registry.yarnpkg.com/needle/-/needle-2.3.2.tgz#3342dea100b7160960a450dc8c22160ac712a528"
integrity sha512-DUzITvPVDUy6vczKKYTnWc/pBZ0EnjMJnQ3y+Jo5zfKFimJs7S3HFCxCRZYB9FUZcrzUQr3WsmvZgddMEIZv6w==
dependencies:
debug "^3.2.6"
iconv-lite "^0.4.4"
sax "^1.2.4"
negotiator@0.6.2:
version "0.6.2"
resolved "https://registry.yarnpkg.com/negotiator/-/negotiator-0.6.2.tgz#feacf7ccf525a77ae9634436a64883ffeca346fb"
......@@ -7113,22 +7023,6 @@ node-notifier@^5.4.2:
shellwords "^0.1.1"
which "^1.3.0"
node-pre-gyp@*:
version "0.14.0"
resolved "https://registry.yarnpkg.com/node-pre-gyp/-/node-pre-gyp-0.14.0.tgz#9a0596533b877289bcad4e143982ca3d904ddc83"
integrity sha512-+CvDC7ZttU/sSt9rFjix/P05iS43qHCOOGzcr3Ry99bXG7VX953+vFyEuph/tfqoYu8dttBkE86JSKBO2OzcxA==
dependencies:
detect-libc "^1.0.2"
mkdirp "^0.5.1"
needle "^2.2.1"
nopt "^4.0.1"
npm-packlist "^1.1.6"
npmlog "^4.0.2"
rc "^1.2.7"
rimraf "^2.6.1"
semver "^5.3.0"
tar "^4.4.2"
node-releases@^1.1.47, node-releases@^1.1.50:
version "1.1.50"
resolved "https://registry.yarnpkg.com/node-releases/-/node-releases-1.1.50.tgz#803c40d2c45db172d0410e4efec83aa8c6ad0592"
......@@ -7136,14 +7030,6 @@ node-releases@^1.1.47, node-releases@^1.1.50:
dependencies:
semver "^6.3.0"
nopt@^4.0.1:
version "4.0.1"
resolved "https://registry.yarnpkg.com/nopt/-/nopt-4.0.1.tgz#d0d4685afd5415193c8c7505602d0d17cd64474d"
integrity sha1-0NRoWv1UFRk8jHUFYC0NF81kR00=
dependencies:
abbrev "1"
osenv "^0.1.4"
normalize-package-data@^2.3.2:
version "2.5.0"
resolved "https://registry.yarnpkg.com/normalize-package-data/-/normalize-package-data-2.5.0.tgz#e66db1838b200c1dfc233225d12cb36520e234a8"
......@@ -7186,27 +7072,6 @@ normalize-url@^3.0.0:
resolved "https://registry.yarnpkg.com/normalize-url/-/normalize-url-3.3.0.tgz#b2e1c4dc4f7c6d57743df733a4f5978d18650559"
integrity sha512-U+JJi7duF1o+u2pynbp2zXDW2/PADgC30f0GsHZtRh+HOcXHnw137TrNlyxxRvWW5fjKd3bcLHPxofWuCjaeZg==
npm-bundled@^1.0.1:
version "1.1.1"
resolved "https://registry.yarnpkg.com/npm-bundled/-/npm-bundled-1.1.1.tgz#1edd570865a94cdb1bc8220775e29466c9fb234b"
integrity sha512-gqkfgGePhTpAEgUsGEgcq1rqPXA+tv/aVBlgEzfXwA1yiUJF7xtEt3CtVwOjNYQOVknDk0F20w58Fnm3EtG0fA==
dependencies:
npm-normalize-package-bin "^1.0.1"
npm-normalize-package-bin@^1.0.1:
version "1.0.1"
resolved "https://registry.yarnpkg.com/npm-normalize-package-bin/-/npm-normalize-package-bin-1.0.1.tgz#6e79a41f23fd235c0623218228da7d9c23b8f6e2"
integrity sha512-EPfafl6JL5/rU+ot6P3gRSCpPDW5VmIzX959Ob1+ySFUuuYHWHekXpwdUZcKP5C+DS4GEtdJluwBjnsNDl+fSA==
npm-packlist@^1.1.6:
version "1.4.8"
resolved "https://registry.yarnpkg.com/npm-packlist/-/npm-packlist-1.4.8.tgz#56ee6cc135b9f98ad3d51c1c95da22bbb9b2ef3e"
integrity sha512-5+AZgwru5IevF5ZdnFglB5wNlHG1AOOuw28WhUq8/8emhBmLv6jX5by4WJCh7lW0uSYZYS6DXqIsyZVIXRZU9A==
dependencies:
ignore-walk "^3.0.1"
npm-bundled "^1.0.1"
npm-normalize-package-bin "^1.0.1"
npm-run-path@^2.0.0:
version "2.0.2"
resolved "https://registry.yarnpkg.com/npm-run-path/-/npm-run-path-2.0.2.tgz#35a9232dfa35d7067b4cb2ddf2357b1871536c5f"
......@@ -7214,16 +7079,6 @@ npm-run-path@^2.0.0:
dependencies:
path-key "^2.0.0"
npmlog@^4.0.2:
version "4.1.2"
resolved "https://registry.yarnpkg.com/npmlog/-/npmlog-4.1.2.tgz#08a7f2a8bf734604779a9efa4ad5cc717abb954b"
integrity sha512-2uUqazuKlTaSI/dC8AzicUck7+IrEaOnN/e0jd3Xtt1KcGpwx30v50mL7oPyr/h9bL3E4aZccVwpwP+5W9Vjkg==
dependencies:
are-we-there-yet "~1.1.2"
console-control-strings "~1.1.0"
gauge "~2.7.3"
set-blocking "~2.0.0"
nth-check@^1.0.2, nth-check@~1.0.1:
version "1.0.2"
resolved "https://registry.yarnpkg.com/nth-check/-/nth-check-1.0.2.tgz#b2bd295c37e3dd58a3bf0700376663ba4d9cf05c"
......@@ -7430,11 +7285,6 @@ os-browserify@^0.3.0:
resolved "https://registry.yarnpkg.com/os-browserify/-/os-browserify-0.3.0.tgz#854373c7f5c2315914fc9bfc6bd8238fdda1ec27"
integrity sha1-hUNzx/XCMVkU/Jv8a9gjj92h7Cc=
os-homedir@^1.0.0:
version "1.0.2"
resolved "https://registry.yarnpkg.com/os-homedir/-/os-homedir-1.0.2.tgz#ffbc4988336e0e833de0c168c7ef152121aa7fb3"
integrity sha1-/7xJiDNuDoM94MFox+8VISGqf7M=
os-locale@^3.0.0:
version "3.1.0"
resolved "https://registry.yarnpkg.com/os-locale/-/os-locale-3.1.0.tgz#a802a6ee17f24c10483ab9935719cef4ed16bf1a"
......@@ -7444,19 +7294,11 @@ os-locale@^3.0.0:
lcid "^2.0.0"
mem "^4.0.0"
os-tmpdir@^1.0.0, os-tmpdir@~1.0.2:
os-tmpdir@~1.0.2:
version "1.0.2"
resolved "https://registry.yarnpkg.com/os-tmpdir/-/os-tmpdir-1.0.2.tgz#bbe67406c79aa85c5cfec766fe5734555dfa1274"
integrity sha1-u+Z0BseaqFxc/sdm/lc0VV36EnQ=
osenv@^0.1.4:
version "0.1.5"
resolved "https://registry.yarnpkg.com/osenv/-/osenv-0.1.5.tgz#85cdfafaeb28e8677f416e287592b5f3f49ea410"
integrity sha512-0CWcCECdMVc2Rw3U5w9ZjqX6ga6ubk1xDVKxtBQPK7wis/0F2r9T6k4ydGYhecl7YUBxBVxhL5oisPsNxAPe2g==
dependencies:
os-homedir "^1.0.0"
os-tmpdir "^1.0.0"
p-defer@^1.0.0:
version "1.0.0"
resolved "https://registry.yarnpkg.com/p-defer/-/p-defer-1.0.0.tgz#9f6eb182f6c9aa8cd743004a7d4f96b196b0fb0c"
......@@ -8727,16 +8569,6 @@ raw-body@2.4.0:
iconv-lite "0.4.24"
unpipe "1.0.0"
rc@^1.2.7:
version "1.2.8"
resolved "https://registry.yarnpkg.com/rc/-/rc-1.2.8.tgz#cd924bf5200a075b83c188cd6b9e211b7fc0d3ed"
integrity sha512-y3bGgqKj3QBdxLbLkomlohkvsA8gdAiUQlSBJnBhfn+BPxg4bc62d8TcBW15wavDfgexCgccckhcZvywyQYPOw==
dependencies:
deep-extend "^0.6.0"
ini "~1.3.0"
minimist "^1.2.0"
strip-json-comments "~2.0.1"
react-app-polyfill@^1.0.6:
version "1.0.6"
resolved "https://registry.yarnpkg.com/react-app-polyfill/-/react-app-polyfill-1.0.6.tgz#890f8d7f2842ce6073f030b117de9130a5f385f0"
......@@ -8912,7 +8744,7 @@ read-pkg@^3.0.0:
normalize-package-data "^2.3.2"
path-type "^3.0.0"
"readable-stream@1 || 2", readable-stream@^2.0.0, readable-stream@^2.0.1, readable-stream@^2.0.2, readable-stream@^2.0.6, readable-stream@^2.1.5, readable-stream@^2.2.2, readable-stream@^2.3.3, readable-stream@^2.3.6, readable-stream@~2.3.6:
"readable-stream@1 || 2", readable-stream@^2.0.0, readable-stream@^2.0.1, readable-stream@^2.0.2, readable-stream@^2.1.5, readable-stream@^2.2.2, readable-stream@^2.3.3, readable-stream@^2.3.6, readable-stream@~2.3.6:
version "2.3.7"
resolved "https://registry.yarnpkg.com/readable-stream/-/readable-stream-2.3.7.tgz#1eca1cf711aef814c04f62252a36a62f6cb23b57"
integrity sha512-Ebho8K4jIbHAxnuxi7o42OrZgF/ZTNcsZj6nRKyUmkhLFq8CHItp/fy6hQZuZmP/n3yZ9VBUbp4zz/mX8hmYPw==
......@@ -9254,7 +9086,7 @@ rimraf@2.6.3:
dependencies:
glob "^7.1.3"
rimraf@^2.5.4, rimraf@^2.6.1, rimraf@^2.6.3, rimraf@^2.7.1:
rimraf@^2.5.4, rimraf@^2.6.3, rimraf@^2.7.1:
version "2.7.1"
resolved "https://registry.yarnpkg.com/rimraf/-/rimraf-2.7.1.tgz#35797f13a7fdadc566142c29d4f07ccad483e3ec"
integrity sha512-uWjbaKIK3T1OSVptzX7Nl6PvQ3qAGtKEtVRjRuazjfL3Bx5eI409VZSqgND+4UNnmzLVdPj9FqFJNPqBZFve4w==
......@@ -9397,7 +9229,7 @@ selfsigned@^1.10.7:
dependencies:
node-forge "0.9.0"
"semver@2 || 3 || 4 || 5", semver@^5.3.0, semver@^5.4.1, semver@^5.5.0, semver@^5.5.1, semver@^5.6.0:
"semver@2 || 3 || 4 || 5", semver@^5.4.1, semver@^5.5.0, semver@^5.5.1, semver@^5.6.0:
version "5.7.1"
resolved "https://registry.yarnpkg.com/semver/-/semver-5.7.1.tgz#a954f931aeba508d307bbf069eff0c01c96116f7"
integrity sha512-sauaDf/PZdVgrLTNYHRtpXa1iRiKcaebiKQ1BJdpQlWH2lCvexQdX55snPFyK7QzpudqbCI0qXFfOasHdyNDGQ==
......@@ -9459,7 +9291,7 @@ serve-static@1.14.1:
parseurl "~1.3.3"
send "0.17.1"
set-blocking@^2.0.0, set-blocking@~2.0.0:
set-blocking@^2.0.0:
version "2.0.0"
resolved "https://registry.yarnpkg.com/set-blocking/-/set-blocking-2.0.0.tgz#045f9782d011ae9a6803ddd382b24392b3d890f7"
integrity sha1-BF+XgtARrppoA93TgrJDkrPYkPc=
......@@ -9866,7 +9698,7 @@ string-width@^1.0.1:
is-fullwidth-code-point "^1.0.0"
strip-ansi "^3.0.0"
"string-width@^1.0.2 || 2", string-width@^2.0.0, string-width@^2.1.1:
string-width@^2.0.0, string-width@^2.1.1:
version "2.1.1"
resolved "https://registry.yarnpkg.com/string-width/-/string-width-2.1.1.tgz#ab93f27a8dc13d28cac815c462143a6d9012ae9e"
integrity sha512-nOqH59deCq9SRHlxq1Aw85Jnt4w6KvLKqWVik6oA9ZklXLNIOlqg4F2yrT1MVaTjAqvVwdfeZ7w7aCvJD7ugkw==
......@@ -9989,11 +9821,6 @@ strip-json-comments@^3.0.1:
resolved "https://registry.yarnpkg.com/strip-json-comments/-/strip-json-comments-3.0.1.tgz#85713975a91fb87bf1b305cca77395e40d2a64a7"
integrity sha512-VTyMAUfdm047mwKl+u79WIdrZxtFtn+nBxHeb844XBQ9uMNTuTHdx2hc5RiAJYqwTj3wc/xe5HLSdJSkJ+WfZw==
strip-json-comments@~2.0.1:
version "2.0.1"
resolved "https://registry.yarnpkg.com/strip-json-comments/-/strip-json-comments-2.0.1.tgz#3c531942e908c2697c0ec344858c286c7ca0a60a"
integrity sha1-PFMZQukIwml8DsNEhYwobHygpgo=
style-loader@0.23.1:
version "0.23.1"
resolved "https://registry.yarnpkg.com/style-loader/-/style-loader-0.23.1.tgz#cb9154606f3e771ab6c4ab637026a1049174d925"
......@@ -10081,19 +9908,6 @@ tapable@^1.0.0, tapable@^1.1.3:
resolved "https://registry.yarnpkg.com/tapable/-/tapable-1.1.3.tgz#a1fccc06b58db61fd7a45da2da44f5f3a3e67ba2"
integrity sha512-4WK/bYZmj8xLr+HUCODHGF1ZFzsYffasLUgEiMBY4fgtltdO6B4WJtlSbPaDTLpYTcGVwM2qLnFTICEcNxs3kA==
tar@^4.4.2:
version "4.4.13"
resolved "https://registry.yarnpkg.com/tar/-/tar-4.4.13.tgz#43b364bc52888d555298637b10d60790254ab525"
integrity sha512-w2VwSrBoHa5BsSyH+KxEqeQBAllHhccyMFVHtGtdMpF4W7IRWfZjFiQceJPChOeTsSDVUpER2T8FA93pr0L+QA==
dependencies:
chownr "^1.1.1"
fs-minipass "^1.2.5"
minipass "^2.8.6"
minizlib "^1.2.1"
mkdirp "^0.5.0"
safe-buffer "^5.1.2"
yallist "^3.0.3"
terser-webpack-plugin@2.3.4:
version "2.3.4"
resolved "https://registry.yarnpkg.com/terser-webpack-plugin/-/terser-webpack-plugin-2.3.4.tgz#ac045703bd8da0936ce910d8fb6350d0e1dee5fe"
......@@ -10784,13 +10598,6 @@ which@^2.0.1:
dependencies:
isexe "^2.0.0"
wide-align@^1.1.0:
version "1.1.3"
resolved "https://registry.yarnpkg.com/wide-align/-/wide-align-1.1.3.tgz#ae074e6bdc0c14a431e804e624549c633b000457"
integrity sha512-QGkOQc8XL6Bt5PwnsExKBPuMKBxnGxWWW3fU55Xt4feHozMUhdUMaBCk290qpm/wG5u/RSKzwdAC4i51YigihA==
dependencies:
string-width "^1.0.2 || 2"
word-wrap@~1.2.3:
version "1.2.3"
resolved "https://registry.yarnpkg.com/word-wrap/-/word-wrap-1.2.3.tgz#610636f6b1f703891bd34771ccb17fb93b47079c"
......@@ -11017,7 +10824,7 @@ xtend@^4.0.0, xtend@~4.0.1:
resolved "https://registry.yarnpkg.com/y18n/-/y18n-4.0.0.tgz#95ef94f85ecc81d007c264e190a120f0a3c8566b"
integrity sha512-r9S/ZyXu/Xu9q1tYlpsLIsa3EeLXXk0VwlxqTcFRfg9EhMW+17kbt9G0NrgCmhGb5vT2hyhJZLfDGx+7+5Uj/w==
yallist@^3.0.0, yallist@^3.0.2, yallist@^3.0.3:
yallist@^3.0.2:
version "3.1.1"
resolved "https://registry.yarnpkg.com/yallist/-/yallist-3.1.1.tgz#dbb7daf9bfd8bac9ab45ebf602b8cbad0d5d08fd"
integrity sha512-a4UGQaWPH59mOXUYnAG2ewncQS4i4F43Tv3JoAM+s2VDAmS9NsK8GpDMLrCHPksFT7h3K6TOoUNn2pb7RoXx4g==
......
......@@ -178,7 +178,7 @@ export namespace ValidationSchemas {
gpuIndices: joi.string()
}),
tuner: joi.object({
builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch', 'NetworkMorphism', 'MetisTuner', 'GPTuner', 'PPOTuner'),
builtinTunerName: joi.string().valid('TPE', 'Random', 'Anneal', 'Evolution', 'SMAC', 'BatchTuner', 'GridSearch', 'NetworkMorphism', 'MetisTuner', 'GPTuner', 'PPOTuner', 'PBTTuner'),
codeDir: joi.string(),
classFileName: joi.string(),
className: joi.string(),
......
......@@ -229,42 +229,39 @@ class ModelSpeedup:
list
a list of scope name of all the leaf modules
"""
pieces = [] # each element is a dict
class SNode:
def __init__(self, name):
self.sname = name
self.childs = {}
root = None
for node in graph.nodes():
scope_name = node.scopeName()
if scope_name == '':
continue
segs = scope_name.split('/')
segs_len = len(segs)
# increase the length of `pieces` if not enough
for _ in range(segs_len - len(pieces)):
pieces.append({})
# process internal segments of the scope name
# 'L' means leaf segment
# 'I' means internal segment
# internal segment can replace leaf segment at the same position of `pieces`
for i, seg in enumerate(segs[:-1]):
seg_name_dict = pieces[i]
if seg in seg_name_dict:
if seg_name_dict[seg][0] == 'L':
seg_name_dict[seg] = ('I', node)
if root is None:
root = SNode(segs[0])
curr = root
for seg in segs[1:]:
if not seg in curr.childs:
curr.childs[seg] = SNode(seg)
curr = curr.childs[seg]
leaf_nodes = []
def traverse_tree(node, scope_name):
if scope_name == '':
sn = node.sname
else:
sn = scope_name + '/' + node.sname
if not node.childs:
if node.sname[-1] == ']':
leaf_nodes.append(sn)
else:
seg_name_dict[seg] = ('I', node)
# process the leaf segment of the scope name
last_segs_dict = pieces[len(segs) - 1]
if not segs[-1] in last_segs_dict:
last_segs_dict[segs[-1]] = ('L', node)
# traverse `pieces` to obtain all the leaf modules which are labeled with 'L'
leaf_modules = []
for piece in pieces:
for _, value in piece.items():
if value[0] == 'L':
assert value[1].scopeName() not in leaf_modules
# if this is a leaf module, the last segment of its scope name
# must be in pattern `xxx[xxx]`
if value[1].scopeName()[-1] == ']':
leaf_modules.append(value[1].scopeName())
return leaf_modules
for key in node.childs:
traverse_tree(node.childs[key], sn)
traverse_tree(root, '')
return leaf_nodes
def _build_graph(self):
"""
......
......@@ -15,7 +15,8 @@ ModuleName = {
'Curvefitting': 'nni.curvefitting_assessor.curvefitting_assessor',
'MetisTuner': 'nni.metis_tuner.metis_tuner',
'GPTuner': 'nni.gp_tuner.gp_tuner',
'PPOTuner': 'nni.ppo_tuner.ppo_tuner'
'PPOTuner': 'nni.ppo_tuner.ppo_tuner',
'PBTTuner': 'nni.pbt_tuner.pbt_tuner'
}
ClassName = {
......@@ -30,6 +31,7 @@ ClassName = {
'MetisTuner':'MetisTuner',
'GPTuner':'GPTuner',
'PPOTuner': 'PPOTuner',
'PBTTuner': 'PBTTuner',
'Medianstop': 'MedianstopAssessor',
'Curvefitting': 'CurvefittingAssessor'
......
......@@ -10,97 +10,8 @@ import random
import numpy as np
from nni.tuner import Tuner
from nni.utils import NodeType, OptimizeMode, extract_scalar_reward, split_index
from nni.utils import OptimizeMode, extract_scalar_reward, split_index, json2parameter, json2space
import nni.parameter_expressions as parameter_expressions
def json2space(x, oldy=None, name=NodeType.ROOT):
"""
Change search space from json format to hyperopt format
"""
y = list()
if isinstance(x, dict):
if NodeType.TYPE in x.keys():
_type = x[NodeType.TYPE]
name = name + '-' + _type
if _type == 'choice':
if oldy is not None:
_index = oldy[NodeType.INDEX]
y += json2space(x[NodeType.VALUE][_index],
oldy[NodeType.VALUE], name=name+'[%d]' % _index)
else:
y += json2space(x[NodeType.VALUE], None, name=name)
y.append(name)
else:
for key in x.keys():
y += json2space(x[key], oldy[key] if oldy else None, name+"[%s]" % str(key))
elif isinstance(x, list):
for i, x_i in enumerate(x):
if isinstance(x_i, dict):
if NodeType.NAME not in x_i.keys():
raise RuntimeError('\'_name\' key is not found in this nested search space.')
y += json2space(x_i, oldy[i] if oldy else None, name + "[%d]" % i)
return y
def json2parameter(x, is_rand, random_state, oldy=None, Rand=False, name=NodeType.ROOT):
"""
Json to pramaters.
"""
if isinstance(x, dict):
if NodeType.TYPE in x.keys():
_type = x[NodeType.TYPE]
_value = x[NodeType.VALUE]
name = name + '-' + _type
Rand |= is_rand[name]
if Rand is True:
if _type == 'choice':
_index = random_state.randint(len(_value))
y = {
NodeType.INDEX: _index,
NodeType.VALUE: json2parameter(
x[NodeType.VALUE][_index],
is_rand,
random_state,
None,
Rand,
name=name+"[%d]" % _index
)
}
else:
y = getattr(parameter_expressions, _type)(*(_value + [random_state]))
else:
y = copy.deepcopy(oldy)
else:
y = dict()
for key in x.keys():
y[key] = json2parameter(
x[key],
is_rand,
random_state,
oldy[key] if oldy else None,
Rand,
name + "[%s]" % str(key)
)
elif isinstance(x, list):
y = list()
for i, x_i in enumerate(x):
if isinstance(x_i, dict):
if NodeType.NAME not in x_i.keys():
raise RuntimeError('\'_name\' key is not found in this nested search space.')
y.append(json2parameter(
x_i,
is_rand,
random_state,
oldy[i] if oldy else None,
Rand,
name + "[%d]" % i
))
else:
y = copy.deepcopy(x)
return y
class Individual:
"""
......
......@@ -9,7 +9,7 @@ import numpy as np
from unittest import TestCase, main
from nni.evolution_tuner.evolution_tuner import json2space, json2parameter
from nni.utils import json2space, json2parameter
class EvolutionTunerTestCase(TestCase):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import logging
import os
import numpy as np
import nni
from nni.tuner import Tuner
from nni.utils import OptimizeMode, extract_scalar_reward, split_index, json2parameter, json2space
logger = logging.getLogger('pbt_tuner_AutoML')
def exploit_and_explore(bot_trial_info, top_trial_info, factors, epoch, search_space):
"""
Replace checkpoint of bot_trial with top, and perturb hyperparameters
Parameters
----------
bot_trial_info : TrialInfo
bottom model whose parameters should be replaced
top_trial_info : TrialInfo
better model
factors : float
factors for perturbation
epoch : int
step of PBTTuner
search_space : dict
search_space to keep perturbed hyperparameters in range
"""
bot_checkpoint_dir = bot_trial_info.checkpoint_dir
top_hyper_parameters = top_trial_info.hyper_parameters
hyper_parameters = copy.deepcopy(top_hyper_parameters)
# TODO think about different type of hyperparameters for 1.perturbation 2.within search space
for key in hyper_parameters.keys():
if key == 'load_checkpoint_dir':
hyper_parameters[key] = hyper_parameters['save_checkpoint_dir']
elif key == 'save_checkpoint_dir':
hyper_parameters[key] = os.path.join(bot_checkpoint_dir, str(epoch))
elif isinstance(hyper_parameters[key], float):
perturb = np.random.choice(factors)
val = hyper_parameters[key] * perturb
lb, ub = search_space[key]["_value"][:2]
if search_space[key]["_type"] in ("uniform", "normal"):
val = np.clip(val, lb, ub).item()
hyper_parameters[key] = val
else:
continue
bot_trial_info.hyper_parameters = hyper_parameters
bot_trial_info.clean_id()
class TrialInfo:
"""
Information of each trial, refresh for each epoch
"""
def __init__(self, checkpoint_dir=None, hyper_parameters=None, parameter_id=None, score=None):
self.checkpoint_dir = checkpoint_dir
self.hyper_parameters = hyper_parameters
self.parameter_id = parameter_id
self.score = score
def clean_id(self):
self.parameter_id = None
class PBTTuner(Tuner):
def __init__(self, optimize_mode="maximize", all_checkpoint_dir=None, population_size=10, factors=(1.2, 0.8), fraction=0.2):
"""
Initialization
Parameters
----------
optimize_mode : str
maximize or minimize
all_checkpoint_dir : str
directory to store training model checkpoint
population_size : int
number of trials for each epoch
factors : tuple
factors for perturbation
fraction : float
fraction for selecting bottom and top trials
"""
self.optimize_mode = OptimizeMode(optimize_mode)
if all_checkpoint_dir is None:
all_checkpoint_dir = os.getenv('NNI_CHECKPOINT_DIRECTORY')
logger.info("Checkpoint dir is set to %s by default.", all_checkpoint_dir)
self.all_checkpoint_dir = all_checkpoint_dir
self.population_size = population_size
self.factors = factors
self.fraction = fraction
# defined in trial code
#self.perturbation_interval = perturbation_interval
self.population = None
self.pos = -1
self.param_ids = []
self.running = {}
self.finished = []
self.credit = 0
self.finished_trials = 0
self.epoch = 0
self.searchspace_json = None
self.space = None
self.send_trial_callback = None
logger.info('PBT tuner initialization')
def update_search_space(self, search_space):
"""
Get search space
Parameters
----------
search_space : dict
Search space
"""
logger.info('Update search space %s', search_space)
self.searchspace_json = search_space
self.space = json2space(self.searchspace_json)
self.random_state = np.random.RandomState()
self.population = []
is_rand = dict()
for item in self.space:
is_rand[item] = True
for i in range(self.population_size):
hyper_parameters = json2parameter(
self.searchspace_json, is_rand, self.random_state)
checkpoint_dir = os.path.join(self.all_checkpoint_dir, str(i))
hyper_parameters['load_checkpoint_dir'] = os.path.join(checkpoint_dir, str(self.epoch))
hyper_parameters['save_checkpoint_dir'] = os.path.join(checkpoint_dir, str(self.epoch))
self.population.append(TrialInfo(checkpoint_dir=checkpoint_dir, hyper_parameters=hyper_parameters))
def generate_multiple_parameters(self, parameter_id_list, **kwargs):
"""
Returns multiple sets of trial (hyper-)parameters, as iterable of serializable objects.
Parameters
----------
parameter_id_list : list of int
Unique identifiers for each set of requested hyper-parameters.
These will later be used in :meth:`receive_trial_result`.
**kwargs
Used for send_trial_callback.
Returns
-------
list
A list of newly generated configurations
"""
result = []
self.send_trial_callback = kwargs['st_callback']
for parameter_id in parameter_id_list:
had_exception = False
try:
logger.debug("generating param for %s", parameter_id)
res = self.generate_parameters(parameter_id, **kwargs)
except nni.NoMoreTrialError:
had_exception = True
if not had_exception:
result.append(res)
return result
def generate_parameters(self, parameter_id, **kwargs):
"""
Generate parameters, if no trial configration for now, self.credit plus 1 to send the config later
Parameters
----------
parameter_id : int
Unique identifier for requested hyper-parameters.
This will later be used in :meth:`receive_trial_result`.
**kwargs
Not used
Returns
-------
dict
One newly generated configuration
"""
if self.pos == self.population_size - 1:
logger.debug('Credit added by one in parameters request')
self.credit += 1
self.param_ids.append(parameter_id)
raise nni.NoMoreTrialError('No more parameters now.')
self.pos += 1
trial_info = self.population[self.pos]
trial_info.parameter_id = parameter_id
self.running[parameter_id] = trial_info
logger.info('Generate parameter : %s', trial_info.hyper_parameters)
return split_index(trial_info.hyper_parameters)
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
"""
Receive trial's result. if the number of finished trials equals ``self.population_size``, start the next epoch to
train the model.
Parameters
----------
parameter_id : int
Unique identifier of used hyper-parameters, same with :meth:`generate_parameters`.
parameters : dict
Hyper-parameters generated by :meth:`generate_parameters`.
value : dict
Result from trial (the return value of :func:`nni.report_final_result`).
"""
logger.info('Get one trial result, id = %d, value = %s', parameter_id, value)
value = extract_scalar_reward(value)
if self.optimize_mode == OptimizeMode.Minimize:
value = -value
trial_info = self.running.pop(parameter_id, None)
trial_info.score = value
self.finished.append(trial_info)
self.finished_trials += 1
if self.finished_trials == self.population_size:
logger.info('Proceeding to next epoch')
self.epoch += 1
self.population = []
self.pos = -1
self.running = {}
#exploit and explore
self.finished = sorted(self.finished, key=lambda x: x.score, reverse=True)
cutoff = int(np.ceil(self.fraction * len(self.finished)))
tops = self.finished[:cutoff]
bottoms = self.finished[self.finished_trials - cutoff:]
for bottom in bottoms:
top = np.random.choice(tops)
exploit_and_explore(bottom, top, self.factors, self.epoch, self.searchspace_json)
for trial in self.finished:
if trial not in bottoms:
trial.clean_id()
trial.hyper_parameters['load_checkpoint_dir'] = trial.hyper_parameters['save_checkpoint_dir']
trial.hyper_parameters['save_checkpoint_dir'] = os.path.join(trial.checkpoint_dir, str(self.epoch))
self.finished_trials = 0
for _ in range(self.population_size):
trial_info = self.finished.pop()
self.population.append(trial_info)
while self.credit > 0 and self.pos + 1 < len(self.population):
self.credit -= 1
self.pos += 1
parameter_id = self.param_ids.pop()
trial_info = self.population[self.pos]
trial_info.parameter_id = parameter_id
self.running[parameter_id] = trial_info
self.send_trial_callback(parameter_id, split_index(trial_info.hyper_parameters))
def import_data(self, data):
pass
......@@ -2,13 +2,16 @@
# Licensed under the MIT license.
import os
import copy
import functools
from enum import Enum, unique
import json_tricks
from . import parameter_expressions
from .common import init_logger
from .env_vars import dispatcher_env_vars
to_json = functools.partial(json_tricks.dumps, allow_nan=True)
@unique
......@@ -124,3 +127,92 @@ def init_dispatcher_logger():
if dispatcher_env_vars.NNI_LOG_DIRECTORY is not None:
logger_file_path = os.path.join(dispatcher_env_vars.NNI_LOG_DIRECTORY, logger_file_path)
init_logger(logger_file_path, dispatcher_env_vars.NNI_LOG_LEVEL)
def json2space(x, oldy=None, name=NodeType.ROOT):
"""
Change search space from json format to hyperopt format
"""
y = list()
if isinstance(x, dict):
if NodeType.TYPE in x.keys():
_type = x[NodeType.TYPE]
name = name + '-' + _type
if _type == 'choice':
if oldy is not None:
_index = oldy[NodeType.INDEX]
y += json2space(x[NodeType.VALUE][_index],
oldy[NodeType.VALUE], name=name+'[%d]' % _index)
else:
y += json2space(x[NodeType.VALUE], None, name=name)
y.append(name)
else:
for key in x.keys():
y += json2space(x[key], oldy[key] if oldy else None, name+"[%s]" % str(key))
elif isinstance(x, list):
for i, x_i in enumerate(x):
if isinstance(x_i, dict):
if NodeType.NAME not in x_i.keys():
raise RuntimeError('\'_name\' key is not found in this nested search space.')
y += json2space(x_i, oldy[i] if oldy else None, name + "[%d]" % i)
return y
def json2parameter(x, is_rand, random_state, oldy=None, Rand=False, name=NodeType.ROOT):
"""
Json to pramaters.
"""
if isinstance(x, dict):
if NodeType.TYPE in x.keys():
_type = x[NodeType.TYPE]
_value = x[NodeType.VALUE]
name = name + '-' + _type
Rand |= is_rand[name]
if Rand is True:
if _type == 'choice':
_index = random_state.randint(len(_value))
y = {
NodeType.INDEX: _index,
NodeType.VALUE: json2parameter(
x[NodeType.VALUE][_index],
is_rand,
random_state,
None,
Rand,
name=name+"[%d]" % _index
)
}
else:
y = getattr(parameter_expressions, _type)(*(_value + [random_state]))
else:
y = copy.deepcopy(oldy)
else:
y = dict()
for key in x.keys():
y[key] = json2parameter(
x[key],
is_rand,
random_state,
oldy[key] if oldy else None,
Rand,
name + "[%s]" % str(key)
)
elif isinstance(x, list):
y = list()
for i, x_i in enumerate(x):
if isinstance(x_i, dict):
if NodeType.NAME not in x_i.keys():
raise RuntimeError('\'_name\' key is not found in this nested search space.')
y.append(json2parameter(
x_i,
is_rand,
random_state,
oldy[i] if oldy else None,
Rand,
name + "[%d]" % i
))
else:
y = copy.deepcopy(x)
return y
......@@ -8,6 +8,7 @@ import os
import random
import shutil
import sys
from collections import deque
from unittest import TestCase, main
from nni.batch_tuner.batch_tuner import BatchTuner
......@@ -16,6 +17,8 @@ from nni.gp_tuner.gp_tuner import GPTuner
from nni.gridsearch_tuner.gridsearch_tuner import GridSearchTuner
from nni.hyperopt_tuner.hyperopt_tuner import HyperoptTuner
from nni.metis_tuner.metis_tuner import MetisTuner
from nni.msg_dispatcher import _pack_parameter, MsgDispatcher
from nni.pbt_tuner.pbt_tuner import PBTTuner
try:
from nni.smac_tuner.smac_tuner import SMACTuner
......@@ -23,6 +26,7 @@ except ImportError:
assert sys.platform == "win32"
from nni.tuner import Tuner
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('test_tuner')
......@@ -44,18 +48,29 @@ class BuiltinTunersTestCase(TestCase):
self.params_each_round = 50
self.exhaustive = False
def send_trial_callback(self, param_queue):
def receive(*args):
param_queue.append(tuple(args))
return receive
def search_space_test_one(self, tuner_factory, search_space):
tuner = tuner_factory()
self.assertIsInstance(tuner, Tuner)
tuner.update_search_space(search_space)
for i in range(self.test_round):
queue = deque()
parameters = tuner.generate_multiple_parameters(list(range(i * self.params_each_round,
(i + 1) * self.params_each_round)))
(i + 1) * self.params_each_round)),
st_callback=self.send_trial_callback(queue))
logger.debug(parameters)
self.check_range(parameters, search_space)
for k in range(min(len(parameters), self.params_each_round)):
tuner.receive_trial_result(self.params_each_round * i + k, parameters[k], random.uniform(-100, 100))
while queue:
id_, params = queue.popleft()
self.check_range([params], search_space)
tuner.receive_trial_result(id_, params, random.uniform(-100, 100))
if not parameters and not self.exhaustive:
raise ValueError("No parameters generated")
......@@ -65,6 +80,9 @@ class BuiltinTunersTestCase(TestCase):
if self._testMethodName == "test_batch":
param = {list(search_space.keys())[0]: param}
for k, v in param.items():
if k == "load_checkpoint_dir" or k == "save_checkpoint_dir":
self.assertIsInstance(v, str)
continue
if k.startswith("_mutable_layer"):
_, block, layer, choice = k.split("/")
cand = search_space[block]["_value"][layer].get(choice)
......@@ -270,6 +288,16 @@ class BuiltinTunersTestCase(TestCase):
def test_ppo(self):
pass
def test_pbt(self):
self.search_space_test_all(lambda: PBTTuner(
all_checkpoint_dir=os.path.expanduser("~/nni/checkpoint/test/"),
population_size=12
))
self.search_space_test_all(lambda: PBTTuner(
all_checkpoint_dir=os.path.expanduser("~/nni/checkpoint/test/"),
population_size=100
))
def tearDown(self):
file_list = glob.glob("smac3*") + ["param_config_space.pcs", "scenario.txt", "model_path"]
for file in file_list:
......
......@@ -1518,8 +1518,8 @@ acorn-jsx@^5.0.0:
resolved "https://registry.yarnpkg.com/acorn-jsx/-/acorn-jsx-5.1.0.tgz#294adb71b57398b0680015f0a38c563ee1db5384"
acorn@^6.0.5, acorn@^6.0.7:
version "6.4.0"
resolved "https://registry.yarnpkg.com/acorn/-/acorn-6.4.0.tgz#b659d2ffbafa24baf5db1cdbb2c94a983ecd2784"
version "6.4.1"
resolved "https://registry.yarnpkg.com/acorn/-/acorn-6.4.1.tgz#531e58ba3f51b9dacb9a6646ca4debf5b14ca474"
address@1.1.2, address@^1.0.1:
version "1.1.2"
......
authorName: nni
experimentName: default_test
maxExecDuration: 15m
maxTrialNum: 4
maxTrialNum: 2
trialConcurrency: 2
searchSpacePath: ./mnist_pytorch_search_space.json
......@@ -13,7 +13,7 @@ assessor:
optimize_mode: maximize
trial:
codeDir: ../../../examples/trials/mnist-pytorch
command: python3 mnist.py --epochs 1
command: python3 mnist.py --epochs 1 --batch_num 10
gpuNum: 0
useAnnotation: false
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment