Commit f5fc733a authored by Byzantine's avatar Byzantine
Browse files

Removing research/community models

parent 09bc9f54
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Builds the Shake-Shake Model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import custom_ops as ops
import tensorflow as tf
def _shake_shake_skip_connection(x, output_filters, stride):
"""Adds a residual connection to the filter x for the shake-shake model."""
curr_filters = int(x.shape[3])
if curr_filters == output_filters:
return x
stride_spec = ops.stride_arr(stride, stride)
# Skip path 1
path1 = tf.nn.avg_pool(
x, [1, 1, 1, 1], stride_spec, 'VALID', data_format='NHWC')
path1 = ops.conv2d(path1, int(output_filters / 2), 1, scope='path1_conv')
# Skip path 2
# First pad with 0's then crop
pad_arr = [[0, 0], [0, 1], [0, 1], [0, 0]]
path2 = tf.pad(x, pad_arr)[:, 1:, 1:, :]
concat_axis = 3
path2 = tf.nn.avg_pool(
path2, [1, 1, 1, 1], stride_spec, 'VALID', data_format='NHWC')
path2 = ops.conv2d(path2, int(output_filters / 2), 1, scope='path2_conv')
# Concat and apply BN
final_path = tf.concat(values=[path1, path2], axis=concat_axis)
final_path = ops.batch_norm(final_path, scope='final_path_bn')
return final_path
def _shake_shake_branch(x, output_filters, stride, rand_forward, rand_backward,
is_training):
"""Building a 2 branching convnet."""
x = tf.nn.relu(x)
x = ops.conv2d(x, output_filters, 3, stride=stride, scope='conv1')
x = ops.batch_norm(x, scope='bn1')
x = tf.nn.relu(x)
x = ops.conv2d(x, output_filters, 3, scope='conv2')
x = ops.batch_norm(x, scope='bn2')
if is_training:
x = x * rand_backward + tf.stop_gradient(x * rand_forward -
x * rand_backward)
else:
x *= 1.0 / 2
return x
def _shake_shake_block(x, output_filters, stride, is_training):
"""Builds a full shake-shake sub layer."""
batch_size = tf.shape(x)[0]
# Generate random numbers for scaling the branches
rand_forward = [
tf.random_uniform(
[batch_size, 1, 1, 1], minval=0, maxval=1, dtype=tf.float32)
for _ in range(2)
]
rand_backward = [
tf.random_uniform(
[batch_size, 1, 1, 1], minval=0, maxval=1, dtype=tf.float32)
for _ in range(2)
]
# Normalize so that all sum to 1
total_forward = tf.add_n(rand_forward)
total_backward = tf.add_n(rand_backward)
rand_forward = [samp / total_forward for samp in rand_forward]
rand_backward = [samp / total_backward for samp in rand_backward]
zipped_rand = zip(rand_forward, rand_backward)
branches = []
for branch, (r_forward, r_backward) in enumerate(zipped_rand):
with tf.variable_scope('branch_{}'.format(branch)):
b = _shake_shake_branch(x, output_filters, stride, r_forward, r_backward,
is_training)
branches.append(b)
res = _shake_shake_skip_connection(x, output_filters, stride)
return res + tf.add_n(branches)
def _shake_shake_layer(x, output_filters, num_blocks, stride,
is_training):
"""Builds many sub layers into one full layer."""
for block_num in range(num_blocks):
curr_stride = stride if (block_num == 0) else 1
with tf.variable_scope('layer_{}'.format(block_num)):
x = _shake_shake_block(x, output_filters, curr_stride,
is_training)
return x
def build_shake_shake_model(images, num_classes, hparams, is_training):
"""Builds the Shake-Shake model.
Build the Shake-Shake model from https://arxiv.org/abs/1705.07485.
Args:
images: Tensor of images that will be fed into the Wide ResNet Model.
num_classes: Number of classed that the model needs to predict.
hparams: tf.HParams object that contains additional hparams needed to
construct the model. In this case it is the `shake_shake_widen_factor`
that is used to determine how many filters the model has.
is_training: Is the model training or not.
Returns:
The logits of the Shake-Shake model.
"""
depth = 26
k = hparams.shake_shake_widen_factor # The widen factor
n = int((depth - 2) / 6)
x = images
x = ops.conv2d(x, 16, 3, scope='init_conv')
x = ops.batch_norm(x, scope='init_bn')
with tf.variable_scope('L1'):
x = _shake_shake_layer(x, 16 * k, n, 1, is_training)
with tf.variable_scope('L2'):
x = _shake_shake_layer(x, 32 * k, n, 2, is_training)
with tf.variable_scope('L3'):
x = _shake_shake_layer(x, 64 * k, n, 2, is_training)
x = tf.nn.relu(x)
x = ops.global_avg_pool(x)
# Fully connected
logits = ops.fc(x, num_classes)
return logits
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""AutoAugment Train/Eval module.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import contextlib
import os
import time
import custom_ops as ops
import data_utils
import helper_utils
import numpy as np
from shake_drop import build_shake_drop_model
from shake_shake import build_shake_shake_model
import tensorflow as tf
from wrn import build_wrn_model
tf.flags.DEFINE_string('model_name', 'wrn',
'wrn, shake_shake_32, shake_shake_96, shake_shake_112, '
'pyramid_net')
tf.flags.DEFINE_string('checkpoint_dir', '/tmp/training', 'Training Directory.')
tf.flags.DEFINE_string('data_path', '/tmp/data',
'Directory where dataset is located.')
tf.flags.DEFINE_string('dataset', 'cifar10',
'Dataset to train with. Either cifar10 or cifar100')
tf.flags.DEFINE_integer('use_cpu', 1, '1 if use CPU, else GPU.')
FLAGS = tf.flags.FLAGS
arg_scope = tf.contrib.framework.arg_scope
def setup_arg_scopes(is_training):
"""Sets up the argscopes that will be used when building an image model.
Args:
is_training: Is the model training or not.
Returns:
Arg scopes to be put around the model being constructed.
"""
batch_norm_decay = 0.9
batch_norm_epsilon = 1e-5
batch_norm_params = {
# Decay for the moving averages.
'decay': batch_norm_decay,
# epsilon to prevent 0s in variance.
'epsilon': batch_norm_epsilon,
'scale': True,
# collection containing the moving mean and moving variance.
'is_training': is_training,
}
scopes = []
scopes.append(arg_scope([ops.batch_norm], **batch_norm_params))
return scopes
def build_model(inputs, num_classes, is_training, hparams):
"""Constructs the vision model being trained/evaled.
Args:
inputs: input features/images being fed to the image model build built.
num_classes: number of output classes being predicted.
is_training: is the model training or not.
hparams: additional hyperparameters associated with the image model.
Returns:
The logits of the image model.
"""
scopes = setup_arg_scopes(is_training)
with contextlib.nested(*scopes):
if hparams.model_name == 'pyramid_net':
logits = build_shake_drop_model(
inputs, num_classes, is_training)
elif hparams.model_name == 'wrn':
logits = build_wrn_model(
inputs, num_classes, hparams.wrn_size)
elif hparams.model_name == 'shake_shake':
logits = build_shake_shake_model(
inputs, num_classes, hparams, is_training)
return logits
class CifarModel(object):
"""Builds an image model for Cifar10/Cifar100."""
def __init__(self, hparams):
self.hparams = hparams
def build(self, mode):
"""Construct the cifar model."""
assert mode in ['train', 'eval']
self.mode = mode
self._setup_misc(mode)
self._setup_images_and_labels()
self._build_graph(self.images, self.labels, mode)
self.init = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
def _setup_misc(self, mode):
"""Sets up miscellaneous in the cifar model constructor."""
self.lr_rate_ph = tf.Variable(0.0, name='lrn_rate', trainable=False)
self.reuse = None if (mode == 'train') else True
self.batch_size = self.hparams.batch_size
if mode == 'eval':
self.batch_size = 25
def _setup_images_and_labels(self):
"""Sets up image and label placeholders for the cifar model."""
if FLAGS.dataset == 'cifar10':
self.num_classes = 10
else:
self.num_classes = 100
self.images = tf.placeholder(tf.float32, [self.batch_size, 32, 32, 3])
self.labels = tf.placeholder(tf.float32,
[self.batch_size, self.num_classes])
def assign_epoch(self, session, epoch_value):
session.run(self._epoch_update, feed_dict={self._new_epoch: epoch_value})
def _build_graph(self, images, labels, mode):
"""Constructs the TF graph for the cifar model.
Args:
images: A 4-D image Tensor
labels: A 2-D labels Tensor.
mode: string indicating training mode ( e.g., 'train', 'valid', 'test').
"""
is_training = 'train' in mode
if is_training:
self.global_step = tf.train.get_or_create_global_step()
logits = build_model(
images,
self.num_classes,
is_training,
self.hparams)
self.predictions, self.cost = helper_utils.setup_loss(
logits, labels)
self.accuracy, self.eval_op = tf.metrics.accuracy(
tf.argmax(labels, 1), tf.argmax(self.predictions, 1))
self._calc_num_trainable_params()
# Adds L2 weight decay to the cost
self.cost = helper_utils.decay_weights(self.cost,
self.hparams.weight_decay_rate)
if is_training:
self._build_train_op()
# Setup checkpointing for this child model
# Keep 2 or more checkpoints around during training.
with tf.device('/cpu:0'):
self.saver = tf.train.Saver(max_to_keep=2)
self.init = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
def _calc_num_trainable_params(self):
self.num_trainable_params = np.sum([
np.prod(var.get_shape().as_list()) for var in tf.trainable_variables()
])
tf.logging.info('number of trainable params: {}'.format(
self.num_trainable_params))
def _build_train_op(self):
"""Builds the train op for the cifar model."""
hparams = self.hparams
tvars = tf.trainable_variables()
grads = tf.gradients(self.cost, tvars)
if hparams.gradient_clipping_by_global_norm > 0.0:
grads, norm = tf.clip_by_global_norm(
grads, hparams.gradient_clipping_by_global_norm)
tf.summary.scalar('grad_norm', norm)
# Setup the initial learning rate
initial_lr = self.lr_rate_ph
optimizer = tf.train.MomentumOptimizer(
initial_lr,
0.9,
use_nesterov=True)
self.optimizer = optimizer
apply_op = optimizer.apply_gradients(
zip(grads, tvars), global_step=self.global_step, name='train_step')
train_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies([apply_op]):
self.train_op = tf.group(*train_ops)
class CifarModelTrainer(object):
"""Trains an instance of the CifarModel class."""
def __init__(self, hparams):
self._session = None
self.hparams = hparams
self.model_dir = os.path.join(FLAGS.checkpoint_dir, 'model')
self.log_dir = os.path.join(FLAGS.checkpoint_dir, 'log')
# Set the random seed to be sure the same validation set
# is used for each model
np.random.seed(0)
self.data_loader = data_utils.DataSet(hparams)
np.random.seed() # Put the random seed back to random
self.data_loader.reset()
def save_model(self, step=None):
"""Dumps model into the backup_dir.
Args:
step: If provided, creates a checkpoint with the given step
number, instead of overwriting the existing checkpoints.
"""
model_save_name = os.path.join(self.model_dir, 'model.ckpt')
if not tf.gfile.IsDirectory(self.model_dir):
tf.gfile.MakeDirs(self.model_dir)
self.saver.save(self.session, model_save_name, global_step=step)
tf.logging.info('Saved child model')
def extract_model_spec(self):
"""Loads a checkpoint with the architecture structure stored in the name."""
checkpoint_path = tf.train.latest_checkpoint(self.model_dir)
if checkpoint_path is not None:
self.saver.restore(self.session, checkpoint_path)
tf.logging.info('Loaded child model checkpoint from %s',
checkpoint_path)
else:
self.save_model(step=0)
def eval_child_model(self, model, data_loader, mode):
"""Evaluate the child model.
Args:
model: image model that will be evaluated.
data_loader: dataset object to extract eval data from.
mode: will the model be evalled on train, val or test.
Returns:
Accuracy of the model on the specified dataset.
"""
tf.logging.info('Evaluating child model in mode %s', mode)
while True:
try:
with self._new_session(model):
accuracy = helper_utils.eval_child_model(
self.session,
model,
data_loader,
mode)
tf.logging.info('Eval child model accuracy: {}'.format(accuracy))
# If epoch trained without raising the below errors, break
# from loop.
break
except (tf.errors.AbortedError, tf.errors.UnavailableError) as e:
tf.logging.info('Retryable error caught: %s. Retrying.', e)
return accuracy
@contextlib.contextmanager
def _new_session(self, m):
"""Creates a new session for model m."""
# Create a new session for this model, initialize
# variables, and save / restore from
# checkpoint.
self._session = tf.Session(
'',
config=tf.ConfigProto(
allow_soft_placement=True, log_device_placement=False))
self.session.run(m.init)
# Load in a previous checkpoint, or save this one
self.extract_model_spec()
try:
yield
finally:
tf.Session.reset('')
self._session = None
def _build_models(self):
"""Builds the image models for train and eval."""
# Determine if we should build the train and eval model. When using
# distributed training we only want to build one or the other and not both.
with tf.variable_scope('model', use_resource=False):
m = CifarModel(self.hparams)
m.build('train')
self._num_trainable_params = m.num_trainable_params
self._saver = m.saver
with tf.variable_scope('model', reuse=True, use_resource=False):
meval = CifarModel(self.hparams)
meval.build('eval')
return m, meval
def _calc_starting_epoch(self, m):
"""Calculates the starting epoch for model m based on global step."""
hparams = self.hparams
batch_size = hparams.batch_size
steps_per_epoch = int(hparams.train_size / batch_size)
with self._new_session(m):
curr_step = self.session.run(m.global_step)
total_steps = steps_per_epoch * hparams.num_epochs
epochs_left = (total_steps - curr_step) // steps_per_epoch
starting_epoch = hparams.num_epochs - epochs_left
return starting_epoch
def _run_training_loop(self, m, curr_epoch):
"""Trains the cifar model `m` for one epoch."""
start_time = time.time()
while True:
try:
with self._new_session(m):
train_accuracy = helper_utils.run_epoch_training(
self.session, m, self.data_loader, curr_epoch)
tf.logging.info('Saving model after epoch')
self.save_model(step=curr_epoch)
break
except (tf.errors.AbortedError, tf.errors.UnavailableError) as e:
tf.logging.info('Retryable error caught: %s. Retrying.', e)
tf.logging.info('Finished epoch: {}'.format(curr_epoch))
tf.logging.info('Epoch time(min): {}'.format(
(time.time() - start_time) / 60.0))
return train_accuracy
def _compute_final_accuracies(self, meval):
"""Run once training is finished to compute final val/test accuracies."""
valid_accuracy = self.eval_child_model(meval, self.data_loader, 'val')
if self.hparams.eval_test:
test_accuracy = self.eval_child_model(meval, self.data_loader, 'test')
else:
test_accuracy = 0
tf.logging.info('Test Accuracy: {}'.format(test_accuracy))
return valid_accuracy, test_accuracy
def run_model(self):
"""Trains and evalutes the image model."""
hparams = self.hparams
# Build the child graph
with tf.Graph().as_default(), tf.device(
'/cpu:0' if FLAGS.use_cpu else '/gpu:0'):
m, meval = self._build_models()
# Figure out what epoch we are on
starting_epoch = self._calc_starting_epoch(m)
# Run the validation error right at the beginning
valid_accuracy = self.eval_child_model(
meval, self.data_loader, 'val')
tf.logging.info('Before Training Epoch: {} Val Acc: {}'.format(
starting_epoch, valid_accuracy))
training_accuracy = None
for curr_epoch in xrange(starting_epoch, hparams.num_epochs):
# Run one training epoch
training_accuracy = self._run_training_loop(m, curr_epoch)
valid_accuracy = self.eval_child_model(
meval, self.data_loader, 'val')
tf.logging.info('Epoch: {} Valid Acc: {}'.format(
curr_epoch, valid_accuracy))
valid_accuracy, test_accuracy = self._compute_final_accuracies(
meval)
tf.logging.info(
'Train Acc: {} Valid Acc: {} Test Acc: {}'.format(
training_accuracy, valid_accuracy, test_accuracy))
@property
def saver(self):
return self._saver
@property
def session(self):
return self._session
@property
def num_trainable_params(self):
return self._num_trainable_params
def main(_):
if FLAGS.dataset not in ['cifar10', 'cifar100']:
raise ValueError('Invalid dataset: %s' % FLAGS.dataset)
hparams = tf.contrib.training.HParams(
train_size=50000,
validation_size=0,
eval_test=1,
dataset=FLAGS.dataset,
data_path=FLAGS.data_path,
batch_size=128,
gradient_clipping_by_global_norm=5.0)
if FLAGS.model_name == 'wrn':
hparams.add_hparam('model_name', 'wrn')
hparams.add_hparam('num_epochs', 200)
hparams.add_hparam('wrn_size', 160)
hparams.add_hparam('lr', 0.1)
hparams.add_hparam('weight_decay_rate', 5e-4)
elif FLAGS.model_name == 'shake_shake_32':
hparams.add_hparam('model_name', 'shake_shake')
hparams.add_hparam('num_epochs', 1800)
hparams.add_hparam('shake_shake_widen_factor', 2)
hparams.add_hparam('lr', 0.01)
hparams.add_hparam('weight_decay_rate', 0.001)
elif FLAGS.model_name == 'shake_shake_96':
hparams.add_hparam('model_name', 'shake_shake')
hparams.add_hparam('num_epochs', 1800)
hparams.add_hparam('shake_shake_widen_factor', 6)
hparams.add_hparam('lr', 0.01)
hparams.add_hparam('weight_decay_rate', 0.001)
elif FLAGS.model_name == 'shake_shake_112':
hparams.add_hparam('model_name', 'shake_shake')
hparams.add_hparam('num_epochs', 1800)
hparams.add_hparam('shake_shake_widen_factor', 7)
hparams.add_hparam('lr', 0.01)
hparams.add_hparam('weight_decay_rate', 0.001)
elif FLAGS.model_name == 'pyramid_net':
hparams.add_hparam('model_name', 'pyramid_net')
hparams.add_hparam('num_epochs', 1800)
hparams.add_hparam('lr', 0.05)
hparams.add_hparam('weight_decay_rate', 5e-5)
hparams.batch_size = 64
else:
raise ValueError('Not Valid Model Name: %s' % FLAGS.model_name)
cifar_trainer = CifarModelTrainer(hparams)
cifar_trainer.run_model()
if __name__ == '__main__':
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Builds the Wide-ResNet Model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import custom_ops as ops
import numpy as np
import tensorflow as tf
def residual_block(
x, in_filter, out_filter, stride, activate_before_residual=False):
"""Adds residual connection to `x` in addition to applying BN->ReLU->3x3 Conv.
Args:
x: Tensor that is the output of the previous layer in the model.
in_filter: Number of filters `x` has.
out_filter: Number of filters that the output of this layer will have.
stride: Integer that specified what stride should be applied to `x`.
activate_before_residual: Boolean on whether a BN->ReLU should be applied
to x before the convolution is applied.
Returns:
A Tensor that is the result of applying two sequences of BN->ReLU->3x3 Conv
and then adding that Tensor to `x`.
"""
if activate_before_residual: # Pass up RELU and BN activation for resnet
with tf.variable_scope('shared_activation'):
x = ops.batch_norm(x, scope='init_bn')
x = tf.nn.relu(x)
orig_x = x
else:
orig_x = x
block_x = x
if not activate_before_residual:
with tf.variable_scope('residual_only_activation'):
block_x = ops.batch_norm(block_x, scope='init_bn')
block_x = tf.nn.relu(block_x)
with tf.variable_scope('sub1'):
block_x = ops.conv2d(
block_x, out_filter, 3, stride=stride, scope='conv1')
with tf.variable_scope('sub2'):
block_x = ops.batch_norm(block_x, scope='bn2')
block_x = tf.nn.relu(block_x)
block_x = ops.conv2d(
block_x, out_filter, 3, stride=1, scope='conv2')
with tf.variable_scope(
'sub_add'): # If number of filters do not agree then zero pad them
if in_filter != out_filter:
orig_x = ops.avg_pool(orig_x, stride, stride)
orig_x = ops.zero_pad(orig_x, in_filter, out_filter)
x = orig_x + block_x
return x
def _res_add(in_filter, out_filter, stride, x, orig_x):
"""Adds `x` with `orig_x`, both of which are layers in the model.
Args:
in_filter: Number of filters in `orig_x`.
out_filter: Number of filters in `x`.
stride: Integer specifying the stide that should be applied `orig_x`.
x: Tensor that is the output of the previous layer.
orig_x: Tensor that is the output of an earlier layer in the network.
Returns:
A Tensor that is the result of `x` and `orig_x` being added after
zero padding and striding are applied to `orig_x` to get the shapes
to match.
"""
if in_filter != out_filter:
orig_x = ops.avg_pool(orig_x, stride, stride)
orig_x = ops.zero_pad(orig_x, in_filter, out_filter)
x = x + orig_x
orig_x = x
return x, orig_x
def build_wrn_model(images, num_classes, wrn_size):
"""Builds the WRN model.
Build the Wide ResNet model from https://arxiv.org/abs/1605.07146.
Args:
images: Tensor of images that will be fed into the Wide ResNet Model.
num_classes: Number of classed that the model needs to predict.
wrn_size: Parameter that scales the number of filters in the Wide ResNet
model.
Returns:
The logits of the Wide ResNet model.
"""
kernel_size = wrn_size
filter_size = 3
num_blocks_per_resnet = 4
filters = [
min(kernel_size, 16), kernel_size, kernel_size * 2, kernel_size * 4
]
strides = [1, 2, 2] # stride for each resblock
# Run the first conv
with tf.variable_scope('init'):
x = images
output_filters = filters[0]
x = ops.conv2d(x, output_filters, filter_size, scope='init_conv')
first_x = x # Res from the beginning
orig_x = x # Res from previous block
for block_num in range(1, 4):
with tf.variable_scope('unit_{}_0'.format(block_num)):
activate_before_residual = True if block_num == 1 else False
x = residual_block(
x,
filters[block_num - 1],
filters[block_num],
strides[block_num - 1],
activate_before_residual=activate_before_residual)
for i in range(1, num_blocks_per_resnet):
with tf.variable_scope('unit_{}_{}'.format(block_num, i)):
x = residual_block(
x,
filters[block_num],
filters[block_num],
1,
activate_before_residual=False)
x, orig_x = _res_add(filters[block_num - 1], filters[block_num],
strides[block_num - 1], x, orig_x)
final_stride_val = np.prod(strides)
x, _ = _res_add(filters[0], filters[3], final_stride_val, x, first_x)
with tf.variable_scope('unit_last'):
x = ops.batch_norm(x, scope='final_bn')
x = tf.nn.relu(x)
x = ops.global_avg_pool(x)
logits = ops.fc(x, num_classes)
return logits
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import sklearn.preprocessing as prep
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from autoencoder_models.DenoisingAutoencoder import AdditiveGaussianNoiseAutoencoder
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
def standard_scale(X_train, X_test):
preprocessor = prep.StandardScaler().fit(X_train)
X_train = preprocessor.transform(X_train)
X_test = preprocessor.transform(X_test)
return X_train, X_test
def get_random_block_from_data(data, batch_size):
start_index = np.random.randint(0, len(data) - batch_size)
return data[start_index:(start_index + batch_size)]
X_train, X_test = standard_scale(mnist.train.images, mnist.test.images)
n_samples = int(mnist.train.num_examples)
training_epochs = 20
batch_size = 128
display_step = 1
autoencoder = AdditiveGaussianNoiseAutoencoder(
n_input=784,
n_hidden=200,
transfer_function=tf.nn.softplus,
optimizer=tf.train.AdamOptimizer(learning_rate = 0.001),
scale=0.01)
for epoch in range(training_epochs):
avg_cost = 0.
total_batch = int(n_samples / batch_size)
# Loop over all batches
for i in range(total_batch):
batch_xs = get_random_block_from_data(X_train, batch_size)
# Fit training using batch data
cost = autoencoder.partial_fit(batch_xs)
# Compute average loss
avg_cost += cost / n_samples * batch_size
# Display logs per epoch step
if epoch % display_step == 0:
print("Epoch:", '%d,' % (epoch + 1),
"Cost:", "{:.9f}".format(avg_cost))
print("Total cost: " + str(autoencoder.calc_total_cost(X_test)))
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import sklearn.preprocessing as prep
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from autoencoder_models.Autoencoder import Autoencoder
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
def standard_scale(X_train, X_test):
preprocessor = prep.StandardScaler().fit(X_train)
X_train = preprocessor.transform(X_train)
X_test = preprocessor.transform(X_test)
return X_train, X_test
def get_random_block_from_data(data, batch_size):
start_index = np.random.randint(0, len(data) - batch_size)
return data[start_index:(start_index + batch_size)]
X_train, X_test = standard_scale(mnist.train.images, mnist.test.images)
n_samples = int(mnist.train.num_examples)
training_epochs = 20
batch_size = 128
display_step = 1
autoencoder = Autoencoder(n_layers=[784, 200],
transfer_function = tf.nn.softplus,
optimizer = tf.train.AdamOptimizer(learning_rate = 0.001))
for epoch in range(training_epochs):
avg_cost = 0.
total_batch = int(n_samples / batch_size)
# Loop over all batches
for i in range(total_batch):
batch_xs = get_random_block_from_data(X_train, batch_size)
# Fit training using batch data
cost = autoencoder.partial_fit(batch_xs)
# Compute average loss
avg_cost += cost / n_samples * batch_size
# Display logs per epoch step
if epoch % display_step == 0:
print("Epoch:", '%d,' % (epoch + 1),
"Cost:", "{:.9f}".format(avg_cost))
print("Total cost: " + str(autoencoder.calc_total_cost(X_test)))
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import sklearn.preprocessing as prep
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from autoencoder_models.DenoisingAutoencoder import MaskingNoiseAutoencoder
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
def standard_scale(X_train, X_test):
preprocessor = prep.StandardScaler().fit(X_train)
X_train = preprocessor.transform(X_train)
X_test = preprocessor.transform(X_test)
return X_train, X_test
def get_random_block_from_data(data, batch_size):
start_index = np.random.randint(0, len(data) - batch_size)
return data[start_index:(start_index + batch_size)]
X_train, X_test = standard_scale(mnist.train.images, mnist.test.images)
n_samples = int(mnist.train.num_examples)
training_epochs = 100
batch_size = 128
display_step = 1
autoencoder = MaskingNoiseAutoencoder(
n_input=784,
n_hidden=200,
transfer_function=tf.nn.softplus,
optimizer=tf.train.AdamOptimizer(learning_rate=0.001),
dropout_probability=0.95)
for epoch in range(training_epochs):
avg_cost = 0.
total_batch = int(n_samples / batch_size)
for i in range(total_batch):
batch_xs = get_random_block_from_data(X_train, batch_size)
cost = autoencoder.partial_fit(batch_xs)
avg_cost += cost / n_samples * batch_size
if epoch % display_step == 0:
print("Epoch:", '%d,' % (epoch + 1),
"Cost:", "{:.9f}".format(avg_cost))
print("Total cost: " + str(autoencoder.calc_total_cost(X_test)))
![No Maintenance Intended](https://img.shields.io/badge/No%20Maintenance%20Intended-%E2%9C%95-red.svg)
![TensorFlow Requirement: 1.x](https://img.shields.io/badge/TensorFlow%20Requirement-1.x-brightgreen)
![TensorFlow 2 Not Supported](https://img.shields.io/badge/TensorFlow%202%20Not%20Supported-%E2%9C%95-red.svg)
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import sklearn.preprocessing as prep
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from autoencoder_models.VariationalAutoencoder import VariationalAutoencoder
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
def min_max_scale(X_train, X_test):
preprocessor = prep.MinMaxScaler().fit(X_train)
X_train = preprocessor.transform(X_train)
X_test = preprocessor.transform(X_test)
return X_train, X_test
def get_random_block_from_data(data, batch_size):
start_index = np.random.randint(0, len(data) - batch_size)
return data[start_index:(start_index + batch_size)]
X_train, X_test = min_max_scale(mnist.train.images, mnist.test.images)
n_samples = int(mnist.train.num_examples)
training_epochs = 20
batch_size = 128
display_step = 1
autoencoder = VariationalAutoencoder(
n_input=784,
n_hidden=200,
optimizer=tf.train.AdamOptimizer(learning_rate = 0.001))
for epoch in range(training_epochs):
avg_cost = 0.
total_batch = int(n_samples / batch_size)
# Loop over all batches
for i in range(total_batch):
batch_xs = get_random_block_from_data(X_train, batch_size)
# Fit training using batch data
cost = autoencoder.partial_fit(batch_xs)
# Compute average loss
avg_cost += cost / n_samples * batch_size
# Display logs per epoch step
if epoch % display_step == 0:
print("Epoch:", '%d,' % (epoch + 1),
"Cost:", "{:.9f}".format(avg_cost))
print("Total cost: " + str(autoencoder.calc_total_cost(X_test)))
import numpy as np
import tensorflow as tf
class Autoencoder(object):
def __init__(self, n_layers, transfer_function=tf.nn.softplus, optimizer=tf.train.AdamOptimizer()):
self.n_layers = n_layers
self.transfer = transfer_function
network_weights = self._initialize_weights()
self.weights = network_weights
# model
self.x = tf.placeholder(tf.float32, [None, self.n_layers[0]])
self.hidden_encode = []
h = self.x
for layer in range(len(self.n_layers)-1):
h = self.transfer(
tf.add(tf.matmul(h, self.weights['encode'][layer]['w']),
self.weights['encode'][layer]['b']))
self.hidden_encode.append(h)
self.hidden_recon = []
for layer in range(len(self.n_layers)-1):
h = self.transfer(
tf.add(tf.matmul(h, self.weights['recon'][layer]['w']),
self.weights['recon'][layer]['b']))
self.hidden_recon.append(h)
self.reconstruction = self.hidden_recon[-1]
# cost
self.cost = 0.5 * tf.reduce_sum(tf.pow(tf.subtract(self.reconstruction, self.x), 2.0))
self.optimizer = optimizer.minimize(self.cost)
init = tf.global_variables_initializer()
self.sess = tf.Session()
self.sess.run(init)
def _initialize_weights(self):
all_weights = dict()
initializer = tf.contrib.layers.xavier_initializer()
# Encoding network weights
encoder_weights = []
for layer in range(len(self.n_layers)-1):
w = tf.Variable(
initializer((self.n_layers[layer], self.n_layers[layer + 1]),
dtype=tf.float32))
b = tf.Variable(
tf.zeros([self.n_layers[layer + 1]], dtype=tf.float32))
encoder_weights.append({'w': w, 'b': b})
# Recon network weights
recon_weights = []
for layer in range(len(self.n_layers)-1, 0, -1):
w = tf.Variable(
initializer((self.n_layers[layer], self.n_layers[layer - 1]),
dtype=tf.float32))
b = tf.Variable(
tf.zeros([self.n_layers[layer - 1]], dtype=tf.float32))
recon_weights.append({'w': w, 'b': b})
all_weights['encode'] = encoder_weights
all_weights['recon'] = recon_weights
return all_weights
def partial_fit(self, X):
cost, opt = self.sess.run((self.cost, self.optimizer), feed_dict={self.x: X})
return cost
def calc_total_cost(self, X):
return self.sess.run(self.cost, feed_dict={self.x: X})
def transform(self, X):
return self.sess.run(self.hidden_encode[-1], feed_dict={self.x: X})
def generate(self, hidden=None):
if hidden is None:
hidden = np.random.normal(size=self.weights['encode'][-1]['b'])
return self.sess.run(self.reconstruction, feed_dict={self.hidden_encode[-1]: hidden})
def reconstruct(self, X):
return self.sess.run(self.reconstruction, feed_dict={self.x: X})
def getWeights(self):
raise NotImplementedError
return self.sess.run(self.weights)
def getBiases(self):
raise NotImplementedError
return self.sess.run(self.weights)
import tensorflow as tf
class AdditiveGaussianNoiseAutoencoder(object):
def __init__(self, n_input, n_hidden, transfer_function = tf.nn.softplus, optimizer = tf.train.AdamOptimizer(),
scale = 0.1):
self.n_input = n_input
self.n_hidden = n_hidden
self.transfer = transfer_function
self.scale = tf.placeholder(tf.float32)
self.training_scale = scale
network_weights = self._initialize_weights()
self.weights = network_weights
# model
self.x = tf.placeholder(tf.float32, [None, self.n_input])
self.hidden = self.transfer(tf.add(tf.matmul(self.x + scale * tf.random_normal((n_input,)),
self.weights['w1']),
self.weights['b1']))
self.reconstruction = tf.add(tf.matmul(self.hidden, self.weights['w2']), self.weights['b2'])
# cost
self.cost = 0.5 * tf.reduce_sum(tf.pow(tf.subtract(self.reconstruction, self.x), 2.0))
self.optimizer = optimizer.minimize(self.cost)
init = tf.global_variables_initializer()
self.sess = tf.Session()
self.sess.run(init)
def _initialize_weights(self):
all_weights = dict()
all_weights['w1'] = tf.get_variable("w1", shape=[self.n_input, self.n_hidden],
initializer=tf.contrib.layers.xavier_initializer())
all_weights['b1'] = tf.Variable(tf.zeros([self.n_hidden], dtype = tf.float32))
all_weights['w2'] = tf.Variable(tf.zeros([self.n_hidden, self.n_input], dtype = tf.float32))
all_weights['b2'] = tf.Variable(tf.zeros([self.n_input], dtype = tf.float32))
return all_weights
def partial_fit(self, X):
cost, opt = self.sess.run((self.cost, self.optimizer), feed_dict = {self.x: X,
self.scale: self.training_scale
})
return cost
def calc_total_cost(self, X):
return self.sess.run(self.cost, feed_dict = {self.x: X,
self.scale: self.training_scale
})
def transform(self, X):
return self.sess.run(self.hidden, feed_dict = {self.x: X,
self.scale: self.training_scale
})
def generate(self, hidden=None):
if hidden is None:
hidden = self.sess.run(tf.random_normal([1, self.n_hidden]))
return self.sess.run(self.reconstruction, feed_dict = {self.hidden: hidden})
def reconstruct(self, X):
return self.sess.run(self.reconstruction, feed_dict = {self.x: X,
self.scale: self.training_scale
})
def getWeights(self):
return self.sess.run(self.weights['w1'])
def getBiases(self):
return self.sess.run(self.weights['b1'])
class MaskingNoiseAutoencoder(object):
def __init__(self, n_input, n_hidden, transfer_function = tf.nn.softplus, optimizer = tf.train.AdamOptimizer(),
dropout_probability = 0.95):
self.n_input = n_input
self.n_hidden = n_hidden
self.transfer = transfer_function
self.dropout_probability = dropout_probability
self.keep_prob = tf.placeholder(tf.float32)
network_weights = self._initialize_weights()
self.weights = network_weights
# model
self.x = tf.placeholder(tf.float32, [None, self.n_input])
self.hidden = self.transfer(tf.add(tf.matmul(tf.nn.dropout(self.x, self.keep_prob), self.weights['w1']),
self.weights['b1']))
self.reconstruction = tf.add(tf.matmul(self.hidden, self.weights['w2']), self.weights['b2'])
# cost
self.cost = 0.5 * tf.reduce_sum(tf.pow(tf.subtract(self.reconstruction, self.x), 2.0))
self.optimizer = optimizer.minimize(self.cost)
init = tf.global_variables_initializer()
self.sess = tf.Session()
self.sess.run(init)
def _initialize_weights(self):
all_weights = dict()
all_weights['w1'] = tf.get_variable("w1", shape=[self.n_input, self.n_hidden],
initializer=tf.contrib.layers.xavier_initializer())
all_weights['b1'] = tf.Variable(tf.zeros([self.n_hidden], dtype = tf.float32))
all_weights['w2'] = tf.Variable(tf.zeros([self.n_hidden, self.n_input], dtype = tf.float32))
all_weights['b2'] = tf.Variable(tf.zeros([self.n_input], dtype = tf.float32))
return all_weights
def partial_fit(self, X):
cost, opt = self.sess.run((self.cost, self.optimizer),
feed_dict = {self.x: X, self.keep_prob: self.dropout_probability})
return cost
def calc_total_cost(self, X):
return self.sess.run(self.cost, feed_dict = {self.x: X, self.keep_prob: 1.0})
def transform(self, X):
return self.sess.run(self.hidden, feed_dict = {self.x: X, self.keep_prob: 1.0})
def generate(self, hidden=None):
if hidden is None:
hidden = self.sess.run(tf.random_normal([1, self.n_hidden]))
return self.sess.run(self.reconstruction, feed_dict = {self.hidden: hidden})
def reconstruct(self, X):
return self.sess.run(self.reconstruction, feed_dict = {self.x: X, self.keep_prob: 1.0})
def getWeights(self):
return self.sess.run(self.weights['w1'])
def getBiases(self):
return self.sess.run(self.weights['b1'])
import tensorflow as tf
class VariationalAutoencoder(object):
def __init__(self, n_input, n_hidden, optimizer = tf.train.AdamOptimizer()):
self.n_input = n_input
self.n_hidden = n_hidden
network_weights = self._initialize_weights()
self.weights = network_weights
# model
self.x = tf.placeholder(tf.float32, [None, self.n_input])
self.z_mean = tf.add(tf.matmul(self.x, self.weights['w1']), self.weights['b1'])
self.z_log_sigma_sq = tf.add(tf.matmul(self.x, self.weights['log_sigma_w1']), self.weights['log_sigma_b1'])
# sample from gaussian distribution
eps = tf.random_normal(tf.stack([tf.shape(self.x)[0], self.n_hidden]), 0, 1, dtype = tf.float32)
self.z = tf.add(self.z_mean, tf.multiply(tf.sqrt(tf.exp(self.z_log_sigma_sq)), eps))
self.reconstruction = tf.add(tf.matmul(self.z, self.weights['w2']), self.weights['b2'])
# cost
reconstr_loss = 0.5 * tf.reduce_sum(tf.pow(tf.subtract(self.reconstruction, self.x), 2.0), 1)
latent_loss = -0.5 * tf.reduce_sum(1 + self.z_log_sigma_sq
- tf.square(self.z_mean)
- tf.exp(self.z_log_sigma_sq), 1)
self.cost = tf.reduce_mean(reconstr_loss + latent_loss)
self.optimizer = optimizer.minimize(self.cost)
init = tf.global_variables_initializer()
self.sess = tf.Session()
self.sess.run(init)
def _initialize_weights(self):
all_weights = dict()
all_weights['w1'] = tf.get_variable("w1", shape=[self.n_input, self.n_hidden],
initializer=tf.contrib.layers.xavier_initializer())
all_weights['log_sigma_w1'] = tf.get_variable("log_sigma_w1", shape=[self.n_input, self.n_hidden],
initializer=tf.contrib.layers.xavier_initializer())
all_weights['b1'] = tf.Variable(tf.zeros([self.n_hidden], dtype=tf.float32))
all_weights['log_sigma_b1'] = tf.Variable(tf.zeros([self.n_hidden], dtype=tf.float32))
all_weights['w2'] = tf.Variable(tf.zeros([self.n_hidden, self.n_input], dtype=tf.float32))
all_weights['b2'] = tf.Variable(tf.zeros([self.n_input], dtype=tf.float32))
return all_weights
def partial_fit(self, X):
cost, opt = self.sess.run((self.cost, self.optimizer), feed_dict={self.x: X})
return cost
def calc_total_cost(self, X):
return self.sess.run(self.cost, feed_dict = {self.x: X})
def transform(self, X):
return self.sess.run(self.z_mean, feed_dict={self.x: X})
def generate(self, hidden = None):
if hidden is None:
hidden = self.sess.run(tf.random_normal([1, self.n_hidden]))
return self.sess.run(self.reconstruction, feed_dict={self.z: hidden})
def reconstruct(self, X):
return self.sess.run(self.reconstruction, feed_dict={self.x: X})
def getWeights(self):
return self.sess.run(self.weights['w1'])
def getBiases(self):
return self.sess.run(self.weights['b1'])
![No Maintenance Intended](https://img.shields.io/badge/No%20Maintenance%20Intended-%E2%9C%95-red.svg)
![TensorFlow Requirement: 1.x](https://img.shields.io/badge/TensorFlow%20Requirement-1.x-brightgreen)
![TensorFlow 2 Not Supported](https://img.shields.io/badge/TensorFlow%202%20Not%20Supported-%E2%9C%95-red.svg)
# Brain Coder
*Authors: Daniel Abolafia, Mohammad Norouzi, Quoc Le*
Brain coder is a code synthesis experimental environment. We provide code that reproduces the results from our recent paper [Neural Program Synthesis with Priority Queue Training](https://arxiv.org/abs/1801.03526). See single_task/README.md for details on how to build and reproduce those experiments.
## Installation
First install dependencies seperately:
* [bazel](https://docs.bazel.build/versions/master/install.html)
* [TensorFlow](https://www.tensorflow.org/install/)
* [scipy](https://www.scipy.org/install.html)
* [absl-py](https://github.com/abseil/abseil-py)
Note: even if you already have these dependencies installed, make sure they are
up-to-date to avoid unnecessary debugging.
## Building
Use bazel from the top-level repo directory.
For example:
```bash
bazel build single_task:run
```
View README.md files in subdirectories for more details.
git_repository(
name = "subpar",
remote = "https://github.com/google/subpar",
tag = "1.0.0",
)
licenses(["notice"])
package(default_visibility = [
"//:__subpackages__",
])
py_library(
name = "bf",
srcs = ["bf.py"],
)
py_test(
name = "bf_test",
srcs = ["bf_test.py"],
deps = [
":bf",
# tensorflow dep
],
)
py_library(
name = "config_lib",
srcs = ["config_lib.py"],
)
py_test(
name = "config_lib_test",
srcs = ["config_lib_test.py"],
deps = [
":config_lib",
# tensorflow dep
],
)
py_library(
name = "reward",
srcs = ["reward.py"],
)
py_test(
name = "reward_test",
srcs = ["reward_test.py"],
deps = [
":reward",
# numpy dep
# tensorflow dep
],
)
py_library(
name = "rollout",
srcs = ["rollout.py"],
deps = [
":utils",
# numpy dep
# scipy dep
],
)
py_test(
name = "rollout_test",
srcs = ["rollout_test.py"],
deps = [
":rollout",
# numpy dep
# tensorflow dep
],
)
py_library(
name = "schedules",
srcs = ["schedules.py"],
deps = [":config_lib"],
)
py_test(
name = "schedules_test",
srcs = ["schedules_test.py"],
deps = [
":config_lib",
":schedules",
# numpy dep
# tensorflow dep
],
)
py_library(
name = "utils",
srcs = ["utils.py"],
deps = [
# file dep
# absl dep /logging
# numpy dep
# tensorflow dep
],
)
py_test(
name = "utils_test",
srcs = ["utils_test.py"],
deps = [
":utils",
# numpy dep
# tensorflow dep
],
)
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""BrainF**k interpreter.
Language info: https://en.wikipedia.org/wiki/Brainfuck
Based on public implementation:
https://github.com/pocmo/Python-Brainfuck/blob/master/brainfuck.py
"""
from collections import namedtuple
import time
EvalResult = namedtuple(
'EvalResult', ['output', 'success', 'failure_reason', 'steps', 'time',
'memory', 'program_trace'])
ExecutionSnapshot = namedtuple(
'ExecutionSnapshot',
['codeptr', 'codechar', 'memptr', 'memval', 'memory', 'next_input',
'output_buffer'])
class Status(object):
SUCCESS = 'success'
TIMEOUT = 'timeout'
STEP_LIMIT = 'step-limit'
SYNTAX_ERROR = 'syntax-error'
CHARS = INT_TO_CHAR = ['>', '<', '+', '-', '[', ']', '.', ',']
CHAR_TO_INT = dict([(c, i) for i, c in enumerate(INT_TO_CHAR)])
class LookAheadIterator(object):
"""Same API as Python iterator, with additional peek method."""
def __init__(self, iterable):
self._it = iter(iterable)
self._current_element = None
self._done = False
self._preload_next()
def _preload_next(self):
try:
self._current_element = self._it.next()
except StopIteration:
self._done = True
def next(self):
if self._done:
raise StopIteration
element = self._current_element
self._preload_next()
return element
def peek(self, default_value=None):
if self._done:
if default_value is None:
raise StopIteration
return default_value
return self._current_element
def buildbracemap(code):
"""Build jump map.
Args:
code: List or string or BF chars.
Returns:
bracemap: dict mapping open and close brace positions in the code to their
destination jumps. Specifically, positions of matching open/close braces
if they exist.
correct_syntax: True if all braces match. False if there are unmatched
braces in the code. Even if there are unmatched braces, a bracemap will
be built, and unmatched braces will map to themselves.
"""
bracestack, bracemap = [], {}
correct_syntax = True
for position, command in enumerate(code):
if command == '[':
bracestack.append(position)
if command == ']':
if not bracestack: # Unmatched closing brace.
bracemap[position] = position # Don't jump to any position.
correct_syntax = False
continue
start = bracestack.pop()
bracemap[start] = position
bracemap[position] = start
if bracestack: # Unmatched opening braces.
for pos in bracestack:
bracemap[pos] = pos # Don't jump to any position.
correct_syntax = False
return bracemap, correct_syntax
def evaluate(code, input_buffer=None, init_memory=None, base=256, timeout=1.0,
max_steps=None, require_correct_syntax=True, output_memory=False,
debug=False):
"""Execute BF code.
Args:
code: String or list of BF characters. Any character not in CHARS will be
ignored.
input_buffer: A list of ints which will be used as the program's input
stream. Each read op "," will read an int from this list. 0's will be
read once the end of the list is reached, or if no input buffer is
given.
init_memory: A list of ints. Memory for first k positions will be
initialized to this list (where k = len(init_memory)). Memory positions
are initialized to 0 by default.
base: Integer base for the memory. When a memory value is incremented to
`base` it will overflow to 0. When a memory value is decremented to -1
it will underflow to `base` - 1.
timeout: Time limit for program execution in seconds. Set to None to
disable.
max_steps: Execution step limit. An execution step is the execution of one
operation (code character), even if that op has been executed before.
Execution exits when this many steps are reached. Set to None to
disable. Disabled by default.
require_correct_syntax: If True, unmatched braces will cause `evaluate` to
return without executing the code. The failure reason will be
`Status.SYNTAX_ERROR`. If False, unmatched braces are ignored
and execution will continue.
output_memory: If True, the state of the memory at the end of execution is
returned.
debug: If True, then a full program trace will be returned.
Returns:
EvalResult namedtuple containing
output: List of ints which were written out by the program with the "."
operation.
success: Boolean. Whether execution completed successfully.
failure_reason: One of the attributes of `Status`. Gives extra info
about why execution was not successful.
steps: Number of execution steps the program ran for.
time: Amount of time in seconds the program ran for.
memory: If `output_memory` is True, a list of memory cells up to the last
one written to. otherwise, None.
"""
input_iter = (
LookAheadIterator(input_buffer) if input_buffer is not None
else LookAheadIterator([]))
# Null memory value. This is the value of an empty memory. Also the value
# returned by the read operation when the input buffer is empty, or the
# end of the buffer is reached.
null_value = 0
code = list(code)
bracemap, correct_syntax = buildbracemap(code) # will modify code list
if require_correct_syntax and not correct_syntax:
return EvalResult([], False, Status.SYNTAX_ERROR, 0, 0.0,
[] if output_memory else None, [] if debug else None)
output_buffer = []
codeptr, cellptr = 0, 0
cells = list(init_memory) if init_memory else [0]
program_trace = [] if debug else None
success = True
reason = Status.SUCCESS
start_time = time.time()
steps = 0
while codeptr < len(code):
command = code[codeptr]
if debug:
# Add step to program trace.
program_trace.append(ExecutionSnapshot(
codeptr=codeptr, codechar=command, memptr=cellptr,
memval=cells[cellptr], memory=list(cells),
next_input=input_iter.peek(null_value),
output_buffer=list(output_buffer)))
if command == '>':
cellptr += 1
if cellptr == len(cells): cells.append(null_value)
if command == '<':
cellptr = 0 if cellptr <= 0 else cellptr - 1
if command == '+':
cells[cellptr] = cells[cellptr] + 1 if cells[cellptr] < (base - 1) else 0
if command == '-':
cells[cellptr] = cells[cellptr] - 1 if cells[cellptr] > 0 else (base - 1)
if command == '[' and cells[cellptr] == 0: codeptr = bracemap[codeptr]
if command == ']' and cells[cellptr] != 0: codeptr = bracemap[codeptr]
if command == '.': output_buffer.append(cells[cellptr])
if command == ',': cells[cellptr] = next(input_iter, null_value)
codeptr += 1
steps += 1
if timeout is not None and time.time() - start_time > timeout:
success = False
reason = Status.TIMEOUT
break
if max_steps is not None and steps >= max_steps:
success = False
reason = Status.STEP_LIMIT
break
if debug:
# Add step to program trace.
command = code[codeptr] if codeptr < len(code) else ''
program_trace.append(ExecutionSnapshot(
codeptr=codeptr, codechar=command, memptr=cellptr,
memval=cells[cellptr], memory=list(cells),
next_input=input_iter.peek(null_value),
output_buffer=list(output_buffer)))
return EvalResult(
output=output_buffer,
success=success,
failure_reason=reason,
steps=steps,
time=time.time() - start_time,
memory=cells if output_memory else None,
program_trace=program_trace)
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Tests for common.bf."""
import tensorflow as tf
from common import bf # brain coder
class BfTest(tf.test.TestCase):
def assertCorrectOutput(self, target_output, eval_result):
self.assertEqual(target_output, eval_result.output)
self.assertTrue(eval_result.success)
self.assertEqual(bf.Status.SUCCESS, eval_result.failure_reason)
def testBasicOps(self):
self.assertCorrectOutput(
[3, 1, 2],
bf.evaluate('+++.--.+.'))
self.assertCorrectOutput(
[1, 1, 2],
bf.evaluate('+.<.>++.'))
self.assertCorrectOutput(
[0],
bf.evaluate('+,.'))
self.assertCorrectOutput(
[ord(char) for char in 'Hello World!\n'],
bf.evaluate(
'>++++++++[-<+++++++++>]<.>>+>-[+]++>++>+++[>[->+++<<+++>]<<]>-----'
'.>->+++..+++.>-.<<+[>[+>+]>>]<--------------.>>.+++.------.-------'
'-.>+.>+.'))
def testBase(self):
self.assertCorrectOutput(
[1, 4],
bf.evaluate('+.--.', base=5, input_buffer=[]))
def testInputBuffer(self):
self.assertCorrectOutput(
[2, 3, 4],
bf.evaluate('>,[>,]<[.<]', input_buffer=[4, 3, 2]))
def testBadChars(self):
self.assertCorrectOutput(
[2, 3, 4],
bf.evaluate('>,[>,]hello<world[.<]comments',
input_buffer=[4, 3, 2]))
def testUnmatchedBraces(self):
self.assertCorrectOutput(
[3, 6, 1],
bf.evaluate('+++.]]]]>----.[[[[[>+.',
input_buffer=[],
base=10,
require_correct_syntax=False))
eval_result = bf.evaluate(
'+++.]]]]>----.[[[[[>+.',
input_buffer=[],
base=10,
require_correct_syntax=True)
self.assertEqual([], eval_result.output)
self.assertFalse(eval_result.success)
self.assertEqual(bf.Status.SYNTAX_ERROR,
eval_result.failure_reason)
def testTimeout(self):
er = bf.evaluate('+.[].', base=5, input_buffer=[], timeout=0.1)
self.assertEqual(
([1], False, bf.Status.TIMEOUT),
(er.output, er.success, er.failure_reason))
self.assertTrue(0.07 < er.time < 0.21)
er = bf.evaluate('+.[-].', base=5, input_buffer=[], timeout=0.1)
self.assertEqual(
([1, 0], True, bf.Status.SUCCESS),
(er.output, er.success, er.failure_reason))
self.assertTrue(er.time < 0.15)
def testMaxSteps(self):
er = bf.evaluate('+.[].', base=5, input_buffer=[], timeout=None,
max_steps=100)
self.assertEqual(
([1], False, bf.Status.STEP_LIMIT, 100),
(er.output, er.success, er.failure_reason, er.steps))
er = bf.evaluate('+.[-].', base=5, input_buffer=[], timeout=None,
max_steps=100)
self.assertEqual(
([1, 0], True, bf.Status.SUCCESS),
(er.output, er.success, er.failure_reason))
self.assertTrue(er.steps < 100)
def testOutputMemory(self):
er = bf.evaluate('+>++>+++>++++.', base=256, input_buffer=[],
output_memory=True)
self.assertEqual(
([4], True, bf.Status.SUCCESS),
(er.output, er.success, er.failure_reason))
self.assertEqual([1, 2, 3, 4], er.memory)
def testProgramTrace(self):
es = bf.ExecutionSnapshot
er = bf.evaluate(',[.>,].', base=256, input_buffer=[2, 1], debug=True)
self.assertEqual(
[es(codeptr=0, codechar=',', memptr=0, memval=0, memory=[0],
next_input=2, output_buffer=[]),
es(codeptr=1, codechar='[', memptr=0, memval=2, memory=[2],
next_input=1, output_buffer=[]),
es(codeptr=2, codechar='.', memptr=0, memval=2, memory=[2],
next_input=1, output_buffer=[]),
es(codeptr=3, codechar='>', memptr=0, memval=2, memory=[2],
next_input=1, output_buffer=[2]),
es(codeptr=4, codechar=',', memptr=1, memval=0, memory=[2, 0],
next_input=1, output_buffer=[2]),
es(codeptr=5, codechar=']', memptr=1, memval=1, memory=[2, 1],
next_input=0, output_buffer=[2]),
es(codeptr=2, codechar='.', memptr=1, memval=1, memory=[2, 1],
next_input=0, output_buffer=[2]),
es(codeptr=3, codechar='>', memptr=1, memval=1, memory=[2, 1],
next_input=0, output_buffer=[2, 1]),
es(codeptr=4, codechar=',', memptr=2, memval=0, memory=[2, 1, 0],
next_input=0, output_buffer=[2, 1]),
es(codeptr=5, codechar=']', memptr=2, memval=0, memory=[2, 1, 0],
next_input=0, output_buffer=[2, 1]),
es(codeptr=6, codechar='.', memptr=2, memval=0, memory=[2, 1, 0],
next_input=0, output_buffer=[2, 1]),
es(codeptr=7, codechar='', memptr=2, memval=0, memory=[2, 1, 0],
next_input=0, output_buffer=[2, 1, 0])],
er.program_trace)
if __name__ == '__main__':
tf.test.main()
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Objects for storing configuration and passing config into binaries.
Config class stores settings and hyperparameters for models, data, and anything
else that may be specific to a particular run.
"""
import ast
import itertools
from six.moves import xrange
class Config(dict):
"""Stores model configuration, hyperparameters, or dataset parameters."""
def __getattr__(self, attr):
return self[attr]
def __setattr__(self, attr, value):
self[attr] = value
def pretty_str(self, new_lines=True, indent=2, final_indent=0):
prefix = (' ' * indent) if new_lines else ''
final_prefix = (' ' * final_indent) if new_lines else ''
kv = ['%s%s=%s' % (prefix, k,
(repr(v) if not isinstance(v, Config)
else v.pretty_str(new_lines=new_lines,
indent=indent+2,
final_indent=indent)))
for k, v in self.items()]
if new_lines:
return 'Config(\n%s\n%s)' % (',\n'.join(kv), final_prefix)
else:
return 'Config(%s)' % ', '.join(kv)
def _update_iterator(self, *args, **kwargs):
"""Convert mixed input into an iterator over (key, value) tuples.
Follows the dict.update call signature.
Args:
*args: (Optional) Pass a dict or iterable of (key, value) 2-tuples as
an unnamed argument. Only one unnamed argument allowed.
**kwargs: (Optional) Pass (key, value) pairs as named arguments, where the
argument name is the key and the argument value is the value.
Returns:
An iterator over (key, value) tuples given in the input.
Raises:
TypeError: If more than one unnamed argument is given.
"""
if len(args) > 1:
raise TypeError('Expected at most 1 unnamed arguments, got %d'
% len(args))
obj = args[0] if args else dict()
if isinstance(obj, dict):
return itertools.chain(obj.items(), kwargs.items())
# Assume obj is an iterable of 2-tuples.
return itertools.chain(obj, kwargs.items())
def make_default(self, keys=None):
"""Convert OneOf objects into their default configs.
Recursively calls into Config objects.
Args:
keys: Iterable of key names to check. If None, all keys in self will be
used.
"""
if keys is None:
keys = self.keys()
for k in keys:
# Replace OneOf with its default value.
if isinstance(self[k], OneOf):
self[k] = self[k].default()
# Recursively call into all Config objects, even those that came from
# OneOf objects in the previous code line (for nested OneOf objects).
if isinstance(self[k], Config):
self[k].make_default()
def update(self, *args, **kwargs):
"""Same as dict.update except nested Config objects are updated.
Args:
*args: (Optional) Pass a dict or list of (key, value) 2-tuples as unnamed
argument.
**kwargs: (Optional) Pass (key, value) pairs as named arguments, where the
argument name is the key and the argument value is the value.
"""
key_set = set(self.keys())
for k, v in self._update_iterator(*args, **kwargs):
if k in key_set:
key_set.remove(k) # This key is updated so exclude from make_default.
if k in self and isinstance(self[k], Config) and isinstance(v, dict):
self[k].update(v)
elif k in self and isinstance(self[k], OneOf) and isinstance(v, dict):
# Replace OneOf with the chosen config.
self[k] = self[k].update(v)
else:
self[k] = v
self.make_default(key_set)
def strict_update(self, *args, **kwargs):
"""Same as Config.update except keys and types are not allowed to change.
If a given key is not already in this instance, an exception is raised. If a
given value does not have the same type as the existing value for the same
key, an exception is raised. Use this method to catch config mistakes.
Args:
*args: (Optional) Pass a dict or list of (key, value) 2-tuples as unnamed
argument.
**kwargs: (Optional) Pass (key, value) pairs as named arguments, where the
argument name is the key and the argument value is the value.
Raises:
TypeError: If more than one unnamed argument is given.
TypeError: If new value type does not match existing type.
KeyError: If a given key is not already defined in this instance.
"""
key_set = set(self.keys())
for k, v in self._update_iterator(*args, **kwargs):
if k in self:
key_set.remove(k) # This key is updated so exclude from make_default.
if isinstance(self[k], Config):
if not isinstance(v, dict):
raise TypeError('dict required for Config value, got %s' % type(v))
self[k].strict_update(v)
elif isinstance(self[k], OneOf):
if not isinstance(v, dict):
raise TypeError('dict required for OneOf value, got %s' % type(v))
# Replace OneOf with the chosen config.
self[k] = self[k].strict_update(v)
else:
if not isinstance(v, type(self[k])):
raise TypeError('Expecting type %s for key %s, got type %s'
% (type(self[k]), k, type(v)))
self[k] = v
else:
raise KeyError(
'Key %s does not exist. New key creation not allowed in '
'strict_update.' % k)
self.make_default(key_set)
@staticmethod
def from_str(config_str):
"""Inverse of Config.__str__."""
parsed = ast.literal_eval(config_str)
assert isinstance(parsed, dict)
def _make_config(dictionary):
for k, v in dictionary.items():
if isinstance(v, dict):
dictionary[k] = _make_config(v)
return Config(**dictionary)
return _make_config(parsed)
@staticmethod
def parse(key_val_string):
"""Parse hyperparameter string into Config object.
Format is 'key=val,key=val,...'
Values can be any python literal, or another Config object encoded as
'c(key=val,key=val,...)'.
c(...) expressions can be arbitrarily nested.
Example:
'a=1,b=3e-5,c=[1,2,3],d="hello world",e={"a":1,"b":2},f=c(x=1,y=[10,20])'
Args:
key_val_string: The hyperparameter string.
Returns:
Config object parsed from the input string.
"""
if not key_val_string.strip():
return Config()
def _pair_to_kv(pair):
split_index = pair.find('=')
key, val = pair[:split_index].strip(), pair[split_index+1:].strip()
if val.startswith('c(') and val.endswith(')'):
val = Config.parse(val[2:-1])
else:
val = ast.literal_eval(val)
return key, val
return Config(**dict([_pair_to_kv(pair)
for pair in _comma_iterator(key_val_string)]))
class OneOf(object):
"""Stores branching config.
In some cases there may be options which each have their own set of config
params. For example, if specifying config for an environment, each environment
can have custom config options. OneOf is a way to organize branching config.
Usage example:
one_of = OneOf(
[Config(a=1, b=2),
Config(a=2, c='hello'),
Config(a=3, d=10, e=-10)],
a=1)
config = one_of.strict_update(Config(a=3, d=20))
config == {'a': 3, 'd': 20, 'e': -10}
"""
def __init__(self, choices, **kwargs):
"""Constructor.
Usage: OneOf([Config(...), Config(...), ...], attribute=default_value)
Args:
choices: An iterable of Config objects. When update/strict_update is
called on this OneOf, one of these Config will be selected.
**kwargs: Give exactly one config attribute to branch on. The value of
this attribute during update/strict_update will determine which
Config is used.
Raises:
ValueError: If kwargs does not contain exactly one entry. Should give one
named argument which is used as the attribute to condition on.
"""
if len(kwargs) != 1:
raise ValueError(
'Incorrect usage. Must give exactly one named argument. The argument '
'name is the config attribute to condition on, and the argument '
'value is the default choice. Got %d named arguments.' % len(kwargs))
key, default_value = kwargs.items()[0]
self.key = key
self.default_value = default_value
# Make sure each choice is a Config object.
for config in choices:
if not isinstance(config, Config):
raise TypeError('choices must be a list of Config objects. Got %s.'
% type(config))
# Map value for key to the config with that value.
self.value_map = {config[key]: config for config in choices}
self.default_config = self.value_map[self.default_value]
# Make sure there are no duplicate values.
if len(self.value_map) != len(choices):
raise ValueError('Multiple choices given for the same value of %s.' % key)
# Check that the default value is valid.
if self.default_value not in self.value_map:
raise ValueError(
'Default value is not an available choice. Got %s=%s. Choices are %s.'
% (key, self.default_value, self.value_map.keys()))
def default(self):
return self.default_config
def update(self, other):
"""Choose a config and update it.
If `other` is a Config, one of the config choices is selected and updated.
Otherwise `other` is returned.
Args:
other: Will update chosen config with this value by calling `update` on
the config.
Returns:
The chosen config after updating it, or `other` if no config could be
selected.
"""
if not isinstance(other, Config):
return other
if self.key not in other or other[self.key] not in self.value_map:
return other
target = self.value_map[other[self.key]]
target.update(other)
return target
def strict_update(self, config):
"""Choose a config and update it.
`config` must be a Config object. `config` must have the key used to select
among the config choices, and that key must have a value which one of the
config choices has.
Args:
config: A Config object. the chosen config will be update by calling
`strict_update`.
Returns:
The chosen config after updating it.
Raises:
TypeError: If `config` is not a Config instance.
ValueError: If `config` does not have the branching key in its key set.
ValueError: If the value of the config's branching key is not one of the
valid choices.
"""
if not isinstance(config, Config):
raise TypeError('Expecting Config instance, got %s.' % type(config))
if self.key not in config:
raise ValueError(
'Branching key %s required but not found in %s' % (self.key, config))
if config[self.key] not in self.value_map:
raise ValueError(
'Value %s for key %s is not a possible choice. Choices are %s.'
% (config[self.key], self.key, self.value_map.keys()))
target = self.value_map[config[self.key]]
target.strict_update(config)
return target
def _next_comma(string, start_index):
"""Finds the position of the next comma not used in a literal collection."""
paren_count = 0
for i in xrange(start_index, len(string)):
c = string[i]
if c == '(' or c == '[' or c == '{':
paren_count += 1
elif c == ')' or c == ']' or c == '}':
paren_count -= 1
if paren_count == 0 and c == ',':
return i
return -1
def _comma_iterator(string):
index = 0
while 1:
next_index = _next_comma(string, index)
if next_index == -1:
yield string[index:]
return
yield string[index:next_index]
index = next_index + 1
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
"""Tests for common.config_lib."""
import tensorflow as tf
from common import config_lib # brain coder
class ConfigLibTest(tf.test.TestCase):
def testConfig(self):
config = config_lib.Config(hello='world', foo='bar', num=123, f=56.7)
self.assertEqual('world', config.hello)
self.assertEqual('bar', config['foo'])
config.hello = 'everyone'
config['bar'] = 9000
self.assertEqual('everyone', config['hello'])
self.assertEqual(9000, config.bar)
self.assertEqual(5, len(config))
def testConfigUpdate(self):
config = config_lib.Config(a=1, b=2, c=3)
config.update({'b': 10, 'd': 4})
self.assertEqual({'a': 1, 'b': 10, 'c': 3, 'd': 4}, config)
config = config_lib.Config(a=1, b=2, c=3)
config.update(b=10, d=4)
self.assertEqual({'a': 1, 'b': 10, 'c': 3, 'd': 4}, config)
config = config_lib.Config(a=1, b=2, c=3)
config.update({'e': 5}, b=10, d=4)
self.assertEqual({'a': 1, 'b': 10, 'c': 3, 'd': 4, 'e': 5}, config)
config = config_lib.Config(
a=1,
b=2,
x=config_lib.Config(
l='a',
y=config_lib.Config(m=1, n=2),
z=config_lib.Config(
q=config_lib.Config(a=10, b=20),
r=config_lib.Config(s=1, t=2))))
config.update(x={'y': {'m': 10}, 'z': {'r': {'s': 5}}})
self.assertEqual(
config_lib.Config(
a=1, b=2,
x=config_lib.Config(
l='a',
y=config_lib.Config(m=10, n=2),
z=config_lib.Config(
q=config_lib.Config(a=10, b=20),
r=config_lib.Config(s=5, t=2)))),
config)
config = config_lib.Config(
foo='bar',
num=100,
x=config_lib.Config(a=1, b=2, c=config_lib.Config(h=10, i=20, j=30)),
y=config_lib.Config(qrs=5, tuv=10),
d={'a': 1, 'b': 2},
l=[1, 2, 3])
config.update(
config_lib.Config(
foo='hat',
num=50.5,
x={'a': 5, 'z': -10},
y=config_lib.Config(wxyz=-1)),
d={'a': 10, 'c': 20},
l=[3, 4, 5, 6])
self.assertEqual(
config_lib.Config(
foo='hat',
num=50.5,
x=config_lib.Config(a=5, b=2, z=-10,
c=config_lib.Config(h=10, i=20, j=30)),
y=config_lib.Config(qrs=5, tuv=10, wxyz=-1),
d={'a': 10, 'c': 20},
l=[3, 4, 5, 6]),
config)
self.assertTrue(isinstance(config.x, config_lib.Config))
self.assertTrue(isinstance(config.x.c, config_lib.Config))
self.assertTrue(isinstance(config.y, config_lib.Config))
config = config_lib.Config(
foo='bar',
num=100,
x=config_lib.Config(a=1, b=2, c=config_lib.Config(h=10, i=20, j=30)),
y=config_lib.Config(qrs=5, tuv=10),
d={'a': 1, 'b': 2},
l=[1, 2, 3])
config.update(
config_lib.Config(
foo=1234,
num='hello',
x={'a': 5, 'z': -10, 'c': {'h': -5, 'k': 40}},
y=[1, 2, 3, 4],
d='stuff',
l={'a': 1, 'b': 2}))
self.assertEqual(
config_lib.Config(
foo=1234,
num='hello',
x=config_lib.Config(a=5, b=2, z=-10,
c=config_lib.Config(h=-5, i=20, j=30, k=40)),
y=[1, 2, 3, 4],
d='stuff',
l={'a': 1, 'b': 2}),
config)
self.assertTrue(isinstance(config.x, config_lib.Config))
self.assertTrue(isinstance(config.x.c, config_lib.Config))
self.assertTrue(isinstance(config.y, list))
def testConfigStrictUpdate(self):
config = config_lib.Config(a=1, b=2, c=3)
config.strict_update({'b': 10, 'c': 20})
self.assertEqual({'a': 1, 'b': 10, 'c': 20}, config)
config = config_lib.Config(a=1, b=2, c=3)
config.strict_update(b=10, c=20)
self.assertEqual({'a': 1, 'b': 10, 'c': 20}, config)
config = config_lib.Config(a=1, b=2, c=3, d=4)
config.strict_update({'d': 100}, b=10, a=20)
self.assertEqual({'a': 20, 'b': 10, 'c': 3, 'd': 100}, config)
config = config_lib.Config(
a=1,
b=2,
x=config_lib.Config(
l='a',
y=config_lib.Config(m=1, n=2),
z=config_lib.Config(
q=config_lib.Config(a=10, b=20),
r=config_lib.Config(s=1, t=2))))
config.strict_update(x={'y': {'m': 10}, 'z': {'r': {'s': 5}}})
self.assertEqual(
config_lib.Config(
a=1, b=2,
x=config_lib.Config(
l='a',
y=config_lib.Config(m=10, n=2),
z=config_lib.Config(
q=config_lib.Config(a=10, b=20),
r=config_lib.Config(s=5, t=2)))),
config)
config = config_lib.Config(
foo='bar',
num=100,
x=config_lib.Config(a=1, b=2, c=config_lib.Config(h=10, i=20, j=30)),
y=config_lib.Config(qrs=5, tuv=10),
d={'a': 1, 'b': 2},
l=[1, 2, 3])
config.strict_update(
config_lib.Config(
foo='hat',
num=50,
x={'a': 5, 'c': {'h': 100}},
y=config_lib.Config(tuv=-1)),
d={'a': 10, 'c': 20},
l=[3, 4, 5, 6])
self.assertEqual(
config_lib.Config(
foo='hat',
num=50,
x=config_lib.Config(a=5, b=2,
c=config_lib.Config(h=100, i=20, j=30)),
y=config_lib.Config(qrs=5, tuv=-1),
d={'a': 10, 'c': 20},
l=[3, 4, 5, 6]),
config)
def testConfigStrictUpdateFail(self):
config = config_lib.Config(a=1, b=2, c=3, x=config_lib.Config(a=1, b=2))
with self.assertRaises(KeyError):
config.strict_update({'b': 10, 'c': 20, 'd': 50})
with self.assertRaises(KeyError):
config.strict_update(b=10, d=50)
with self.assertRaises(KeyError):
config.strict_update(x={'c': 3})
with self.assertRaises(TypeError):
config.strict_update(a='string')
with self.assertRaises(TypeError):
config.strict_update(x={'a': 'string'})
with self.assertRaises(TypeError):
config.strict_update(x=[1, 2, 3])
def testConfigFromStr(self):
config = config_lib.Config.from_str("{'c': {'d': 5}, 'b': 2, 'a': 1}")
self.assertEqual(
{'c': {'d': 5}, 'b': 2, 'a': 1}, config)
self.assertTrue(isinstance(config, config_lib.Config))
self.assertTrue(isinstance(config.c, config_lib.Config))
def testConfigParse(self):
config = config_lib.Config.parse(
'hello="world",num=1234.5,lst=[10,20.5,True,"hi",("a","b","c")],'
'dct={9:10,"stuff":"qwerty","subdict":{1:True,2:False}},'
'subconfig=c(a=1,b=[1,2,[3,4]],c=c(f="f",g="g"))')
self.assertEqual(
{'hello': 'world', 'num': 1234.5,
'lst': [10, 20.5, True, 'hi', ('a', 'b', 'c')],
'dct': {9: 10, 'stuff': 'qwerty', 'subdict': {1: True, 2: False}},
'subconfig': {'a': 1, 'b': [1, 2, [3, 4]], 'c': {'f': 'f', 'g': 'g'}}},
config)
self.assertTrue(isinstance(config, config_lib.Config))
self.assertTrue(isinstance(config.subconfig, config_lib.Config))
self.assertTrue(isinstance(config.subconfig.c, config_lib.Config))
self.assertFalse(isinstance(config.dct, config_lib.Config))
self.assertFalse(isinstance(config.dct['subdict'], config_lib.Config))
self.assertTrue(isinstance(config.lst[4], tuple))
def testConfigParseErrors(self):
with self.assertRaises(SyntaxError):
config_lib.Config.parse('a=[1,2,b="hello"')
with self.assertRaises(SyntaxError):
config_lib.Config.parse('a=1,b=c(x="a",y="b"')
with self.assertRaises(SyntaxError):
config_lib.Config.parse('a=1,b=c(x="a")y="b"')
with self.assertRaises(SyntaxError):
config_lib.Config.parse('a=1,b=c(x="a"),y="b",')
def testOneOf(self):
def make_config():
return config_lib.Config(
data=config_lib.OneOf(
[config_lib.Config(task=1, a='hello'),
config_lib.Config(task=2, a='world', b='stuff'),
config_lib.Config(task=3, c=1234)],
task=2),
model=config_lib.Config(stuff=1))
config = make_config()
config.update(config_lib.Config.parse(
'model=c(stuff=2),data=c(task=1,a="hi")'))
self.assertEqual(
config_lib.Config(
data=config_lib.Config(task=1, a='hi'),
model=config_lib.Config(stuff=2)),
config)
config = make_config()
config.update(config_lib.Config.parse(
'model=c(stuff=2),data=c(task=2,a="hi")'))
self.assertEqual(
config_lib.Config(
data=config_lib.Config(task=2, a='hi', b='stuff'),
model=config_lib.Config(stuff=2)),
config)
config = make_config()
config.update(config_lib.Config.parse(
'model=c(stuff=2),data=c(task=3)'))
self.assertEqual(
config_lib.Config(
data=config_lib.Config(task=3, c=1234),
model=config_lib.Config(stuff=2)),
config)
config = make_config()
config.update(config_lib.Config.parse(
'model=c(stuff=2)'))
self.assertEqual(
config_lib.Config(
data=config_lib.Config(task=2, a='world', b='stuff'),
model=config_lib.Config(stuff=2)),
config)
config = make_config()
config.update(config_lib.Config.parse(
'model=c(stuff=2),data=c(task=4,d=9999)'))
self.assertEqual(
config_lib.Config(
data=config_lib.Config(task=4, d=9999),
model=config_lib.Config(stuff=2)),
config)
config = make_config()
config.update(config_lib.Config.parse(
'model=c(stuff=2),data=5'))
self.assertEqual(
config_lib.Config(
data=5,
model=config_lib.Config(stuff=2)),
config)
def testOneOfStrict(self):
def make_config():
return config_lib.Config(
data=config_lib.OneOf(
[config_lib.Config(task=1, a='hello'),
config_lib.Config(task=2, a='world', b='stuff'),
config_lib.Config(task=3, c=1234)],
task=2),
model=config_lib.Config(stuff=1))
config = make_config()
config.strict_update(config_lib.Config.parse(
'model=c(stuff=2),data=c(task=1,a="hi")'))
self.assertEqual(
config_lib.Config(
data=config_lib.Config(task=1, a='hi'),
model=config_lib.Config(stuff=2)),
config)
config = make_config()
config.strict_update(config_lib.Config.parse(
'model=c(stuff=2),data=c(task=2,a="hi")'))
self.assertEqual(
config_lib.Config(
data=config_lib.Config(task=2, a='hi', b='stuff'),
model=config_lib.Config(stuff=2)),
config)
config = make_config()
config.strict_update(config_lib.Config.parse(
'model=c(stuff=2),data=c(task=3)'))
self.assertEqual(
config_lib.Config(
data=config_lib.Config(task=3, c=1234),
model=config_lib.Config(stuff=2)),
config)
config = make_config()
config.strict_update(config_lib.Config.parse(
'model=c(stuff=2)'))
self.assertEqual(
config_lib.Config(
data=config_lib.Config(task=2, a='world', b='stuff'),
model=config_lib.Config(stuff=2)),
config)
def testNestedOneOf(self):
def make_config():
return config_lib.Config(
data=config_lib.OneOf(
[config_lib.Config(task=1, a='hello'),
config_lib.Config(
task=2,
a=config_lib.OneOf(
[config_lib.Config(x=1, y=2),
config_lib.Config(x=-1, y=1000, z=4)],
x=1)),
config_lib.Config(task=3, c=1234)],
task=2),
model=config_lib.Config(stuff=1))
config = make_config()
config.update(config_lib.Config.parse(
'model=c(stuff=2),data=c(task=2,a=c(x=-1,z=8))'))
self.assertEqual(
config_lib.Config(
data=config_lib.Config(
task=2,
a=config_lib.Config(x=-1, y=1000, z=8)),
model=config_lib.Config(stuff=2)),
config)
config = make_config()
config.strict_update(config_lib.Config.parse(
'model=c(stuff=2),data=c(task=2,a=c(x=-1,z=8))'))
self.assertEqual(
config_lib.Config(
data=config_lib.Config(
task=2,
a=config_lib.Config(x=-1, y=1000, z=8)),
model=config_lib.Config(stuff=2)),
config)
config = make_config()
config.update(config_lib.Config.parse('model=c(stuff=2)'))
self.assertEqual(
config_lib.Config(
data=config_lib.Config(
task=2,
a=config_lib.Config(x=1, y=2)),
model=config_lib.Config(stuff=2)),
config)
config = make_config()
config.strict_update(config_lib.Config.parse('model=c(stuff=2)'))
self.assertEqual(
config_lib.Config(
data=config_lib.Config(
task=2,
a=config_lib.Config(x=1, y=2)),
model=config_lib.Config(stuff=2)),
config)
def testOneOfStrictErrors(self):
def make_config():
return config_lib.Config(
data=config_lib.OneOf(
[config_lib.Config(task=1, a='hello'),
config_lib.Config(task=2, a='world', b='stuff'),
config_lib.Config(task=3, c=1234)],
task=2),
model=config_lib.Config(stuff=1))
config = make_config()
with self.assertRaises(TypeError):
config.strict_update(config_lib.Config.parse(
'model=c(stuff=2),data=[1,2,3]'))
config = make_config()
with self.assertRaises(KeyError):
config.strict_update(config_lib.Config.parse(
'model=c(stuff=2),data=c(task=3,c=5678,d=9999)'))
config = make_config()
with self.assertRaises(ValueError):
config.strict_update(config_lib.Config.parse(
'model=c(stuff=2),data=c(task=4,d=9999)'))
config = make_config()
with self.assertRaises(TypeError):
config.strict_update(config_lib.Config.parse(
'model=c(stuff=2),data=5'))
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment