Commit dff0f0c1 authored by Alexander Gorban's avatar Alexander Gorban
Browse files

Merge branch 'master' of github.com:tensorflow/models

parents da341f70 36203f09
adversarial_crypto/* @dave-andersen
adversarial_text/* @rsepassi
attention_ocr/* @alexgorban
audioset/* @plakal @dpwe
autoencoders/* @snurkabill
cognitive_mapping_and_planning/* @s-gupta
compression/* @nmjohn
differential_privacy/* @panyx0718
domain_adaptation/* @bousmalis @ddohan
im2txt/* @cshallue
inception/* @shlens @vincentvanhoucke
learning_to_remember_rare_events/* @lukaszkaiser @ofirnachum
lfads/* @jazcollins @susillo
lm_1b/* @oriolvinyals @panyx0718
namignizer/* @knathanieltucker
neural_gpu/* @lukaszkaiser
neural_programmer/* @arvind2505
next_frame_prediction/* @panyx0718
object_detection/* @jch1 @tombstone @derekjchow @jesu9 @dreamdragon
pcl_rl/* @ofirnachum
ptn/* @xcyan @arkanath @hellojas @honglaklee
real_nvp/* @laurent-dinh
rebar/* @gjtucker
resnet/* @panyx0718
skip_thoughts/* @cshallue
slim/* @sguada @nathansilberman
street/* @theraysmith
swivel/* @waterson
syntaxnet/* @calberti @andorardo
textsum/* @panyx0718 @peterjliu
transformer/* @daviddao
tutorials/embedding/* @zffchen78 @a-dai
tutorials/image/* @sherrym @shlens
tutorials/rnn/* @lukaszkaiser @ebrevdo
video_prediction/* @cbfinn
# Contributing guidelines # Contributing guidelines
If you have created a model and would like to publish it here, please send us a If you have created a model and would like to publish it here, please send us a
pull request. For those just getting started with pull reuests, GitHub has a pull request. For those just getting started with pull requests, GitHub has a
[howto](https://help.github.com/articles/using-pull-requests/). [howto](https://help.github.com/articles/using-pull-requests/).
The code for any model in this repository is licensed under the Apache License The code for any model in this repository is licensed under the Apache License
......
...@@ -14,6 +14,7 @@ running TensorFlow 0.12 or earlier, please ...@@ -14,6 +14,7 @@ running TensorFlow 0.12 or earlier, please
- [adversarial_crypto](adversarial_crypto): protecting communications with adversarial neural cryptography. - [adversarial_crypto](adversarial_crypto): protecting communications with adversarial neural cryptography.
- [adversarial_text](adversarial_text): semi-supervised sequence learning with adversarial training. - [adversarial_text](adversarial_text): semi-supervised sequence learning with adversarial training.
- [attention_ocr](attention_ocr): a model for real-world image text extraction. - [attention_ocr](attention_ocr): a model for real-world image text extraction.
- [audioset](audioset): Models and supporting code for use with [AudioSet](http://g.co.audioset).
- [autoencoder](autoencoder): various autoencoders. - [autoencoder](autoencoder): various autoencoders.
- [cognitive_mapping_and_planning](cognitive_mapping_and_planning): implementation of a spatial memory based mapping and planning architecture for visual navigation. - [cognitive_mapping_and_planning](cognitive_mapping_and_planning): implementation of a spatial memory based mapping and planning architecture for visual navigation.
- [compression](compression): compressing and decompressing images using a pre-trained Residual GRU network. - [compression](compression): compressing and decompressing images using a pre-trained Residual GRU network.
...@@ -30,6 +31,7 @@ running TensorFlow 0.12 or earlier, please ...@@ -30,6 +31,7 @@ running TensorFlow 0.12 or earlier, please
- [next_frame_prediction](next_frame_prediction): probabilistic future frame synthesis via cross convolutional networks. - [next_frame_prediction](next_frame_prediction): probabilistic future frame synthesis via cross convolutional networks.
- [object_detection](object_detection): localizing and identifying multiple objects in a single image. - [object_detection](object_detection): localizing and identifying multiple objects in a single image.
- [real_nvp](real_nvp): density estimation using real-valued non-volume preserving (real NVP) transformations. - [real_nvp](real_nvp): density estimation using real-valued non-volume preserving (real NVP) transformations.
- [rebar](rebar): low-variance, unbiased gradient estimates for discrete latent variable models.
- [resnet](resnet): deep and wide residual networks. - [resnet](resnet): deep and wide residual networks.
- [skip_thoughts](skip_thoughts): recurrent neural network sentence-to-vector encoder. - [skip_thoughts](skip_thoughts): recurrent neural network sentence-to-vector encoder.
- [slim](slim): image classification models in TF-Slim. - [slim](slim): image classification models in TF-Slim.
......
...@@ -28,7 +28,7 @@ Pull requests: ...@@ -28,7 +28,7 @@ Pull requests:
virtualenv --system-site-packages ~/.tensorflow virtualenv --system-site-packages ~/.tensorflow
source ~/.tensorflow/bin/activate source ~/.tensorflow/bin/activate
pip install --upgrade pip pip install --upgrade pip
pip install --upgrade tensorflow_gpu pip install --upgrade tensorflow-gpu
``` ```
2. At least 158GB of free disk space to download the FSNS dataset: 2. At least 158GB of free disk space to download the FSNS dataset:
...@@ -65,7 +65,7 @@ To train a model using pre-trained Inception weights as initialization: ...@@ -65,7 +65,7 @@ To train a model using pre-trained Inception weights as initialization:
``` ```
wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz
tar xf inception_v3_2016_08_28.tar.gz tar xf inception_v3_2016_08_28.tar.gz
python train.py --checkpoint_inception=inception_v3.ckpt python train.py --checkpoint_inception=./inception_v3.ckpt
``` ```
To fine tune the Attention OCR model using a checkpoint: To fine tune the Attention OCR model using a checkpoint:
......
# Models for AudioSet: A Large Scale Dataset of Audio Events
This repository provides models and supporting code associated with
[AudioSet](http://g.co/audioset), a dataset of over 2 million human-labeled
10-second YouTube video soundtracks, with labels taken from an ontology of
more than 600 audio event classes.
AudioSet was
[released](https://research.googleblog.com/2017/03/announcing-audioset-dataset-for-audio.html)
in March 2017 by Google's Sound Understanding team to provide a common
large-scale evaluation task for audio event detection as well as a starting
point for a comprehensive vocabulary of sound events.
For more details about AudioSet and the various models we have trained, please
visit the [AudioSet website](http://g.co/audioset) and read our papers:
* Gemmeke, J. et. al.,
[AudioSet: An ontology and human-labelled dataset for audio events](https://research.google.com/pubs/pub45857.html),
ICASSP 2017
* Hershey, S. et. al.,
[CNN Architectures for Large-Scale Audio Classification](https://research.google.com/pubs/pub45611.html),
ICASSP 2017
If you use the pre-trained VGGish model in your published research, we ask that
you cite [CNN Architectures for Large-Scale Audio Classification](https://research.google.com/pubs/pub45611.html).
If you use the AudioSet dataset or the released 128-D embeddings of AudioSet
segments, please cite
[AudioSet: An ontology and human-labelled dataset for audio events](https://research.google.com/pubs/pub45857.html).
## VGGish
The initial AudioSet release included 128-dimensional embeddings of each
AudioSet segment produced from a VGG-like audio classification model that was
trained on a large YouTube dataset (a preliminary version of what later became
[YouTube-8M](https://research.google.com/youtube8m)).
We provide a TensorFlow definition of this model, which we call __*VGGish*__, as
well as supporting code to extract input features for the model from audio
waveforms and to post-process the model embedding output into the same format as
the released embedding features.
### Installation
VGGish depends on the following Python packages:
* [`numpy`](http://www.numpy.org/)
* [`scipy`](http://www.scipy.org/)
* [`resampy`](http://resampy.readthedocs.io/en/latest/)
* [`tensorflow`](http://www.tensorflow.org/)
* [`six`](https://pythonhosted.org/six/)
These are all easily installable via, e.g., `pip install numpy` (as in the
example command sequence below).
Any reasonably recent version of these packages should work. TensorFlow should
be at least version 1.0. We have tested with Python 2.7.6 and 3.4.3 on an
Ubuntu-like system with NumPy v1.13.1, SciPy v0.19.1, resampy v0.1.5, TensorFlow
v1.2.1, and Six v1.10.0.
VGGish also requires downloading two data files:
* [VGGish model checkpoint](https://storage.googleapis.com/audioset/vggish_model.ckpt),
in TensorFlow checkpoint format.
* [Embedding PCA parameters](https://storage.googleapis.com/audioset/vggish_pca_params.npz),
in NumPy compressed archive format.
After downloading these files into the same directory as this README, the
installation can be tested by running `python vggish_smoke_test.py` which
runs a known signal through the model and checks the output.
Here's a sample installation and test session:
```shell
# You can optionally install and test VGGish within a Python virtualenv, which
# is useful for isolating changes from the rest of your system. For example, you
# may have an existing version of some packages that you do not want to upgrade,
# or you want to try Python 3 instead of Python 2. If you decide to use a
# virtualenv, you can create one by running
# $ virtualenv vggish # For Python 2
# or
# $ python3 -m venv vggish # For Python 3
# and then enter the virtual environment by running
# $ source vggish/bin/activate # Assuming you use bash
# Leave the virtual environment at the end of the session by running
# $ deactivate
# Upgrade pip first.
$ python -m pip install --upgrade pip
# Install dependences. Resampy needs to be installed after NumPy and SciPy
# are already installed.
$ pip install numpy scipy
$ pip install resampy tensorflow six
# Clone TensorFlow models repo into a 'models' directory.
$ git clone https://github.com/tensorflow/models.git
$ cd models/audioset
# Download data files into same directory as code.
$ curl -O https://storage.googleapis.com/audioset/vggish_model.ckpt
$ curl -O https://storage.googleapis.com/audioset/vggish_pca_params.npz
# Installation ready, let's test it.
$ python vggish_smoke_test.py
# If we see "Looks Good To Me", then we're all set.
```
### Usage
VGGish can be used in two ways:
* *As a feature extractor*: VGGish converts audio input features into a
semantically meaningful, high-level 128-D embedding which can be fed as input
to a downstream classification model. The downstream model can be shallower
than usual because the VGGish embedding is more semantically compact than raw
audio features.
So, for example, you could train a classifier for 10 of the AudioSet classes
by using the released embeddings as features. Then, you could use that
trained classifier with any arbitrary audio input by running the audio through
the audio feature extractor and VGGish model provided here, passing the
resulting embedding features as input to your trained model.
`vggish_inference_demo.py` shows how to produce VGGish embeddings from
arbitrary audio.
* *As part of a larger model*: Here, we treat VGGish as a "warm start" for the
lower layers of a model that takes audio features as input and adds more
layers on top of the VGGish embedding. This can be used to fine-tune VGGish
(or parts thereof) if you have large datasets that might be very different
from the typical YouTube video clip. `vggish_train_demo.py` shows how to add
layers on top of VGGish and train the whole model.
### About the Model
The VGGish code layout is as follows:
* `vggish_slim.py`: Model definition in TensorFlow Slim notation.
* `vggish_params.py`: Hyperparameters.
* `vggish_input.py`: Converter from audio waveform into input examples.
* `mel_features.py`: Audio feature extraction helpers.
* `vggish_postprocess.py`: Embedding postprocessing.
* `vggish_inference_demo.py`: Demo of VGGish in inference mode.
* `vggish_train_demo.py`: Demo of VGGish in training mode.
* `vggish_smoke_test.py`: Simple test of a VGGish installation
#### Architecture
See `vggish_slim.py` and `vggish_params.py`.
VGGish is a variant of the [VGG](https://arxiv.org/abs/1409.1556) model, in
particular Configuration A with 11 weight layers. Specifically, here are the
changes we made:
* The input size was changed to 96x64 for log mel spectrogram audio inputs.
* We drop the last group of convolutional and maxpool layers, so we now have
only four groups of convolution/maxpool layers instead of five.
* Instead of a 1000-wide fully connected layer at the end, we use a 128-wide
fully connected layer. This acts as a compact embedding layer.
The model definition provided here defines layers up to and including the
128-wide embedding layer.
#### Input: Audio Features
See `vggish_input.py` and `mel_features.py`.
VGGish was trained with audio features computed as follows:
* All audio is resampled to 16 kHz mono.
* A spectrogram is computed using magnitudes of the Short-Time Fourier Transform
with a window size of 25 ms, a window hop of 10 ms, and a periodic Hann
window.
* A mel spectrogram is computed by mapping the spectrogram to 64 mel bins
covering the range 125-7500 Hz.
* A stabilized log mel spectrogram is computed by applying
log(mel-spectrum + 0.01) where the offset is used to avoid taking a logarithm
of zero.
* These features are then framed into non-overlapping examples of 0.96 seconds,
where each example covers 64 mel bands and 96 frames of 10 ms each.
We provide our own NumPy implementation that produces features that are very
similar to those produced by our internal production code. This results in
embedding outputs that are closely match the embeddings that we have already
released. Note that these embeddings will *not* be bit-for-bit identical to the
released embeddings due to small differences between the feature computation
code paths, and even between two different installations of VGGish with
different underlying libraries and hardware. However, we expect that the
embeddings will be equivalent in the context of a downstream classification
task.
#### Output: Embeddings
See `vggish_postprocess.py`.
The released AudioSet embeddings were postprocessed before release by applying a
PCA transformation (which performs both PCA and whitening) as well as
quantization to 8 bits per embedding element. This was done to be compatible
with the [YouTube-8M](https://research.google.com/youtube8m) project which has
released visual and audio embeddings for millions of YouTube videos in the same
PCA/whitened/quantized format.
We provide a Python implementation of the postprocessing which can be applied to
batches of embeddings produced by VGGish. `vggish_inference_demo.py` shows how
the postprocessor can be run after inference.
If you don't need to use the released embeddings or YouTube-8M, then you could
skip postprocessing and use raw embeddings.
### Future Work
Below are some of the things we would like to add to this repository. We
welcome pull requests for these or other enhancements, but please consider
sending an email to the mailing list (see the Contact section) describing what
you plan to do before you invest a lot of time, to get feedback from us and the
rest of the community.
* An AudioSet classifier trained on top of the VGGish embeddings to predict all
the AudioSet labels. This can act as a baseline for audio research using
AudioSet.
* Feature extraction implemented within TensorFlow using the upcoming
[tf.contrib.signal](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/api_guides/python/contrib.signal.md)
ops.
* A Keras version of the VGGish model definition and checkpoint.
* Jupyter notebook demonstrating audio feature extraction and model performance.
## Contact
For general questions about AudioSet and VGGish, please use the
[audioset-users@googlegroups.com](https://groups.google.com/forum/#!forum/audioset-users)
mailing list.
For technical problems with the released model and code, please open an issue on
the [tensorflow/models issue tracker](https://github.com/tensorflow/models/issues)
and __*assign to @plakal and @dpwe*__. Please note that because the issue tracker
is shared across all models released by Google, we won't be notified about an
issue unless you explicitly @-mention us (@plakal and @dpwe) or assign the issue
to us.
## Credits
Original authors and reviewers of the code in this package include (in
alphabetical order):
* DAn Ellis
* Shawn Hershey
* Aren Jansen
* Manoj Plakal
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines routines to compute mel spectrogram features from audio waveform."""
import numpy as np
def frame(data, window_length, hop_length):
"""Convert array into a sequence of successive possibly overlapping frames.
An n-dimensional array of shape (num_samples, ...) is converted into an
(n+1)-D array of shape (num_frames, window_length, ...), where each frame
starts hop_length points after the preceding one.
This is accomplished using stride_tricks, so the original data is not
copied. However, there is no zero-padding, so any incomplete frames at the
end are not included.
Args:
data: np.array of dimension N >= 1.
window_length: Number of samples in each frame.
hop_length: Advance (in samples) between each window.
Returns:
(N+1)-D np.array with as many rows as there are complete frames that can be
extracted.
"""
num_samples = data.shape[0]
num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length))
shape = (num_frames, window_length) + data.shape[1:]
strides = (data.strides[0] * hop_length,) + data.strides
return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides)
def periodic_hann(window_length):
"""Calculate a "periodic" Hann window.
The classic Hann window is defined as a raised cosine that starts and
ends on zero, and where every value appears twice, except the middle
point for an odd-length window. Matlab calls this a "symmetric" window
and np.hanning() returns it. However, for Fourier analysis, this
actually represents just over one cycle of a period N-1 cosine, and
thus is not compactly expressed on a length-N Fourier basis. Instead,
it's better to use a raised cosine that ends just before the final
zero value - i.e. a complete cycle of a period-N cosine. Matlab
calls this a "periodic" window. This routine calculates it.
Args:
window_length: The number of points in the returned window.
Returns:
A 1D np.array containing the periodic hann window.
"""
return 0.5 - (0.5 * np.cos(2 * np.pi / window_length *
np.arange(window_length)))
def stft_magnitude(signal, fft_length,
hop_length=None,
window_length=None):
"""Calculate the short-time Fourier transform magnitude.
Args:
signal: 1D np.array of the input time-domain signal.
fft_length: Size of the FFT to apply.
hop_length: Advance (in samples) between each frame passed to FFT.
window_length: Length of each block of samples to pass to FFT.
Returns:
2D np.array where each row contains the magnitudes of the fft_length/2+1
unique values of the FFT for the corresponding frame of input samples.
"""
frames = frame(signal, window_length, hop_length)
# Apply frame window to each frame. We use a periodic Hann (cosine of period
# window_length) instead of the symmetric Hann of np.hanning (period
# window_length-1).
window = periodic_hann(window_length)
windowed_frames = frames * window
return np.abs(np.fft.rfft(windowed_frames, int(fft_length)))
# Mel spectrum constants and functions.
_MEL_BREAK_FREQUENCY_HERTZ = 700.0
_MEL_HIGH_FREQUENCY_Q = 1127.0
def hertz_to_mel(frequencies_hertz):
"""Convert frequencies to mel scale using HTK formula.
Args:
frequencies_hertz: Scalar or np.array of frequencies in hertz.
Returns:
Object of same size as frequencies_hertz containing corresponding values
on the mel scale.
"""
return _MEL_HIGH_FREQUENCY_Q * np.log(
1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
def spectrogram_to_mel_matrix(num_mel_bins=20,
num_spectrogram_bins=129,
audio_sample_rate=8000,
lower_edge_hertz=125.0,
upper_edge_hertz=3800.0):
"""Return a matrix that can post-multiply spectrogram rows to make mel.
Returns a np.array matrix A that can be used to post-multiply a matrix S of
spectrogram values (STFT magnitudes) arranged as frames x bins to generate a
"mel spectrogram" M of frames x num_mel_bins. M = S A.
The classic HTK algorithm exploits the complementarity of adjacent mel bands
to multiply each FFT bin by only one mel weight, then add it, with positive
and negative signs, to the two adjacent mel bands to which that bin
contributes. Here, by expressing this operation as a matrix multiply, we go
from num_fft multiplies per frame (plus around 2*num_fft adds) to around
num_fft^2 multiplies and adds. However, because these are all presumably
accomplished in a single call to np.dot(), it's not clear which approach is
faster in Python. The matrix multiplication has the attraction of being more
general and flexible, and much easier to read.
Args:
num_mel_bins: How many bands in the resulting mel spectrum. This is
the number of columns in the output matrix.
num_spectrogram_bins: How many bins there are in the source spectrogram
data, which is understood to be fft_size/2 + 1, i.e. the spectrogram
only contains the nonredundant FFT bins.
audio_sample_rate: Samples per second of the audio at the input to the
spectrogram. We need this to figure out the actual frequencies for
each spectrogram bin, which dictates how they are mapped into mel.
lower_edge_hertz: Lower bound on the frequencies to be included in the mel
spectrum. This corresponds to the lower edge of the lowest triangular
band.
upper_edge_hertz: The desired top edge of the highest frequency band.
Returns:
An np.array with shape (num_spectrogram_bins, num_mel_bins).
Raises:
ValueError: if frequency edges are incorrectly ordered.
"""
nyquist_hertz = audio_sample_rate / 2.
if lower_edge_hertz >= upper_edge_hertz:
raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" %
(lower_edge_hertz, upper_edge_hertz))
spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins)
spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz)
# The i'th mel band (starting from i=1) has center frequency
# band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge
# band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in
# the band_edges_mel arrays.
band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz),
hertz_to_mel(upper_edge_hertz), num_mel_bins + 2)
# Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins
# of spectrogram values.
mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins))
for i in range(num_mel_bins):
lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3]
# Calculate lower and upper slopes for every spectrogram bin.
# Line segments are linear in the *mel* domain, not hertz.
lower_slope = ((spectrogram_bins_mel - lower_edge_mel) /
(center_mel - lower_edge_mel))
upper_slope = ((upper_edge_mel - spectrogram_bins_mel) /
(upper_edge_mel - center_mel))
# .. then intersect them with each other and zero.
mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope,
upper_slope))
# HTK excludes the spectrogram DC bin; make sure it always gets a zero
# coefficient.
mel_weights_matrix[0, :] = 0.0
return mel_weights_matrix
def log_mel_spectrogram(data,
audio_sample_rate=8000,
log_offset=0.0,
window_length_secs=0.025,
hop_length_secs=0.010,
**kwargs):
"""Convert waveform to a log magnitude mel-frequency spectrogram.
Args:
data: 1D np.array of waveform data.
audio_sample_rate: The sampling rate of data.
log_offset: Add this to values when taking log to avoid -Infs.
window_length_secs: Duration of each window to analyze.
hop_length_secs: Advance between successive analysis windows.
**kwargs: Additional arguments to pass to spectrogram_to_mel_matrix.
Returns:
2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank
magnitudes for successive frames.
"""
window_length_samples = int(round(audio_sample_rate * window_length_secs))
hop_length_samples = int(round(audio_sample_rate * hop_length_secs))
fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
spectrogram = stft_magnitude(
data,
fft_length=fft_length,
hop_length=hop_length_samples,
window_length=window_length_samples)
mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix(
num_spectrogram_bins=spectrogram.shape[1],
audio_sample_rate=audio_sample_rate, **kwargs))
return np.log(mel_spectrogram + log_offset)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""A simple demonstration of running VGGish in inference mode.
This is intended as a toy example that demonstrates how the various building
blocks (feature extraction, model definition and loading, postprocessing) work
together in an inference context.
A WAV file (assumed to contain signed 16-bit PCM samples) is read in, converted
into log mel spectrogram examples, fed into VGGish, the raw embedding output is
whitened and quantized, and the postprocessed embeddings are optionally written
in a SequenceExample to a TFRecord file (using the same format as the embedding
features released in AudioSet).
Usage:
# Run a WAV file through the model and print the embeddings. The model
# checkpoint is loaded from vggish_model.ckpt and the PCA parameters are
# loaded from vggish_pca_params.npz in the current directory.
$ python vggish_inference_demo.py --wav_file /path/to/a/wav/file
# Run a WAV file through the model and also write the embeddings to
# a TFRecord file. The model checkpoint and PCA parameters are explicitly
# passed in as well.
$ python vggish_inference_demo.py --wav_file /path/to/a/wav/file \
--tfrecord_file /path/to/tfrecord/file \
--checkpoint /path/to/model/checkpoint \
--pca_params /path/to/pca/params
# Run a built-in input (a sine wav) through the model and print the
# embeddings. Associated model files are read from the current directory.
$ python vggish_inference_demo.py
"""
from __future__ import print_function
import numpy as np
from scipy.io import wavfile
import six
import tensorflow as tf
import vggish_input
import vggish_params
import vggish_postprocess
import vggish_slim
flags = tf.app.flags
flags.DEFINE_string(
'wav_file', None,
'Path to a wav file. Should contain signed 16-bit PCM samples. '
'If none is provided, a synthetic sound is used.')
flags.DEFINE_string(
'checkpoint', 'vggish_model.ckpt',
'Path to the VGGish checkpoint file.')
flags.DEFINE_string(
'pca_params', 'vggish_pca_params.npz',
'Path to the VGGish PCA parameters file.')
flags.DEFINE_string(
'tfrecord_file', None,
'Path to a TFRecord file where embeddings will be written.')
FLAGS = flags.FLAGS
def main(_):
# In this simple example, we run the examples from a single audio file through
# the model. If none is provided, we generate a synthetic input.
if FLAGS.wav_file:
wav_file = FLAGS.wav_file
else:
# Write a WAV of a sine wav into an in-memory file object.
num_secs = 5
freq = 1000
sr = 44100
t = np.linspace(0, num_secs, int(num_secs * sr))
x = np.sin(2 * np.pi * freq * t)
# Convert to signed 16-bit samples.
samples = np.clip(x * 32768, -32768, 32767).astype(np.int16)
wav_file = six.BytesIO()
wavfile.write(wav_file, sr, samples)
wav_file.seek(0)
examples_batch = vggish_input.wavfile_to_examples(wav_file)
print(examples_batch)
# Prepare a postprocessor to munge the model embeddings.
pproc = vggish_postprocess.Postprocessor(FLAGS.pca_params)
# If needed, prepare a record writer to store the postprocessed embeddings.
writer = tf.python_io.TFRecordWriter(
FLAGS.tfrecord_file) if FLAGS.tfrecord_file else None
with tf.Graph().as_default(), tf.Session() as sess:
# Define the model in inference mode, load the checkpoint, and
# locate input and output tensors.
vggish_slim.define_vggish_slim(training=False)
vggish_slim.load_vggish_slim_checkpoint(sess, FLAGS.checkpoint)
features_tensor = sess.graph.get_tensor_by_name(
vggish_params.INPUT_TENSOR_NAME)
embedding_tensor = sess.graph.get_tensor_by_name(
vggish_params.OUTPUT_TENSOR_NAME)
# Run inference and postprocessing.
[embedding_batch] = sess.run([embedding_tensor],
feed_dict={features_tensor: examples_batch})
print(embedding_batch)
postprocessed_batch = pproc.postprocess(embedding_batch)
print(postprocessed_batch)
# Write the postprocessed embeddings as a SequenceExample, in a similar
# format as the features released in AudioSet. Each row of the batch of
# embeddings corresponds to roughly a second of audio (96 10ms frames), and
# the rows are written as a sequence of bytes-valued features, where each
# feature value contains the 128 bytes of the whitened quantized embedding.
seq_example = tf.train.SequenceExample(
feature_lists=tf.train.FeatureLists(
feature_list={
vggish_params.AUDIO_EMBEDDING_FEATURE_NAME:
tf.train.FeatureList(
feature=[
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[embedding.tobytes()]))
for embedding in postprocessed_batch
]
)
}
)
)
print(seq_example)
if writer:
writer.write(seq_example.SerializeToString())
if writer:
writer.close()
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Compute input examples for VGGish from audio waveform."""
import numpy as np
import resampy
from scipy.io import wavfile
import mel_features
import vggish_params
def waveform_to_examples(data, sample_rate):
"""Converts audio waveform into an array of examples for VGGish.
Args:
data: np.array of either one dimension (mono) or two dimensions
(multi-channel, with the outer dimension representing channels).
Each sample is generally expected to lie in the range [-1.0, +1.0],
although this is not required.
sample_rate: Sample rate of data.
Returns:
3-D np.array of shape [num_examples, num_frames, num_bands] which represents
a sequence of examples, each of which contains a patch of log mel
spectrogram, covering num_frames frames of audio and num_bands mel frequency
bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS.
"""
# Convert to mono.
if len(data.shape) > 1:
data = np.mean(data, axis=1)
# Resample to the rate assumed by VGGish.
if sample_rate != vggish_params.SAMPLE_RATE:
data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE)
# Compute log mel spectrogram features.
log_mel = mel_features.log_mel_spectrogram(
data,
audio_sample_rate=vggish_params.SAMPLE_RATE,
log_offset=vggish_params.LOG_OFFSET,
window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS,
hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS,
num_mel_bins=vggish_params.NUM_MEL_BINS,
lower_edge_hertz=vggish_params.MEL_MIN_HZ,
upper_edge_hertz=vggish_params.MEL_MAX_HZ)
# Frame features into examples.
features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS
example_window_length = int(round(
vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate))
example_hop_length = int(round(
vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate))
log_mel_examples = mel_features.frame(
log_mel,
window_length=example_window_length,
hop_length=example_hop_length)
return log_mel_examples
def wavfile_to_examples(wav_file):
"""Convenience wrapper around waveform_to_examples() for a common WAV format.
Args:
wav_file: String path to a file, or a file-like object. The file
is assumed to contain WAV audio data with signed 16-bit PCM samples.
Returns:
See waveform_to_examples.
"""
sr, wav_data = wavfile.read(wav_file)
assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
samples = wav_data / 32768.0 # Convert to [-1.0, +1.0]
return waveform_to_examples(samples, sr)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Global parameters for the VGGish model.
See vggish_slim.py for more information.
"""
# Architectural constants.
NUM_FRAMES = 96 # Frames in input mel-spectrogram patch.
NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch.
EMBEDDING_SIZE = 128 # Size of embedding layer.
# Hyperparameters used in feature and example generation.
SAMPLE_RATE = 16000
STFT_WINDOW_LENGTH_SECONDS = 0.025
STFT_HOP_LENGTH_SECONDS = 0.010
NUM_MEL_BINS = NUM_BANDS
MEL_MIN_HZ = 125
MEL_MAX_HZ = 7500
LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram.
EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames
EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap.
# Parameters used for embedding postprocessing.
PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors'
PCA_MEANS_NAME = 'pca_means'
QUANTIZE_MIN_VAL = -2.0
QUANTIZE_MAX_VAL = +2.0
# Hyperparameters used in training.
INIT_STDDEV = 0.01 # Standard deviation used to initialize weights.
LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer.
ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer.
# Names of ops, tensors, and features.
INPUT_OP_NAME = 'vggish/input_features'
INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0'
OUTPUT_OP_NAME = 'vggish/embedding'
OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0'
AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding'
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Post-process embeddings from VGGish."""
import numpy as np
import vggish_params
class Postprocessor(object):
"""Post-processes VGGish embeddings.
The initial release of AudioSet included 128-D VGGish embeddings for each
segment of AudioSet. These released embeddings were produced by applying
a PCA transformation (technically, a whitening transform is included as well)
and 8-bit quantization to the raw embedding output from VGGish, in order to
stay compatible with the YouTube-8M project which provides visual embeddings
in the same format for a large set of YouTube videos. This class implements
the same PCA (with whitening) and quantization transformations.
"""
def __init__(self, pca_params_npz_path):
"""Constructs a postprocessor.
Args:
pca_params_npz_path: Path to a NumPy-format .npz file that
contains the PCA parameters used in postprocessing.
"""
params = np.load(pca_params_npz_path)
self._pca_matrix = params[vggish_params.PCA_EIGEN_VECTORS_NAME]
# Load means into a column vector for easier broadcasting later.
self._pca_means = params[vggish_params.PCA_MEANS_NAME].reshape(-1, 1)
assert self._pca_matrix.shape == (
vggish_params.EMBEDDING_SIZE, vggish_params.EMBEDDING_SIZE), (
'Bad PCA matrix shape: %r' % (self._pca_matrix.shape,))
assert self._pca_means.shape == (vggish_params.EMBEDDING_SIZE, 1), (
'Bad PCA means shape: %r' % (self._pca_means.shape,))
def postprocess(self, embeddings_batch):
"""Applies postprocessing to a batch of embeddings.
Args:
embeddings_batch: An nparray of shape [batch_size, embedding_size]
containing output from the embedding layer of VGGish.
Returns:
An nparray of the same shape as the input but of type uint8,
containing the PCA-transformed and quantized version of the input.
"""
assert len(embeddings_batch.shape) == 2, (
'Expected 2-d batch, got %r' % (embeddings_batch.shape,))
assert embeddings_batch.shape[1] == vggish_params.EMBEDDING_SIZE, (
'Bad batch shape: %r' % (embeddings_batch.shape,))
# Apply PCA.
# - Embeddings come in as [batch_size, embedding_size].
# - Transpose to [embedding_size, batch_size].
# - Subtract pca_means column vector from each column.
# - Premultiply by PCA matrix of shape [output_dims, input_dims]
# where both are are equal to embedding_size in our case.
# - Transpose result back to [batch_size, embedding_size].
pca_applied = np.dot(self._pca_matrix,
(embeddings_batch.T - self._pca_means)).T
# Quantize by:
# - clipping to [min, max] range
clipped_embeddings = np.clip(
pca_applied, vggish_params.QUANTIZE_MIN_VAL,
vggish_params.QUANTIZE_MAX_VAL)
# - convert to 8-bit in range [0.0, 255.0]
quantized_embeddings = (
(clipped_embeddings - vggish_params.QUANTIZE_MIN_VAL) *
(255.0 /
(vggish_params.QUANTIZE_MAX_VAL - vggish_params.QUANTIZE_MIN_VAL)))
# - cast 8-bit float to uint8
quantized_embeddings = quantized_embeddings.astype(np.uint8)
return quantized_embeddings
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines the 'VGGish' model used to generate AudioSet embedding features.
The public AudioSet release (https://research.google.com/audioset/download.html)
includes 128-D features extracted from the embedding layer of a VGG-like model
that was trained on a large Google-internal YouTube dataset. Here we provide
a TF-Slim definition of the same model, without any dependences on libraries
internal to Google. We call it 'VGGish'.
Note that we only define the model up to the embedding layer, which is the
penultimate layer before the final classifier layer. We also provide various
hyperparameter values (in vggish_params.py) that were used to train this model
internally.
For comparison, here is TF-Slim's VGG definition:
https://github.com/tensorflow/models/blob/master/slim/nets/vgg.py
"""
import tensorflow as tf
import vggish_params as params
slim = tf.contrib.slim
def define_vggish_slim(training=False):
"""Defines the VGGish TensorFlow model.
All ops are created in the current default graph, under the scope 'vggish/'.
The input is a placeholder named 'vggish/input_features' of type float32 and
shape [batch_size, num_frames, num_bands] where batch_size is variable and
num_frames and num_bands are constants, and [num_frames, num_bands] represents
a log-mel-scale spectrogram patch covering num_bands frequency bands and
num_frames time frames (where each frame step is usually 10ms). This is
produced by computing the stabilized log(mel-spectrogram + params.LOG_OFFSET).
The output is an op named 'vggish/embedding' which produces the activations of
a 128-D embedding layer, which is usually the penultimate layer when used as
part of a full model with a final classifier layer.
Args:
training: If true, all parameters are marked trainable.
Returns:
The op 'vggish/embeddings'.
"""
# Defaults:
# - All weights are initialized to N(0, INIT_STDDEV).
# - All biases are initialized to 0.
# - All activations are ReLU.
# - All convolutions are 3x3 with stride 1 and SAME padding.
# - All max-pools are 2x2 with stride 2 and SAME padding.
with slim.arg_scope([slim.conv2d, slim.fully_connected],
weights_initializer=tf.truncated_normal_initializer(
stddev=params.INIT_STDDEV),
biases_initializer=tf.zeros_initializer(),
activation_fn=tf.nn.relu,
trainable=training), \
slim.arg_scope([slim.conv2d],
kernel_size=[3, 3], stride=1, padding='SAME'), \
slim.arg_scope([slim.max_pool2d],
kernel_size=[2, 2], stride=2, padding='SAME'), \
tf.variable_scope('vggish'):
# Input: a batch of 2-D log-mel-spectrogram patches.
features = tf.placeholder(
tf.float32, shape=(None, params.NUM_FRAMES, params.NUM_BANDS),
name='input_features')
# Reshape to 4-D so that we can convolve a batch with conv2d().
net = tf.reshape(features, [-1, params.NUM_FRAMES, params.NUM_BANDS, 1])
# The VGG stack of alternating convolutions and max-pools.
net = slim.conv2d(net, 64, scope='conv1')
net = slim.max_pool2d(net, scope='pool1')
net = slim.conv2d(net, 128, scope='conv2')
net = slim.max_pool2d(net, scope='pool2')
net = slim.repeat(net, 2, slim.conv2d, 256, scope='conv3')
net = slim.max_pool2d(net, scope='pool3')
net = slim.repeat(net, 2, slim.conv2d, 512, scope='conv4')
net = slim.max_pool2d(net, scope='pool4')
# Flatten before entering fully-connected layers
net = slim.flatten(net)
net = slim.repeat(net, 2, slim.fully_connected, 4096, scope='fc1')
# The embedding layer.
net = slim.fully_connected(net, params.EMBEDDING_SIZE, scope='fc2')
return tf.identity(net, name='embedding')
def load_vggish_slim_checkpoint(session, checkpoint_path):
"""Loads a pre-trained VGGish-compatible checkpoint.
This function can be used as an initialization function (referred to as
init_fn in TensorFlow documentation) which is called in a Session after
initializating all variables. When used as an init_fn, this will load
a pre-trained checkpoint that is compatible with the VGGish model
definition. Only variables defined by VGGish will be loaded.
Args:
session: an active TensorFlow session.
checkpoint_path: path to a file containing a checkpoint that is
compatible with the VGGish model definition.
"""
# Get the list of names of all VGGish variables that exist in
# the checkpoint (i.e., all inference-mode VGGish variables).
with tf.Graph().as_default():
define_vggish_slim(training=False)
vggish_var_names = [v.name for v in tf.global_variables()]
# Get the list of all currently existing variables that match
# the list of variable names we just computed.
vggish_vars = [v for v in tf.global_variables() if v.name in vggish_var_names]
# Use a Saver to restore just the variables selected above.
saver = tf.train.Saver(vggish_vars, name='vggish_load_pretrained')
saver.restore(session, checkpoint_path)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A smoke test for VGGish.
This is a simple smoke test of a local install of VGGish and its associated
downloaded files. We create a synthetic sound, extract log mel spectrogram
features, run them through VGGish, post-process the embedding ouputs, and
check some simple statistics of the results, allowing for variations that
might occur due to platform/version differences in the libraries we use.
Usage:
- Download the VGGish checkpoint and PCA parameters into the same directory as
the VGGish source code. If you keep them elsewhere, update the checkpoint_path
and pca_params_path variables below.
- Run:
$ python vggish_smoke_test.py
"""
from __future__ import print_function
import numpy as np
import tensorflow as tf
import vggish_input
import vggish_params
import vggish_postprocess
import vggish_slim
print('\nTesting your install of VGGish\n')
# Paths to downloaded VGGish files.
checkpoint_path = 'vggish_model.ckpt'
pca_params_path = 'vggish_pca_params.npz'
# Relative tolerance of errors in mean and standard deviation of embeddings.
rel_error = 0.1 # Up to 10%
# Generate a 1 kHz sine wave at 44.1 kHz (we use a high sampling rate
# to test resampling to 16 kHz during feature extraction).
num_secs = 3
freq = 1000
sr = 44100
t = np.linspace(0, num_secs, int(num_secs * sr))
x = np.sin(2 * np.pi * freq * t)
# Produce a batch of log mel spectrogram examples.
input_batch = vggish_input.waveform_to_examples(x, sr)
print('Log Mel Spectrogram example: ', input_batch[0])
np.testing.assert_equal(
input_batch.shape,
[num_secs, vggish_params.NUM_FRAMES, vggish_params.NUM_BANDS])
# Define VGGish, load the checkpoint, and run the batch through the model to
# produce embeddings.
with tf.Graph().as_default(), tf.Session() as sess:
vggish_slim.define_vggish_slim()
vggish_slim.load_vggish_slim_checkpoint(sess, checkpoint_path)
features_tensor = sess.graph.get_tensor_by_name(
vggish_params.INPUT_TENSOR_NAME)
embedding_tensor = sess.graph.get_tensor_by_name(
vggish_params.OUTPUT_TENSOR_NAME)
[embedding_batch] = sess.run([embedding_tensor],
feed_dict={features_tensor: input_batch})
print('VGGish embedding: ', embedding_batch[0])
expected_embedding_mean = 0.131
expected_embedding_std = 0.238
np.testing.assert_allclose(
[np.mean(embedding_batch), np.std(embedding_batch)],
[expected_embedding_mean, expected_embedding_std],
rtol=rel_error)
# Postprocess the results to produce whitened quantized embeddings.
pproc = vggish_postprocess.Postprocessor(pca_params_path)
postprocessed_batch = pproc.postprocess(embedding_batch)
print('Postprocessed VGGish embedding: ', postprocessed_batch[0])
expected_postprocessed_mean = 123.0
expected_postprocessed_std = 75.0
np.testing.assert_allclose(
[np.mean(postprocessed_batch), np.std(postprocessed_batch)],
[expected_postprocessed_mean, expected_postprocessed_std],
rtol=rel_error)
print('\nLooks Good To Me!\n')
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""A simple demonstration of running VGGish in training mode.
This is intended as a toy example that demonstrates how to use the VGGish model
definition within a larger model that adds more layers on top, and then train
the larger model. If you let VGGish train as well, then this allows you to
fine-tune the VGGish model parameters for your application. If you don't let
VGGish train, then you use VGGish as a feature extractor for the layers above
it.
For this toy task, we are training a classifier to distinguish between three
classes: sine waves, constant signals, and white noise. We generate synthetic
waveforms from each of these classes, convert into shuffled batches of log mel
spectrogram examples with associated labels, and feed the batches into a model
that includes VGGish at the bottom and a couple of additional layers on top. We
also plumb in labels that are associated with the examples, which feed a label
loss used for training.
Usage:
# Run training for 100 steps using a model checkpoint in the default
# location (vggish_model.ckpt in the current directory). Allow VGGish
# to get fine-tuned.
$ python vggish_train_demo.py --num_batches 100
# Same as before but run for fewer steps and don't change VGGish parameters
# and use a checkpoint in a different location
$ python vggish_train_demo.py --num_batches 50 \
--train_vggish=False \
--checkpoint /path/to/model/checkpoint
"""
from __future__ import print_function
from random import shuffle
import numpy as np
import tensorflow as tf
import vggish_input
import vggish_params
import vggish_slim
flags = tf.app.flags
slim = tf.contrib.slim
flags.DEFINE_integer(
'num_batches', 30,
'Number of batches of examples to feed into the model. Each batch is of '
'variable size and contains shuffled examples of each class of audio.')
flags.DEFINE_boolean(
'train_vggish', True,
'If Frue, allow VGGish parameters to change during training, thus '
'fine-tuning VGGish. If False, VGGish parameters are fixed, thus using '
'VGGish as a fixed feature extractor.')
flags.DEFINE_string(
'checkpoint', 'vggish_model.ckpt',
'Path to the VGGish checkpoint file.')
FLAGS = flags.FLAGS
_NUM_CLASSES = 3
def _get_examples_batch():
"""Returns a shuffled batch of examples of all audio classes.
Note that this is just a toy function because this is a simple demo intended
to illustrate how the training code might work.
Returns:
a tuple (features, labels) where features is a NumPy array of shape
[batch_size, num_frames, num_bands] where the batch_size is variable and
each row is a log mel spectrogram patch of shape [num_frames, num_bands]
suitable for feeding VGGish, while labels is a NumPy array of shape
[batch_size, num_classes] where each row is a multi-hot label vector that
provides the labels for corresponding rows in features.
"""
# Make a waveform for each class.
num_seconds = 5
sr = 44100 # Sampling rate.
t = np.linspace(0, num_seconds, int(num_seconds * sr)) # Time axis.
# Random sine wave.
freq = np.random.uniform(100, 1000)
sine = np.sin(2 * np.pi * freq * t)
# Random constant signal.
magnitude = np.random.uniform(-1, 1)
const = magnitude * t
# White noise.
noise = np.random.normal(-1, 1, size=t.shape)
# Make examples of each signal and corresponding labels.
# Sine is class index 0, Const class index 1, Noise class index 2.
sine_examples = vggish_input.waveform_to_examples(sine, sr)
sine_labels = np.array([[1, 0, 0]] * sine_examples.shape[0])
const_examples = vggish_input.waveform_to_examples(const, sr)
const_labels = np.array([[0, 1, 0]] * const_examples.shape[0])
noise_examples = vggish_input.waveform_to_examples(noise, sr)
noise_labels = np.array([[0, 0, 1]] * noise_examples.shape[0])
# Shuffle (example, label) pairs across all classes.
all_examples = sine_examples + const_examples + noise_examples
all_labels = sine_labels + const_labels + noise_labels
labeled_examples = list(zip(all_examples, all_labels))
shuffle(labeled_examples)
# Separate and return the features and labels.
features = [example for (example, _) in labeled_examples]
labels = [label for (_, label) in labeled_examples]
return (features, labels)
def main(_):
with tf.Graph().as_default(), tf.Session() as sess:
# Define VGGish.
embeddings = vggish_slim.define_vggish_slim(FLAGS.train_vggish)
# Define a shallow classification model and associated training ops on top
# of VGGish.
with tf.variable_scope('mymodel'):
# Add a fully connected layer with 100 units.
num_units = 100
fc = slim.fully_connected(embeddings, num_units)
# Add a classifier layer at the end, consisting of parallel logistic
# classifiers, one per class. This allows for multi-class tasks.
logits = slim.fully_connected(
fc, _NUM_CLASSES, activation_fn=None, scope='logits')
tf.sigmoid(logits, name='prediction')
# Add training ops.
with tf.variable_scope('train'):
global_step = tf.Variable(
0, name='global_step', trainable=False,
collections=[tf.GraphKeys.GLOBAL_VARIABLES,
tf.GraphKeys.GLOBAL_STEP])
# Labels are assumed to be fed as a batch multi-hot vectors, with
# a 1 in the position of each positive class label, and 0 elsewhere.
labels = tf.placeholder(
tf.float32, shape=(None, _NUM_CLASSES), name='labels')
# Cross-entropy label loss.
xent = tf.nn.sigmoid_cross_entropy_with_logits(
logits=logits, labels=labels, name='xent')
loss = tf.reduce_mean(xent, name='loss_op')
tf.summary.scalar('loss', loss)
# We use the same optimizer and hyperparameters as used to train VGGish.
optimizer = tf.train.AdamOptimizer(
learning_rate=vggish_params.LEARNING_RATE,
epsilon=vggish_params.ADAM_EPSILON)
optimizer.minimize(loss, global_step=global_step, name='train_op')
# Initialize all variables in the model, and then load the pre-trained
# VGGish checkpoint.
sess.run(tf.global_variables_initializer())
vggish_slim.load_vggish_slim_checkpoint(sess, FLAGS.checkpoint)
# Locate all the tensors and ops we need for the training loop.
features_tensor = sess.graph.get_tensor_by_name(
vggish_params.INPUT_TENSOR_NAME)
labels_tensor = sess.graph.get_tensor_by_name('mymodel/train/labels:0')
global_step_tensor = sess.graph.get_tensor_by_name(
'mymodel/train/global_step:0')
loss_tensor = sess.graph.get_tensor_by_name('mymodel/train/loss_op:0')
train_op = sess.graph.get_operation_by_name('mymodel/train/train_op')
# The training loop.
for _ in range(FLAGS.num_batches):
(features, labels) = _get_examples_batch()
[num_steps, loss, _] = sess.run(
[global_step_tensor, loss_tensor, train_op],
feed_dict={features_tensor: features, labels_tensor: labels})
print('Step %d: loss %g' % (num_steps, loss))
if __name__ == '__main__':
tf.app.run()
...@@ -22,7 +22,7 @@ citing the following paper: ...@@ -22,7 +22,7 @@ citing the following paper:
### Contents ### Contents
1. [Requirements: software](#requirements-software) 1. [Requirements: software](#requirements-software)
2. [Requirements: data](#requirements-data) 2. [Requirements: data](#requirements-data)
3. [Test Pre-trained Models](#test-pre_trained-models) 3. [Test Pre-trained Models](#test-pre-trained-models)
4. [Train your Own Models](#train-your-own-models) 4. [Train your Own Models](#train-your-own-models)
### Requirements: software ### Requirements: software
...@@ -46,7 +46,8 @@ citing the following paper: ...@@ -46,7 +46,8 @@ citing the following paper:
``` ```
2. Install [Tensorflow](https://www.tensorflow.org/) inside this virtual 2. Install [Tensorflow](https://www.tensorflow.org/) inside this virtual
environment. Typically done with `pip install --upgrade tensorflow-gpu`. environment. You will need to use one of the latest nightly builds
(see instructions [here](https://github.com/tensorflow/tensorflow#installation)).
3. Swiftshader: We use 3. Swiftshader: We use
[Swiftshader](https://github.com/google/swiftshader.git), a CPU based [Swiftshader](https://github.com/google/swiftshader.git), a CPU based
...@@ -99,8 +100,7 @@ citing the following paper: ...@@ -99,8 +100,7 @@ citing the following paper:
`data/README.md` `data/README.md`
### Test Pre-trained Models ### Test Pre-trained Models
1. Download pre-trained models using 1. Download pre-trained models. See `output/README.md`.
`scripts/scripts_download_pretrained_models.sh`
2. Test models using `scripts/script_test_pretrained_models.sh`. 2. Test models using `scripts/script_test_pretrained_models.sh`.
......
# Domain Separation Networks ## Introduction
This is the code used for two domain adaptation papers.
The `domain_separation` directory contains code for the "Domain Separation
Networks" paper by Bousmalis K., Trigeorgis G., et al. which was presented at
NIPS 2016. The paper can be found here: https://arxiv.org/abs/1608.06019.
## Introduction The `pixel_domain_adaptation` directory contains the code used for the
This code is the code used for the "Domain Separation Networks" paper "Unsupervised Pixel-Level Domain Adaptation with Generative Adversarial
by Bousmalis K., Trigeorgis G., et al. which was presented at NIPS 2016. The Networks" paper by Bousmalis K., et al. (presented at CVPR 2017). The paper can
paper can be found here: https://arxiv.org/abs/1608.06019. be found here: https://arxiv.org/abs/1612.05424. PixelDA aims to perform domain
adaptation by transfering the visual style of the target domain (which has few
or no labels) to a source domain (which has many labels). This is accomplished
using a Generative Adversarial Network (GAN).
## Contact ## Contact
This code was open-sourced by [Konstantinos Bousmalis](https://github.com/bousmalis) (konstantinos@google.com). The domain separation code was open-sourced
by [Konstantinos Bousmalis](https://github.com/bousmalis)
(konstantinos@google.com), while the pixel level domain adaptation code was
open-sourced by [David Dohan](https://github.com/dmrd) (ddohan@google.com).
## Installation ## Installation
You will need to have the following installed on your machine before trying out the DSN code. You will need to have the following installed on your machine before trying out the DSN code.
...@@ -16,26 +26,70 @@ You will need to have the following installed on your machine before trying out ...@@ -16,26 +26,70 @@ You will need to have the following installed on your machine before trying out
* Bazel: https://bazel.build/ * Bazel: https://bazel.build/
## Important Note ## Important Note
Although we are making the code available, you are only able to use the MNIST We are working to open source the pose estimation dataset. For now, the MNIST to
provider for now. We will soon provide a script to download and convert MNIST-M MNIST-M dataset is available. Check back here in a few weeks or wait for a
as well. Check back here in a few weeks or wait for a relevant announcement from relevant announcement from [@bousmalis](https://twitter.com/bousmalis).
[@bousmalis](https://twitter.com/bousmalis).
## Running the code for adapting MNIST to MNIST-M ## Initial setup
In order to run the MNIST to MNIST-M experiments with DANNs and/or DANNs with In order to run the MNIST to MNIST-M experiments, you will need to set the
domain separation (DSNs) you will need to set the directory you used to download data directory:
MNIST and MNIST-M:
``` ```
$ export DSN_DATA_DIR=/your/dir $ export DSN_DATA_DIR=/your/dir
``` ```
Add models and models/slim to your `$PYTHONPATH`: Add models and models/slim to your `$PYTHONPATH` (assumes $PWD is /models):
``` ```
$ export PYTHONPATH=$PYTHONPATH:$PWD:$PWD/slim $ export PYTHONPATH=$PYTHONPATH:$PWD:$PWD/slim
``` ```
## Getting the datasets
You can fetch the MNIST data by running
```
$ bazel run slim:download_and_convert_data -- --dataset_dir $DSN_DATA_DIR --dataset_name=mnist
```
The MNIST-M dataset is available online [here](http://bit.ly/2nrlUAJ). Once it is downloaded and extracted into your data directory, create TFRecord files by running:
```
$ bazel run domain_adaptation/datasets:download_and_convert_mnist_m -- --dataset_dir $DSN_DATA_DIR
```
# Running PixelDA from MNIST to MNIST-M
You can run PixelDA as follows (using Tensorboard to examine the results):
```
$ bazel run domain_adaptation/pixel_domain_adaptation:pixelda_train -- --dataset_dir $DSN_DATA_DIR --source_dataset mnist --target_dataset mnist_m
```
And evaluation as:
```
$ bazel run domain_adaptation/pixel_domain_adaptation:pixelda_eval -- --dataset_dir $DSN_DATA_DIR --source_dataset mnist --target_dataset mnist_m --target_split_name test
```
The MNIST-M results in the paper were run with the following hparams flag:
```
--hparams arch=resnet,domain_loss_weight=0.135603587834,num_training_examples=16000000,style_transfer_loss_weight=0.0113173311334,task_loss_in_g_weight=0.0100959947002,task_tower=mnist,task_tower_in_g_step=true
```
### A note on terminology/language of the code:
The components of the network can be grouped into two parts
which correspond to elements which are jointly optimized: The generator
component and the discriminator component.
The generator component takes either an image or noise vector and produces an
output image.
The discriminator component takes the generated images and the target images
and attempts to discriminate between them.
## Running DSN code for adapting MNIST to MNIST-M
Then you need to build the binaries with Bazel: Then you need to build the binaries with Bazel:
``` ```
......
...@@ -26,10 +26,20 @@ py_library( ...@@ -26,10 +26,20 @@ py_library(
], ],
) )
py_binary(
name = "download_and_convert_mnist_m",
srcs = ["download_and_convert_mnist_m.py"],
deps = [
"//slim:dataset_utils",
],
)
py_binary( py_binary(
name = "mnist_m", name = "mnist_m",
srcs = ["mnist_m.py"], srcs = ["mnist_m.py"],
deps = [ deps = [
"//slim:dataset_utils", "//slim:dataset_utils",
], ],
) )
# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # Copyright 2017 Google Inc.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -11,13 +11,14 @@ ...@@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""A factory-pattern class which returns image/label pairs.""" """A factory-pattern class which returns image/label pairs."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
# Dependency imports
import tensorflow as tf import tensorflow as tf
from slim.datasets import mnist from slim.datasets import mnist
......
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""Downloads and converts MNIST-M data to TFRecords of TF-Example protos.
This module downloads the MNIST-M data, uncompresses it, reads the files
that make up the MNIST-M data and creates two TFRecord datasets: one for train
and one for test. Each TFRecord dataset is comprised of a set of TF-Example
protocol buffers, each of which contain a single image and label.
The script should take about a minute to run.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import sys
# Dependency imports
import numpy as np
from six.moves import urllib
import tensorflow as tf
from slim.datasets import dataset_utils
tf.app.flags.DEFINE_string(
'dataset_dir', None,
'The directory where the output TFRecords and temporary files are saved.')
FLAGS = tf.app.flags.FLAGS
_IMAGE_SIZE = 32
_NUM_CHANNELS = 3
# The number of images in the training set.
_NUM_TRAIN_SAMPLES = 59001
# The number of images to be kept from the training set for the validation set.
_NUM_VALIDATION = 1000
# The number of images in the test set.
_NUM_TEST_SAMPLES = 9001
# Seed for repeatability.
_RANDOM_SEED = 0
# The names of the classes.
_CLASS_NAMES = [
'zero',
'one',
'two',
'three',
'four',
'five',
'size',
'seven',
'eight',
'nine',
]
class ImageReader(object):
"""Helper class that provides TensorFlow image coding utilities."""
def __init__(self):
# Initializes function that decodes RGB PNG data.
self._decode_png_data = tf.placeholder(dtype=tf.string)
self._decode_png = tf.image.decode_png(self._decode_png_data, channels=3)
def read_image_dims(self, sess, image_data):
image = self.decode_png(sess, image_data)
return image.shape[0], image.shape[1]
def decode_png(self, sess, image_data):
image = sess.run(
self._decode_png, feed_dict={self._decode_png_data: image_data})
assert len(image.shape) == 3
assert image.shape[2] == 3
return image
def _convert_dataset(split_name, filenames, filename_to_class_id, dataset_dir):
"""Converts the given filenames to a TFRecord dataset.
Args:
split_name: The name of the dataset, either 'train' or 'valid'.
filenames: A list of absolute paths to png images.
filename_to_class_id: A dictionary from filenames (strings) to class ids
(integers).
dataset_dir: The directory where the converted datasets are stored.
"""
print('Converting the {} split.'.format(split_name))
# Train and validation splits are both in the train directory.
if split_name in ['train', 'valid']:
png_directory = os.path.join(dataset_dir, 'mnist_m', 'mnist_m_train')
elif split_name == 'test':
png_directory = os.path.join(dataset_dir, 'mnist_m', 'mnist_m_test')
with tf.Graph().as_default():
image_reader = ImageReader()
with tf.Session('') as sess:
output_filename = _get_output_filename(dataset_dir, split_name)
with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
for filename in filenames:
# Read the filename:
image_data = tf.gfile.FastGFile(
os.path.join(png_directory, filename), 'r').read()
height, width = image_reader.read_image_dims(sess, image_data)
class_id = filename_to_class_id[filename]
example = dataset_utils.image_to_tfexample(image_data, 'png', height,
width, class_id)
tfrecord_writer.write(example.SerializeToString())
sys.stdout.write('\n')
sys.stdout.flush()
def _extract_labels(label_filename):
"""Extract the labels into a dict of filenames to int labels.
Args:
labels_filename: The filename of the MNIST-M labels.
Returns:
A dictionary of filenames to int labels.
"""
print('Extracting labels from: ', label_filename)
label_file = tf.gfile.FastGFile(label_filename, 'r').readlines()
label_lines = [line.rstrip('\n').split() for line in label_file]
labels = {}
for line in label_lines:
assert len(line) == 2
labels[line[0]] = int(line[1])
return labels
def _get_output_filename(dataset_dir, split_name):
"""Creates the output filename.
Args:
dataset_dir: The directory where the temporary files are stored.
split_name: The name of the train/test split.
Returns:
An absolute file path.
"""
return '%s/mnist_m_%s.tfrecord' % (dataset_dir, split_name)
def _get_filenames(dataset_dir):
"""Returns a list of filenames and inferred class names.
Args:
dataset_dir: A directory containing a set PNG encoded MNIST-M images.
Returns:
A list of image file paths, relative to `dataset_dir`.
"""
photo_filenames = []
for filename in os.listdir(dataset_dir):
photo_filenames.append(filename)
return photo_filenames
def run(dataset_dir):
"""Runs the download and conversion operation.
Args:
dataset_dir: The dataset directory where the dataset is stored.
"""
if not tf.gfile.Exists(dataset_dir):
tf.gfile.MakeDirs(dataset_dir)
train_filename = _get_output_filename(dataset_dir, 'train')
testing_filename = _get_output_filename(dataset_dir, 'test')
if tf.gfile.Exists(train_filename) and tf.gfile.Exists(testing_filename):
print('Dataset files already exist. Exiting without re-creating them.')
return
# TODO(konstantinos): Add download and cleanup functionality
train_validation_filenames = _get_filenames(
os.path.join(dataset_dir, 'mnist_m', 'mnist_m_train'))
test_filenames = _get_filenames(
os.path.join(dataset_dir, 'mnist_m', 'mnist_m_test'))
# Divide into train and validation:
random.seed(_RANDOM_SEED)
random.shuffle(train_validation_filenames)
train_filenames = train_validation_filenames[_NUM_VALIDATION:]
validation_filenames = train_validation_filenames[:_NUM_VALIDATION]
train_validation_filenames_to_class_ids = _extract_labels(
os.path.join(dataset_dir, 'mnist_m', 'mnist_m_train_labels.txt'))
test_filenames_to_class_ids = _extract_labels(
os.path.join(dataset_dir, 'mnist_m', 'mnist_m_test_labels.txt'))
# Convert the train, validation, and test sets.
_convert_dataset('train', train_filenames,
train_validation_filenames_to_class_ids, dataset_dir)
_convert_dataset('valid', validation_filenames,
train_validation_filenames_to_class_ids, dataset_dir)
_convert_dataset('test', test_filenames, test_filenames_to_class_ids,
dataset_dir)
# Finally, write the labels file:
labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
print('\nFinished converting the MNIST-M dataset!')
def main(_):
assert FLAGS.dataset_dir
run(FLAGS.dataset_dir)
if __name__ == '__main__':
tf.app.run()
# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # Copyright 2017 Google Inc.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ==============================================================================
"""Provides data for the MNIST-M dataset. """Provides data for the MNIST-M dataset.
The dataset scripts used to create the dataset can be found at:
tensorflow_models/domain_adaptation_/datasets/download_and_convert_mnist_m_dataset.py
""" """
from __future__ import absolute_import from __future__ import absolute_import
...@@ -20,6 +23,7 @@ from __future__ import division ...@@ -20,6 +23,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
# Dependency imports
import tensorflow as tf import tensorflow as tf
from slim.datasets import dataset_utils from slim.datasets import dataset_utils
......
# Description:
# Contains code for domain-adaptation style transfer.
package(
default_visibility = [
":internal",
],
)
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package_group(
name = "internal",
packages = [
"//domain_adaptation/...",
],
)
py_library(
name = "pixelda_preprocess",
srcs = ["pixelda_preprocess.py"],
deps = [
],
)
py_test(
name = "pixelda_preprocess_test",
srcs = ["pixelda_preprocess_test.py"],
deps = [
":pixelda_preprocess",
],
)
py_library(
name = "pixelda_model",
srcs = [
"pixelda_model.py",
"pixelda_task_towers.py",
"hparams.py",
],
deps = [
],
)
py_library(
name = "pixelda_utils",
srcs = ["pixelda_utils.py"],
deps = [
],
)
py_library(
name = "pixelda_losses",
srcs = ["pixelda_losses.py"],
deps = [
],
)
py_binary(
name = "pixelda_train",
srcs = ["pixelda_train.py"],
deps = [
":pixelda_losses",
":pixelda_model",
":pixelda_preprocess",
":pixelda_utils",
"//domain_adaptation/datasets:dataset_factory",
],
)
py_binary(
name = "pixelda_eval",
srcs = ["pixelda_eval.py"],
deps = [
":pixelda_losses",
":pixelda_model",
":pixelda_preprocess",
":pixelda_utils",
"//domain_adaptation/datasets:dataset_factory",
],
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment