Unverified Commit 2f5272c7 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #208 from microsoft/master

merge master
parents c785655e 7a20792a
authorName: default
experimentName: auto_rocksdb_TPE
trialConcurrency: 1
maxExecDuration: 12h
maxTrialNum: 256
#choice: local, remote, pai
trainingServicePlatform: local
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
tuner:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName: TPE
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
trial:
command: python3 main.py
codeDir: .
gpuNum: 0
#!/bin/bash
# Install db_bench and its dependencies on Ubuntu
pushd $PWD 1>/dev/null
# install snappy
echo "****************** Installing snappy *******************"
sudo apt-get install libsnappy-dev -y
# install gflag
echo "****************** Installing gflag ********************"
cd /tmp
git clone https://github.com/gflags/gflags.git
cd gflags
git checkout v2.0
./configure && make && sudo make install
# install rocksdb
echo "****************** Installing rocksdb ******************"
cd /tmp
git clone https://github.com/facebook/rocksdb.git
cd rocksdb
CPATH=/usr/local/include LIBRARY_PATH=/usr/local/lib DEBUG_LEVEL=0 make db_bench -j7
DIR=$HOME/.local/bin/
if [[ ! -e $DIR ]]; then
mkdir $dir
elif [[ ! -d $DIR ]]; then
echo "$DIR already exists but is not a directory" 1>&2
exit
fi
mv db_bench $HOME/.local/bin &&
echo "Successfully installed rocksed in "$DIR" !" &&
echo "Please add "$DIR" to your PATH for runing this example."
popd 1>/dev/null
# Copyright (c) Microsoft Corporation
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import nni
import subprocess
import logging
LOG = logging.getLogger('rocksdb-fillrandom')
def run(**parameters):
'''Run rocksdb benchmark and return throughput'''
bench_type = parameters['benchmarks']
# recover args
args = ["--{}={}".format(k, v) for k, v in parameters.items()]
# subprocess communicate
process = subprocess.Popen(['db_bench'] + args, stdout=subprocess.PIPE)
out, err = process.communicate()
# split into lines
lines = out.decode("utf8").splitlines()
match_lines = []
for line in lines:
# find the line with matched str
if bench_type not in line:
continue
else:
match_lines.append(line)
break
results = {}
for line in match_lines:
key, _, value = line.partition(":")
key = key.strip()
value = value.split("op")[1]
results[key] = float(value)
return results[bench_type]
def generate_params(received_params):
'''generate parameters based on received parameters'''
params = {
"benchmarks": "fillrandom",
"threads": 1,
"key_size": 20,
"value_size": 100,
"num": 13107200,
"db": "/tmp/rockdb",
"disable_wal": 1,
"max_background_flushes": 1,
"max_background_compactions": 4,
"write_buffer_size": 67108864,
"max_write_buffer_number": 16,
"min_write_buffer_number_to_merge": 2,
"level0_file_num_compaction_trigger": 2,
"max_bytes_for_level_base": 268435456,
"max_bytes_for_level_multiplier": 10,
"target_file_size_base": 33554432,
"target_file_size_multiplier": 1
}
for k, v in received_params.items():
params[k] = int(v)
return params
if __name__ == "__main__":
try:
# get parameters from tuner
RECEIVED_PARAMS = nni.get_next_parameter()
LOG.debug(RECEIVED_PARAMS)
PARAMS = generate_params(RECEIVED_PARAMS)
LOG.debug(PARAMS)
# run benchmark
throughput = run(**PARAMS)
# report throughput to nni
nni.report_final_result(throughput)
except Exception as exception:
LOG.exception(exception)
raise
{
"write_buffer_size": {
"_type": "quniform",
"_value": [2097152, 16777216, 1048576]
},
"min_write_buffer_number_to_merge": {
"_type": "quniform",
"_value": [2, 16, 1]
},
"level0_file_num_compaction_trigger": {
"_type": "quniform",
"_value": [2, 16, 1]
}
}
...@@ -48,7 +48,7 @@ setup( ...@@ -48,7 +48,7 @@ setup(
python_requires = '>=3.5', python_requires = '>=3.5',
install_requires = [ install_requires = [
'astor', 'astor',
'hyperopt', 'hyperopt==0.1.2',
'json_tricks', 'json_tricks',
'numpy', 'numpy',
'psutil', 'psutil',
......
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
import tensorflow as tf import tensorflow as tf
from .compressor import Pruner from .compressor import Pruner
__all__ = [ 'LevelPruner', 'AGP_Pruner', 'SensitivityPruner' ] __all__ = ['LevelPruner', 'AGP_Pruner']
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -14,10 +14,18 @@ class LevelPruner(Pruner): ...@@ -14,10 +14,18 @@ class LevelPruner(Pruner):
- sparsity - sparsity
""" """
super().__init__(config_list) super().__init__(config_list)
self.mask_list = {}
self.if_init_list = {}
def calc_mask(self, weight, config, **kwargs): def calc_mask(self, weight, config, op_name, **kwargs):
if self.if_init_list.get(op_name, True):
threshold = tf.contrib.distributions.percentile(tf.abs(weight), config['sparsity'] * 100) threshold = tf.contrib.distributions.percentile(tf.abs(weight), config['sparsity'] * 100)
return tf.cast(tf.math.greater(tf.abs(weight), threshold), weight.dtype) mask = tf.cast(tf.math.greater(tf.abs(weight), threshold), weight.dtype)
self.mask_list.update({op_name: mask})
self.if_init_list.update({op_name: False})
else:
mask = self.mask_list[op_name]
return mask
class AGP_Pruner(Pruner): class AGP_Pruner(Pruner):
...@@ -29,6 +37,7 @@ class AGP_Pruner(Pruner): ...@@ -29,6 +37,7 @@ class AGP_Pruner(Pruner):
Learning of Phones and other Consumer Devices, Learning of Phones and other Consumer Devices,
https://arxiv.org/pdf/1710.01878.pdf https://arxiv.org/pdf/1710.01878.pdf
""" """
def __init__(self, config_list): def __init__(self, config_list):
""" """
config_list: supported keys: config_list: supported keys:
...@@ -39,15 +48,25 @@ class AGP_Pruner(Pruner): ...@@ -39,15 +48,25 @@ class AGP_Pruner(Pruner):
- frequency: if you want update every 2 epoch, you can set it 2 - frequency: if you want update every 2 epoch, you can set it 2
""" """
super().__init__(config_list) super().__init__(config_list)
self.mask_list = {}
self.if_init_list = {}
self.now_epoch = tf.Variable(0) self.now_epoch = tf.Variable(0)
self.assign_handler = [] self.assign_handler = []
def calc_mask(self, weight, config, **kwargs): def calc_mask(self, weight, config, op_name, **kwargs):
start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1)
if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) and (
self.now_epoch - start_epoch) % freq == 0:
target_sparsity = self.compute_target_sparsity(config) target_sparsity = self.compute_target_sparsity(config)
threshold = tf.contrib.distributions.percentile(weight, target_sparsity * 100) threshold = tf.contrib.distributions.percentile(weight, target_sparsity * 100)
# stop gradient in case gradient change the mask # stop gradient in case gradient change the mask
mask = tf.stop_gradient(tf.cast(tf.math.greater(weight, threshold), weight.dtype)) mask = tf.stop_gradient(tf.cast(tf.math.greater(weight, threshold), weight.dtype))
self.assign_handler.append(tf.assign(weight, weight * mask)) self.assign_handler.append(tf.assign(weight, weight * mask))
self.mask_list.update({op_name: tf.constant(mask)})
self.if_init_list.update({op_name: False})
else:
mask = self.mask_list[op_name]
return mask return mask
def compute_target_sparsity(self, config): def compute_target_sparsity(self, config):
...@@ -62,49 +81,16 @@ class AGP_Pruner(Pruner): ...@@ -62,49 +81,16 @@ class AGP_Pruner(Pruner):
return final_sparsity return final_sparsity
now_epoch = tf.minimum(self.now_epoch, tf.constant(end_epoch)) now_epoch = tf.minimum(self.now_epoch, tf.constant(end_epoch))
span = int(((end_epoch - start_epoch-1)//freq)*freq) span = int(((end_epoch - start_epoch - 1) // freq) * freq)
assert span > 0 assert span > 0
base = tf.cast(now_epoch - start_epoch, tf.float32) / span base = tf.cast(now_epoch - start_epoch, tf.float32) / span
target_sparsity = (final_sparsity + target_sparsity = (final_sparsity +
(initial_sparsity - final_sparsity)* (initial_sparsity - final_sparsity) *
(tf.pow(1.0 - base, 3))) (tf.pow(1.0 - base, 3)))
return target_sparsity return target_sparsity
def update_epoch(self, epoch, sess): def update_epoch(self, epoch, sess):
sess.run(self.assign_handler) sess.run(self.assign_handler)
sess.run(tf.assign(self.now_epoch, int(epoch))) sess.run(tf.assign(self.now_epoch, int(epoch)))
for k in self.if_init_list.keys():
self.if_init_list[k] = True
class SensitivityPruner(Pruner):
"""Use algorithm from "Learning both Weights and Connections for Efficient Neural Networks"
https://arxiv.org/pdf/1506.02626v3.pdf
I.e.: "The pruning threshold is chosen as a quality parameter multiplied
by the standard deviation of a layers weights."
"""
def __init__(self, config_list):
"""
config_list: supported keys
- sparsity: chosen pruning sparsity
"""
super().__init__(config_list)
self.layer_mask = {}
self.assign_handler = []
def calc_mask(self, weight, config, op_name, **kwargs):
target_sparsity = config['sparsity'] * tf.math.reduce_std(weight)
mask = tf.get_variable(op_name + '_mask', initializer=tf.ones(weight.shape), trainable=False)
self.layer_mask[op_name] = mask
weight_assign_handler = tf.assign(weight, mask*weight)
# use control_dependencies so that weight_assign_handler will be executed before mask_update_handler
with tf.control_dependencies([weight_assign_handler]):
threshold = tf.contrib.distributions.percentile(weight, target_sparsity * 100)
# stop gradient in case gradient change the mask
new_mask = tf.stop_gradient(tf.cast(tf.math.greater(weight, threshold), weight.dtype))
mask_update_handler = tf.assign(mask, new_mask)
self.assign_handler.append(mask_update_handler)
return mask
def update_epoch(self, epoch, sess):
sess.run(self.assign_handler)
...@@ -2,7 +2,7 @@ import logging ...@@ -2,7 +2,7 @@ import logging
import torch import torch
from .compressor import Pruner from .compressor import Pruner
__all__ = [ 'LevelPruner', 'AGP_Pruner', 'SensitivityPruner' ] __all__ = ['LevelPruner', 'AGP_Pruner']
logger = logging.getLogger('torch pruner') logger = logging.getLogger('torch pruner')
...@@ -10,20 +10,29 @@ logger = logging.getLogger('torch pruner') ...@@ -10,20 +10,29 @@ logger = logging.getLogger('torch pruner')
class LevelPruner(Pruner): class LevelPruner(Pruner):
"""Prune to an exact pruning level specification """Prune to an exact pruning level specification
""" """
def __init__(self, config_list): def __init__(self, config_list):
""" """
config_list: supported keys: config_list: supported keys:
- sparsity - sparsity
""" """
super().__init__(config_list) super().__init__(config_list)
self.mask_list = {}
self.if_init_list = {}
def calc_mask(self, weight, config, **kwargs): def calc_mask(self, weight, config, op_name, **kwargs):
if self.if_init_list.get(op_name, True):
w_abs = weight.abs() w_abs = weight.abs()
k = int(weight.numel() * config['sparsity']) k = int(weight.numel() * config['sparsity'])
if k == 0: if k == 0:
return torch.ones(weight.shape) return torch.ones(weight.shape).type_as(weight)
threshold = torch.topk(w_abs.view(-1), k, largest = False).values.max() threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
return torch.gt(w_abs, threshold).type(weight.type()) mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_list.update({op_name: mask})
self.if_init_list.update({op_name: False})
else:
mask = self.mask_list[op_name]
return mask
class AGP_Pruner(Pruner): class AGP_Pruner(Pruner):
...@@ -35,35 +44,44 @@ class AGP_Pruner(Pruner): ...@@ -35,35 +44,44 @@ class AGP_Pruner(Pruner):
Learning of Phones and other Consumer Devices, Learning of Phones and other Consumer Devices,
https://arxiv.org/pdf/1710.01878.pdf https://arxiv.org/pdf/1710.01878.pdf
""" """
def __init__(self, config_list): def __init__(self, config_list):
""" """
config_list: supported keys: config_list: supported keys:
- initial_sparsity - initial_sparsity
- final_sparsity: you should make sure initial_sparsity <= final_sparsity - final_sparsity: you should make sure initial_sparsity <= final_sparsity
- start_epoch: start epoch numer begin update mask - start_epoch: start epoch number begin update mask
- end_epoch: end epoch number stop update mask, you should make sure start_epoch <= end_epoch - end_epoch: end epoch number stop update mask, you should make sure start_epoch <= end_epoch
- frequency: if you want update every 2 epoch, you can set it 2 - frequency: if you want update every 2 epoch, you can set it 2
""" """
super().__init__(config_list) super().__init__(config_list)
self.mask_list = {} self.mask_list = {}
self.now_epoch = 1 self.now_epoch = 0
self.if_init_list = {}
def calc_mask(self, weight, config, op_name, **kwargs): def calc_mask(self, weight, config, op_name, **kwargs):
mask = self.mask_list.get(op_name, torch.ones(weight.shape)) start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1)
if self.now_epoch >= start_epoch and self.if_init_list.get(op_name, True) and (
self.now_epoch - start_epoch) % freq == 0:
mask = self.mask_list.get(op_name, torch.ones(weight.shape).type_as(weight))
target_sparsity = self.compute_target_sparsity(config) target_sparsity = self.compute_target_sparsity(config)
k = int(weight.numel() * target_sparsity) k = int(weight.numel() * target_sparsity)
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0: if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
return mask return mask
# if we want to generate new mask, we should update weigth first # if we want to generate new mask, we should update weigth first
w_abs = weight.abs()*mask w_abs = weight.abs() * mask
threshold = torch.topk(w_abs.view(-1), k, largest = False).values.max() threshold = torch.topk(w_abs.view(-1), k, largest=False).values.max()
new_mask = torch.gt(w_abs, threshold).type(weight.type()) new_mask = torch.gt(w_abs, threshold).type_as(weight)
self.mask_list[op_name] = new_mask self.mask_list.update({op_name: new_mask})
self.if_init_list.update({op_name: False})
else:
new_mask = self.mask_list.get(op_name, torch.ones(weight.shape).type_as(weight))
return new_mask return new_mask
def compute_target_sparsity(self, config): def compute_target_sparsity(self, config):
end_epoch = config.get('end_epoch', 1) end_epoch = config.get('end_epoch', 1)
start_epoch = config.get('start_epoch', 1) start_epoch = config.get('start_epoch', 0)
freq = config.get('frequency', 1) freq = config.get('frequency', 1)
final_sparsity = config.get('final_sparsity', 0) final_sparsity = config.get('final_sparsity', 0)
initial_sparsity = config.get('initial_sparsity', 0) initial_sparsity = config.get('initial_sparsity', 0)
...@@ -74,45 +92,15 @@ class AGP_Pruner(Pruner): ...@@ -74,45 +92,15 @@ class AGP_Pruner(Pruner):
if end_epoch <= self.now_epoch: if end_epoch <= self.now_epoch:
return final_sparsity return final_sparsity
span = ((end_epoch - start_epoch-1)//freq)*freq span = ((end_epoch - start_epoch - 1) // freq) * freq
assert span > 0 assert span > 0
target_sparsity = (final_sparsity + target_sparsity = (final_sparsity +
(initial_sparsity - final_sparsity)* (initial_sparsity - final_sparsity) *
(1.0 - ((self.now_epoch - start_epoch)/span))**3) (1.0 - ((self.now_epoch - start_epoch) / span)) ** 3)
return target_sparsity return target_sparsity
def update_epoch(self, epoch): def update_epoch(self, epoch):
if epoch > 0: if epoch > 0:
self.now_epoch = epoch self.now_epoch = epoch
for k in self.if_init_list.keys():
self.if_init_list[k] = True
class SensitivityPruner(Pruner):
"""Use algorithm from "Learning both Weights and Connections for Efficient Neural Networks"
https://arxiv.org/pdf/1506.02626v3.pdf
I.e.: "The pruning threshold is chosen as a quality parameter multiplied
by the standard deviation of a layers weights."
"""
def __init__(self, config_list):
"""
config_list: supported keys:
- sparsity: chosen pruning sparsity
"""
super().__init__(config_list)
self.mask_list = {}
def calc_mask(self, weight, config, op_name, **kwargs):
mask = self.mask_list.get(op_name, torch.ones(weight.shape))
# if we want to generate new mask, we should update weigth first
weight = weight*mask
target_sparsity = config['sparsity'] * torch.std(weight).item()
k = int(weight.numel() * target_sparsity)
if k == 0:
return mask
w_abs = weight.abs()
threshold = torch.topk(w_abs.view(-1), k, largest = False).values.max()
new_mask = torch.gt(w_abs, threshold).type(weight.type())
self.mask_list[op_name] = new_mask
return new_mask
...@@ -38,7 +38,6 @@ class Compressor: ...@@ -38,7 +38,6 @@ class Compressor:
if config is not None: if config is not None:
self._instrument_layer(layer, config) self._instrument_layer(layer, config)
def bind_model(self, model): def bind_model(self, model):
"""This method is called when a model is bound to the compressor. """This method is called when a model is bound to the compressor.
Users can optionally overload this method to do model-specific initialization. Users can optionally overload this method to do model-specific initialization.
...@@ -56,7 +55,6 @@ class Compressor: ...@@ -56,7 +55,6 @@ class Compressor:
""" """
pass pass
def _instrument_layer(self, layer, config): def _instrument_layer(self, layer, config):
raise NotImplementedError() raise NotImplementedError()
...@@ -90,7 +88,6 @@ class Pruner(Compressor): ...@@ -90,7 +88,6 @@ class Pruner(Compressor):
""" """
raise NotImplementedError("Pruners must overload calc_mask()") raise NotImplementedError("Pruners must overload calc_mask()")
def _instrument_layer(self, layer, config): def _instrument_layer(self, layer, config):
# TODO: support multiple weight tensors # TODO: support multiple weight tensors
# create a wrapper forward function to replace the original one # create a wrapper forward function to replace the original one
......
...@@ -4,7 +4,7 @@ json_tricks ...@@ -4,7 +4,7 @@ json_tricks
# hyperopt tuner # hyperopt tuner
numpy numpy
scipy scipy
hyperopt hyperopt==0.1.2
# metis tuner # metis tuner
sklearn sklearn
...@@ -32,7 +32,7 @@ setuptools.setup( ...@@ -32,7 +32,7 @@ setuptools.setup(
python_requires = '>=3.5', python_requires = '>=3.5',
install_requires = [ install_requires = [
'hyperopt', 'hyperopt==0.1.2',
'json_tricks', 'json_tricks',
'numpy', 'numpy',
'scipy', 'scipy',
......
...@@ -5,24 +5,28 @@ import torch.nn.functional as F ...@@ -5,24 +5,28 @@ import torch.nn.functional as F
import nni.compression.tensorflow as tf_compressor import nni.compression.tensorflow as tf_compressor
import nni.compression.torch as torch_compressor import nni.compression.torch as torch_compressor
def weight_variable(shape): def weight_variable(shape):
return tf.Variable(tf.truncated_normal(shape, stddev = 0.1)) return tf.Variable(tf.truncated_normal(shape, stddev=0.1))
def bias_variable(shape): def bias_variable(shape):
return tf.Variable(tf.constant(0.1, shape = shape)) return tf.Variable(tf.constant(0.1, shape=shape))
def conv2d(x_input, w_matrix): def conv2d(x_input, w_matrix):
return tf.nn.conv2d(x_input, w_matrix, strides = [ 1, 1, 1, 1 ], padding = 'SAME') return tf.nn.conv2d(x_input, w_matrix, strides=[1, 1, 1, 1], padding='SAME')
def max_pool(x_input, pool_size): def max_pool(x_input, pool_size):
size = [ 1, pool_size, pool_size, 1 ] size = [1, pool_size, pool_size, 1]
return tf.nn.max_pool(x_input, ksize = size, strides = size, padding = 'SAME') return tf.nn.max_pool(x_input, ksize=size, strides=size, padding='SAME')
class TfMnist: class TfMnist:
def __init__(self): def __init__(self):
images = tf.placeholder(tf.float32, [ None, 784 ], name = 'input_x') images = tf.placeholder(tf.float32, [None, 784], name='input_x')
labels = tf.placeholder(tf.float32, [ None, 10 ], name = 'input_y') labels = tf.placeholder(tf.float32, [None, 10], name='input_y')
keep_prob = tf.placeholder(tf.float32, name='keep_prob') keep_prob = tf.placeholder(tf.float32, name='keep_prob')
self.images = images self.images = images
...@@ -37,35 +41,35 @@ class TfMnist: ...@@ -37,35 +41,35 @@ class TfMnist:
self.fcw1 = None self.fcw1 = None
self.cross = None self.cross = None
with tf.name_scope('reshape'): with tf.name_scope('reshape'):
x_image = tf.reshape(images, [ -1, 28, 28, 1 ]) x_image = tf.reshape(images, [-1, 28, 28, 1])
with tf.name_scope('conv1'): with tf.name_scope('conv1'):
w_conv1 = weight_variable([ 5, 5, 1, 32 ]) w_conv1 = weight_variable([5, 5, 1, 32])
self.w1 = w_conv1 self.w1 = w_conv1
b_conv1 = bias_variable([ 32 ]) b_conv1 = bias_variable([32])
self.b1 = b_conv1 self.b1 = b_conv1
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1) h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1)
with tf.name_scope('pool1'): with tf.name_scope('pool1'):
h_pool1 = max_pool(h_conv1, 2) h_pool1 = max_pool(h_conv1, 2)
with tf.name_scope('conv2'): with tf.name_scope('conv2'):
w_conv2 = weight_variable([ 5, 5, 32, 64 ]) w_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([ 64 ]) b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2) h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2)
with tf.name_scope('pool2'): with tf.name_scope('pool2'):
h_pool2 = max_pool(h_conv2, 2) h_pool2 = max_pool(h_conv2, 2)
with tf.name_scope('fc1'): with tf.name_scope('fc1'):
w_fc1 = weight_variable([ 7 * 7 * 64, 1024 ]) w_fc1 = weight_variable([7 * 7 * 64, 1024])
self.fcw1 = w_fc1 self.fcw1 = w_fc1
b_fc1 = bias_variable([ 1024 ]) b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [ -1, 7 * 7 * 64 ]) h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1) h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)
with tf.name_scope('dropout'): with tf.name_scope('dropout'):
h_fc1_drop = tf.nn.dropout(h_fc1, 0.5) h_fc1_drop = tf.nn.dropout(h_fc1, 0.5)
with tf.name_scope('fc2'): with tf.name_scope('fc2'):
w_fc2 = weight_variable([ 1024, 10 ]) w_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([ 10 ]) b_fc2 = bias_variable([10])
y_conv = tf.matmul(h_fc1_drop, w_fc2) + b_fc2 y_conv = tf.matmul(h_fc1_drop, w_fc2) + b_fc2
with tf.name_scope('loss'): with tf.name_scope('loss'):
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = labels, logits = y_conv)) cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=y_conv))
self.cross = cross_entropy self.cross = cross_entropy
with tf.name_scope('adam_optimizer'): with tf.name_scope('adam_optimizer'):
self.train_step = tf.train.AdamOptimizer(0.0001).minimize(cross_entropy) self.train_step = tf.train.AdamOptimizer(0.0001).minimize(cross_entropy)
...@@ -73,6 +77,7 @@ class TfMnist: ...@@ -73,6 +77,7 @@ class TfMnist:
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(labels, 1)) correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(labels, 1))
self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
class TorchMnist(torch.nn.Module): class TorchMnist(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -89,22 +94,22 @@ class TorchMnist(torch.nn.Module): ...@@ -89,22 +94,22 @@ class TorchMnist(torch.nn.Module):
x = x.view(-1, 4 * 4 * 50) x = x.view(-1, 4 * 4 * 50)
x = F.relu(self.fc1(x)) x = F.relu(self.fc1(x))
x = self.fc2(x) x = self.fc2(x)
return F.log_softmax(x, dim = 1) return F.log_softmax(x, dim=1)
class CompressorTestCase(TestCase): class CompressorTestCase(TestCase):
def test_tf_pruner(self): def test_tf_pruner(self):
model = TfMnist() model = TfMnist()
configure_list = [{'sparsity':0.8, 'op_types':'default'}] configure_list = [{'sparsity': 0.8, 'op_types': 'default'}]
tf_compressor.LevelPruner(configure_list).compress_default_graph() tf_compressor.LevelPruner(configure_list).compress_default_graph()
def test_tf_quantizer(self): def test_tf_quantizer(self):
model = TfMnist() model = TfMnist()
tf_compressor.NaiveQuantizer([{'op_types': 'default'}]).compress_default_graph() tf_compressor.NaiveQuantizer([{'op_types': 'default'}]).compress_default_graph()
def test_torch_pruner(self): def test_torch_pruner(self):
model = TorchMnist() model = TorchMnist()
configure_list = [{'sparsity':0.8, 'op_types':'default'}] configure_list = [{'sparsity': 0.8, 'op_types': 'default'}]
torch_compressor.LevelPruner(configure_list).compress(model) torch_compressor.LevelPruner(configure_list).compress(model)
def test_torch_quantizer(self): def test_torch_quantizer(self):
......
...@@ -125,6 +125,11 @@ class Compare extends React.Component<CompareProps, {}> { ...@@ -125,6 +125,11 @@ class Compare extends React.Component<CompareProps, {}> {
durationList.push(temp.duration); durationList.push(temp.duration);
parameterList.push(temp.description.parameters); parameterList.push(temp.description.parameters);
}); });
let isComplexSearchSpace;
if (parameterList.length > 0) {
isComplexSearchSpace = (typeof parameterList[0][parameterKeys[0]] === 'object')
? true : false;
}
return ( return (
<table className="compare"> <table className="compare">
<tbody> <tbody>
...@@ -164,6 +169,10 @@ class Compare extends React.Component<CompareProps, {}> { ...@@ -164,6 +169,10 @@ class Compare extends React.Component<CompareProps, {}> {
})} })}
</tr> </tr>
{ {
isComplexSearchSpace
?
null
:
Object.keys(parameterKeys).map(index => { Object.keys(parameterKeys).map(index => {
return ( return (
<tr key={index}> <tr key={index}>
......
...@@ -14,6 +14,7 @@ interface ExpDrawerProps { ...@@ -14,6 +14,7 @@ interface ExpDrawerProps {
interface ExpDrawerState { interface ExpDrawerState {
experiment: string; experiment: string;
expDrawerHeight: number;
} }
class ExperimentDrawer extends React.Component<ExpDrawerProps, ExpDrawerState> { class ExperimentDrawer extends React.Component<ExpDrawerProps, ExpDrawerState> {
...@@ -23,7 +24,8 @@ class ExperimentDrawer extends React.Component<ExpDrawerProps, ExpDrawerState> { ...@@ -23,7 +24,8 @@ class ExperimentDrawer extends React.Component<ExpDrawerProps, ExpDrawerState> {
super(props); super(props);
this.state = { this.state = {
experiment: '' experiment: '',
expDrawerHeight: window.innerHeight - 48
}; };
} }
...@@ -69,9 +71,14 @@ class ExperimentDrawer extends React.Component<ExpDrawerProps, ExpDrawerState> { ...@@ -69,9 +71,14 @@ class ExperimentDrawer extends React.Component<ExpDrawerProps, ExpDrawerState> {
downFile(experiment, 'experiment.json'); downFile(experiment, 'experiment.json');
} }
onWindowResize = () => {
this.setState(() => ({expDrawerHeight: window.innerHeight - 48}));
}
componentDidMount() { componentDidMount() {
this._isCompareMount = true; this._isCompareMount = true;
this.getExperimentContent(); this.getExperimentContent();
window.addEventListener('resize', this.onWindowResize);
} }
componentWillReceiveProps(nextProps: ExpDrawerProps) { componentWillReceiveProps(nextProps: ExpDrawerProps) {
...@@ -83,12 +90,12 @@ class ExperimentDrawer extends React.Component<ExpDrawerProps, ExpDrawerState> { ...@@ -83,12 +90,12 @@ class ExperimentDrawer extends React.Component<ExpDrawerProps, ExpDrawerState> {
componentWillUnmount() { componentWillUnmount() {
this._isCompareMount = false; this._isCompareMount = false;
window.removeEventListener('resize', this.onWindowResize);
} }
render() { render() {
const { isVisble, closeExpDrawer } = this.props; const { isVisble, closeExpDrawer } = this.props;
const { experiment } = this.state; const { experiment, expDrawerHeight } = this.state;
const heights: number = window.innerHeight - 48;
return ( return (
<Row className="logDrawer"> <Row className="logDrawer">
<Drawer <Drawer
...@@ -99,15 +106,16 @@ class ExperimentDrawer extends React.Component<ExpDrawerProps, ExpDrawerState> { ...@@ -99,15 +106,16 @@ class ExperimentDrawer extends React.Component<ExpDrawerProps, ExpDrawerState> {
onClose={closeExpDrawer} onClose={closeExpDrawer}
visible={isVisble} visible={isVisble}
width="54%" width="54%"
height={heights} height={expDrawerHeight}
> >
<div className="card-container log-tab-body" style={{ height: heights }}> {/* 104: tabHeight(40) + tabMarginBottom(16) + buttonHeight(32) + buttonMarginTop(16) */}
<Tabs type="card" style={{ height: heights + 19 }}> <div className="card-container log-tab-body">
<Tabs type="card" style={{ height: expDrawerHeight, minHeight: 190 }}>
<TabPane tab="Experiment Parameters" key="Experiment"> <TabPane tab="Experiment Parameters" key="Experiment">
<div className="just-for-log"> <div className="just-for-log">
<MonacoEditor <MonacoEditor
width="100%" width="100%"
height={heights * 0.9} height={expDrawerHeight - 104}
language="json" language="json"
value={experiment} value={experiment}
options={DRAWEROPTION} options={DRAWEROPTION}
......
...@@ -8,89 +8,48 @@ import MonacoHTML from '../public-child/MonacoEditor'; ...@@ -8,89 +8,48 @@ import MonacoHTML from '../public-child/MonacoEditor';
import '../../static/style/logDrawer.scss'; import '../../static/style/logDrawer.scss';
interface LogDrawerProps { interface LogDrawerProps {
isVisble: boolean;
closeDrawer: () => void; closeDrawer: () => void;
activeTab?: string; activeTab?: string;
} }
interface LogDrawerState { interface LogDrawerState {
nniManagerLogStr: string; nniManagerLogStr: string | null;
dispatcherLogStr: string; dispatcherLogStr: string | null;
isLoading: boolean; isLoading: boolean;
isLoadispatcher: boolean; logDrawerHeight: number;
} }
class LogDrawer extends React.Component<LogDrawerProps, LogDrawerState> { class LogDrawer extends React.Component<LogDrawerProps, LogDrawerState> {
private timerId: number | undefined;
public _isLogDrawer: boolean;
constructor(props: LogDrawerProps) { constructor(props: LogDrawerProps) {
super(props); super(props);
this.state = { this.state = {
nniManagerLogStr: 'nnimanager', nniManagerLogStr: null,
dispatcherLogStr: 'dispatcher', dispatcherLogStr: null,
isLoading: false, isLoading: true,
isLoadispatcher: false logDrawerHeight: window.innerHeight - 48
}; };
} }
getNNImanagerLogmessage = () => {
if (this._isLogDrawer === true) {
this.setState({ isLoading: true }, () => {
axios(`${DOWNLOAD_IP}/nnimanager.log`, {
method: 'GET'
})
.then(res => {
if (res.status === 200) {
setTimeout(() => { this.setNNImanager(res.data); }, 300);
}
});
});
}
}
setDispatcher = (value: string) => {
if (this._isLogDrawer === true) {
this.setState({ isLoadispatcher: false, dispatcherLogStr: value });
}
}
setNNImanager = (val: string) => {
if (this._isLogDrawer === true) {
this.setState({ isLoading: false, nniManagerLogStr: val });
}
}
getdispatcherLogmessage = () => {
if (this._isLogDrawer === true) {
this.setState({ isLoadispatcher: true }, () => {
axios(`${DOWNLOAD_IP}/dispatcher.log`, {
method: 'GET'
})
.then(res => {
if (res.status === 200) {
setTimeout(() => { this.setDispatcher(res.data); }, 300);
}
});
});
}
}
downloadNNImanager = () => { downloadNNImanager = () => {
const { nniManagerLogStr } = this.state; if (this.state.nniManagerLogStr !== null) {
downFile(nniManagerLogStr, 'nnimanager.log'); downFile(this.state.nniManagerLogStr, 'nnimanager.log');
}
} }
downloadDispatcher = () => { downloadDispatcher = () => {
const { dispatcherLogStr } = this.state; if (this.state.dispatcherLogStr !== null) {
downFile(dispatcherLogStr, 'dispatcher.log'); downFile(this.state.dispatcherLogStr, 'dispatcher.log');
}
} }
dispatcherHTML = () => { dispatcherHTML = () => {
return ( return (
<div> <div>
<span>Dispatcher Log</span> <span>Dispatcher Log</span>
<span className="refresh" onClick={this.getdispatcherLogmessage}> <span className="refresh" onClick={this.manualRefresh}>
<Icon type="sync" /> <Icon type="sync" />
</span> </span>
</div> </div>
...@@ -101,37 +60,28 @@ class LogDrawer extends React.Component<LogDrawerProps, LogDrawerState> { ...@@ -101,37 +60,28 @@ class LogDrawer extends React.Component<LogDrawerProps, LogDrawerState> {
return ( return (
<div> <div>
<span>NNImanager Log</span> <span>NNImanager Log</span>
<span className="refresh" onClick={this.getNNImanagerLogmessage}><Icon type="sync" /></span> <span className="refresh" onClick={this.manualRefresh}><Icon type="sync" /></span>
</div> </div>
); );
} }
componentDidMount() { setLogDrawerHeight = () => {
this._isLogDrawer = true; this.setState(() => ({ logDrawerHeight: window.innerHeight - 48 }));
this.getNNImanagerLogmessage();
this.getdispatcherLogmessage();
} }
componentWillReceiveProps(nextProps: LogDrawerProps) { async componentDidMount() {
const { isVisble, activeTab } = nextProps; this.refresh();
if (isVisble === true) { window.addEventListener('resize', this.setLogDrawerHeight);
if (activeTab === 'nnimanager') {
this.getNNImanagerLogmessage();
}
if (activeTab === 'dispatcher') {
this.getdispatcherLogmessage();
}
}
} }
componentWillUnmount() { componentWillUnmount() {
this._isLogDrawer = false; window.clearTimeout(this.timerId);
window.removeEventListener('resize', this.setLogDrawerHeight);
} }
render() { render() {
const { isVisble, closeDrawer, activeTab } = this.props; const { closeDrawer, activeTab } = this.props;
const { nniManagerLogStr, dispatcherLogStr, isLoadispatcher, isLoading } = this.state; const { nniManagerLogStr, dispatcherLogStr, isLoading, logDrawerHeight } = this.state;
const heights: number = window.innerHeight - 48; // padding top and bottom
return ( return (
<Row> <Row>
<Drawer <Drawer
...@@ -139,18 +89,26 @@ class LogDrawer extends React.Component<LogDrawerProps, LogDrawerState> { ...@@ -139,18 +89,26 @@ class LogDrawer extends React.Component<LogDrawerProps, LogDrawerState> {
closable={false} closable={false}
destroyOnClose={true} destroyOnClose={true}
onClose={closeDrawer} onClose={closeDrawer}
visible={isVisble} visible={true}
width="76%" width="76%"
height={heights} height={logDrawerHeight}
// className="logDrawer" // className="logDrawer"
> >
<div className="card-container log-tab-body" style={{ height: heights }}> <div className="card-container log-tab-body">
<Tabs type="card" defaultActiveKey={activeTab} style={{ height: heights + 19 }}> <Tabs
type="card"
defaultActiveKey={activeTab}
style={{ height: logDrawerHeight, minHeight: 190 }}
>
{/* <Tabs type="card" onTabClick={this.selectwhichLog} defaultActiveKey={activeTab}> */} {/* <Tabs type="card" onTabClick={this.selectwhichLog} defaultActiveKey={activeTab}> */}
{/* <TabPane tab="Dispatcher Log" key="dispatcher"> */} {/* <TabPane tab="Dispatcher Log" key="dispatcher"> */}
<TabPane tab={this.dispatcherHTML()} key="dispatcher"> <TabPane tab={this.dispatcherHTML()} key="dispatcher">
<div> <div>
<MonacoHTML content={dispatcherLogStr} loading={isLoadispatcher} /> <MonacoHTML
content={dispatcherLogStr || 'Loading...'}
loading={isLoading}
height={logDrawerHeight - 104}
/>
</div> </div>
<Row className="buttons"> <Row className="buttons">
<Col span={12} className="download"> <Col span={12} className="download">
...@@ -174,7 +132,11 @@ class LogDrawer extends React.Component<LogDrawerProps, LogDrawerState> { ...@@ -174,7 +132,11 @@ class LogDrawer extends React.Component<LogDrawerProps, LogDrawerState> {
<TabPane tab={this.nnimanagerHTML()} key="nnimanager"> <TabPane tab={this.nnimanagerHTML()} key="nnimanager">
{/* <TabPane tab="NNImanager Log" key="nnimanager"> */} {/* <TabPane tab="NNImanager Log" key="nnimanager"> */}
<div> <div>
<MonacoHTML content={nniManagerLogStr} loading={isLoading} /> <MonacoHTML
content={nniManagerLogStr || 'Loading...'}
loading={isLoading}
height={logDrawerHeight - 104}
/>
</div> </div>
<Row className="buttons"> <Row className="buttons">
<Col span={12} className="download"> <Col span={12} className="download">
...@@ -201,6 +163,31 @@ class LogDrawer extends React.Component<LogDrawerProps, LogDrawerState> { ...@@ -201,6 +163,31 @@ class LogDrawer extends React.Component<LogDrawerProps, LogDrawerState> {
</Row> </Row>
); );
} }
private refresh = () => {
window.clearTimeout(this.timerId);
const dispatcherPromise = axios.get(`${DOWNLOAD_IP}/dispatcher.log`);
const nniManagerPromise = axios.get(`${DOWNLOAD_IP}/nnimanager.log`);
dispatcherPromise.then(res => {
if (res.status === 200) {
this.setState({ dispatcherLogStr: res.data });
}
});
nniManagerPromise.then(res => {
if (res.status === 200) {
this.setState({ nniManagerLogStr: res.data });
}
});
Promise.all([dispatcherPromise, nniManagerPromise]).then(() => {
this.setState({ isLoading: false });
this.timerId = window.setTimeout(this.refresh, 300);
});
}
private manualRefresh = () => {
this.setState({ isLoading: true });
this.refresh();
}
} }
export default LogDrawer; export default LogDrawer;
...@@ -214,7 +214,12 @@ class SlideBar extends React.Component<SliderProps, SliderState> { ...@@ -214,7 +214,12 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
type="ghost" type="ghost"
> >
<a target="_blank" href="https://nni.readthedocs.io/en/latest/Tutorial/WebUI.html"> <a target="_blank" href="https://nni.readthedocs.io/en/latest/Tutorial/WebUI.html">
<Icon type="question" /><span>Help</span> <img
src={require('../static/img/icon/ques.png')}
alt="question"
className="question"
/>
<span>Help</span>
</a> </a>
</Button> </Button>
</span> </span>
...@@ -329,8 +334,8 @@ class SlideBar extends React.Component<SliderProps, SliderState> { ...@@ -329,8 +334,8 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
render() { render() {
const mobile = (<MediaQuery maxWidth={884}>{this.mobileHTML()}</MediaQuery>); const mobile = (<MediaQuery maxWidth={884}>{this.mobileHTML()}</MediaQuery>);
const tablet = (<MediaQuery minWidth={885} maxWidth={1241}>{this.tabeltHTML()}</MediaQuery>); const tablet = (<MediaQuery minWidth={885} maxWidth={1281}>{this.tabeltHTML()}</MediaQuery>);
const desktop = (<MediaQuery minWidth={1242}>{this.desktopHTML()}</MediaQuery>); const desktop = (<MediaQuery minWidth={1282}>{this.desktopHTML()}</MediaQuery>);
const { isvisibleLogDrawer, activeKey, isvisibleExperimentDrawer } = this.state; const { isvisibleLogDrawer, activeKey, isvisibleExperimentDrawer } = this.state;
return ( return (
<div> <div>
...@@ -338,11 +343,12 @@ class SlideBar extends React.Component<SliderProps, SliderState> { ...@@ -338,11 +343,12 @@ class SlideBar extends React.Component<SliderProps, SliderState> {
{tablet} {tablet}
{desktop} {desktop}
{/* the drawer for dispatcher & nnimanager log message */} {/* the drawer for dispatcher & nnimanager log message */}
{isvisibleLogDrawer ? (
<LogDrawer <LogDrawer
isVisble={isvisibleLogDrawer}
closeDrawer={this.closeLogDrawer} closeDrawer={this.closeLogDrawer}
activeTab={activeKey} activeTab={activeKey}
/> />
) : null}
<ExperimentDrawer <ExperimentDrawer
isVisble={isvisibleExperimentDrawer} isVisble={isvisibleExperimentDrawer}
closeExpDrawer={this.closeExpDrawer} closeExpDrawer={this.closeExpDrawer}
......
...@@ -100,10 +100,6 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState> ...@@ -100,10 +100,6 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
this.setState({ whichGraph: activeKey }); this.setState({ whichGraph: activeKey });
} }
test = () => {
alert('TableList component was not properly initialized.');
}
updateSearchFilterType = (value: string) => { updateSearchFilterType = (value: string) => {
// clear input value and re-render table // clear input value and re-render table
if (this.searchInput !== null) { if (this.searchInput !== null) {
...@@ -167,14 +163,14 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState> ...@@ -167,14 +163,14 @@ class TrialsDetail extends React.Component<TrialsDetailProps, TrialDetailState>
<Col span={14} className="right"> <Col span={14} className="right">
<Button <Button
className="common" className="common"
onClick={this.tableList ? this.tableList.addColumn : this.test} onClick={() => { if (this.tableList) { this.tableList.addColumn(); }}}
> >
Add column Add column
</Button> </Button>
<Button <Button
className="mediateBtn common" className="mediateBtn common"
// use child-component tableList's function, the function is in child-component. // use child-component tableList's function, the function is in child-component.
onClick={this.tableList ? this.tableList.compareBtn : this.test} onClick={() => { if (this.tableList) { this.tableList.compareBtn(); }}}
> >
Compare Compare
</Button> </Button>
......
...@@ -184,11 +184,12 @@ class Progressed extends React.Component<ProgressProps, ProgressState> { ...@@ -184,11 +184,12 @@ class Progressed extends React.Component<ProgressProps, ProgressState> {
</Col> </Col>
</Row> </Row>
{/* learn about click -> default active key is dispatcher. */} {/* learn about click -> default active key is dispatcher. */}
{isShowLogDrawer ? (
<LogDrawer <LogDrawer
isVisble={isShowLogDrawer}
closeDrawer={this.closeDrawer} closeDrawer={this.closeDrawer}
activeTab="dispatcher" activeTab="dispatcher"
/> />
) : null}
</Row> </Row>
); );
} }
......
...@@ -6,6 +6,7 @@ import MonacoEditor from 'react-monaco-editor'; ...@@ -6,6 +6,7 @@ import MonacoEditor from 'react-monaco-editor';
interface MonacoEditorProps { interface MonacoEditorProps {
content: string; content: string;
loading: boolean; loading: boolean;
height: number;
} }
class MonacoHTML extends React.Component<MonacoEditorProps, {}> { class MonacoHTML extends React.Component<MonacoEditorProps, {}> {
...@@ -25,18 +26,17 @@ class MonacoHTML extends React.Component<MonacoEditorProps, {}> { ...@@ -25,18 +26,17 @@ class MonacoHTML extends React.Component<MonacoEditorProps, {}> {
} }
render() { render() {
const { content, loading } = this.props; const { content, loading, height } = this.props;
const heights: number = window.innerHeight - 48;
return ( return (
<div className="just-for-log"> <div className="just-for-log">
<Spin <Spin
// tip="Loading..." // tip="Loading..."
style={{ width: '100%', height: heights * 0.9 }} style={{ width: '100%', height: height }}
spinning={loading} spinning={loading}
> >
<MonacoEditor <MonacoEditor
width="100%" width="100%"
height={heights * 0.9} height={height}
language="json" language="json"
value={content} value={content}
options={DRAWEROPTION} options={DRAWEROPTION}
......
...@@ -135,17 +135,17 @@ function generateScatterSeries(trials: Trial[]) { ...@@ -135,17 +135,17 @@ function generateScatterSeries(trials: Trial[]) {
function generateBestCurveSeries(trials: Trial[]) { function generateBestCurveSeries(trials: Trial[]) {
let best = trials[0]; let best = trials[0];
const data = [[ best.sequenceId, best.accuracy, best.info.hyperParameters ]]; const data = [[ best.sequenceId, best.accuracy, best.description.parameters ]];
for (let i = 1; i < trials.length; i++) { for (let i = 1; i < trials.length; i++) {
const trial = trials[i]; const trial = trials[i];
const delta = trial.accuracy! - best.accuracy!; const delta = trial.accuracy! - best.accuracy!;
const better = (EXPERIMENT.optimizeMode === 'minimize') ? (delta < 0) : (delta > 0); const better = (EXPERIMENT.optimizeMode === 'minimize') ? (delta < 0) : (delta > 0);
if (better) { if (better) {
data.push([ trial.sequenceId, trial.accuracy, trial.info.hyperParameters ]); data.push([ trial.sequenceId, trial.accuracy, trial.description.parameters ]);
best = trial; best = trial;
} else { } else {
data.push([ trial.sequenceId, best.accuracy, trial.info.hyperParameters ]); data.push([ trial.sequenceId, best.accuracy, trial.description.parameters ]);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment