Unverified Commit 5ec50328 authored by Lukasz Kaiser's avatar Lukasz Kaiser Committed by GitHub
Browse files

Merge pull request #4610 from apbusia/master

Added new model.
parents 564f8d4d f82adfe1
...@@ -37,6 +37,7 @@ ...@@ -37,6 +37,7 @@
/research/real_nvp/ @laurent-dinh /research/real_nvp/ @laurent-dinh
/research/rebar/ @gjtucker /research/rebar/ @gjtucker
/research/resnet/ @panyx0718 /research/resnet/ @panyx0718
/research/seq2species/ @apbusia @depristo
/research/skip_thoughts/ @cshallue /research/skip_thoughts/ @cshallue
/research/slim/ @sguada @nathansilberman /research/slim/ @sguada @nathansilberman
/research/street/ @theraysmith /research/street/ @theraysmith
...@@ -54,4 +55,4 @@ ...@@ -54,4 +55,4 @@
/tutorials/embedding/ @zffchen78 @a-dai /tutorials/embedding/ @zffchen78 @a-dai
/tutorials/image/ @sherrym @shlens /tutorials/image/ @sherrym @shlens
/tutorials/image/cifar10_estimator/ @tfboyd @protoget /tutorials/image/cifar10_estimator/ @tfboyd @protoget
/tutorials/rnn/ @lukaszkaiser @ebrevdo /tutorials/rnn/ @lukaszkaiser @ebrevdo
\ No newline at end of file
...@@ -59,6 +59,8 @@ request. ...@@ -59,6 +59,8 @@ request.
- [rebar](rebar): low-variance, unbiased gradient estimates for discrete - [rebar](rebar): low-variance, unbiased gradient estimates for discrete
latent variable models. latent variable models.
- [resnet](resnet): deep and wide residual networks. - [resnet](resnet): deep and wide residual networks.
- [seq2species](seq2species): deep learning solution for read-level taxonomic
classification.
- [skip_thoughts](skip_thoughts): recurrent neural network sentence-to-vector - [skip_thoughts](skip_thoughts): recurrent neural network sentence-to-vector
encoder. encoder.
- [slim](slim): image classification models in TF-Slim. - [slim](slim): image classification models in TF-Slim.
......
# Seq2Species: Neural Network Models for Species Classification
*A deep learning solution for read-level taxonomic classification with 16s.*
Recent improvements in sequencing technology have made possible large, public
databases of biological sequencing data, bringing about new data richness for
many important problems in bioinformatics. However, this growing availability of
data creates a need for analysis methods capable of efficiently handling these
large sequencing datasets. We on the [Genomics team in Google
Brain](https://ai.google/research/teams/brain/healthcare-biosciences) are
particularly interested in the class of problems which can be framed as
assigning meaningful labels to short biological sequences, and are exploring the
possiblity of creating a general deep learning solution for solving this class
of sequence-labeling problems. We are excited to share our initial progress in
this direction by releasing Seq2Species, an open-source neural network framework
for [TensorFlow](https://www.tensorflow.org/) for predicting read-level
taxonomic labels from genomic sequence. Our release includes all the code
necessary to train new Seq2Species models.
## About Seq2Species
Briefly, Seq2Species provides a framework for training deep neural networks to
predict database-derived labels directly from short reads of DNA. Thus far, our
research has focused predominantly on demonstrating the value of this deep
learning approach on the problem of determining the species of origin of
next-generation sequencing reads from [16S ribosomal
DNA](https://en.wikipedia.org/wiki/16S_ribosomal_RNA). We used this
Seq2Species framework to train depthwise separable convolutional neural networks
on short subsequences from the 16S genes of more than 13 thousand distinct
species. The resulting classification model assign species-level probabilities
to individual 16S reads.
For more information about the use cases we have explored, or for technical
details describing how Seq2Species work, please see our
[preprint](https://www.biorxiv.org/content/early/2018/06/22/353474).
## Installation
Training Seq2Species models requires installing the following dependencies:
* python 2.7
* protocol buffers
* numpy
* absl
### Dependencies
Detailed instructions for installing TensorFlow are available on the [Installing
TensorFlow](https://www.tensorflow.org/install/) website. Please follow the
full instructions for installing TensorFlow with GPU support. For most
users, the following command will suffice for continuing with CPU support only:
```bash
# For CPU
pip install --upgrade tensorflow
```
The TensorFlow installation should also include installation of the numpy and
absl libraries, which are two of TensorFlow's python dependencies. If
necessary, instructions for standalone installation are available:
* [numpy](https://scipy.org/install.html)
* [absl](https://github.com/abseil/abseil-py)
Information about protocol buffers, as well as download and installation
intructions for the protocol buffer (protobuf) compiler, are available on the [Google
Developers website](https://developers.google.com/protocol-buffers/). A typical
Ubuntu user can install this library using `apt-get`:
```bash
sudo apt-get install protobuf-compiler
```
### Clone
Now, clone `tensorflow/models` to start working with the code:
```bash
git clone https://github.com/tensorflow/models.git
```
### Protobuf Compilation
Seq2Species uses protobufs to store and save dataset and model metadata. Before
the framework can be used to build and train models, the protobuf libraries must
be compiled. This can be accomplished using the following command:
```bash
# From tensorflow/models/research
protoc seq2species/protos/seq2label.proto --python_out=.
```
### Testing the Installation
One can test that Seq2Species has been installed correctly by running the
following command:
```bash
python seq2species/run_training_test.py
```
## Usage Information
Input data to Seq2Species models should be [tf.train.Example protocol messages](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/example/example.proto) stored in
[TFRecord format](https://www.tensorflow.org/versions/r1.0/api_guides/python/python_io#tfrecords_format_details).
Specifically, the input pipeline expects tf.train.Examples with a 'sequence' field
containing a genomic sequence as an upper-case string, as one field for each
target label (e.g. 'species'). There should also be an accompanying
Seq2LabelDatasetInfo text protobuf containing metadata about the input, including
the possible label values for each target.
Below, we give an example command that could be used to launch training for 1000
steps, assuming that appropriate data and metadata files are stored at
`${TFRECORD}` and `${DATASET_INFO}`:
```bash
python seq2species/run_training.py --train_files ${TFRECORD}
--metadata_path ${DATASET_INFO} --hparams 'train_steps=1000'
--logdir $HOME/seq2species
```
This will output [TensorBoard
summaries](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard), [TensorFlow
checkpoints](https://www.tensorflow.org/programmers_guide/variables#checkpoint_files), Seq2LabelModelInfo and
Seq2LabelExperimentMeasures metadata to the logdir `$HOME/seq2species`.
### Preprocessed Seq2Species Data
We have provided preprocessed data based on 16S reference sequences from the
[NCBI RefSeq Targeted Loci
Project](https://www.ncbi.nlm.nih.gov/refseq/targetedloci/) in a Seq2Species
bucket on Google Cloud Storage. After installing the
[Cloud SDK](https://cloud.google.com/sdk/install),
one can download those data (roughly 25 GB) to a local directory `${DEST}` using
the `gsutil` command:
```bash
BUCKET=gs://brain-genomics-public/research/seq2species
mkdir -p ${DEST}
gsutil -m cp ${BUCKET}/* ${DEST}
```
To check if the copy has completed successsfully, check the `${DEST}` directory:
```bash
ls -1 ${DEST}
```
which should produce:
```bash
ncbi_100bp_revcomp.dataset_info.pbtxt
ncbi_100bp_revcomp.tfrecord
```
The following command can be used to train a copy of one of our best-perfoming
deep neural network models for 100 base pair (bp) data. This command also
illustrates how to set hyperparameter values explicitly from the commandline.
The file `configuration.py` provides a full list of hyperparameters, their descriptions,
and their default values. Additional flags are described at the top of
`run_training.py`.
```bash
python seq2species/run_training.py \
--num_filters 3 \
--noise_rate 0.04 \
--train_files ${DEST}/ncbi_100bp_revcomp.tfrecord \
--metadata_path ${DEST}/ncbi_100bp_revcomp.dataset_info.pbtxt \
--logdir $HOME/seq2species \
--hparams 'filter_depths=[1,1,1],filter_widths=[5,9,13],grad_clip_norm=20.0,keep_prob=0.94017831318,
lr_decay=0.0655052811,lr_init=0.000469689635793,lrelu_slope=0.0125376069918,min_read_length=100,num_fc_layers=2,num_fc_units=2828,optimizer=adam,optimizer_hp=0.885769367218,pointwise_depths=[84,58,180],pooling_type=avg,train_steps=3000000,use_depthwise_separable=true,weight_scale=1.18409526348'
```
### Visualization
[TensorBoard](https://github.com/tensorflow/tensorboard) can be used to
visualize training curves and other metrics stored in the summary files produced
by `run_training.py`. Use the following command to launch a TensorBoard instance
for the example model directory `$HOME/seq2species`:
```bash
tensorboard --logdir=$HOME/seq2species
```
## Contact
Any issues with the Seq2Species framework should be filed with the
[TensorFlow/models issue tracker](https://github.com/tensorflow/models/issues).
Questions regarding Seq2Species capabilities can be directed to
[seq2species-interest@google.com](mailto:seq2species-interest@google.com). This
code is maintained by [@apbusia](https://github.com/apbusia) and
[@depristo](https://github.com/depristo).
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines convolutional model graph for Seq2Species.
Builds TensorFlow computation graph for predicting the given taxonomic target
labels from short reads of DNA using convolutional filters, followed by
fully-connected layers and a softmax output layer.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import math
import tensorflow as tf
import input as seq2species_input
import seq2label_utils
class ConvolutionalNet(object):
"""Class to build and store the model's computational graph and operations.
Attributes:
read_length: int; the length in basepairs of the input reads of DNA.
placeholders: dict; mapping from name to tf.Placeholder.
global_step: tf.Variable tracking number of training iterations performed.
train_op: operation to perform one training step by gradient descent.
summary_op: operation to log model's performance metrics to TF event files.
accuracy: tf.Variable giving the model's read-level accuracy for the
current inputs.
weighted_accuracy: tf.Variable giving the model's read-level weighted
accuracy for the current inputs.
loss: tf.Variable giving the model's current cross entropy loss.
logits: tf.Variable containing the model's logits for the current inputs.
predictions: tf.Variable containing the model's current predicted
probability distributions for the current inputs.
possible_labels: a dict of possible label values (list of strings), keyed by
target name. Labels in the lists are the order used for integer encoding.
use_tpu: whether model is to be run on TPU.
"""
def __init__(self, hparams, dataset_info, targets, use_tpu=False):
"""Initializes the ConvolutionalNet according to provided hyperparameters.
Does not build the graph---this is done by calling `build_graph` on the
constructed object or using `model_fn`.
Args:
hparams: tf.contrib.training.Hparams object containing the model's
hyperparamters; see configuration.py for hyperparameter definitions.
dataset_info: a `Seq2LabelDatasetInfo` message reflecting the dataset
metadata.
targets: list of strings: the names of the prediction targets.
use_tpu: whether we are running on TPU; if True, summaries will be
disabled.
"""
self._placeholders = {}
self._targets = targets
self._dataset_info = dataset_info
self._hparams = hparams
all_label_values = seq2label_utils.get_all_label_values(self.dataset_info)
self._possible_labels = {
target: all_label_values[target]
for target in self.targets
}
self._use_tpu = use_tpu
@property
def hparams(self):
return self._hparams
@property
def dataset_info(self):
return self._dataset_info
@property
def possible_labels(self):
return self._possible_labels
@property
def bases(self):
return seq2species_input.DNA_BASES
@property
def n_bases(self):
return seq2species_input.NUM_DNA_BASES
@property
def targets(self):
return self._targets
@property
def read_length(self):
return self.dataset_info.read_length
@property
def placeholders(self):
return self._placeholders
@property
def global_step(self):
return self._global_step
@property
def train_op(self):
return self._train_op
@property
def summary_op(self):
return self._summary_op
@property
def accuracy(self):
return self._accuracy
@property
def weighted_accuracy(self):
return self._weighted_accuracy
@property
def loss(self):
return self._loss
@property
def total_loss(self):
return self._total_loss
@property
def logits(self):
return self._logits
@property
def predictions(self):
return self._predictions
@property
def use_tpu(self):
return self._use_tpu
def _summary_scalar(self, name, scalar):
"""Adds a summary scalar, if the platform supports summaries."""
if not self.use_tpu:
return tf.summary.scalar(name, scalar)
else:
return None
def _summary_histogram(self, name, values):
"""Adds a summary histogram, if the platform supports summaries."""
if not self.use_tpu:
return tf.summary.histogram(name, values)
else:
return None
def _init_weights(self, shape, scale=1.0, name='weights'):
"""Randomly initializes a weight Tensor of the given shape.
Args:
shape: list; desired Tensor dimensions.
scale: float; standard deviation scale with which to initialize weights.
name: string name for the variable.
Returns:
TF Variable contining truncated random Normal initialized weights.
"""
num_inputs = shape[0] if len(shape) < 3 else shape[0] * shape[1] * shape[2]
stddev = scale / math.sqrt(num_inputs)
return tf.get_variable(
name,
shape=shape,
initializer=tf.truncated_normal_initializer(0., stddev))
def _init_bias(self, size):
"""Initializes bias vector of given shape as zeros.
Args:
size: int; desired size of bias Tensor.
Returns:
TF Variable containing the initialized biases.
"""
return tf.get_variable(
name='b_{}'.format(size),
shape=[size],
initializer=tf.zeros_initializer())
def _add_summaries(self, mode, gradient_norm, parameter_norm):
"""Defines TensorFlow operation for logging summaries to event files.
Args:
mode: the ModeKey string.
gradient_norm: Tensor; norm of gradients produced during the current
training operation.
parameter_norm: Tensor; norm of the model parameters produced during the
current training operation.
"""
# Log summaries for TensorBoard.
if mode == tf.estimator.ModeKeys.TRAIN:
self._summary_scalar('norm_of_gradients', gradient_norm)
self._summary_scalar('norm_of_parameters', parameter_norm)
self._summary_scalar('total_loss', self.total_loss)
self._summary_scalar('learning_rate', self._learn_rate)
for target in self.targets:
self._summary_scalar('per_read_weighted_accuracy/{}'.format(target),
self.weighted_accuracy[target])
self._summary_scalar('per_read_accuracy/{}'.format(target),
self.accuracy[target])
self._summary_histogram('prediction_frequency/{}'.format(target),
self._predictions[target])
self._summary_scalar('cross_entropy_loss/{}'.format(target),
self._loss[target])
self._summary_op = tf.summary.merge_all()
else:
# Log average performance metrics over many batches using placeholders.
summaries = []
for target in self.targets:
accuracy_ph = tf.placeholder(tf.float32, shape=())
weighted_accuracy_ph = tf.placeholder(tf.float32, shape=())
cross_entropy_ph = tf.placeholder(tf.float32, shape=())
self._placeholders.update({
'accuracy/{}'.format(target): accuracy_ph,
'weighted_accuracy/{}'.format(target): weighted_accuracy_ph,
'cross_entropy/{}'.format(target): cross_entropy_ph,
})
summaries += [
self._summary_scalar('cross_entropy_loss/{}'.format(target),
cross_entropy_ph),
self._summary_scalar('per_read_accuracy/{}'.format(target),
accuracy_ph),
self._summary_scalar('per_read_weighted_accuracy/{}'.format(target),
weighted_accuracy_ph)
]
self._summary_op = tf.summary.merge(summaries)
def _convolution(self,
inputs,
filter_dim,
pointwise_dim=None,
scale=1.0,
padding='SAME'):
"""Applies convolutional filter of given dimensions to given input Tensor.
If a pointwise dimension is specified, a depthwise separable convolution is
performed.
Args:
inputs: 4D Tensor of shape (# reads, 1, # basepairs, # bases).
filter_dim: integer tuple of the form (width, depth).
pointwise_dim: int; output dimension for pointwise convolution.
scale: float; standard deviation scale with which to initialize weights.
padding: string; type of padding to use. One of "SAME" or "VALID".
Returns:
4D Tensor result of applying the convolutional filter to the inputs.
"""
in_channels = inputs.get_shape()[3].value
filter_width, filter_depth = filter_dim
filters = self._init_weights([1, filter_width, in_channels, filter_depth],
scale)
self._summary_histogram(filters.name.split(':')[0].split('/')[1], filters)
if pointwise_dim is None:
return tf.nn.conv2d(
inputs,
filters,
strides=[1, 1, 1, 1],
padding=padding,
name='weights')
pointwise_filters = self._init_weights(
[1, 1, filter_depth * in_channels, pointwise_dim],
scale,
name='pointwise_weights')
self._summary_histogram(
pointwise_filters.name.split(':')[0].split('/')[1], pointwise_filters)
return tf.nn.separable_conv2d(
inputs,
filters,
pointwise_filters,
strides=[1, 1, 1, 1],
padding=padding)
def _pool(self, inputs, pooling_type):
"""Performs pooling across width and height of the given inputs.
Args:
inputs: Tensor shaped (batch, height, width, channels) over which to pool.
In our case, height is a unitary dimension and width can be thought of
as the read dimension.
pooling_type: string; one of "avg" or "max".
Returns:
Tensor result of performing pooling of the given pooling_type over the
height and width dimensions of the given inputs.
"""
if pooling_type == 'max':
return tf.reduce_max(inputs, axis=[1, 2])
if pooling_type == 'avg':
return tf.reduce_sum(
inputs, axis=[1, 2]) / tf.to_float(tf.shape(inputs)[2])
def _leaky_relu(self, lrelu_slope, inputs):
"""Applies leaky ReLu activation to the given inputs with the given slope.
Args:
lrelu_slope: float; slope value for the activation function.
A slope of 0.0 defines a standard ReLu activation, while a positive
slope defines a leaky ReLu.
inputs: Tensor upon which to apply the activation function.
Returns:
Tensor result of applying the activation function to the given inputs.
"""
with tf.variable_scope('leaky_relu_activation'):
return tf.maximum(lrelu_slope * inputs, inputs)
def _dropout(self, inputs, keep_prob):
"""Applies dropout to the given inputs.
Args:
inputs: Tensor upon which to apply dropout.
keep_prob: float; probability with which to randomly retain values in
the given input.
Returns:
Tensor result of applying dropout to the given inputs.
"""
with tf.variable_scope('dropout'):
if keep_prob < 1.0:
return tf.nn.dropout(inputs, keep_prob)
return inputs
def build_graph(self, features, labels, mode, batch_size):
"""Creates TensorFlow model graph.
Args:
features: a dict of input features Tensors.
labels: a dict (by target name) of prediction labels.
mode: the ModeKey string.
batch_size: the integer batch size.
Side Effect:
Adds the following key Tensors and operations as class attributes:
placeholders, global_step, train_op, summary_op, accuracy,
weighted_accuracy, loss, logits, and predictions.
"""
is_train = (mode == tf.estimator.ModeKeys.TRAIN)
read = features['sequence']
# Add a unitary dimension, so we can use conv2d.
read = tf.expand_dims(read, 1)
prev_out = read
filters = zip(self.hparams.filter_widths, self.hparams.filter_depths)
for i, f in enumerate(filters):
with tf.variable_scope('convolution_' + str(i)):
if self.hparams.use_depthwise_separable:
p = self.hparams.pointwise_depths[i]
else:
p = None
conv_out = self._convolution(
prev_out, f, pointwise_dim=p, scale=self.hparams.weight_scale)
conv_act_out = self._leaky_relu(self.hparams.lrelu_slope, conv_out)
prev_out = (
self._dropout(conv_act_out, self.hparams.keep_prob)
if is_train else conv_act_out)
for i in xrange(self.hparams.num_fc_layers):
with tf.variable_scope('fully_connected_' + str(i)):
# Create a convolutional layer which is equivalent to a fully-connected
# layer when reads have length self.hparams.min_read_length.
# The convolution will tile the layer appropriately for longer reads.
biases = self._init_bias(self.hparams.num_fc_units)
if i == 0:
# Take entire min_read_length segment as input.
# Output a single value per min_read_length_segment.
filter_dimensions = (self.hparams.min_read_length,
self.hparams.num_fc_units)
else:
# Take single output value of previous layer as input.
filter_dimensions = (1, self.hparams.num_fc_units)
fc_out = biases + self._convolution(
prev_out,
filter_dimensions,
scale=self.hparams.weight_scale,
padding='VALID')
self._summary_histogram(biases.name.split(':')[0].split('/')[1], biases)
fc_act_out = self._leaky_relu(self.hparams.lrelu_slope, fc_out)
prev_out = (
self._dropout(fc_act_out, self.hparams.keep_prob)
if is_train else fc_act_out)
# Pool to collapse tiling for reads longer than hparams.min_read_length.
with tf.variable_scope('pool'):
pool_out = self._pool(prev_out, self.hparams.pooling_type)
with tf.variable_scope('output'):
self._logits = {}
self._predictions = {}
self._weighted_accuracy = {}
self._accuracy = {}
self._loss = collections.OrderedDict()
for target in self.targets:
with tf.variable_scope(target):
label = labels[target]
possible_labels = self.possible_labels[target]
weights = self._init_weights(
[pool_out.get_shape()[1].value,
len(possible_labels)],
self.hparams.weight_scale,
name='weights')
biases = self._init_bias(len(possible_labels))
self._summary_histogram(
weights.name.split(':')[0].split('/')[1], weights)
self._summary_histogram(
biases.name.split(':')[0].split('/')[1], biases)
logits = tf.matmul(pool_out, weights) + biases
predictions = tf.nn.softmax(logits)
gather_inds = tf.stack([tf.range(batch_size), label], axis=1)
self._weighted_accuracy[target] = tf.reduce_mean(
tf.gather_nd(predictions, gather_inds))
argmax_prediction = tf.cast(tf.argmax(predictions, axis=1), tf.int32)
self._accuracy[target] = tf.reduce_mean(
tf.to_float(tf.equal(label, argmax_prediction)))
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=label, logits=logits)
self._loss[target] = tf.reduce_mean(losses)
self._logits[target] = logits
self._predictions[target] = predictions
# Compute total loss
self._total_loss = tf.add_n(self._loss.values())
# Define the optimizer.
# tf.estimator framework builds the global_step for us, but if we aren't
# using the framework we have to make it ourselves.
self._global_step = tf.train.get_or_create_global_step()
if self.hparams.lr_decay < 0:
self._learn_rate = self.hparams.lr_init
else:
self._learn_rate = tf.train.exponential_decay(
self.hparams.lr_init,
self._global_step,
int(self.hparams.train_steps),
self.hparams.lr_decay,
staircase=False)
if self.hparams.optimizer == 'adam':
opt = tf.train.AdamOptimizer(self._learn_rate, self.hparams.optimizer_hp)
elif self.hparams.optimizer == 'momentum':
opt = tf.train.MomentumOptimizer(self._learn_rate,
self.hparams.optimizer_hp)
if self.use_tpu:
opt = tf.contrib.tpu.CrossShardOptimizer(opt)
gradients, variables = zip(*opt.compute_gradients(self._total_loss))
clipped_gradients, _ = tf.clip_by_global_norm(gradients,
self.hparams.grad_clip_norm)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
self._train_op = opt.apply_gradients(
zip(clipped_gradients, variables), global_step=self._global_step)
if not self.use_tpu:
grad_norm = tf.global_norm(gradients) if is_train else None
param_norm = tf.global_norm(variables) if is_train else None
self._add_summaries(mode, grad_norm, param_norm)
def model_fn(self, features, labels, mode, params):
"""Function fulfilling the tf.estimator model_fn interface.
Args:
features: a dict containing the input features for prediction.
labels: a dict from target name to Tensor-value prediction.
mode: the ModeKey string.
params: a dictionary of parameters for building the model; current params
are params["batch_size"]: the integer batch size.
Returns:
A tf.estimator.EstimatorSpec object ready for use in training, inference.
or evaluation.
"""
self.build_graph(features, labels, mode, params['batch_size'])
return tf.estimator.EstimatorSpec(
mode,
predictions=self.predictions,
loss=self.total_loss,
train_op=self.train_op,
eval_metric_ops={})
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines hyperparameter configuration for ConvolutionalNet models.
Specifically, provides methods for defining and initializing TensorFlow
hyperparameters objects for a convolutional model as defined in:
seq2species.build_model
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
def parse_hparams(hparam_values='', num_filters=1):
"""Initializes TensorFlow hyperparameters object with default values.
In addition, default hyperparameter values are overwritten with the specified
ones, where necessary.
Args:
hparam_values: comma-separated string of name=value pairs for setting
particular hyperparameters.
num_filters: int; number of filters in the model.
Must be fixed outside of hyperparameter/study object as Vizier does not
support having inter-hyperparameter dependencies.
Returns:
tf.contrib.training.Hparams object containing the model's hyperparameters.
"""
hparams = tf.contrib.training.HParams()
# Specify model architecture option.
hparams.add_hparam('use_depthwise_separable', True)
# Specify number of model parameters.
hparams.add_hparam('filter_widths', [3] * num_filters)
hparams.add_hparam('filter_depths', [1] * num_filters)
hparams.add_hparam('pointwise_depths', [64] * num_filters)
hparams.add_hparam('num_fc_layers', 2)
hparams.add_hparam('num_fc_units', 455)
hparams.add_hparam('min_read_length', 100)
hparams.add_hparam('pooling_type', 'avg')
# Specify activation options.
hparams.add_hparam('lrelu_slope', 0.0) # Negative slope for leaky relu.
# Specify training options.
hparams.add_hparam('keep_prob', 1.0)
hparams.add_hparam('weight_scale', 1.0)
hparams.add_hparam('grad_clip_norm', 20.0)
hparams.add_hparam('lr_init', 0.001)
hparams.add_hparam('lr_decay', 0.1)
hparams.add_hparam('optimizer', 'adam')
# optimizer_hp is decay rate for 1st moment estimates for ADAM, and
# momentum for SGD.
hparams.add_hparam('optimizer_hp', 0.9)
hparams.add_hparam('train_steps', 400000)
# Overwrite defaults with specified values.
hparams.parse(hparam_values)
return hparams
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Input pipe for feeding examples to a Seq2Label model graph."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from google.protobuf import text_format
from protos import seq2label_pb2
import seq2label_utils
DNA_BASES = tuple('ACGT')
NUM_DNA_BASES = len(DNA_BASES)
# Possible FASTA characters/IUPAC ambiguity codes.
# See https://en.wikipedia.org/wiki/Nucleic_acid_notation.
AMBIGUITY_CODES = {
'K': 'GT',
'M': 'AC',
'R': 'AG',
'Y': 'CT',
'S': 'CG',
'W': 'AT',
'B': 'CGT',
'V': 'ACG',
'H': 'ACT',
'D': 'AGT',
'X': 'ACGT',
'N': 'ACGT'
}
def load_dataset_info(dataset_info_path):
"""Load a `Seq2LabelDatasetInfo` from a serialized text proto file."""
dataset_info = seq2label_pb2.Seq2LabelDatasetInfo()
with tf.gfile.Open(dataset_info_path, 'r') as f:
text_format.Parse(f.read(), dataset_info)
return dataset_info
class _InputEncoding(object):
"""A helper class providing the graph operations needed to encode input.
Instantiation of an _InputEncoding will write on the default TF graph, so it
should only be instantiated inside the `input_fn`.
Attributes:
mode: `tf.estimator.ModeKeys`; the execution mode {TRAIN, EVAL, INFER}.
targets: list of strings; the names of the labels of interest (e.g.
"species").
dna_bases: a tuple of the recognized DNA alphabet.
n_bases: the size of the DNA alphabet.
all_characters: list of recognized alphabet, including ambiguity codes.
label_values: a tuple of strings, the possible label values of the
prediction target.
n_labels: the size of label_values
fixed_read_length: an integer value of the statically-known read length, or
None if the read length is to be determined dynamically.
"""
def __init__(self,
dataset_info,
mode,
targets,
noise_rate=0.0,
fixed_read_length=None):
self.mode = mode
self.targets = targets
self.dna_bases = DNA_BASES
self.n_bases = NUM_DNA_BASES
self.all_characters = list(DNA_BASES) + sorted(AMBIGUITY_CODES.keys())
self.character_encodings = np.concatenate(
[[self._character_to_base_distribution(char)]
for char in self.all_characters],
axis=0)
all_legal_label_values = seq2label_utils.get_all_label_values(dataset_info)
# TF lookup tables.
self.characters_table = tf.contrib.lookup.index_table_from_tensor(
mapping=self.all_characters)
self.label_tables = {
target: tf.contrib.lookup.index_table_from_tensor(
all_legal_label_values[target])
for target in targets
}
self.fixed_read_length = fixed_read_length
self.noise_rate = noise_rate
def _character_to_base_distribution(self, char):
"""Maps the given character to a probability distribution over DNA bases.
Args:
char: character to be encoded as a probability distribution over bases.
Returns:
Array of size (self.n_bases,) representing the identity of the given
character as a distribution over the possible DNA bases, self.dna_bases.
Raises:
ValueError: if the given character is not contained in the recognized
alphabet, self.all_characters.
"""
if char not in self.all_characters:
raise ValueError(
'Base distribution requested for unrecognized character %s.' % char)
possible_bases = AMBIGUITY_CODES[char] if char in AMBIGUITY_CODES else char
base_indices = [self.dna_bases.index(base) for base in possible_bases]
probability_weight = 1.0 / len(possible_bases)
distribution = np.zeros((self.n_bases))
distribution[base_indices] = probability_weight
return distribution
def encode_read(self, string_seq):
"""Converts the input read sequence to one-hot encoding.
Args:
string_seq: tf.String; input read sequence.
Returns:
Input read sequence as a one-hot encoded Tensor, with depth and ordering
of one-hot encoding determined by the given bases. Ambiguous characters
such as "N" and "S" are encoded as a probability distribution over the
possible bases they represent.
"""
with tf.variable_scope('encode_read'):
read = tf.string_split([string_seq], delimiter='').values
read = self.characters_table.lookup(read)
read = tf.cast(tf.gather(self.character_encodings, read), tf.float32)
if self.fixed_read_length:
read = tf.reshape(read, (self.fixed_read_length, self.n_bases))
return read
def encode_label(self, target, string_label):
"""Converts the label value to an integer encoding.
Args:
target: str; the target name.
string_label: tf.String; value of the label for the current input read.
Returns:
Given label value as an index into the possible_target_values.
"""
with tf.variable_scope('encode_label/{}'.format(target)):
return tf.cast(self.label_tables[target].lookup(string_label), tf.int32)
def _empty_label(self):
return tf.constant((), dtype=tf.int32, shape=())
def parse_single_tfexample(self, serialized_example):
"""Parses a tf.train.Example proto to a one-hot encoded read, label pair.
Injects noise into the incoming tf.train.Example's read sequence
when noise_rate is non-zero.
Args:
serialized_example: string; the serialized tf.train.Example proto
containing the read sequence and label value of interest as
tf.FixedLenFeatures.
Returns:
Tuple (features, labels) of dicts for the input features and prediction
targets.
"""
with tf.variable_scope('parse_single_tfexample'):
features_spec = {'sequence': tf.FixedLenFeature([], tf.string)}
for target in self.targets:
features_spec[target] = tf.FixedLenFeature([], tf.string)
features = tf.parse_single_example(
serialized_example, features=features_spec)
if self.noise_rate > 0.0:
read_sequence = tf.py_func(seq2label_utils.add_read_noise,
[features['sequence'], self.noise_rate],
(tf.string))
else:
read_sequence = features['sequence']
read_sequence = self.encode_read(read_sequence)
read_features = {'sequence': read_sequence}
if self.mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
label = {
target: self.encode_label(target, features[target])
for target in self.targets
}
else:
label = {target: self._empty_label() for target in self.targets}
return read_features, label
class InputDataset(object):
"""A class providing access to input data for the Seq2Label model.
Attributes:
mode: `tf.estimator.ModeKeys`; the execution mode {TRAIN, EVAL, INFER}.
targets: list of strings; the names of the labels of interest (e.g.
"species").
dataset_info: a `Seq2LabelDatasetInfo` message reflecting the dataset
metadata.
initializer: the TF initializer op for the underlying iterator, which
will rewind the iterator.
is_train: Boolean indicating whether or not the execution mode is TRAIN.
"""
def __init__(self,
mode,
targets,
dataset_info,
train_epochs=None,
noise_rate=0.0,
random_seed=None,
input_tfrecord_files=None,
fixed_read_length=None,
ensure_constant_batch_size=False,
num_parallel_calls=32):
"""Constructor for InputDataset.
Args:
mode: `tf.estimator.ModeKeys`; the execution mode {TRAIN, EVAL, INFER}.
targets: list of strings; the names of the labels of interest (e.g.
"species").
dataset_info: a `Seq2LabelDatasetInfo` message reflecting the dataset
metadata.
train_epochs: the number of training epochs to perform, if mode==TRAIN.
noise_rate: float [0.0, 1.0] specifying rate at which to inject
base-flipping noise into the read sequences.
random_seed: seed to be used for shuffling, if mode==TRAIN.
input_tfrecord_files: a list of filenames for TFRecords of TF examples.
fixed_read_length: an integer value of the statically-known read length,
or None if the read length is to be determined dynamically. The read
length must be known statically for TPU execution.
ensure_constant_batch_size: ensure a constant batch size at the expense of
discarding the last "short" batch. This also gives us a statically
constant batch size, which is essential for e.g. the TPU platform.
num_parallel_calls: the number of dataset elements to process in parallel.
If None, elements will be processed sequentially.
"""
self.input_tfrecord_files = input_tfrecord_files
self.mode = mode
self.targets = targets
self.dataset_info = dataset_info
self._train_epochs = train_epochs
self._noise_rate = noise_rate
self._random_seed = random_seed
if random_seed is not None:
np.random.seed(random_seed)
self._fixed_read_length = fixed_read_length
self._ensure_constant_batch_size = ensure_constant_batch_size
self._num_parallel_calls = num_parallel_calls
@staticmethod
def from_tfrecord_files(input_tfrecord_files, *args, **kwargs):
return InputDataset(
*args, input_tfrecord_files=input_tfrecord_files, **kwargs)
@property
def is_train(self):
return self.mode == tf.estimator.ModeKeys.TRAIN
def input_fn(self, params):
"""Supplies input for the model.
This function supplies input to our model as a function of the mode.
Args:
params: a dictionary, containing:
- params['batch_size']: the integer batch size.
Returns:
A tuple of two values as follows:
1) the *features* dict, containing a tensor value for keys as follows:
- "sequence" - the encoded read input sequence.
2) the *labels* dict. containing a key for `target`, whose value is:
- a string Tensor value (in TRAIN/EVAL mode), or
- a blank Tensor (PREDICT mode).
"""
randomize_input = self.is_train
batch_size = params['batch_size']
encoding = _InputEncoding(
self.dataset_info,
self.mode,
self.targets,
noise_rate=self._noise_rate,
fixed_read_length=self._fixed_read_length)
dataset = tf.data.TFRecordDataset(self.input_tfrecord_files)
dataset = dataset.map(
encoding.parse_single_tfexample,
num_parallel_calls=self._num_parallel_calls)
dataset = dataset.repeat(self._train_epochs if self.is_train else 1)
if randomize_input:
dataset = dataset.shuffle(
buffer_size=max(1000, batch_size), seed=self._random_seed)
if self._ensure_constant_batch_size:
# Only take batches of *exactly* size batch_size; then we get a
# statically knowable batch shape.
dataset = dataset.apply(
tf.contrib.data.batch_and_drop_remainder(batch_size))
else:
dataset = dataset.batch(batch_size)
# Prefetch to allow infeed to be in parallel with model computations.
dataset = dataset.prefetch(2)
# Use initializable iterator to support table lookups.
iterator = dataset.make_initializable_iterator()
self.initializer = iterator.initializer
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
features, labels = iterator.get_next()
return (features, labels)
# Protos for Tensorflow Seq2Species API.
package(
default_visibility = ["//visibility:public"],
)
py_proto_library(
name = "seq2label_py_pb2",
api_version = 2,
deps = [":seq2label_proto"],
)
proto_library(
name = "seq2label_proto",
srcs = ["seq2label.proto"],
)
syntax = "proto2";
package seq2species.protos;
// Summarizes metadata information for a dataset that can be used for running
// training or inference.
message Seq2LabelDatasetInfo {
// Summarizes all possible values for a given label in the dataset.
message LabelInfo {
optional string name = 1;
repeated string values = 2;
// Per-value weights used to normalize the classes in a dataset.
repeated float weights = 3;
}
repeated LabelInfo labels = 3;
// Length (in basepairs) of the reads in the dataset.
optional int32 read_length = 4;
// Stride (in number of basepairs) in the moving window.
optional int32 read_stride = 7;
// Total number of examples in the dataset.
optional int64 num_examples = 5;
// Full path to the dataset.
optional string dataset_path = 6;
}
// Summarizes metadata information about a model trained on a Seq2Label dataset.
message Seq2LabelModelInfo {
optional string hparams_string = 1;
optional string model_type = 2;
repeated string targets = 3;
optional int32 num_filters = 4;
optional int32 batch_size = 5;
optional string metadata_path = 6;
optional float training_noise_rate = 7;
}
// Summarizes resulting measures of modelling experiments.
message Seq2LabelExperimentMeasures {
optional string checkpoint_path = 1;
optional int64 steps = 2;
optional float wall_time = 3;
optional bool experiment_infeasible = 4;
message Measure {
optional string name = 1;
optional float value = 2;
}
repeated Measure measures = 5;
}
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines training scheme for neural networks for Seq2Species prediction.
Defines and runs the loop for training a (optionally) depthwise separable
convolutional model for predicting taxonomic labels from short reads of DNA.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
from absl import flags
import numpy as np
import tensorflow as tf
from google.protobuf import text_format
import build_model
import configuration
import input as seq2species_input
from protos import seq2label_pb2
import seq2label_utils
# Define non-tunable parameters.
flags.DEFINE_integer('num_filters', 1, 'Number of filters for conv model')
flags.DEFINE_string('hparams', '',
'Comma-separated list of name=value hyperparameter '
"pairs ('hp1=value1,hp2=value2'). Unspecified "
'hyperparameters will be filled with defaults.')
flags.DEFINE_integer('batch_size', 512, 'Size of batches during training.')
flags.DEFINE_integer('min_train_steps', 1000,
'Minimum number of training steps to run.')
flags.DEFINE_float('max_task_loss', 10.0,
"Terminate trial if task loss doesn't fall below this "
'within --min_train_steps.')
flags.DEFINE_integer('n_print_progress_every', 1000,
'Print training progress every '
'--n_print_progress_every global steps.')
flags.DEFINE_list('targets', ['species'],
'Names of taxonomic ranks to use as training targets.')
flags.DEFINE_float(
'noise_rate', 0.0, 'Rate [0.0, 1.0] at which to inject '
'base-flipping noise into input read sequences.')
# Define paths to logs and data.
flags.DEFINE_list(
'train_files', [], 'Full paths to the TFRecords containing the '
'training examples.')
flags.DEFINE_string(
'metadata_path', '', 'Full path of the text proto containing configuration '
'information about the set of training examples.')
flags.DEFINE_string('logdir', '/tmp/seq2species',
'Directory to which to write logs.')
# Define supervisor/checkpointing options.
flags.DEFINE_integer('task', 0, 'Task ID of the replica running the training.')
flags.DEFINE_string('master', '', 'Name of the TF master to use.')
flags.DEFINE_integer(
'save_model_secs', 900, 'Rate at which to save model parameters. '
'Set to 0 to disable checkpointing.')
flags.DEFINE_integer('recovery_wait_secs', 30,
'Wait to recover model from checkpoint '
'before timing out.')
flags.DEFINE_integer('save_summaries_secs', 900,
'Rate at which to save Tensorboard summaries.')
flags.DEFINE_integer('ps_tasks', 0,
'Number of tasks in the ps job; 0 if no ps is used.')
FLAGS = flags.FLAGS
RANDOM_SEED = 42
def wait_until(time_sec):
"""Stalls execution until a given time.
Args:
time_sec: time, in seconds, until which to loop idly.
"""
while time.time() < time_sec:
pass
def update_measures(measures, new_measures, loss_val, max_loss=None):
"""Updates tracking of experimental measures and infeasibilty.
Args:
measures: dict; mapping from measure name to measure value.
new_measures: dict; mapping from measure name to new measure values.
loss_val: float; value of loss metric by which to determine fesibility.
max_loss: float; maximum value at which to consider the loss feasible.
Side Effects:
Updates the given mapping of measures and values based on the current
experimental metrics stored in new_measures, and determines current
feasibility of the experiment based on the provided loss value.
"""
max_loss = max_loss if max_loss else np.finfo('f').max
measures['is_infeasible'] = (
loss_val >= max_loss or not np.isfinite(loss_val))
measures.update(new_measures)
def run_training(model, hparams, training_dataset, logdir, batch_size):
"""Trains the given model on random mini-batches of reads.
Args:
model: ConvolutionalNet instance containing the model graph and operations.
hparams: tf.contrib.training.Hparams object containing the model's
hyperparamters; see configuration.py for hyperparameter definitions.
training_dataset: an `InputDataset` that can feed labelled examples.
logdir: string; full path of directory to which to save checkpoints.
batch_size: integer batch size.
Yields:
Tuple comprising a dictionary of experimental measures and the save path
for train checkpoints and summaries.
"""
input_params = dict(batch_size=batch_size)
features, labels = training_dataset.input_fn(input_params)
model.build_graph(features, labels, tf.estimator.ModeKeys.TRAIN, batch_size)
is_chief = FLAGS.task == 0
scaffold = tf.train.Scaffold(
saver=tf.train.Saver(
tf.global_variables(),
max_to_keep=5,
keep_checkpoint_every_n_hours=1.0),
init_op=tf.global_variables_initializer(),
summary_op=model.summary_op)
with tf.train.MonitoredTrainingSession(
master=FLAGS.master,
checkpoint_dir=logdir,
is_chief=is_chief,
scaffold=scaffold,
save_summaries_secs=FLAGS.save_summaries_secs,
save_checkpoint_secs=FLAGS.save_model_secs,
max_wait_secs=FLAGS.recovery_wait_secs) as sess:
global_step = sess.run(model.global_step)
print('Initialized model at global step ', global_step)
init_time = time.time()
measures = {'is_infeasible': False}
if is_chief:
model_info = seq2label_utils.construct_seq2label_model_info(
hparams, 'conv', FLAGS.targets, FLAGS.metadata_path, FLAGS.batch_size,
FLAGS.num_filters, FLAGS.noise_rate)
write_message(model_info, os.path.join(logdir, 'model_info.pbtxt'))
ops = [
model.accuracy, model.weighted_accuracy, model.total_loss,
model.global_step, model.train_op
]
while not sess.should_stop() and global_step < hparams.train_steps:
accuracy, weighted_accuracy, loss, global_step, _ = sess.run(ops)
def gather_measures():
"""Updates the measures dictionary from this batch."""
new_measures = {'train_loss': loss, 'global_step': global_step}
for target in FLAGS.targets:
new_measures.update({
('train_accuracy/%s' % target): accuracy[target],
('train_weighted_accuracy/%s' % target): weighted_accuracy[target]
})
update_measures(
measures, new_measures, loss, max_loss=FLAGS.max_task_loss)
# Periodically track measures according to current mini-batch performance.
# Log a message.
if global_step % FLAGS.n_print_progress_every == 0:
log_message = ('\tstep: %d (%d sec), loss: %f' %
(global_step, time.time() - init_time, loss))
for target in FLAGS.targets:
log_message += (', accuracy/%s: %f ' % (target, accuracy[target]))
log_message += (', weighted_accuracy/%s: %f ' %
(target, weighted_accuracy[target]))
print(log_message)
# Gather new measures and update the measures dictionary.
gather_measures()
yield measures, scaffold.saver.last_checkpoints[-1]
# Check for additional stopping criteria.
if not np.isfinite(loss) or (loss >= FLAGS.max_task_loss and
global_step > FLAGS.min_train_steps):
break
# Always yield once at the end.
gather_measures()
yield measures, scaffold.saver.last_checkpoints[-1]
def write_message(message, filename):
"""Writes contents of the given message to the given filename as a text proto.
Args:
message: the proto message to save.
filename: full path of file to which to save the text proto.
Side Effects:
Outputs a text proto file to the given filename.
"""
message_string = text_format.MessageToString(message)
with tf.gfile.GFile(filename, 'w') as f:
f.write(message_string)
def write_measures(measures, checkpoint_file, init_time):
"""Writes performance measures to file.
Args:
measures: dict; mapping from measure name to measure value.
checkpoint_file: string; full save path for checkpoints and summaries.
init_time: int; start time for work on the current experiment.
Side Effects:
Writes given dictionary of performance measures for the current experiment
to a 'measures.pbtxt' file in the checkpoint directory.
"""
# Save experiment measures.
print('global_step: ', measures['global_step'])
experiment_measures = seq2label_pb2.Seq2LabelExperimentMeasures(
checkpoint_path=checkpoint_file,
steps=measures['global_step'],
experiment_infeasible=measures['is_infeasible'],
wall_time=time.time() - init_time) # Inaccurate for restarts.
for name, value in measures.iteritems():
if name not in ['is_infeasible', 'global_step']:
experiment_measures.measures.add(name=name, value=value)
measures_file = os.path.join(
os.path.dirname(checkpoint_file), 'measures.pbtxt')
write_message(experiment_measures, measures_file)
print('Wrote ', measures_file,
' containing the following experiment measures:\n', experiment_measures)
def main(unused_argv):
dataset_info = seq2species_input.load_dataset_info(FLAGS.metadata_path)
init_time = time.time()
# Determine model hyperparameters.
hparams = configuration.parse_hparams(FLAGS.hparams, FLAGS.num_filters)
print('Current Hyperparameters:')
for hp_name, hp_val in hparams.values().items():
print('\t', hp_name, ': ', hp_val)
# Initialize the model graph.
print('Constructing TensorFlow Graph.')
tf.reset_default_graph()
input_dataset = seq2species_input.InputDataset.from_tfrecord_files(
FLAGS.train_files,
'train',
FLAGS.targets,
dataset_info,
noise_rate=FLAGS.noise_rate,
random_seed=RANDOM_SEED)
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
model = build_model.ConvolutionalNet(
hparams, dataset_info, targets=FLAGS.targets)
# Run the experiment.
measures, checkpoint_file = None, None
print('Starting model training.')
for cur_measures, cur_file in run_training(
model, hparams, input_dataset, FLAGS.logdir, batch_size=FLAGS.batch_size):
measures, checkpoint_file = cur_measures, cur_file
# Save experiment results.
write_measures(measures, checkpoint_file, init_time)
if __name__ == '__main__':
tf.app.run(main)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for run_training."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
from absl import flags
from absl.testing import absltest
from absl.testing import flagsaver
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from google.protobuf import text_format
import run_training
from protos import seq2label_pb2
import test_utils
FLAGS = flags.FLAGS
class RunTrainingTest(parameterized.TestCase):
@parameterized.parameters(2, 4, 7)
def test_wait_until(self, wait_sec):
end_time = time.time() + wait_sec
run_training.wait_until(end_time)
self.assertEqual(round(time.time() - end_time), 0)
@parameterized.parameters(
({}, {'a': 0.7, 'b': 12.3}, 12.3, None,
{'a': 0.7, 'b': 12.3, 'is_infeasible': False}),
({'a': 0.42}, {'b': 24.5}, 24.5, 32.0,
{'a': 0.42, 'b': 24.5, 'is_infeasible': False}),
({'a': 0.503}, {'a': 0.82, 'b': 7.2}, 7.2, 0.1,
{'a': 0.82, 'b': 7.2, 'is_infeasible': True}),
({}, {'a': 0.7, 'b': 12.3}, float('Inf'), None,
{'a': 0.7, 'b': 12.3, 'is_infeasible': True})
)
def test_update_measures(self, measures, new_measures, loss, max_loss,
expected):
run_training.update_measures(measures, new_measures, loss, max_loss)
self.assertEqual(measures, expected)
def test_write_measures(self):
init_time = time.time()
measures = {
'global_step': 311448,
'train_loss': np.float32(18.36),
'train_weighted_accuracy': np.float32(0.3295),
'train_accuracy': 0.8243,
'is_infeasible': False
}
tmp_path = os.path.join(FLAGS.test_tmpdir, 'measures.pbtxt')
run_training.write_measures(measures, tmp_path, init_time)
experiment_measures = seq2label_pb2.Seq2LabelExperimentMeasures()
with tf.gfile.Open(tmp_path) as f:
text_format.Parse(f.read(), experiment_measures)
self.assertEqual(experiment_measures.checkpoint_path, tmp_path)
self.assertFalse(experiment_measures.experiment_infeasible)
self.assertEqual(experiment_measures.steps, measures['global_step'])
self.assertGreater(experiment_measures.wall_time, 0)
self.assertEqual(len(experiment_measures.measures), 3)
for measure in experiment_measures.measures:
self.assertAlmostEqual(measure.value, measures[measure.name])
@parameterized.parameters((test_utils.TEST_TARGETS[:1],),
(test_utils.TEST_TARGETS,))
def test_run_training(self, targets):
"""Tests whether the training loop can be run successfully.
Generates test input files and runs the main driving code.
Args:
targets: the targets to train on.
"""
# Create test input and metadata files.
num_examples, read_len = 20, 5
train_file = test_utils.create_tmp_train_file(num_examples, read_len)
metadata_path = test_utils.create_tmp_metadata(num_examples, read_len)
# Check that the training loop runs as expected.
logdir = os.path.join(FLAGS.test_tmpdir, 'train:{}'.format(len(targets)))
with flagsaver.flagsaver(
train_files=train_file,
metadata_path=metadata_path,
targets=targets,
logdir=logdir,
hparams='train_steps=10,min_read_length=5',
batch_size=10):
run_training.main(FLAGS)
# Check training loop ran by confirming existence of a checkpoint file.
self.assertIsNotNone(tf.train.latest_checkpoint(FLAGS.logdir))
# Check training loop ran by confiming existence of a measures file.
self.assertTrue(
os.path.exists(os.path.join(FLAGS.logdir, 'measures.pbtxt')))
if __name__ == '__main__':
absltest.main()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for working with Seq2Label datasets and models.
This library provides utilities for parsing and generating Seq2Label protos.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from protos import seq2label_pb2
def get_all_label_values(dataset_info):
"""Retrieves possible values for modeled labels from a `Seq2LabelDatasetInfo`.
Args:
dataset_info: a `Seq2LabelDatasetInfo` message.
Returns:
A dictionary mapping each label name to a tuple of its permissible values.
"""
return {
label_info.name: tuple(label_info.values)
for label_info in dataset_info.labels
}
def construct_seq2label_model_info(hparams, model_type, targets, metadata_path,
batch_size, num_filters,
training_noise_rate):
"""Constructs a Seq2LabelModelInfo proto with the given properties.
Args:
hparams: initialized tf.contrib.training.Hparams object.
model_type: string; descriptive tag indicating type of model, ie. "conv".
targets: list of names of the targets the model is trained to predict.
metadata_path: string; full path to Seq2LabelDatasetInfo text proto used
to initialize the model.
batch_size: int; number of reads per mini-batch.
num_filters: int; number of filters for convolutional model.
training_noise_rate: float; rate [0.0, 1.0] of base-flipping noise injected
into input read sequenced at training time.
Returns:
The Seq2LabelModelInfo proto with the hparams, model_type, targets,
num_filters, batch_size, metadata_path, and training_noise_rate fields
set to the given values.
"""
return seq2label_pb2.Seq2LabelModelInfo(
hparams_string=hparams.to_json(),
model_type=model_type,
targets=sorted(targets),
num_filters=num_filters,
batch_size=batch_size,
metadata_path=metadata_path,
training_noise_rate=training_noise_rate)
def add_read_noise(read, base_flip_probability=0.01):
"""Adds base-flipping noise to the given read sequence.
Args:
read: string; the read sequence to which to add noise.
base_flip_probability: float; probability of a base flip at each position.
Returns:
The given read with base-flipping noise added at the provided
base_flip_probability rate.
"""
base_flips = np.random.binomial(1, base_flip_probability, len(read))
if sum(base_flips) == 0:
return read
read = np.array(list(read))
possible_mutations = np.char.replace(['ACTG'] * sum(base_flips),
read[base_flips == 1], '')
mutations = map(np.random.choice, map(list, possible_mutations))
read[base_flips == 1] = mutations
return ''.join(read)
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility methods for accessing and operating on test data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl import flags
import tensorflow as tf
from google.protobuf import text_format
import input as seq2species_input
from protos import seq2label_pb2
FLAGS = flags.FLAGS
# Target names included in the example inputs.
TEST_TARGETS = ['test_target_1', 'test_target_2']
def _as_bytes_feature(in_string):
"""Converts the given string to a tf.train.BytesList feature.
Args:
in_string: string to be converted to BytesList Feature.
Returns:
The TF BytesList Feature representing the given string.
"""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[in_string]))
def create_tmp_train_file(num_examples,
read_len,
characters=seq2species_input.DNA_BASES,
name='test.tfrecord'):
"""Write a test TFRecord of input examples to temporary test directory.
The generated input examples are test tf.train.Example protos, each comprised
of a toy sequence of length read_len and non-meaningful labels for targets in
TEST_TARGETS.
Args:
num_examples: int; number of examples to write to test input file.
read_len: int; length of test read sequences.
characters: string; set of characters from which to construct test reads.
Defaults to canonical DNA bases.
name: string; filename for the test input file.
Returns:
Full path to the generated temporary test input file.
"""
tmp_path = os.path.join(FLAGS.test_tmpdir, name)
with tf.python_io.TFRecordWriter(tmp_path) as writer:
for i in xrange(num_examples):
char = characters[i % len(characters)]
features_dict = {'sequence': _as_bytes_feature(char * read_len)}
for target_name in TEST_TARGETS:
nonsense_label = _as_bytes_feature(str(i))
features_dict[target_name] = nonsense_label
tf_features = tf.train.Features(feature=features_dict)
example = tf.train.Example(features=tf_features)
writer.write(example.SerializeToString())
return tmp_path
def create_tmp_metadata(num_examples, read_len):
"""Write a test Seq2LabelDatasetInfo test proto to temporary test directory.
Args:
num_examples: int; number of example labels to write into test metadata.
read_len: int; length of test read sequences.
Returns:
Full path to the generated temporary test file containing the
Seq2LabelDatasetInfo text proto.
"""
dataset_info = seq2label_pb2.Seq2LabelDatasetInfo(
read_length=read_len,
num_examples=num_examples,
read_stride=1,
dataset_path='test.tfrecord')
for target in TEST_TARGETS:
dataset_info.labels.add(
name=target, values=[str(i) for i in xrange(num_examples)])
tmp_path = os.path.join(FLAGS.test_tmpdir, 'test.pbtxt')
with tf.gfile.GFile(tmp_path, 'w') as f:
f.write(text_format.MessageToString(dataset_info))
return tmp_path
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment