Commit f5fc733a authored by Byzantine's avatar Byzantine
Browse files

Removing research/community models

parent 09bc9f54
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Global parameters for the VGGish model.
See vggish_slim.py for more information.
"""
# Architectural constants.
NUM_FRAMES = 96 # Frames in input mel-spectrogram patch.
NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch.
EMBEDDING_SIZE = 128 # Size of embedding layer.
# Hyperparameters used in feature and example generation.
SAMPLE_RATE = 16000
STFT_WINDOW_LENGTH_SECONDS = 0.025
STFT_HOP_LENGTH_SECONDS = 0.010
NUM_MEL_BINS = NUM_BANDS
MEL_MIN_HZ = 125
MEL_MAX_HZ = 7500
LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram.
EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames
EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap.
# Parameters used for embedding postprocessing.
PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors'
PCA_MEANS_NAME = 'pca_means'
QUANTIZE_MIN_VAL = -2.0
QUANTIZE_MAX_VAL = +2.0
# Hyperparameters used in training.
INIT_STDDEV = 0.01 # Standard deviation used to initialize weights.
LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer.
ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer.
# Names of ops, tensors, and features.
INPUT_OP_NAME = 'vggish/input_features'
INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0'
OUTPUT_OP_NAME = 'vggish/embedding'
OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0'
AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding'
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Post-process embeddings from VGGish."""
import numpy as np
import vggish_params
class Postprocessor(object):
"""Post-processes VGGish embeddings.
The initial release of AudioSet included 128-D VGGish embeddings for each
segment of AudioSet. These released embeddings were produced by applying
a PCA transformation (technically, a whitening transform is included as well)
and 8-bit quantization to the raw embedding output from VGGish, in order to
stay compatible with the YouTube-8M project which provides visual embeddings
in the same format for a large set of YouTube videos. This class implements
the same PCA (with whitening) and quantization transformations.
"""
def __init__(self, pca_params_npz_path):
"""Constructs a postprocessor.
Args:
pca_params_npz_path: Path to a NumPy-format .npz file that
contains the PCA parameters used in postprocessing.
"""
params = np.load(pca_params_npz_path)
self._pca_matrix = params[vggish_params.PCA_EIGEN_VECTORS_NAME]
# Load means into a column vector for easier broadcasting later.
self._pca_means = params[vggish_params.PCA_MEANS_NAME].reshape(-1, 1)
assert self._pca_matrix.shape == (
vggish_params.EMBEDDING_SIZE, vggish_params.EMBEDDING_SIZE), (
'Bad PCA matrix shape: %r' % (self._pca_matrix.shape,))
assert self._pca_means.shape == (vggish_params.EMBEDDING_SIZE, 1), (
'Bad PCA means shape: %r' % (self._pca_means.shape,))
def postprocess(self, embeddings_batch):
"""Applies postprocessing to a batch of embeddings.
Args:
embeddings_batch: An nparray of shape [batch_size, embedding_size]
containing output from the embedding layer of VGGish.
Returns:
An nparray of the same shape as the input but of type uint8,
containing the PCA-transformed and quantized version of the input.
"""
assert len(embeddings_batch.shape) == 2, (
'Expected 2-d batch, got %r' % (embeddings_batch.shape,))
assert embeddings_batch.shape[1] == vggish_params.EMBEDDING_SIZE, (
'Bad batch shape: %r' % (embeddings_batch.shape,))
# Apply PCA.
# - Embeddings come in as [batch_size, embedding_size].
# - Transpose to [embedding_size, batch_size].
# - Subtract pca_means column vector from each column.
# - Premultiply by PCA matrix of shape [output_dims, input_dims]
# where both are are equal to embedding_size in our case.
# - Transpose result back to [batch_size, embedding_size].
pca_applied = np.dot(self._pca_matrix,
(embeddings_batch.T - self._pca_means)).T
# Quantize by:
# - clipping to [min, max] range
clipped_embeddings = np.clip(
pca_applied, vggish_params.QUANTIZE_MIN_VAL,
vggish_params.QUANTIZE_MAX_VAL)
# - convert to 8-bit in range [0.0, 255.0]
quantized_embeddings = (
(clipped_embeddings - vggish_params.QUANTIZE_MIN_VAL) *
(255.0 /
(vggish_params.QUANTIZE_MAX_VAL - vggish_params.QUANTIZE_MIN_VAL)))
# - cast 8-bit float to uint8
quantized_embeddings = quantized_embeddings.astype(np.uint8)
return quantized_embeddings
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines the 'VGGish' model used to generate AudioSet embedding features.
The public AudioSet release (https://research.google.com/audioset/download.html)
includes 128-D features extracted from the embedding layer of a VGG-like model
that was trained on a large Google-internal YouTube dataset. Here we provide
a TF-Slim definition of the same model, without any dependences on libraries
internal to Google. We call it 'VGGish'.
Note that we only define the model up to the embedding layer, which is the
penultimate layer before the final classifier layer. We also provide various
hyperparameter values (in vggish_params.py) that were used to train this model
internally.
For comparison, here is TF-Slim's VGG definition:
https://github.com/tensorflow/models/blob/master/research/slim/nets/vgg.py
"""
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import tf_slim as slim
import vggish_params as params
def define_vggish_slim(training=False):
"""Defines the VGGish TensorFlow model.
All ops are created in the current default graph, under the scope 'vggish/'.
The input is a placeholder named 'vggish/input_features' of type float32 and
shape [batch_size, num_frames, num_bands] where batch_size is variable and
num_frames and num_bands are constants, and [num_frames, num_bands] represents
a log-mel-scale spectrogram patch covering num_bands frequency bands and
num_frames time frames (where each frame step is usually 10ms). This is
produced by computing the stabilized log(mel-spectrogram + params.LOG_OFFSET).
The output is an op named 'vggish/embedding' which produces the activations of
a 128-D embedding layer, which is usually the penultimate layer when used as
part of a full model with a final classifier layer.
Args:
training: If true, all parameters are marked trainable.
Returns:
The op 'vggish/embeddings'.
"""
# Defaults:
# - All weights are initialized to N(0, INIT_STDDEV).
# - All biases are initialized to 0.
# - All activations are ReLU.
# - All convolutions are 3x3 with stride 1 and SAME padding.
# - All max-pools are 2x2 with stride 2 and SAME padding.
with slim.arg_scope([slim.conv2d, slim.fully_connected],
weights_initializer=tf.truncated_normal_initializer(
stddev=params.INIT_STDDEV),
biases_initializer=tf.zeros_initializer(),
activation_fn=tf.nn.relu,
trainable=training), \
slim.arg_scope([slim.conv2d],
kernel_size=[3, 3], stride=1, padding='SAME'), \
slim.arg_scope([slim.max_pool2d],
kernel_size=[2, 2], stride=2, padding='SAME'), \
tf.variable_scope('vggish'):
# Input: a batch of 2-D log-mel-spectrogram patches.
features = tf.placeholder(
tf.float32, shape=(None, params.NUM_FRAMES, params.NUM_BANDS),
name='input_features')
# Reshape to 4-D so that we can convolve a batch with conv2d().
net = tf.reshape(features, [-1, params.NUM_FRAMES, params.NUM_BANDS, 1])
# The VGG stack of alternating convolutions and max-pools.
net = slim.conv2d(net, 64, scope='conv1')
net = slim.max_pool2d(net, scope='pool1')
net = slim.conv2d(net, 128, scope='conv2')
net = slim.max_pool2d(net, scope='pool2')
net = slim.repeat(net, 2, slim.conv2d, 256, scope='conv3')
net = slim.max_pool2d(net, scope='pool3')
net = slim.repeat(net, 2, slim.conv2d, 512, scope='conv4')
net = slim.max_pool2d(net, scope='pool4')
# Flatten before entering fully-connected layers
net = slim.flatten(net)
net = slim.repeat(net, 2, slim.fully_connected, 4096, scope='fc1')
# The embedding layer.
net = slim.fully_connected(net, params.EMBEDDING_SIZE, scope='fc2')
return tf.identity(net, name='embedding')
def load_vggish_slim_checkpoint(session, checkpoint_path):
"""Loads a pre-trained VGGish-compatible checkpoint.
This function can be used as an initialization function (referred to as
init_fn in TensorFlow documentation) which is called in a Session after
initializating all variables. When used as an init_fn, this will load
a pre-trained checkpoint that is compatible with the VGGish model
definition. Only variables defined by VGGish will be loaded.
Args:
session: an active TensorFlow session.
checkpoint_path: path to a file containing a checkpoint that is
compatible with the VGGish model definition.
"""
# Get the list of names of all VGGish variables that exist in
# the checkpoint (i.e., all inference-mode VGGish variables).
with tf.Graph().as_default():
define_vggish_slim(training=False)
vggish_var_names = [v.name for v in tf.global_variables()]
# Get the list of all currently existing variables that match
# the list of variable names we just computed.
vggish_vars = [v for v in tf.global_variables() if v.name in vggish_var_names]
# Use a Saver to restore just the variables selected above.
saver = tf.train.Saver(vggish_vars, name='vggish_load_pretrained',
write_version=1)
saver.restore(session, checkpoint_path)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A smoke test for VGGish.
This is a simple smoke test of a local install of VGGish and its associated
downloaded files. We create a synthetic sound, extract log mel spectrogram
features, run them through VGGish, post-process the embedding ouputs, and
check some simple statistics of the results, allowing for variations that
might occur due to platform/version differences in the libraries we use.
Usage:
- Download the VGGish checkpoint and PCA parameters into the same directory as
the VGGish source code. If you keep them elsewhere, update the checkpoint_path
and pca_params_path variables below.
- Run:
$ python vggish_smoke_test.py
"""
from __future__ import print_function
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import vggish_input
import vggish_params
import vggish_postprocess
import vggish_slim
print('\nTesting your install of VGGish\n')
# Paths to downloaded VGGish files.
checkpoint_path = 'vggish_model.ckpt'
pca_params_path = 'vggish_pca_params.npz'
# Relative tolerance of errors in mean and standard deviation of embeddings.
rel_error = 0.1 # Up to 10%
# Generate a 1 kHz sine wave at 44.1 kHz (we use a high sampling rate
# to test resampling to 16 kHz during feature extraction).
num_secs = 3
freq = 1000
sr = 44100
t = np.linspace(0, num_secs, int(num_secs * sr))
x = np.sin(2 * np.pi * freq * t)
# Produce a batch of log mel spectrogram examples.
input_batch = vggish_input.waveform_to_examples(x, sr)
print('Log Mel Spectrogram example: ', input_batch[0])
np.testing.assert_equal(
input_batch.shape,
[num_secs, vggish_params.NUM_FRAMES, vggish_params.NUM_BANDS])
# Define VGGish, load the checkpoint, and run the batch through the model to
# produce embeddings.
with tf.Graph().as_default(), tf.Session() as sess:
vggish_slim.define_vggish_slim()
vggish_slim.load_vggish_slim_checkpoint(sess, checkpoint_path)
features_tensor = sess.graph.get_tensor_by_name(
vggish_params.INPUT_TENSOR_NAME)
embedding_tensor = sess.graph.get_tensor_by_name(
vggish_params.OUTPUT_TENSOR_NAME)
[embedding_batch] = sess.run([embedding_tensor],
feed_dict={features_tensor: input_batch})
print('VGGish embedding: ', embedding_batch[0])
expected_embedding_mean = 0.131
expected_embedding_std = 0.238
np.testing.assert_allclose(
[np.mean(embedding_batch), np.std(embedding_batch)],
[expected_embedding_mean, expected_embedding_std],
rtol=rel_error)
# Postprocess the results to produce whitened quantized embeddings.
pproc = vggish_postprocess.Postprocessor(pca_params_path)
postprocessed_batch = pproc.postprocess(embedding_batch)
print('Postprocessed VGGish embedding: ', postprocessed_batch[0])
expected_postprocessed_mean = 123.0
expected_postprocessed_std = 75.0
np.testing.assert_allclose(
[np.mean(postprocessed_batch), np.std(postprocessed_batch)],
[expected_postprocessed_mean, expected_postprocessed_std],
rtol=rel_error)
print('\nLooks Good To Me!\n')
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""A simple demonstration of running VGGish in training mode.
This is intended as a toy example that demonstrates how to use the VGGish model
definition within a larger model that adds more layers on top, and then train
the larger model. If you let VGGish train as well, then this allows you to
fine-tune the VGGish model parameters for your application. If you don't let
VGGish train, then you use VGGish as a feature extractor for the layers above
it.
For this toy task, we are training a classifier to distinguish between three
classes: sine waves, constant signals, and white noise. We generate synthetic
waveforms from each of these classes, convert into shuffled batches of log mel
spectrogram examples with associated labels, and feed the batches into a model
that includes VGGish at the bottom and a couple of additional layers on top. We
also plumb in labels that are associated with the examples, which feed a label
loss used for training.
Usage:
# Run training for 100 steps using a model checkpoint in the default
# location (vggish_model.ckpt in the current directory). Allow VGGish
# to get fine-tuned.
$ python vggish_train_demo.py --num_batches 100
# Same as before but run for fewer steps and don't change VGGish parameters
# and use a checkpoint in a different location
$ python vggish_train_demo.py --num_batches 50 \
--train_vggish=False \
--checkpoint /path/to/model/checkpoint
"""
from __future__ import print_function
from random import shuffle
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import tf_slim as slim
import vggish_input
import vggish_params
import vggish_slim
flags = tf.app.flags
flags.DEFINE_integer(
'num_batches', 30,
'Number of batches of examples to feed into the model. Each batch is of '
'variable size and contains shuffled examples of each class of audio.')
flags.DEFINE_boolean(
'train_vggish', True,
'If True, allow VGGish parameters to change during training, thus '
'fine-tuning VGGish. If False, VGGish parameters are fixed, thus using '
'VGGish as a fixed feature extractor.')
flags.DEFINE_string(
'checkpoint', 'vggish_model.ckpt',
'Path to the VGGish checkpoint file.')
FLAGS = flags.FLAGS
_NUM_CLASSES = 3
def _get_examples_batch():
"""Returns a shuffled batch of examples of all audio classes.
Note that this is just a toy function because this is a simple demo intended
to illustrate how the training code might work.
Returns:
a tuple (features, labels) where features is a NumPy array of shape
[batch_size, num_frames, num_bands] where the batch_size is variable and
each row is a log mel spectrogram patch of shape [num_frames, num_bands]
suitable for feeding VGGish, while labels is a NumPy array of shape
[batch_size, num_classes] where each row is a multi-hot label vector that
provides the labels for corresponding rows in features.
"""
# Make a waveform for each class.
num_seconds = 5
sr = 44100 # Sampling rate.
t = np.linspace(0, num_seconds, int(num_seconds * sr)) # Time axis.
# Random sine wave.
freq = np.random.uniform(100, 1000)
sine = np.sin(2 * np.pi * freq * t)
# Random constant signal.
magnitude = np.random.uniform(-1, 1)
const = magnitude * t
# White noise.
noise = np.random.normal(-1, 1, size=t.shape)
# Make examples of each signal and corresponding labels.
# Sine is class index 0, Const class index 1, Noise class index 2.
sine_examples = vggish_input.waveform_to_examples(sine, sr)
sine_labels = np.array([[1, 0, 0]] * sine_examples.shape[0])
const_examples = vggish_input.waveform_to_examples(const, sr)
const_labels = np.array([[0, 1, 0]] * const_examples.shape[0])
noise_examples = vggish_input.waveform_to_examples(noise, sr)
noise_labels = np.array([[0, 0, 1]] * noise_examples.shape[0])
# Shuffle (example, label) pairs across all classes.
all_examples = np.concatenate((sine_examples, const_examples, noise_examples))
all_labels = np.concatenate((sine_labels, const_labels, noise_labels))
labeled_examples = list(zip(all_examples, all_labels))
shuffle(labeled_examples)
# Separate and return the features and labels.
features = [example for (example, _) in labeled_examples]
labels = [label for (_, label) in labeled_examples]
return (features, labels)
def main(_):
with tf.Graph().as_default(), tf.Session() as sess:
# Define VGGish.
embeddings = vggish_slim.define_vggish_slim(FLAGS.train_vggish)
# Define a shallow classification model and associated training ops on top
# of VGGish.
with tf.variable_scope('mymodel'):
# Add a fully connected layer with 100 units.
num_units = 100
fc = slim.fully_connected(embeddings, num_units)
# Add a classifier layer at the end, consisting of parallel logistic
# classifiers, one per class. This allows for multi-class tasks.
logits = slim.fully_connected(
fc, _NUM_CLASSES, activation_fn=None, scope='logits')
tf.sigmoid(logits, name='prediction')
# Add training ops.
with tf.variable_scope('train'):
global_step = tf.Variable(
0, name='global_step', trainable=False,
collections=[tf.GraphKeys.GLOBAL_VARIABLES,
tf.GraphKeys.GLOBAL_STEP])
# Labels are assumed to be fed as a batch multi-hot vectors, with
# a 1 in the position of each positive class label, and 0 elsewhere.
labels = tf.placeholder(
tf.float32, shape=(None, _NUM_CLASSES), name='labels')
# Cross-entropy label loss.
xent = tf.nn.sigmoid_cross_entropy_with_logits(
logits=logits, labels=labels, name='xent')
loss = tf.reduce_mean(xent, name='loss_op')
tf.summary.scalar('loss', loss)
# We use the same optimizer and hyperparameters as used to train VGGish.
optimizer = tf.train.AdamOptimizer(
learning_rate=vggish_params.LEARNING_RATE,
epsilon=vggish_params.ADAM_EPSILON)
optimizer.minimize(loss, global_step=global_step, name='train_op')
# Initialize all variables in the model, and then load the pre-trained
# VGGish checkpoint.
sess.run(tf.global_variables_initializer())
vggish_slim.load_vggish_slim_checkpoint(sess, FLAGS.checkpoint)
# Locate all the tensors and ops we need for the training loop.
features_tensor = sess.graph.get_tensor_by_name(
vggish_params.INPUT_TENSOR_NAME)
labels_tensor = sess.graph.get_tensor_by_name('mymodel/train/labels:0')
global_step_tensor = sess.graph.get_tensor_by_name(
'mymodel/train/global_step:0')
loss_tensor = sess.graph.get_tensor_by_name('mymodel/train/loss_op:0')
train_op = sess.graph.get_operation_by_name('mymodel/train/train_op')
# The training loop.
for _ in range(FLAGS.num_batches):
(features, labels) = _get_examples_batch()
[num_steps, loss, _] = sess.run(
[global_step_tensor, loss_tensor, train_op],
feed_dict={features_tensor: features, labels_tensor: labels})
print('Step %d: loss %g' % (num_steps, loss))
if __name__ == '__main__':
tf.app.run()
# YAMNet
YAMNet is a pretrained deep net that predicts 521 audio event classes based on
the [AudioSet-YouTube corpus](http://g.co/audioset), and employing the
[Mobilenet_v1](https://arxiv.org/pdf/1704.04861.pdf) depthwise-separable
convolution architecture.
This directory contains the Keras code to construct the model, and example code
for applying the model to input sound files.
## Installation
YAMNet depends on the following Python packages:
* [`numpy`](http://www.numpy.org/)
* [`resampy`](http://resampy.readthedocs.io/en/latest/)
* [`tensorflow`](http://www.tensorflow.org/)
* [`pysoundfile`](https://pysoundfile.readthedocs.io/)
These are all easily installable via, e.g., `pip install numpy` (as in the
example command sequence below).
Any reasonably recent version of these packages should work. TensorFlow should
be at least version 1.8 to ensure Keras support is included. Note that while
the code works fine with TensorFlow v1.x or v2.x, we explicitly enable v1.x
behavior.
YAMNet also requires downloading the following data file:
* [YAMNet model weights](https://storage.googleapis.com/audioset/yamnet.h5)
in Keras saved weights in HDF5 format.
After downloading this file into the same directory as this README, the
installation can be tested by running `python yamnet_test.py` which
runs some synthetic signals through the model and checks the outputs.
Here's a sample installation and test session:
```shell
# Upgrade pip first. Also make sure wheel is installed.
python -m pip install --upgrade pip wheel
# Install dependences.
pip install numpy resampy tensorflow soundfile
# Clone TensorFlow models repo into a 'models' directory.
git clone https://github.com/tensorflow/models.git
cd models/research/audioset/yamnet
# Download data file into same directory as code.
curl -O https://storage.googleapis.com/audioset/yamnet.h5
# Installation ready, let's test it.
python yamnet_test.py
# If we see "Ran 4 tests ... OK ...", then we're all set.
```
## Usage
You can run the model over existing soundfiles using inference.py:
```shell
python inference.py input_sound.wav
```
The code will report the top-5 highest-scoring classes averaged over all the
frames of the input. You can access greater detail by modifying the example
code in inference.py.
See the jupyter notebook `yamnet_visualization.ipynb` for an example of
displaying the per-frame model output scores.
## About the Model
The YAMNet code layout is as follows:
* `yamnet.py`: Model definition in Keras.
* `params.py`: Hyperparameters. You can usefully modify PATCH_HOP_SECONDS.
* `features.py`: Audio feature extraction helpers.
* `inference.py`: Example code to classify input wav files.
* `yamnet_test.py`: Simple test of YAMNet installation
### Input: Audio Features
See `features.py`.
As with our previous release
[VGGish](https://github.com/tensorflow/models/tree/master/research/audioset/vggish),
YAMNet was trained with audio features computed as follows:
* All audio is resampled to 16 kHz mono.
* A spectrogram is computed using magnitudes of the Short-Time Fourier Transform
with a window size of 25 ms, a window hop of 10 ms, and a periodic Hann
window.
* A mel spectrogram is computed by mapping the spectrogram to 64 mel bins
covering the range 125-7500 Hz.
* A stabilized log mel spectrogram is computed by applying
log(mel-spectrum + 0.001) where the offset is used to avoid taking a logarithm
of zero.
* These features are then framed into 50%-overlapping examples of 0.96 seconds,
where each example covers 64 mel bands and 96 frames of 10 ms each.
These 96x64 patches are then fed into the Mobilenet_v1 model to yield a 3x2
array of activations for 1024 kernels at the top of the convolution. These are
averaged to give a 1024-dimension embedding, then put through a single logistic
layer to get the 521 per-class output scores corresponding to the 960 ms input
waveform segment. (Because of the window framing, you need at least 975 ms of
input waveform to get the first frame of output scores.)
### Class vocabulary
The file `yamnet_class_map.csv` describes the audio event classes associated
with each of the 521 outputs of the network. Its format is:
```text
index,mid,display_name
```
where `index` is the model output index (0..520), `mid` is the machine
identifier for that class (e.g. `/m/09x0r`), and display_name is a
human-readable description of the class (e.g. `Speech`).
The original Audioset data release had 527 classes. This model drops six of
them on the recommendation of our Fairness reviewers to avoid potentially
offensive mislabelings. We dropped the gendered versions (Male/Female) of
Speech and Singing. We also dropped Battle cry and Funny music.
### Performance
On the 20,366-segment AudioSet eval set, over the 521 included classes, the
balanced average d-prime is 2.318, balanced mAP is 0.306, and the balanced
average lwlrap is 0.393.
According to our calculations, the classifier has 3.7M weights and performs
69.2M multiplies for each 960ms input frame.
### Contact information
This model repository is maintained by [Manoj Plakal](https://github.com/plakal) and [Dan Ellis](https://github.com/dpwe).
# Copyright 2019 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Feature computation for YAMNet."""
import numpy as np
import tensorflow as tf
def waveform_to_log_mel_spectrogram(waveform, params):
"""Compute log mel spectrogram of a 1-D waveform."""
with tf.name_scope('log_mel_features'):
# waveform has shape [<# samples>]
# Convert waveform into spectrogram using a Short-Time Fourier Transform.
# Note that tf.signal.stft() uses a periodic Hann window by default.
window_length_samples = int(
round(params.SAMPLE_RATE * params.STFT_WINDOW_SECONDS))
hop_length_samples = int(
round(params.SAMPLE_RATE * params.STFT_HOP_SECONDS))
fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
num_spectrogram_bins = fft_length // 2 + 1
magnitude_spectrogram = tf.abs(tf.signal.stft(
signals=waveform,
frame_length=window_length_samples,
frame_step=hop_length_samples,
fft_length=fft_length))
# magnitude_spectrogram has shape [<# STFT frames>, num_spectrogram_bins]
# Convert spectrogram into log mel spectrogram.
linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
num_mel_bins=params.MEL_BANDS,
num_spectrogram_bins=num_spectrogram_bins,
sample_rate=params.SAMPLE_RATE,
lower_edge_hertz=params.MEL_MIN_HZ,
upper_edge_hertz=params.MEL_MAX_HZ)
mel_spectrogram = tf.matmul(
magnitude_spectrogram, linear_to_mel_weight_matrix)
log_mel_spectrogram = tf.math.log(mel_spectrogram + params.LOG_OFFSET)
# log_mel_spectrogram has shape [<# STFT frames>, MEL_BANDS]
return log_mel_spectrogram
def spectrogram_to_patches(spectrogram, params):
"""Break up a spectrogram into a stack of fixed-size patches."""
with tf.name_scope('feature_patches'):
# Frame spectrogram (shape [<# STFT frames>, MEL_BANDS]) into patches
# (the input examples).
# Only complete frames are emitted, so if there is less than
# PATCH_WINDOW_SECONDS of waveform then nothing is emitted
# (to avoid this, zero-pad before processing).
hop_length_samples = int(
round(params.SAMPLE_RATE * params.STFT_HOP_SECONDS))
spectrogram_sr = params.SAMPLE_RATE / hop_length_samples
patch_window_length_samples = int(
round(spectrogram_sr * params.PATCH_WINDOW_SECONDS))
patch_hop_length_samples = int(
round(spectrogram_sr * params.PATCH_HOP_SECONDS))
features = tf.signal.frame(
signal=spectrogram,
frame_length=patch_window_length_samples,
frame_step=patch_hop_length_samples,
axis=0)
# features has shape [<# patches>, <# STFT frames in an patch>, MEL_BANDS]
return features
# Copyright 2019 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Inference demo for YAMNet."""
from __future__ import division, print_function
import sys
import numpy as np
import resampy
import soundfile as sf
import tensorflow as tf
import params
import yamnet as yamnet_model
def main(argv):
assert argv
graph = tf.Graph()
with graph.as_default():
yamnet = yamnet_model.yamnet_frames_model(params)
yamnet.load_weights('yamnet.h5')
yamnet_classes = yamnet_model.class_names('yamnet_class_map.csv')
for file_name in argv:
# Decode the WAV file.
wav_data, sr = sf.read(file_name, dtype=np.int16)
assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
waveform = wav_data / 32768.0 # Convert to [-1.0, +1.0]
# Convert to mono and the sample rate expected by YAMNet.
if len(waveform.shape) > 1:
waveform = np.mean(waveform, axis=1)
if sr != params.SAMPLE_RATE:
waveform = resampy.resample(waveform, sr, params.SAMPLE_RATE)
# Predict YAMNet classes.
# Second output is log-mel-spectrogram array (used for visualizations).
# (steps=1 is a work around for Keras batching limitations.)
with graph.as_default():
scores, _ = yamnet.predict(np.reshape(waveform, [1, -1]), steps=1)
# Scores is a matrix of (time_frames, num_classes) classifier scores.
# Average them along time to get an overall classifier output for the clip.
prediction = np.mean(scores, axis=0)
# Report the highest-scoring classes and their scores.
top5_i = np.argsort(prediction)[::-1][:5]
print(file_name, ':\n' +
'\n'.join(' {:12s}: {:.3f}'.format(yamnet_classes[i], prediction[i])
for i in top5_i))
if __name__ == '__main__':
main(sys.argv[1:])
# Copyright 2019 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hyperparameters for YAMNet."""
# The following hyperparameters (except PATCH_HOP_SECONDS) were used to train YAMNet,
# so expect some variability in performance if you change these. The patch hop can
# be changed arbitrarily: a smaller hop should give you more patches from the same
# clip and possibly better performance at a larger computational cost.
SAMPLE_RATE = 16000
STFT_WINDOW_SECONDS = 0.025
STFT_HOP_SECONDS = 0.010
MEL_BANDS = 64
MEL_MIN_HZ = 125
MEL_MAX_HZ = 7500
LOG_OFFSET = 0.001
PATCH_WINDOW_SECONDS = 0.96
PATCH_HOP_SECONDS = 0.48
PATCH_FRAMES = int(round(PATCH_WINDOW_SECONDS / STFT_HOP_SECONDS))
PATCH_BANDS = MEL_BANDS
NUM_CLASSES = 521
CONV_PADDING = 'same'
BATCHNORM_CENTER = True
BATCHNORM_SCALE = False
BATCHNORM_EPSILON = 1e-4
CLASSIFIER_ACTIVATION = 'sigmoid'
FEATURES_LAYER_NAME = 'features'
EXAMPLE_PREDICTIONS_LAYER_NAME = 'predictions'
# Copyright 2019 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Core model definition of YAMNet."""
import csv
import numpy as np
import tensorflow as tf
from tensorflow.keras import Model, layers
import features as features_lib
import params
def _batch_norm(name):
def _bn_layer(layer_input):
return layers.BatchNormalization(
name=name,
center=params.BATCHNORM_CENTER,
scale=params.BATCHNORM_SCALE,
epsilon=params.BATCHNORM_EPSILON)(layer_input)
return _bn_layer
def _conv(name, kernel, stride, filters):
def _conv_layer(layer_input):
output = layers.Conv2D(name='{}/conv'.format(name),
filters=filters,
kernel_size=kernel,
strides=stride,
padding=params.CONV_PADDING,
use_bias=False,
activation=None)(layer_input)
output = _batch_norm(name='{}/conv/bn'.format(name))(output)
output = layers.ReLU(name='{}/relu'.format(name))(output)
return output
return _conv_layer
def _separable_conv(name, kernel, stride, filters):
def _separable_conv_layer(layer_input):
output = layers.DepthwiseConv2D(name='{}/depthwise_conv'.format(name),
kernel_size=kernel,
strides=stride,
depth_multiplier=1,
padding=params.CONV_PADDING,
use_bias=False,
activation=None)(layer_input)
output = _batch_norm(name='{}/depthwise_conv/bn'.format(name))(output)
output = layers.ReLU(name='{}/depthwise_conv/relu'.format(name))(output)
output = layers.Conv2D(name='{}/pointwise_conv'.format(name),
filters=filters,
kernel_size=(1, 1),
strides=1,
padding=params.CONV_PADDING,
use_bias=False,
activation=None)(output)
output = _batch_norm(name='{}/pointwise_conv/bn'.format(name))(output)
output = layers.ReLU(name='{}/pointwise_conv/relu'.format(name))(output)
return output
return _separable_conv_layer
_YAMNET_LAYER_DEFS = [
# (layer_function, kernel, stride, num_filters)
(_conv, [3, 3], 2, 32),
(_separable_conv, [3, 3], 1, 64),
(_separable_conv, [3, 3], 2, 128),
(_separable_conv, [3, 3], 1, 128),
(_separable_conv, [3, 3], 2, 256),
(_separable_conv, [3, 3], 1, 256),
(_separable_conv, [3, 3], 2, 512),
(_separable_conv, [3, 3], 1, 512),
(_separable_conv, [3, 3], 1, 512),
(_separable_conv, [3, 3], 1, 512),
(_separable_conv, [3, 3], 1, 512),
(_separable_conv, [3, 3], 1, 512),
(_separable_conv, [3, 3], 2, 1024),
(_separable_conv, [3, 3], 1, 1024)
]
def yamnet(features):
"""Define the core YAMNet mode in Keras."""
net = layers.Reshape(
(params.PATCH_FRAMES, params.PATCH_BANDS, 1),
input_shape=(params.PATCH_FRAMES, params.PATCH_BANDS))(features)
for (i, (layer_fun, kernel, stride, filters)) in enumerate(_YAMNET_LAYER_DEFS):
net = layer_fun('layer{}'.format(i + 1), kernel, stride, filters)(net)
net = layers.GlobalAveragePooling2D()(net)
logits = layers.Dense(units=params.NUM_CLASSES, use_bias=True)(net)
predictions = layers.Activation(
name=params.EXAMPLE_PREDICTIONS_LAYER_NAME,
activation=params.CLASSIFIER_ACTIVATION)(logits)
return predictions
def yamnet_frames_model(feature_params):
"""Defines the YAMNet waveform-to-class-scores model.
Args:
feature_params: An object with parameter fields to control the feature
calculation.
Returns:
A model accepting (1, num_samples) waveform input and emitting a
(num_patches, num_classes) matrix of class scores per time frame as
well as a (num_spectrogram_frames, num_mel_bins) spectrogram feature
matrix.
"""
waveform = layers.Input(batch_shape=(1, None))
# Store the intermediate spectrogram features to use in visualization.
spectrogram = features_lib.waveform_to_log_mel_spectrogram(
tf.squeeze(waveform, axis=0), feature_params)
patches = features_lib.spectrogram_to_patches(spectrogram, feature_params)
predictions = yamnet(patches)
frames_model = Model(name='yamnet_frames',
inputs=waveform, outputs=[predictions, spectrogram])
return frames_model
def class_names(class_map_csv):
"""Read the class name definition file and return a list of strings."""
with open(class_map_csv) as csv_file:
reader = csv.reader(csv_file)
next(reader) # Skip header
return np.array([display_name for (_, _, display_name) in reader])
index,mid,display_name
0,/m/09x0r,Speech
1,/m/0ytgt,"Child speech, kid speaking"
2,/m/01h8n0,Conversation
3,/m/02qldy,"Narration, monologue"
4,/m/0261r1,Babbling
5,/m/0brhx,Speech synthesizer
6,/m/07p6fty,Shout
7,/m/07q4ntr,Bellow
8,/m/07rwj3x,Whoop
9,/m/07sr1lc,Yell
10,/t/dd00135,Children shouting
11,/m/03qc9zr,Screaming
12,/m/02rtxlg,Whispering
13,/m/01j3sz,Laughter
14,/t/dd00001,Baby laughter
15,/m/07r660_,Giggle
16,/m/07s04w4,Snicker
17,/m/07sq110,Belly laugh
18,/m/07rgt08,"Chuckle, chortle"
19,/m/0463cq4,"Crying, sobbing"
20,/t/dd00002,"Baby cry, infant cry"
21,/m/07qz6j3,Whimper
22,/m/07qw_06,"Wail, moan"
23,/m/07plz5l,Sigh
24,/m/015lz1,Singing
25,/m/0l14jd,Choir
26,/m/01swy6,Yodeling
27,/m/02bk07,Chant
28,/m/01c194,Mantra
29,/t/dd00005,Child singing
30,/t/dd00006,Synthetic singing
31,/m/06bxc,Rapping
32,/m/02fxyj,Humming
33,/m/07s2xch,Groan
34,/m/07r4k75,Grunt
35,/m/01w250,Whistling
36,/m/0lyf6,Breathing
37,/m/07mzm6,Wheeze
38,/m/01d3sd,Snoring
39,/m/07s0dtb,Gasp
40,/m/07pyy8b,Pant
41,/m/07q0yl5,Snort
42,/m/01b_21,Cough
43,/m/0dl9sf8,Throat clearing
44,/m/01hsr_,Sneeze
45,/m/07ppn3j,Sniff
46,/m/06h7j,Run
47,/m/07qv_x_,Shuffle
48,/m/07pbtc8,"Walk, footsteps"
49,/m/03cczk,"Chewing, mastication"
50,/m/07pdhp0,Biting
51,/m/0939n_,Gargling
52,/m/01g90h,Stomach rumble
53,/m/03q5_w,"Burping, eructation"
54,/m/02p3nc,Hiccup
55,/m/02_nn,Fart
56,/m/0k65p,Hands
57,/m/025_jnm,Finger snapping
58,/m/0l15bq,Clapping
59,/m/01jg02,"Heart sounds, heartbeat"
60,/m/01jg1z,Heart murmur
61,/m/053hz1,Cheering
62,/m/028ght,Applause
63,/m/07rkbfh,Chatter
64,/m/03qtwd,Crowd
65,/m/07qfr4h,"Hubbub, speech noise, speech babble"
66,/t/dd00013,Children playing
67,/m/0jbk,Animal
68,/m/068hy,"Domestic animals, pets"
69,/m/0bt9lr,Dog
70,/m/05tny_,Bark
71,/m/07r_k2n,Yip
72,/m/07qf0zm,Howl
73,/m/07rc7d9,Bow-wow
74,/m/0ghcn6,Growling
75,/t/dd00136,Whimper (dog)
76,/m/01yrx,Cat
77,/m/02yds9,Purr
78,/m/07qrkrw,Meow
79,/m/07rjwbb,Hiss
80,/m/07r81j2,Caterwaul
81,/m/0ch8v,"Livestock, farm animals, working animals"
82,/m/03k3r,Horse
83,/m/07rv9rh,Clip-clop
84,/m/07q5rw0,"Neigh, whinny"
85,/m/01xq0k1,"Cattle, bovinae"
86,/m/07rpkh9,Moo
87,/m/0239kh,Cowbell
88,/m/068zj,Pig
89,/t/dd00018,Oink
90,/m/03fwl,Goat
91,/m/07q0h5t,Bleat
92,/m/07bgp,Sheep
93,/m/025rv6n,Fowl
94,/m/09b5t,"Chicken, rooster"
95,/m/07st89h,Cluck
96,/m/07qn5dc,"Crowing, cock-a-doodle-doo"
97,/m/01rd7k,Turkey
98,/m/07svc2k,Gobble
99,/m/09ddx,Duck
100,/m/07qdb04,Quack
101,/m/0dbvp,Goose
102,/m/07qwf61,Honk
103,/m/01280g,Wild animals
104,/m/0cdnk,"Roaring cats (lions, tigers)"
105,/m/04cvmfc,Roar
106,/m/015p6,Bird
107,/m/020bb7,"Bird vocalization, bird call, bird song"
108,/m/07pggtn,"Chirp, tweet"
109,/m/07sx8x_,Squawk
110,/m/0h0rv,"Pigeon, dove"
111,/m/07r_25d,Coo
112,/m/04s8yn,Crow
113,/m/07r5c2p,Caw
114,/m/09d5_,Owl
115,/m/07r_80w,Hoot
116,/m/05_wcq,"Bird flight, flapping wings"
117,/m/01z5f,"Canidae, dogs, wolves"
118,/m/06hps,"Rodents, rats, mice"
119,/m/04rmv,Mouse
120,/m/07r4gkf,Patter
121,/m/03vt0,Insect
122,/m/09xqv,Cricket
123,/m/09f96,Mosquito
124,/m/0h2mp,"Fly, housefly"
125,/m/07pjwq1,Buzz
126,/m/01h3n,"Bee, wasp, etc."
127,/m/09ld4,Frog
128,/m/07st88b,Croak
129,/m/078jl,Snake
130,/m/07qn4z3,Rattle
131,/m/032n05,Whale vocalization
132,/m/04rlf,Music
133,/m/04szw,Musical instrument
134,/m/0fx80y,Plucked string instrument
135,/m/0342h,Guitar
136,/m/02sgy,Electric guitar
137,/m/018vs,Bass guitar
138,/m/042v_gx,Acoustic guitar
139,/m/06w87,"Steel guitar, slide guitar"
140,/m/01glhc,Tapping (guitar technique)
141,/m/07s0s5r,Strum
142,/m/018j2,Banjo
143,/m/0jtg0,Sitar
144,/m/04rzd,Mandolin
145,/m/01bns_,Zither
146,/m/07xzm,Ukulele
147,/m/05148p4,Keyboard (musical)
148,/m/05r5c,Piano
149,/m/01s0ps,Electric piano
150,/m/013y1f,Organ
151,/m/03xq_f,Electronic organ
152,/m/03gvt,Hammond organ
153,/m/0l14qv,Synthesizer
154,/m/01v1d8,Sampler
155,/m/03q5t,Harpsichord
156,/m/0l14md,Percussion
157,/m/02hnl,Drum kit
158,/m/0cfdd,Drum machine
159,/m/026t6,Drum
160,/m/06rvn,Snare drum
161,/m/03t3fj,Rimshot
162,/m/02k_mr,Drum roll
163,/m/0bm02,Bass drum
164,/m/011k_j,Timpani
165,/m/01p970,Tabla
166,/m/01qbl,Cymbal
167,/m/03qtq,Hi-hat
168,/m/01sm1g,Wood block
169,/m/07brj,Tambourine
170,/m/05r5wn,Rattle (instrument)
171,/m/0xzly,Maraca
172,/m/0mbct,Gong
173,/m/016622,Tubular bells
174,/m/0j45pbj,Mallet percussion
175,/m/0dwsp,"Marimba, xylophone"
176,/m/0dwtp,Glockenspiel
177,/m/0dwt5,Vibraphone
178,/m/0l156b,Steelpan
179,/m/05pd6,Orchestra
180,/m/01kcd,Brass instrument
181,/m/0319l,French horn
182,/m/07gql,Trumpet
183,/m/07c6l,Trombone
184,/m/0l14_3,Bowed string instrument
185,/m/02qmj0d,String section
186,/m/07y_7,"Violin, fiddle"
187,/m/0d8_n,Pizzicato
188,/m/01xqw,Cello
189,/m/02fsn,Double bass
190,/m/085jw,"Wind instrument, woodwind instrument"
191,/m/0l14j_,Flute
192,/m/06ncr,Saxophone
193,/m/01wy6,Clarinet
194,/m/03m5k,Harp
195,/m/0395lw,Bell
196,/m/03w41f,Church bell
197,/m/027m70_,Jingle bell
198,/m/0gy1t2s,Bicycle bell
199,/m/07n_g,Tuning fork
200,/m/0f8s22,Chime
201,/m/026fgl,Wind chime
202,/m/0150b9,Change ringing (campanology)
203,/m/03qjg,Harmonica
204,/m/0mkg,Accordion
205,/m/0192l,Bagpipes
206,/m/02bxd,Didgeridoo
207,/m/0l14l2,Shofar
208,/m/07kc_,Theremin
209,/m/0l14t7,Singing bowl
210,/m/01hgjl,Scratching (performance technique)
211,/m/064t9,Pop music
212,/m/0glt670,Hip hop music
213,/m/02cz_7,Beatboxing
214,/m/06by7,Rock music
215,/m/03lty,Heavy metal
216,/m/05r6t,Punk rock
217,/m/0dls3,Grunge
218,/m/0dl5d,Progressive rock
219,/m/07sbbz2,Rock and roll
220,/m/05w3f,Psychedelic rock
221,/m/06j6l,Rhythm and blues
222,/m/0gywn,Soul music
223,/m/06cqb,Reggae
224,/m/01lyv,Country
225,/m/015y_n,Swing music
226,/m/0gg8l,Bluegrass
227,/m/02x8m,Funk
228,/m/02w4v,Folk music
229,/m/06j64v,Middle Eastern music
230,/m/03_d0,Jazz
231,/m/026z9,Disco
232,/m/0ggq0m,Classical music
233,/m/05lls,Opera
234,/m/02lkt,Electronic music
235,/m/03mb9,House music
236,/m/07gxw,Techno
237,/m/07s72n,Dubstep
238,/m/0283d,Drum and bass
239,/m/0m0jc,Electronica
240,/m/08cyft,Electronic dance music
241,/m/0fd3y,Ambient music
242,/m/07lnk,Trance music
243,/m/0g293,Music of Latin America
244,/m/0ln16,Salsa music
245,/m/0326g,Flamenco
246,/m/0155w,Blues
247,/m/05fw6t,Music for children
248,/m/02v2lh,New-age music
249,/m/0y4f8,Vocal music
250,/m/0z9c,A capella
251,/m/0164x2,Music of Africa
252,/m/0145m,Afrobeat
253,/m/02mscn,Christian music
254,/m/016cjb,Gospel music
255,/m/028sqc,Music of Asia
256,/m/015vgc,Carnatic music
257,/m/0dq0md,Music of Bollywood
258,/m/06rqw,Ska
259,/m/02p0sh1,Traditional music
260,/m/05rwpb,Independent music
261,/m/074ft,Song
262,/m/025td0t,Background music
263,/m/02cjck,Theme music
264,/m/03r5q_,Jingle (music)
265,/m/0l14gg,Soundtrack music
266,/m/07pkxdp,Lullaby
267,/m/01z7dr,Video game music
268,/m/0140xf,Christmas music
269,/m/0ggx5q,Dance music
270,/m/04wptg,Wedding music
271,/t/dd00031,Happy music
272,/t/dd00033,Sad music
273,/t/dd00034,Tender music
274,/t/dd00035,Exciting music
275,/t/dd00036,Angry music
276,/t/dd00037,Scary music
277,/m/03m9d0z,Wind
278,/m/09t49,Rustling leaves
279,/t/dd00092,Wind noise (microphone)
280,/m/0jb2l,Thunderstorm
281,/m/0ngt1,Thunder
282,/m/0838f,Water
283,/m/06mb1,Rain
284,/m/07r10fb,Raindrop
285,/t/dd00038,Rain on surface
286,/m/0j6m2,Stream
287,/m/0j2kx,Waterfall
288,/m/05kq4,Ocean
289,/m/034srq,"Waves, surf"
290,/m/06wzb,Steam
291,/m/07swgks,Gurgling
292,/m/02_41,Fire
293,/m/07pzfmf,Crackle
294,/m/07yv9,Vehicle
295,/m/019jd,"Boat, Water vehicle"
296,/m/0hsrw,"Sailboat, sailing ship"
297,/m/056ks2,"Rowboat, canoe, kayak"
298,/m/02rlv9,"Motorboat, speedboat"
299,/m/06q74,Ship
300,/m/012f08,Motor vehicle (road)
301,/m/0k4j,Car
302,/m/0912c9,"Vehicle horn, car horn, honking"
303,/m/07qv_d5,Toot
304,/m/02mfyn,Car alarm
305,/m/04gxbd,"Power windows, electric windows"
306,/m/07rknqz,Skidding
307,/m/0h9mv,Tire squeal
308,/t/dd00134,Car passing by
309,/m/0ltv,"Race car, auto racing"
310,/m/07r04,Truck
311,/m/0gvgw0,Air brake
312,/m/05x_td,"Air horn, truck horn"
313,/m/02rhddq,Reversing beeps
314,/m/03cl9h,"Ice cream truck, ice cream van"
315,/m/01bjv,Bus
316,/m/03j1ly,Emergency vehicle
317,/m/04qvtq,Police car (siren)
318,/m/012n7d,Ambulance (siren)
319,/m/012ndj,"Fire engine, fire truck (siren)"
320,/m/04_sv,Motorcycle
321,/m/0btp2,"Traffic noise, roadway noise"
322,/m/06d_3,Rail transport
323,/m/07jdr,Train
324,/m/04zmvq,Train whistle
325,/m/0284vy3,Train horn
326,/m/01g50p,"Railroad car, train wagon"
327,/t/dd00048,Train wheels squealing
328,/m/0195fx,"Subway, metro, underground"
329,/m/0k5j,Aircraft
330,/m/014yck,Aircraft engine
331,/m/04229,Jet engine
332,/m/02l6bg,"Propeller, airscrew"
333,/m/09ct_,Helicopter
334,/m/0cmf2,"Fixed-wing aircraft, airplane"
335,/m/0199g,Bicycle
336,/m/06_fw,Skateboard
337,/m/02mk9,Engine
338,/t/dd00065,Light engine (high frequency)
339,/m/08j51y,"Dental drill, dentist's drill"
340,/m/01yg9g,Lawn mower
341,/m/01j4z9,Chainsaw
342,/t/dd00066,Medium engine (mid frequency)
343,/t/dd00067,Heavy engine (low frequency)
344,/m/01h82_,Engine knocking
345,/t/dd00130,Engine starting
346,/m/07pb8fc,Idling
347,/m/07q2z82,"Accelerating, revving, vroom"
348,/m/02dgv,Door
349,/m/03wwcy,Doorbell
350,/m/07r67yg,Ding-dong
351,/m/02y_763,Sliding door
352,/m/07rjzl8,Slam
353,/m/07r4wb8,Knock
354,/m/07qcpgn,Tap
355,/m/07q6cd_,Squeak
356,/m/0642b4,Cupboard open or close
357,/m/0fqfqc,Drawer open or close
358,/m/04brg2,"Dishes, pots, and pans"
359,/m/023pjk,"Cutlery, silverware"
360,/m/07pn_8q,Chopping (food)
361,/m/0dxrf,Frying (food)
362,/m/0fx9l,Microwave oven
363,/m/02pjr4,Blender
364,/m/02jz0l,"Water tap, faucet"
365,/m/0130jx,Sink (filling or washing)
366,/m/03dnzn,Bathtub (filling or washing)
367,/m/03wvsk,Hair dryer
368,/m/01jt3m,Toilet flush
369,/m/012xff,Toothbrush
370,/m/04fgwm,Electric toothbrush
371,/m/0d31p,Vacuum cleaner
372,/m/01s0vc,Zipper (clothing)
373,/m/03v3yw,Keys jangling
374,/m/0242l,Coin (dropping)
375,/m/01lsmm,Scissors
376,/m/02g901,"Electric shaver, electric razor"
377,/m/05rj2,Shuffling cards
378,/m/0316dw,Typing
379,/m/0c2wf,Typewriter
380,/m/01m2v,Computer keyboard
381,/m/081rb,Writing
382,/m/07pp_mv,Alarm
383,/m/07cx4,Telephone
384,/m/07pp8cl,Telephone bell ringing
385,/m/01hnzm,Ringtone
386,/m/02c8p,"Telephone dialing, DTMF"
387,/m/015jpf,Dial tone
388,/m/01z47d,Busy signal
389,/m/046dlr,Alarm clock
390,/m/03kmc9,Siren
391,/m/0dgbq,Civil defense siren
392,/m/030rvx,Buzzer
393,/m/01y3hg,"Smoke detector, smoke alarm"
394,/m/0c3f7m,Fire alarm
395,/m/04fq5q,Foghorn
396,/m/0l156k,Whistle
397,/m/06hck5,Steam whistle
398,/t/dd00077,Mechanisms
399,/m/02bm9n,"Ratchet, pawl"
400,/m/01x3z,Clock
401,/m/07qjznt,Tick
402,/m/07qjznl,Tick-tock
403,/m/0l7xg,Gears
404,/m/05zc1,Pulleys
405,/m/0llzx,Sewing machine
406,/m/02x984l,Mechanical fan
407,/m/025wky1,Air conditioning
408,/m/024dl,Cash register
409,/m/01m4t,Printer
410,/m/0dv5r,Camera
411,/m/07bjf,Single-lens reflex camera
412,/m/07k1x,Tools
413,/m/03l9g,Hammer
414,/m/03p19w,Jackhammer
415,/m/01b82r,Sawing
416,/m/02p01q,Filing (rasp)
417,/m/023vsd,Sanding
418,/m/0_ksk,Power tool
419,/m/01d380,Drill
420,/m/014zdl,Explosion
421,/m/032s66,"Gunshot, gunfire"
422,/m/04zjc,Machine gun
423,/m/02z32qm,Fusillade
424,/m/0_1c,Artillery fire
425,/m/073cg4,Cap gun
426,/m/0g6b5,Fireworks
427,/g/122z_qxw,Firecracker
428,/m/07qsvvw,"Burst, pop"
429,/m/07pxg6y,Eruption
430,/m/07qqyl4,Boom
431,/m/083vt,Wood
432,/m/07pczhz,Chop
433,/m/07pl1bw,Splinter
434,/m/07qs1cx,Crack
435,/m/039jq,Glass
436,/m/07q7njn,"Chink, clink"
437,/m/07rn7sz,Shatter
438,/m/04k94,Liquid
439,/m/07rrlb6,"Splash, splatter"
440,/m/07p6mqd,Slosh
441,/m/07qlwh6,Squish
442,/m/07r5v4s,Drip
443,/m/07prgkl,Pour
444,/m/07pqc89,"Trickle, dribble"
445,/t/dd00088,Gush
446,/m/07p7b8y,Fill (with liquid)
447,/m/07qlf79,Spray
448,/m/07ptzwd,Pump (liquid)
449,/m/07ptfmf,Stir
450,/m/0dv3j,Boiling
451,/m/0790c,Sonar
452,/m/0dl83,Arrow
453,/m/07rqsjt,"Whoosh, swoosh, swish"
454,/m/07qnq_y,"Thump, thud"
455,/m/07rrh0c,Thunk
456,/m/0b_fwt,Electronic tuner
457,/m/02rr_,Effects unit
458,/m/07m2kt,Chorus effect
459,/m/018w8,Basketball bounce
460,/m/07pws3f,Bang
461,/m/07ryjzk,"Slap, smack"
462,/m/07rdhzs,"Whack, thwack"
463,/m/07pjjrj,"Smash, crash"
464,/m/07pc8lb,Breaking
465,/m/07pqn27,Bouncing
466,/m/07rbp7_,Whip
467,/m/07pyf11,Flap
468,/m/07qb_dv,Scratch
469,/m/07qv4k0,Scrape
470,/m/07pdjhy,Rub
471,/m/07s8j8t,Roll
472,/m/07plct2,Crushing
473,/t/dd00112,"Crumpling, crinkling"
474,/m/07qcx4z,Tearing
475,/m/02fs_r,"Beep, bleep"
476,/m/07qwdck,Ping
477,/m/07phxs1,Ding
478,/m/07rv4dm,Clang
479,/m/07s02z0,Squeal
480,/m/07qh7jl,Creak
481,/m/07qwyj0,Rustle
482,/m/07s34ls,Whir
483,/m/07qmpdm,Clatter
484,/m/07p9k1k,Sizzle
485,/m/07qc9xj,Clicking
486,/m/07rwm0c,Clickety-clack
487,/m/07phhsh,Rumble
488,/m/07qyrcz,Plop
489,/m/07qfgpx,"Jingle, tinkle"
490,/m/07rcgpl,Hum
491,/m/07p78v5,Zing
492,/t/dd00121,Boing
493,/m/07s12q4,Crunch
494,/m/028v0c,Silence
495,/m/01v_m0,Sine wave
496,/m/0b9m1,Harmonic
497,/m/0hdsk,Chirp tone
498,/m/0c1dj,Sound effect
499,/m/07pt_g0,Pulse
500,/t/dd00125,"Inside, small room"
501,/t/dd00126,"Inside, large room or hall"
502,/t/dd00127,"Inside, public space"
503,/t/dd00128,"Outside, urban or manmade"
504,/t/dd00129,"Outside, rural or natural"
505,/m/01b9nn,Reverberation
506,/m/01jnbd,Echo
507,/m/096m7z,Noise
508,/m/06_y0by,Environmental noise
509,/m/07rgkc5,Static
510,/m/06xkwv,Mains hum
511,/m/0g12c5,Distortion
512,/m/08p9q4,Sidetone
513,/m/07szfh9,Cacophony
514,/m/0chx_,White noise
515,/m/0cj0r,Pink noise
516,/m/07p_0gm,Throbbing
517,/m/01jwx6,Vibration
518,/m/07c52,Television
519,/m/06bz3,Radio
520,/m/07hvw1,Field recording
# Copyright 2019 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Installation test for YAMNet."""
import numpy as np
import tensorflow as tf
import params
import yamnet
class YAMNetTest(tf.test.TestCase):
_yamnet_graph = None
_yamnet = None
_yamnet_classes = None
@classmethod
def setUpClass(cls):
super(YAMNetTest, cls).setUpClass()
cls._yamnet_graph = tf.Graph()
with cls._yamnet_graph.as_default():
cls._yamnet = yamnet.yamnet_frames_model(params)
cls._yamnet.load_weights('yamnet.h5')
cls._yamnet_classes = yamnet.class_names('yamnet_class_map.csv')
def clip_test(self, waveform, expected_class_name, top_n=10):
"""Run the model on the waveform, check that expected class is in top-n."""
with YAMNetTest._yamnet_graph.as_default():
prediction = np.mean(YAMNetTest._yamnet.predict(
np.reshape(waveform, [1, -1]), steps=1)[0], axis=0)
top_n_class_names = YAMNetTest._yamnet_classes[
np.argsort(prediction)[-top_n:]]
self.assertIn(expected_class_name, top_n_class_names)
def testZeros(self):
self.clip_test(
waveform=np.zeros((1, int(3 * params.SAMPLE_RATE))),
expected_class_name='Silence')
def testRandom(self):
np.random.seed(51773) # Ensure repeatability.
self.clip_test(
waveform=np.random.uniform(-1.0, +1.0,
(1, int(3 * params.SAMPLE_RATE))),
expected_class_name='White noise')
def testSine(self):
self.clip_test(
waveform=np.reshape(
np.sin(2 * np.pi * 440 * np.linspace(
0, 3, int(3 *params.SAMPLE_RATE))),
[1, -1]),
expected_class_name='Sine wave')
if __name__ == '__main__':
tf.test.main()
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Copyright 2019 The TensorFlow Authors All Rights Reserved.\n",
"#\n",
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# http://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n",
"# =============================================================================="
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Visualization of the YAMNet audio event classification model.\n",
"# See https://github.com/tensorflow/models/tree/master/research/audioset/yamnet/"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Imports.\n",
"import numpy as np\n",
"import soundfile as sf\n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import params\n",
"import yamnet as yamnet_model\n",
"import tensorflow as tf"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sample rate = 16000\n"
]
}
],
"source": [
"# Read in the audio.\n",
"# You can get this example waveform via:\n",
"# curl -O https://storage.googleapis.com/audioset/speech_whistling2.wav\n",
"\n",
"wav_file_name = 'speech_whistling2.wav'\n",
"\n",
"wav_data, sr = sf.read(wav_file_name, dtype=np.int16)\n",
"waveform = wav_data / 32768.0\n",
"# The graph is designed for a sampling rate of 16 kHz, but higher rates \n",
"# should work too.\n",
"params.SAMPLE_RATE = sr\n",
"print(\"Sample rate =\", params.SAMPLE_RATE)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:From /Users/dpwe/google/vggish/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1635: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
"Instructions for updating:\n",
"If using Keras pass *_constraint arguments to layers.\n"
]
}
],
"source": [
"# Set up the YAMNet model.\n",
"class_names = yamnet_model.class_names('yamnet_class_map.csv')\n",
"params.PATCH_HOP_SECONDS = 0.1 # 10 Hz scores frame rate.\n",
"graph = tf.Graph()\n",
"with graph.as_default():\n",
" yamnet = yamnet_model.yamnet_frames_model(params)\n",
" yamnet.load_weights('yamnet.h5')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:When passing input data as arrays, do not specify `steps_per_epoch`/`steps` argument. Please use `batch_size` instead.\n"
]
}
],
"source": [
"# Run the model.\n",
"with graph.as_default():\n",
" scores, spectrogram = yamnet.predict(np.reshape(waveform, [1, -1]), steps=1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 720x576 with 3 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Visualize the results.\n",
"plt.figure(figsize=(10, 8))\n",
"\n",
"# Plot the waveform.\n",
"plt.subplot(3, 1, 1)\n",
"plt.plot(waveform)\n",
"plt.xlim([0, len(waveform)])\n",
"# Plot the log-mel spectrogram (returned by the model).\n",
"plt.subplot(3, 1, 2)\n",
"plt.imshow(spectrogram.T, aspect='auto', interpolation='nearest', origin='bottom')\n",
"\n",
"# Plot and label the model output scores for the top-scoring classes.\n",
"mean_scores = np.mean(scores, axis=0)\n",
"top_N = 10\n",
"top_class_indices = np.argsort(mean_scores)[::-1][:top_N]\n",
"plt.subplot(3, 1, 3)\n",
"plt.imshow(scores[:, top_class_indices].T, aspect='auto', interpolation='nearest', cmap='gray_r')\n",
"# Compensate for the PATCH_WINDOW_SECONDS (0.96 s) context window to align with spectrogram.\n",
"patch_padding = (params.PATCH_WINDOW_SECONDS / 2) / params.PATCH_HOP_SECONDS\n",
"plt.xlim([-patch_padding, scores.shape[0] + patch_padding])\n",
"# Label the top_N classes.\n",
"yticks = range(0, top_N, 1)\n",
"plt.yticks(yticks, [class_names[top_class_indices[x]] for x in yticks])\n",
"_ = plt.ylim(-0.5 + np.array([top_N, 0]))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
![TensorFlow Requirement: 1.x](https://img.shields.io/badge/TensorFlow%20Requirement-1.x-brightgreen)
![TensorFlow 2 Not Supported](https://img.shields.io/badge/TensorFlow%202%20Not%20Supported-%E2%9C%95-red.svg)
<font size=4><b>Train Wide-ResNet, Shake-Shake and ShakeDrop models on CIFAR-10
and CIFAR-100 dataset with AutoAugment.</b></font>
The CIFAR-10/CIFAR-100 data can be downloaded from:
https://www.cs.toronto.edu/~kriz/cifar.html. Use the Python version instead of the binary version.
The code replicates the results from Tables 1 and 2 on CIFAR-10/100 with the
following models: Wide-ResNet-28-10, Shake-Shake (26 2x32d), Shake-Shake (26
2x96d) and PyramidNet+ShakeDrop.
<b>Related papers:</b>
AutoAugment: Learning Augmentation Policies from Data
https://arxiv.org/abs/1805.09501
Wide Residual Networks
https://arxiv.org/abs/1605.07146
Shake-Shake regularization
https://arxiv.org/abs/1705.07485
ShakeDrop regularization
https://arxiv.org/abs/1802.02375
<b>Settings:</b>
CIFAR-10 Model | Learning Rate | Weight Decay | Num. Epochs | Batch Size
---------------------- | ------------- | ------------ | ----------- | ----------
Wide-ResNet-28-10 | 0.1 | 5e-4 | 200 | 128
Shake-Shake (26 2x32d) | 0.01 | 1e-3 | 1800 | 128
Shake-Shake (26 2x96d) | 0.01 | 1e-3 | 1800 | 128
PyramidNet + ShakeDrop | 0.05 | 5e-5 | 1800 | 64
<b>Prerequisite:</b>
1. Install TensorFlow. Be sure to run the code using python2 and not python3.
2. Download CIFAR-10/CIFAR-100 dataset.
```shell
curl -o cifar-10-binary.tar.gz https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
curl -o cifar-100-binary.tar.gz https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
```
<b>How to run:</b>
```shell
# cd to the your workspace.
# Specify the directory where dataset is located using the data_path flag.
# Note: User can split samples from training set into the eval set by changing train_size and validation_size.
# For example, to train the Wide-ResNet-28-10 model on a GPU.
python train_cifar.py --model_name=wrn \
--checkpoint_dir=/tmp/training \
--data_path=/tmp/data \
--dataset='cifar10' \
--use_cpu=0
```
## Contact for Issues
* Barret Zoph, @barretzoph <barretzoph@google.com>
* Ekin Dogus Cubuk, <cubuk@google.com>
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Transforms used in the Augmentation Policies."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import random
import numpy as np
# pylint:disable=g-multiple-import
from PIL import ImageOps, ImageEnhance, ImageFilter, Image
# pylint:enable=g-multiple-import
IMAGE_SIZE = 32
# What is the dataset mean and std of the images on the training set
MEANS = [0.49139968, 0.48215841, 0.44653091]
STDS = [0.24703223, 0.24348513, 0.26158784]
PARAMETER_MAX = 10 # What is the max 'level' a transform could be predicted
def random_flip(x):
"""Flip the input x horizontally with 50% probability."""
if np.random.rand(1)[0] > 0.5:
return np.fliplr(x)
return x
def zero_pad_and_crop(img, amount=4):
"""Zero pad by `amount` zero pixels on each side then take a random crop.
Args:
img: numpy image that will be zero padded and cropped.
amount: amount of zeros to pad `img` with horizontally and verically.
Returns:
The cropped zero padded img. The returned numpy array will be of the same
shape as `img`.
"""
padded_img = np.zeros((img.shape[0] + amount * 2, img.shape[1] + amount * 2,
img.shape[2]))
padded_img[amount:img.shape[0] + amount, amount:
img.shape[1] + amount, :] = img
top = np.random.randint(low=0, high=2 * amount)
left = np.random.randint(low=0, high=2 * amount)
new_img = padded_img[top:top + img.shape[0], left:left + img.shape[1], :]
return new_img
def create_cutout_mask(img_height, img_width, num_channels, size):
"""Creates a zero mask used for cutout of shape `img_height` x `img_width`.
Args:
img_height: Height of image cutout mask will be applied to.
img_width: Width of image cutout mask will be applied to.
num_channels: Number of channels in the image.
size: Size of the zeros mask.
Returns:
A mask of shape `img_height` x `img_width` with all ones except for a
square of zeros of shape `size` x `size`. This mask is meant to be
elementwise multiplied with the original image. Additionally returns
the `upper_coord` and `lower_coord` which specify where the cutout mask
will be applied.
"""
assert img_height == img_width
# Sample center where cutout mask will be applied
height_loc = np.random.randint(low=0, high=img_height)
width_loc = np.random.randint(low=0, high=img_width)
# Determine upper right and lower left corners of patch
upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2))
lower_coord = (min(img_height, height_loc + size // 2),
min(img_width, width_loc + size // 2))
mask_height = lower_coord[0] - upper_coord[0]
mask_width = lower_coord[1] - upper_coord[1]
assert mask_height > 0
assert mask_width > 0
mask = np.ones((img_height, img_width, num_channels))
zeros = np.zeros((mask_height, mask_width, num_channels))
mask[upper_coord[0]:lower_coord[0], upper_coord[1]:lower_coord[1], :] = (
zeros)
return mask, upper_coord, lower_coord
def cutout_numpy(img, size=16):
"""Apply cutout with mask of shape `size` x `size` to `img`.
The cutout operation is from the paper https://arxiv.org/abs/1708.04552.
This operation applies a `size`x`size` mask of zeros to a random location
within `img`.
Args:
img: Numpy image that cutout will be applied to.
size: Height/width of the cutout mask that will be
Returns:
A numpy tensor that is the result of applying the cutout mask to `img`.
"""
img_height, img_width, num_channels = (img.shape[0], img.shape[1],
img.shape[2])
assert len(img.shape) == 3
mask, _, _ = create_cutout_mask(img_height, img_width, num_channels, size)
return img * mask
def float_parameter(level, maxval):
"""Helper function to scale `val` between 0 and maxval .
Args:
level: Level of the operation that will be between [0, `PARAMETER_MAX`].
maxval: Maximum value that the operation can have. This will be scaled
to level/PARAMETER_MAX.
Returns:
A float that results from scaling `maxval` according to `level`.
"""
return float(level) * maxval / PARAMETER_MAX
def int_parameter(level, maxval):
"""Helper function to scale `val` between 0 and maxval .
Args:
level: Level of the operation that will be between [0, `PARAMETER_MAX`].
maxval: Maximum value that the operation can have. This will be scaled
to level/PARAMETER_MAX.
Returns:
An int that results from scaling `maxval` according to `level`.
"""
return int(level * maxval / PARAMETER_MAX)
def pil_wrap(img):
"""Convert the `img` numpy tensor to a PIL Image."""
return Image.fromarray(
np.uint8((img * STDS + MEANS) * 255.0)).convert('RGBA')
def pil_unwrap(pil_img):
"""Converts the PIL img to a numpy array."""
pic_array = (np.array(pil_img.getdata()).reshape((32, 32, 4)) / 255.0)
i1, i2 = np.where(pic_array[:, :, 3] == 0)
pic_array = (pic_array[:, :, :3] - MEANS) / STDS
pic_array[i1, i2] = [0, 0, 0]
return pic_array
def apply_policy(policy, img):
"""Apply the `policy` to the numpy `img`.
Args:
policy: A list of tuples with the form (name, probability, level) where
`name` is the name of the augmentation operation to apply, `probability`
is the probability of applying the operation and `level` is what strength
the operation to apply.
img: Numpy image that will have `policy` applied to it.
Returns:
The result of applying `policy` to `img`.
"""
pil_img = pil_wrap(img)
for xform in policy:
assert len(xform) == 3
name, probability, level = xform
xform_fn = NAME_TO_TRANSFORM[name].pil_transformer(probability, level)
pil_img = xform_fn(pil_img)
return pil_unwrap(pil_img)
class TransformFunction(object):
"""Wraps the Transform function for pretty printing options."""
def __init__(self, func, name):
self.f = func
self.name = name
def __repr__(self):
return '<' + self.name + '>'
def __call__(self, pil_img):
return self.f(pil_img)
class TransformT(object):
"""Each instance of this class represents a specific transform."""
def __init__(self, name, xform_fn):
self.name = name
self.xform = xform_fn
def pil_transformer(self, probability, level):
def return_function(im):
if random.random() < probability:
im = self.xform(im, level)
return im
name = self.name + '({:.1f},{})'.format(probability, level)
return TransformFunction(return_function, name)
def do_transform(self, image, level):
f = self.pil_transformer(PARAMETER_MAX, level)
return pil_unwrap(f(pil_wrap(image)))
################## Transform Functions ##################
identity = TransformT('identity', lambda pil_img, level: pil_img)
flip_lr = TransformT(
'FlipLR',
lambda pil_img, level: pil_img.transpose(Image.FLIP_LEFT_RIGHT))
flip_ud = TransformT(
'FlipUD',
lambda pil_img, level: pil_img.transpose(Image.FLIP_TOP_BOTTOM))
# pylint:disable=g-long-lambda
auto_contrast = TransformT(
'AutoContrast',
lambda pil_img, level: ImageOps.autocontrast(
pil_img.convert('RGB')).convert('RGBA'))
equalize = TransformT(
'Equalize',
lambda pil_img, level: ImageOps.equalize(
pil_img.convert('RGB')).convert('RGBA'))
invert = TransformT(
'Invert',
lambda pil_img, level: ImageOps.invert(
pil_img.convert('RGB')).convert('RGBA'))
# pylint:enable=g-long-lambda
blur = TransformT(
'Blur', lambda pil_img, level: pil_img.filter(ImageFilter.BLUR))
smooth = TransformT(
'Smooth',
lambda pil_img, level: pil_img.filter(ImageFilter.SMOOTH))
def _rotate_impl(pil_img, level):
"""Rotates `pil_img` from -30 to 30 degrees depending on `level`."""
degrees = int_parameter(level, 30)
if random.random() > 0.5:
degrees = -degrees
return pil_img.rotate(degrees)
rotate = TransformT('Rotate', _rotate_impl)
def _posterize_impl(pil_img, level):
"""Applies PIL Posterize to `pil_img`."""
level = int_parameter(level, 4)
return ImageOps.posterize(pil_img.convert('RGB'), 4 - level).convert('RGBA')
posterize = TransformT('Posterize', _posterize_impl)
def _shear_x_impl(pil_img, level):
"""Applies PIL ShearX to `pil_img`.
The ShearX operation shears the image along the horizontal axis with `level`
magnitude.
Args:
pil_img: Image in PIL object.
level: Strength of the operation specified as an Integer from
[0, `PARAMETER_MAX`].
Returns:
A PIL Image that has had ShearX applied to it.
"""
level = float_parameter(level, 0.3)
if random.random() > 0.5:
level = -level
return pil_img.transform((32, 32), Image.AFFINE, (1, level, 0, 0, 1, 0))
shear_x = TransformT('ShearX', _shear_x_impl)
def _shear_y_impl(pil_img, level):
"""Applies PIL ShearY to `pil_img`.
The ShearY operation shears the image along the vertical axis with `level`
magnitude.
Args:
pil_img: Image in PIL object.
level: Strength of the operation specified as an Integer from
[0, `PARAMETER_MAX`].
Returns:
A PIL Image that has had ShearX applied to it.
"""
level = float_parameter(level, 0.3)
if random.random() > 0.5:
level = -level
return pil_img.transform((32, 32), Image.AFFINE, (1, 0, 0, level, 1, 0))
shear_y = TransformT('ShearY', _shear_y_impl)
def _translate_x_impl(pil_img, level):
"""Applies PIL TranslateX to `pil_img`.
Translate the image in the horizontal direction by `level`
number of pixels.
Args:
pil_img: Image in PIL object.
level: Strength of the operation specified as an Integer from
[0, `PARAMETER_MAX`].
Returns:
A PIL Image that has had TranslateX applied to it.
"""
level = int_parameter(level, 10)
if random.random() > 0.5:
level = -level
return pil_img.transform((32, 32), Image.AFFINE, (1, 0, level, 0, 1, 0))
translate_x = TransformT('TranslateX', _translate_x_impl)
def _translate_y_impl(pil_img, level):
"""Applies PIL TranslateY to `pil_img`.
Translate the image in the vertical direction by `level`
number of pixels.
Args:
pil_img: Image in PIL object.
level: Strength of the operation specified as an Integer from
[0, `PARAMETER_MAX`].
Returns:
A PIL Image that has had TranslateY applied to it.
"""
level = int_parameter(level, 10)
if random.random() > 0.5:
level = -level
return pil_img.transform((32, 32), Image.AFFINE, (1, 0, 0, 0, 1, level))
translate_y = TransformT('TranslateY', _translate_y_impl)
def _crop_impl(pil_img, level, interpolation=Image.BILINEAR):
"""Applies a crop to `pil_img` with the size depending on the `level`."""
cropped = pil_img.crop((level, level, IMAGE_SIZE - level, IMAGE_SIZE - level))
resized = cropped.resize((IMAGE_SIZE, IMAGE_SIZE), interpolation)
return resized
crop_bilinear = TransformT('CropBilinear', _crop_impl)
def _solarize_impl(pil_img, level):
"""Applies PIL Solarize to `pil_img`.
Translate the image in the vertical direction by `level`
number of pixels.
Args:
pil_img: Image in PIL object.
level: Strength of the operation specified as an Integer from
[0, `PARAMETER_MAX`].
Returns:
A PIL Image that has had Solarize applied to it.
"""
level = int_parameter(level, 256)
return ImageOps.solarize(pil_img.convert('RGB'), 256 - level).convert('RGBA')
solarize = TransformT('Solarize', _solarize_impl)
def _cutout_pil_impl(pil_img, level):
"""Apply cutout to pil_img at the specified level."""
size = int_parameter(level, 20)
if size <= 0:
return pil_img
img_height, img_width, num_channels = (32, 32, 3)
_, upper_coord, lower_coord = (
create_cutout_mask(img_height, img_width, num_channels, size))
pixels = pil_img.load() # create the pixel map
for i in range(upper_coord[0], lower_coord[0]): # for every col:
for j in range(upper_coord[1], lower_coord[1]): # For every row
pixels[i, j] = (125, 122, 113, 0) # set the colour accordingly
return pil_img
cutout = TransformT('Cutout', _cutout_pil_impl)
def _enhancer_impl(enhancer):
"""Sets level to be between 0.1 and 1.8 for ImageEnhance transforms of PIL."""
def impl(pil_img, level):
v = float_parameter(level, 1.8) + .1 # going to 0 just destroys it
return enhancer(pil_img).enhance(v)
return impl
color = TransformT('Color', _enhancer_impl(ImageEnhance.Color))
contrast = TransformT('Contrast', _enhancer_impl(ImageEnhance.Contrast))
brightness = TransformT('Brightness', _enhancer_impl(
ImageEnhance.Brightness))
sharpness = TransformT('Sharpness', _enhancer_impl(ImageEnhance.Sharpness))
ALL_TRANSFORMS = [
flip_lr,
flip_ud,
auto_contrast,
equalize,
invert,
rotate,
posterize,
crop_bilinear,
solarize,
color,
contrast,
brightness,
sharpness,
shear_x,
shear_y,
translate_x,
translate_y,
cutout,
blur,
smooth
]
NAME_TO_TRANSFORM = {t.name: t for t in ALL_TRANSFORMS}
TRANSFORM_NAMES = NAME_TO_TRANSFORM.keys()
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains convenience wrappers for typical Neural Network TensorFlow layers.
Ops that have different behavior during training or eval have an is_training
parameter.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
arg_scope = tf.contrib.framework.arg_scope
def variable(name, shape, dtype, initializer, trainable):
"""Returns a TF variable with the passed in specifications."""
var = tf.get_variable(
name,
shape=shape,
dtype=dtype,
initializer=initializer,
trainable=trainable)
return var
def global_avg_pool(x, scope=None):
"""Average pools away spatial height and width dimension of 4D tensor."""
assert x.get_shape().ndims == 4
with tf.name_scope(scope, 'global_avg_pool', [x]):
kernel_size = (1, int(x.shape[1]), int(x.shape[2]), 1)
squeeze_dims = (1, 2)
result = tf.nn.avg_pool(
x,
ksize=kernel_size,
strides=(1, 1, 1, 1),
padding='VALID',
data_format='NHWC')
return tf.squeeze(result, squeeze_dims)
def zero_pad(inputs, in_filter, out_filter):
"""Zero pads `input` tensor to have `out_filter` number of filters."""
outputs = tf.pad(inputs, [[0, 0], [0, 0], [0, 0],
[(out_filter - in_filter) // 2,
(out_filter - in_filter) // 2]])
return outputs
@tf.contrib.framework.add_arg_scope
def batch_norm(inputs,
decay=0.999,
center=True,
scale=False,
epsilon=0.001,
is_training=True,
reuse=None,
scope=None):
"""Small wrapper around tf.contrib.layers.batch_norm."""
return tf.contrib.layers.batch_norm(
inputs,
decay=decay,
center=center,
scale=scale,
epsilon=epsilon,
activation_fn=None,
param_initializers=None,
updates_collections=tf.GraphKeys.UPDATE_OPS,
is_training=is_training,
reuse=reuse,
trainable=True,
fused=True,
data_format='NHWC',
zero_debias_moving_mean=False,
scope=scope)
def stride_arr(stride_h, stride_w):
return [1, stride_h, stride_w, 1]
@tf.contrib.framework.add_arg_scope
def conv2d(inputs,
num_filters_out,
kernel_size,
stride=1,
scope=None,
reuse=None):
"""Adds a 2D convolution.
conv2d creates a variable called 'weights', representing the convolutional
kernel, that is convolved with the input.
Args:
inputs: a 4D tensor in NHWC format.
num_filters_out: the number of output filters.
kernel_size: an int specifying the kernel height and width size.
stride: an int specifying the height and width stride.
scope: Optional scope for variable_scope.
reuse: whether or not the layer and its variables should be reused.
Returns:
a tensor that is the result of a convolution being applied to `inputs`.
"""
with tf.variable_scope(scope, 'Conv', [inputs], reuse=reuse):
num_filters_in = int(inputs.shape[3])
weights_shape = [kernel_size, kernel_size, num_filters_in, num_filters_out]
# Initialization
n = int(weights_shape[0] * weights_shape[1] * weights_shape[3])
weights_initializer = tf.random_normal_initializer(
stddev=np.sqrt(2.0 / n))
weights = variable(
name='weights',
shape=weights_shape,
dtype=tf.float32,
initializer=weights_initializer,
trainable=True)
strides = stride_arr(stride, stride)
outputs = tf.nn.conv2d(
inputs, weights, strides, padding='SAME', data_format='NHWC')
return outputs
@tf.contrib.framework.add_arg_scope
def fc(inputs,
num_units_out,
scope=None,
reuse=None):
"""Creates a fully connected layer applied to `inputs`.
Args:
inputs: a tensor that the fully connected layer will be applied to. It
will be reshaped if it is not 2D.
num_units_out: the number of output units in the layer.
scope: Optional scope for variable_scope.
reuse: whether or not the layer and its variables should be reused.
Returns:
a tensor that is the result of applying a linear matrix to `inputs`.
"""
if len(inputs.shape) > 2:
inputs = tf.reshape(inputs, [int(inputs.shape[0]), -1])
with tf.variable_scope(scope, 'FC', [inputs], reuse=reuse):
num_units_in = inputs.shape[1]
weights_shape = [num_units_in, num_units_out]
unif_init_range = 1.0 / (num_units_out)**(0.5)
weights_initializer = tf.random_uniform_initializer(
-unif_init_range, unif_init_range)
weights = variable(
name='weights',
shape=weights_shape,
dtype=tf.float32,
initializer=weights_initializer,
trainable=True)
bias_initializer = tf.constant_initializer(0.0)
biases = variable(
name='biases',
shape=[num_units_out,],
dtype=tf.float32,
initializer=bias_initializer,
trainable=True)
outputs = tf.nn.xw_plus_b(inputs, weights, biases)
return outputs
@tf.contrib.framework.add_arg_scope
def avg_pool(inputs, kernel_size, stride=2, padding='VALID', scope=None):
"""Wrapper around tf.nn.avg_pool."""
with tf.name_scope(scope, 'AvgPool', [inputs]):
kernel = stride_arr(kernel_size, kernel_size)
strides = stride_arr(stride, stride)
return tf.nn.avg_pool(
inputs,
ksize=kernel,
strides=strides,
padding=padding,
data_format='NHWC')
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Data utils for CIFAR-10 and CIFAR-100."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
import cPickle
import os
import augmentation_transforms
import numpy as np
import policies as found_policies
import tensorflow as tf
# pylint:disable=logging-format-interpolation
class DataSet(object):
"""Dataset object that produces augmented training and eval data."""
def __init__(self, hparams):
self.hparams = hparams
self.epochs = 0
self.curr_train_index = 0
all_labels = []
self.good_policies = found_policies.good_policies()
# Determine how many databatched to load
num_data_batches_to_load = 5
total_batches_to_load = num_data_batches_to_load
train_batches_to_load = total_batches_to_load
assert hparams.train_size + hparams.validation_size <= 50000
if hparams.eval_test:
total_batches_to_load += 1
# Determine how many images we have loaded
total_dataset_size = 10000 * num_data_batches_to_load
train_dataset_size = total_dataset_size
if hparams.eval_test:
total_dataset_size += 10000
if hparams.dataset == 'cifar10':
all_data = np.empty((total_batches_to_load, 10000, 3072), dtype=np.uint8)
elif hparams.dataset == 'cifar100':
assert num_data_batches_to_load == 5
all_data = np.empty((1, 50000, 3072), dtype=np.uint8)
if hparams.eval_test:
test_data = np.empty((1, 10000, 3072), dtype=np.uint8)
if hparams.dataset == 'cifar10':
tf.logging.info('Cifar10')
datafiles = [
'data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4',
'data_batch_5']
datafiles = datafiles[:train_batches_to_load]
if hparams.eval_test:
datafiles.append('test_batch')
num_classes = 10
elif hparams.dataset == 'cifar100':
datafiles = ['train']
if hparams.eval_test:
datafiles.append('test')
num_classes = 100
else:
raise NotImplementedError('Unimplemented dataset: ', hparams.dataset)
if hparams.dataset != 'test':
for file_num, f in enumerate(datafiles):
d = unpickle(os.path.join(hparams.data_path, f))
if f == 'test':
test_data[0] = copy.deepcopy(d['data'])
all_data = np.concatenate([all_data, test_data], axis=1)
else:
all_data[file_num] = copy.deepcopy(d['data'])
if hparams.dataset == 'cifar10':
labels = np.array(d['labels'])
else:
labels = np.array(d['fine_labels'])
nsamples = len(labels)
for idx in range(nsamples):
all_labels.append(labels[idx])
all_data = all_data.reshape(total_dataset_size, 3072)
all_data = all_data.reshape(-1, 3, 32, 32)
all_data = all_data.transpose(0, 2, 3, 1).copy()
all_data = all_data / 255.0
mean = augmentation_transforms.MEANS
std = augmentation_transforms.STDS
tf.logging.info('mean:{} std: {}'.format(mean, std))
all_data = (all_data - mean) / std
all_labels = np.eye(num_classes)[np.array(all_labels, dtype=np.int32)]
assert len(all_data) == len(all_labels)
tf.logging.info(
'In CIFAR10 loader, number of images: {}'.format(len(all_data)))
# Break off test data
if hparams.eval_test:
self.test_images = all_data[train_dataset_size:]
self.test_labels = all_labels[train_dataset_size:]
# Shuffle the rest of the data
all_data = all_data[:train_dataset_size]
all_labels = all_labels[:train_dataset_size]
np.random.seed(0)
perm = np.arange(len(all_data))
np.random.shuffle(perm)
all_data = all_data[perm]
all_labels = all_labels[perm]
# Break into train and val
train_size, val_size = hparams.train_size, hparams.validation_size
assert 50000 >= train_size + val_size
self.train_images = all_data[:train_size]
self.train_labels = all_labels[:train_size]
self.val_images = all_data[train_size:train_size + val_size]
self.val_labels = all_labels[train_size:train_size + val_size]
self.num_train = self.train_images.shape[0]
def next_batch(self):
"""Return the next minibatch of augmented data."""
next_train_index = self.curr_train_index + self.hparams.batch_size
if next_train_index > self.num_train:
# Increase epoch number
epoch = self.epochs + 1
self.reset()
self.epochs = epoch
batched_data = (
self.train_images[self.curr_train_index:
self.curr_train_index + self.hparams.batch_size],
self.train_labels[self.curr_train_index:
self.curr_train_index + self.hparams.batch_size])
final_imgs = []
images, labels = batched_data
for data in images:
epoch_policy = self.good_policies[np.random.choice(
len(self.good_policies))]
final_img = augmentation_transforms.apply_policy(
epoch_policy, data)
final_img = augmentation_transforms.random_flip(
augmentation_transforms.zero_pad_and_crop(final_img, 4))
# Apply cutout
final_img = augmentation_transforms.cutout_numpy(final_img)
final_imgs.append(final_img)
batched_data = (np.array(final_imgs, np.float32), labels)
self.curr_train_index += self.hparams.batch_size
return batched_data
def reset(self):
"""Reset training data and index into the training data."""
self.epochs = 0
# Shuffle the training data
perm = np.arange(self.num_train)
np.random.shuffle(perm)
assert self.num_train == self.train_images.shape[
0], 'Error incorrect shuffling mask'
self.train_images = self.train_images[perm]
self.train_labels = self.train_labels[perm]
self.curr_train_index = 0
def unpickle(f):
tf.logging.info('loading file: {}'.format(f))
fo = tf.gfile.Open(f, 'r')
d = cPickle.load(fo)
fo.close()
return d
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions used for training AutoAugment models."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
def setup_loss(logits, labels):
"""Returns the cross entropy for the given `logits` and `labels`."""
predictions = tf.nn.softmax(logits)
cost = tf.losses.softmax_cross_entropy(onehot_labels=labels,
logits=logits)
return predictions, cost
def decay_weights(cost, weight_decay_rate):
"""Calculates the loss for l2 weight decay and adds it to `cost`."""
costs = []
for var in tf.trainable_variables():
costs.append(tf.nn.l2_loss(var))
cost += tf.multiply(weight_decay_rate, tf.add_n(costs))
return cost
def eval_child_model(session, model, data_loader, mode):
"""Evaluates `model` on held out data depending on `mode`.
Args:
session: TensorFlow session the model will be run with.
model: TensorFlow model that will be evaluated.
data_loader: DataSet object that contains data that `model` will
evaluate.
mode: Will `model` either evaluate validation or test data.
Returns:
Accuracy of `model` when evaluated on the specified dataset.
Raises:
ValueError: if invalid dataset `mode` is specified.
"""
if mode == 'val':
images = data_loader.val_images
labels = data_loader.val_labels
elif mode == 'test':
images = data_loader.test_images
labels = data_loader.test_labels
else:
raise ValueError('Not valid eval mode')
assert len(images) == len(labels)
tf.logging.info('model.batch_size is {}'.format(model.batch_size))
assert len(images) % model.batch_size == 0
eval_batches = int(len(images) / model.batch_size)
for i in range(eval_batches):
eval_images = images[i * model.batch_size:(i + 1) * model.batch_size]
eval_labels = labels[i * model.batch_size:(i + 1) * model.batch_size]
_ = session.run(
model.eval_op,
feed_dict={
model.images: eval_images,
model.labels: eval_labels,
})
return session.run(model.accuracy)
def cosine_lr(learning_rate, epoch, iteration, batches_per_epoch, total_epochs):
"""Cosine Learning rate.
Args:
learning_rate: Initial learning rate.
epoch: Current epoch we are one. This is one based.
iteration: Current batch in this epoch.
batches_per_epoch: Batches per epoch.
total_epochs: Total epochs you are training for.
Returns:
The learning rate to be used for this current batch.
"""
t_total = total_epochs * batches_per_epoch
t_cur = float(epoch * batches_per_epoch + iteration)
return 0.5 * learning_rate * (1 + np.cos(np.pi * t_cur / t_total))
def get_lr(curr_epoch, hparams, iteration=None):
"""Returns the learning rate during training based on the current epoch."""
assert iteration is not None
batches_per_epoch = int(hparams.train_size / hparams.batch_size)
lr = cosine_lr(hparams.lr, curr_epoch, iteration, batches_per_epoch,
hparams.num_epochs)
return lr
def run_epoch_training(session, model, data_loader, curr_epoch):
"""Runs one epoch of training for the model passed in.
Args:
session: TensorFlow session the model will be run with.
model: TensorFlow model that will be evaluated.
data_loader: DataSet object that contains data that `model` will
evaluate.
curr_epoch: How many of epochs of training have been done so far.
Returns:
The accuracy of 'model' on the training set
"""
steps_per_epoch = int(model.hparams.train_size / model.hparams.batch_size)
tf.logging.info('steps per epoch: {}'.format(steps_per_epoch))
curr_step = session.run(model.global_step)
assert curr_step % steps_per_epoch == 0
# Get the current learning rate for the model based on the current epoch
curr_lr = get_lr(curr_epoch, model.hparams, iteration=0)
tf.logging.info('lr of {} for epoch {}'.format(curr_lr, curr_epoch))
for step in xrange(steps_per_epoch):
curr_lr = get_lr(curr_epoch, model.hparams, iteration=(step + 1))
# Update the lr rate variable to the current LR.
model.lr_rate_ph.load(curr_lr, session=session)
if step % 20 == 0:
tf.logging.info('Training {}/{}'.format(step, steps_per_epoch))
train_images, train_labels = data_loader.next_batch()
_, step, _ = session.run(
[model.train_op, model.global_step, model.eval_op],
feed_dict={
model.images: train_images,
model.labels: train_labels,
})
train_accuracy = session.run(model.accuracy)
tf.logging.info('Train accuracy: {}'.format(train_accuracy))
return train_accuracy
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
def good_policies():
"""AutoAugment policies found on Cifar."""
exp0_0 = [
[('Invert', 0.1, 7), ('Contrast', 0.2, 6)],
[('Rotate', 0.7, 2), ('TranslateX', 0.3, 9)],
[('Sharpness', 0.8, 1), ('Sharpness', 0.9, 3)],
[('ShearY', 0.5, 8), ('TranslateY', 0.7, 9)],
[('AutoContrast', 0.5, 8), ('Equalize', 0.9, 2)]]
exp0_1 = [
[('Solarize', 0.4, 5), ('AutoContrast', 0.9, 3)],
[('TranslateY', 0.9, 9), ('TranslateY', 0.7, 9)],
[('AutoContrast', 0.9, 2), ('Solarize', 0.8, 3)],
[('Equalize', 0.8, 8), ('Invert', 0.1, 3)],
[('TranslateY', 0.7, 9), ('AutoContrast', 0.9, 1)]]
exp0_2 = [
[('Solarize', 0.4, 5), ('AutoContrast', 0.0, 2)],
[('TranslateY', 0.7, 9), ('TranslateY', 0.7, 9)],
[('AutoContrast', 0.9, 0), ('Solarize', 0.4, 3)],
[('Equalize', 0.7, 5), ('Invert', 0.1, 3)],
[('TranslateY', 0.7, 9), ('TranslateY', 0.7, 9)]]
exp0_3 = [
[('Solarize', 0.4, 5), ('AutoContrast', 0.9, 1)],
[('TranslateY', 0.8, 9), ('TranslateY', 0.9, 9)],
[('AutoContrast', 0.8, 0), ('TranslateY', 0.7, 9)],
[('TranslateY', 0.2, 7), ('Color', 0.9, 6)],
[('Equalize', 0.7, 6), ('Color', 0.4, 9)]]
exp1_0 = [
[('ShearY', 0.2, 7), ('Posterize', 0.3, 7)],
[('Color', 0.4, 3), ('Brightness', 0.6, 7)],
[('Sharpness', 0.3, 9), ('Brightness', 0.7, 9)],
[('Equalize', 0.6, 5), ('Equalize', 0.5, 1)],
[('Contrast', 0.6, 7), ('Sharpness', 0.6, 5)]]
exp1_1 = [
[('Brightness', 0.3, 7), ('AutoContrast', 0.5, 8)],
[('AutoContrast', 0.9, 4), ('AutoContrast', 0.5, 6)],
[('Solarize', 0.3, 5), ('Equalize', 0.6, 5)],
[('TranslateY', 0.2, 4), ('Sharpness', 0.3, 3)],
[('Brightness', 0.0, 8), ('Color', 0.8, 8)]]
exp1_2 = [
[('Solarize', 0.2, 6), ('Color', 0.8, 6)],
[('Solarize', 0.2, 6), ('AutoContrast', 0.8, 1)],
[('Solarize', 0.4, 1), ('Equalize', 0.6, 5)],
[('Brightness', 0.0, 0), ('Solarize', 0.5, 2)],
[('AutoContrast', 0.9, 5), ('Brightness', 0.5, 3)]]
exp1_3 = [
[('Contrast', 0.7, 5), ('Brightness', 0.0, 2)],
[('Solarize', 0.2, 8), ('Solarize', 0.1, 5)],
[('Contrast', 0.5, 1), ('TranslateY', 0.2, 9)],
[('AutoContrast', 0.6, 5), ('TranslateY', 0.0, 9)],
[('AutoContrast', 0.9, 4), ('Equalize', 0.8, 4)]]
exp1_4 = [
[('Brightness', 0.0, 7), ('Equalize', 0.4, 7)],
[('Solarize', 0.2, 5), ('Equalize', 0.7, 5)],
[('Equalize', 0.6, 8), ('Color', 0.6, 2)],
[('Color', 0.3, 7), ('Color', 0.2, 4)],
[('AutoContrast', 0.5, 2), ('Solarize', 0.7, 2)]]
exp1_5 = [
[('AutoContrast', 0.2, 0), ('Equalize', 0.1, 0)],
[('ShearY', 0.6, 5), ('Equalize', 0.6, 5)],
[('Brightness', 0.9, 3), ('AutoContrast', 0.4, 1)],
[('Equalize', 0.8, 8), ('Equalize', 0.7, 7)],
[('Equalize', 0.7, 7), ('Solarize', 0.5, 0)]]
exp1_6 = [
[('Equalize', 0.8, 4), ('TranslateY', 0.8, 9)],
[('TranslateY', 0.8, 9), ('TranslateY', 0.6, 9)],
[('TranslateY', 0.9, 0), ('TranslateY', 0.5, 9)],
[('AutoContrast', 0.5, 3), ('Solarize', 0.3, 4)],
[('Solarize', 0.5, 3), ('Equalize', 0.4, 4)]]
exp2_0 = [
[('Color', 0.7, 7), ('TranslateX', 0.5, 8)],
[('Equalize', 0.3, 7), ('AutoContrast', 0.4, 8)],
[('TranslateY', 0.4, 3), ('Sharpness', 0.2, 6)],
[('Brightness', 0.9, 6), ('Color', 0.2, 8)],
[('Solarize', 0.5, 2), ('Invert', 0.0, 3)]]
exp2_1 = [
[('AutoContrast', 0.1, 5), ('Brightness', 0.0, 0)],
[('Cutout', 0.2, 4), ('Equalize', 0.1, 1)],
[('Equalize', 0.7, 7), ('AutoContrast', 0.6, 4)],
[('Color', 0.1, 8), ('ShearY', 0.2, 3)],
[('ShearY', 0.4, 2), ('Rotate', 0.7, 0)]]
exp2_2 = [
[('ShearY', 0.1, 3), ('AutoContrast', 0.9, 5)],
[('TranslateY', 0.3, 6), ('Cutout', 0.3, 3)],
[('Equalize', 0.5, 0), ('Solarize', 0.6, 6)],
[('AutoContrast', 0.3, 5), ('Rotate', 0.2, 7)],
[('Equalize', 0.8, 2), ('Invert', 0.4, 0)]]
exp2_3 = [
[('Equalize', 0.9, 5), ('Color', 0.7, 0)],
[('Equalize', 0.1, 1), ('ShearY', 0.1, 3)],
[('AutoContrast', 0.7, 3), ('Equalize', 0.7, 0)],
[('Brightness', 0.5, 1), ('Contrast', 0.1, 7)],
[('Contrast', 0.1, 4), ('Solarize', 0.6, 5)]]
exp2_4 = [
[('Solarize', 0.2, 3), ('ShearX', 0.0, 0)],
[('TranslateX', 0.3, 0), ('TranslateX', 0.6, 0)],
[('Equalize', 0.5, 9), ('TranslateY', 0.6, 7)],
[('ShearX', 0.1, 0), ('Sharpness', 0.5, 1)],
[('Equalize', 0.8, 6), ('Invert', 0.3, 6)]]
exp2_5 = [
[('AutoContrast', 0.3, 9), ('Cutout', 0.5, 3)],
[('ShearX', 0.4, 4), ('AutoContrast', 0.9, 2)],
[('ShearX', 0.0, 3), ('Posterize', 0.0, 3)],
[('Solarize', 0.4, 3), ('Color', 0.2, 4)],
[('Equalize', 0.1, 4), ('Equalize', 0.7, 6)]]
exp2_6 = [
[('Equalize', 0.3, 8), ('AutoContrast', 0.4, 3)],
[('Solarize', 0.6, 4), ('AutoContrast', 0.7, 6)],
[('AutoContrast', 0.2, 9), ('Brightness', 0.4, 8)],
[('Equalize', 0.1, 0), ('Equalize', 0.0, 6)],
[('Equalize', 0.8, 4), ('Equalize', 0.0, 4)]]
exp2_7 = [
[('Equalize', 0.5, 5), ('AutoContrast', 0.1, 2)],
[('Solarize', 0.5, 5), ('AutoContrast', 0.9, 5)],
[('AutoContrast', 0.6, 1), ('AutoContrast', 0.7, 8)],
[('Equalize', 0.2, 0), ('AutoContrast', 0.1, 2)],
[('Equalize', 0.6, 9), ('Equalize', 0.4, 4)]]
exp0s = exp0_0 + exp0_1 + exp0_2 + exp0_3
exp1s = exp1_0 + exp1_1 + exp1_2 + exp1_3 + exp1_4 + exp1_5 + exp1_6
exp2s = exp2_0 + exp2_1 + exp2_2 + exp2_3 + exp2_4 + exp2_5 + exp2_6 + exp2_7
return exp0s + exp1s + exp2s
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Builds the Shake-Shake Model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import custom_ops as ops
import tensorflow as tf
def round_int(x):
"""Rounds `x` and then converts to an int."""
return int(math.floor(x + 0.5))
def shortcut(x, output_filters, stride):
"""Applies strided avg pool or zero padding to make output_filters match x."""
num_filters = int(x.shape[3])
if stride == 2:
x = ops.avg_pool(x, 2, stride=stride, padding='SAME')
if num_filters != output_filters:
diff = output_filters - num_filters
assert diff > 0
# Zero padd diff zeros
padding = [[0, 0], [0, 0], [0, 0], [0, diff]]
x = tf.pad(x, padding)
return x
def calc_prob(curr_layer, total_layers, p_l):
"""Calculates drop prob depending on the current layer."""
return 1 - (float(curr_layer) / total_layers) * p_l
def bottleneck_layer(x, n, stride, prob, is_training, alpha, beta):
"""Bottleneck layer for shake drop model."""
assert alpha[1] > alpha[0]
assert beta[1] > beta[0]
with tf.variable_scope('bottleneck_{}'.format(prob)):
input_layer = x
x = ops.batch_norm(x, scope='bn_1_pre')
x = ops.conv2d(x, n, 1, scope='1x1_conv_contract')
x = ops.batch_norm(x, scope='bn_1_post')
x = tf.nn.relu(x)
x = ops.conv2d(x, n, 3, stride=stride, scope='3x3')
x = ops.batch_norm(x, scope='bn_2')
x = tf.nn.relu(x)
x = ops.conv2d(x, n * 4, 1, scope='1x1_conv_expand')
x = ops.batch_norm(x, scope='bn_3')
# Apply regularization here
# Sample bernoulli with prob
if is_training:
batch_size = tf.shape(x)[0]
bern_shape = [batch_size, 1, 1, 1]
random_tensor = prob
random_tensor += tf.random_uniform(bern_shape, dtype=tf.float32)
binary_tensor = tf.floor(random_tensor)
alpha_values = tf.random_uniform(
[batch_size, 1, 1, 1], minval=alpha[0], maxval=alpha[1],
dtype=tf.float32)
beta_values = tf.random_uniform(
[batch_size, 1, 1, 1], minval=beta[0], maxval=beta[1],
dtype=tf.float32)
rand_forward = (
binary_tensor + alpha_values - binary_tensor * alpha_values)
rand_backward = (
binary_tensor + beta_values - binary_tensor * beta_values)
x = x * rand_backward + tf.stop_gradient(x * rand_forward -
x * rand_backward)
else:
expected_alpha = (alpha[1] + alpha[0])/2
# prob is the expectation of the bernoulli variable
x = (prob + expected_alpha - prob * expected_alpha) * x
res = shortcut(input_layer, n * 4, stride)
return x + res
def build_shake_drop_model(images, num_classes, is_training):
"""Builds the PyramidNet Shake-Drop model.
Build the PyramidNet Shake-Drop model from https://arxiv.org/abs/1802.02375.
Args:
images: Tensor of images that will be fed into the Wide ResNet Model.
num_classes: Number of classed that the model needs to predict.
is_training: Is the model training or not.
Returns:
The logits of the PyramidNet Shake-Drop model.
"""
# ShakeDrop Hparams
p_l = 0.5
alpha_shake = [-1, 1]
beta_shake = [0, 1]
# PyramidNet Hparams
alpha = 200
depth = 272
# This is for the bottleneck architecture specifically
n = int((depth - 2) / 9)
start_channel = 16
add_channel = alpha / (3 * n)
# Building the models
x = images
x = ops.conv2d(x, 16, 3, scope='init_conv')
x = ops.batch_norm(x, scope='init_bn')
layer_num = 1
total_layers = n * 3
start_channel += add_channel
prob = calc_prob(layer_num, total_layers, p_l)
x = bottleneck_layer(
x, round_int(start_channel), 1, prob, is_training, alpha_shake,
beta_shake)
layer_num += 1
for _ in range(1, n):
start_channel += add_channel
prob = calc_prob(layer_num, total_layers, p_l)
x = bottleneck_layer(
x, round_int(start_channel), 1, prob, is_training, alpha_shake,
beta_shake)
layer_num += 1
start_channel += add_channel
prob = calc_prob(layer_num, total_layers, p_l)
x = bottleneck_layer(
x, round_int(start_channel), 2, prob, is_training, alpha_shake,
beta_shake)
layer_num += 1
for _ in range(1, n):
start_channel += add_channel
prob = calc_prob(layer_num, total_layers, p_l)
x = bottleneck_layer(
x, round_int(start_channel), 1, prob, is_training, alpha_shake,
beta_shake)
layer_num += 1
start_channel += add_channel
prob = calc_prob(layer_num, total_layers, p_l)
x = bottleneck_layer(
x, round_int(start_channel), 2, prob, is_training, alpha_shake,
beta_shake)
layer_num += 1
for _ in range(1, n):
start_channel += add_channel
prob = calc_prob(layer_num, total_layers, p_l)
x = bottleneck_layer(
x, round_int(start_channel), 1, prob, is_training, alpha_shake,
beta_shake)
layer_num += 1
assert layer_num - 1 == total_layers
x = ops.batch_norm(x, scope='final_bn')
x = tf.nn.relu(x)
x = ops.global_avg_pool(x)
# Fully connected
logits = ops.fc(x, num_classes)
return logits
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment