"tests/vscode:/vscode.git/clone" did not exist on "3a0bbb3eddc23559c670693fc8f9f3f0032914c8"
Commit 36203f09 authored by Lukasz Kaiser's avatar Lukasz Kaiser Committed by GitHub
Browse files

Merge pull request #2165 from plakal/master

Added model and supporting code for use with AudioSet
parents a2931072 fc75956f
adversarial_crypto/* @dave-andersen adversarial_crypto/* @dave-andersen
adversarial_text/* @rsepassi adversarial_text/* @rsepassi
attention_ocr/* @alexgorban attention_ocr/* @alexgorban
audioset/* @plakal @dpwe
autoencoders/* @snurkabill autoencoders/* @snurkabill
cognitive_mapping_and_planning/* @s-gupta cognitive_mapping_and_planning/* @s-gupta
compression/* @nmjohn compression/* @nmjohn
......
...@@ -14,6 +14,7 @@ running TensorFlow 0.12 or earlier, please ...@@ -14,6 +14,7 @@ running TensorFlow 0.12 or earlier, please
- [adversarial_crypto](adversarial_crypto): protecting communications with adversarial neural cryptography. - [adversarial_crypto](adversarial_crypto): protecting communications with adversarial neural cryptography.
- [adversarial_text](adversarial_text): semi-supervised sequence learning with adversarial training. - [adversarial_text](adversarial_text): semi-supervised sequence learning with adversarial training.
- [attention_ocr](attention_ocr): a model for real-world image text extraction. - [attention_ocr](attention_ocr): a model for real-world image text extraction.
- [audioset](audioset): Models and supporting code for use with [AudioSet](http://g.co.audioset).
- [autoencoder](autoencoder): various autoencoders. - [autoencoder](autoencoder): various autoencoders.
- [cognitive_mapping_and_planning](cognitive_mapping_and_planning): implementation of a spatial memory based mapping and planning architecture for visual navigation. - [cognitive_mapping_and_planning](cognitive_mapping_and_planning): implementation of a spatial memory based mapping and planning architecture for visual navigation.
- [compression](compression): compressing and decompressing images using a pre-trained Residual GRU network. - [compression](compression): compressing and decompressing images using a pre-trained Residual GRU network.
......
# Models for AudioSet: A Large Scale Dataset of Audio Events
This repository provides models and supporting code associated with
[AudioSet](http://g.co/audioset), a dataset of over 2 million human-labeled
10-second YouTube video soundtracks, with labels taken from an ontology of
more than 600 audio event classes.
AudioSet was
[released](https://research.googleblog.com/2017/03/announcing-audioset-dataset-for-audio.html)
in March 2017 by Google's Sound Understanding team to provide a common
large-scale evaluation task for audio event detection as well as a starting
point for a comprehensive vocabulary of sound events.
For more details about AudioSet and the various models we have trained, please
visit the [AudioSet website](http://g.co/audioset) and read our papers:
* Gemmeke, J. et. al.,
[AudioSet: An ontology and human-labelled dataset for audio events](https://research.google.com/pubs/pub45857.html),
ICASSP 2017
* Hershey, S. et. al.,
[CNN Architectures for Large-Scale Audio Classification](https://research.google.com/pubs/pub45611.html),
ICASSP 2017
If you use the pre-trained VGGish model in your published research, we ask that
you cite [CNN Architectures for Large-Scale Audio Classification](https://research.google.com/pubs/pub45611.html).
If you use the AudioSet dataset or the released 128-D embeddings of AudioSet
segments, please cite
[AudioSet: An ontology and human-labelled dataset for audio events](https://research.google.com/pubs/pub45857.html).
## VGGish
The initial AudioSet release included 128-dimensional embeddings of each
AudioSet segment produced from a VGG-like audio classification model that was
trained on a large YouTube dataset (a preliminary version of what later became
[YouTube-8M](https://research.google.com/youtube8m)).
We provide a TensorFlow definition of this model, which we call __*VGGish*__, as
well as supporting code to extract input features for the model from audio
waveforms and to post-process the model embedding output into the same format as
the released embedding features.
### Installation
VGGish depends on the following Python packages:
* [`numpy`](http://www.numpy.org/)
* [`scipy`](http://www.scipy.org/)
* [`resampy`](http://resampy.readthedocs.io/en/latest/)
* [`tensorflow`](http://www.tensorflow.org/)
* [`six`](https://pythonhosted.org/six/)
These are all easily installable via, e.g., `pip install numpy` (as in the
example command sequence below).
Any reasonably recent version of these packages should work. TensorFlow should
be at least version 1.0. We have tested with Python 2.7.6 and 3.4.3 on an
Ubuntu-like system with NumPy v1.13.1, SciPy v0.19.1, resampy v0.1.5, TensorFlow
v1.2.1, and Six v1.10.0.
VGGish also requires downloading two data files:
* [VGGish model checkpoint](https://storage.googleapis.com/audioset/vggish_model.ckpt),
in TensorFlow checkpoint format.
* [Embedding PCA parameters](https://storage.googleapis.com/audioset/vggish_pca_params.npz),
in NumPy compressed archive format.
After downloading these files into the same directory as this README, the
installation can be tested by running `python vggish_smoke_test.py` which
runs a known signal through the model and checks the output.
Here's a sample installation and test session:
```shell
# You can optionally install and test VGGish within a Python virtualenv, which
# is useful for isolating changes from the rest of your system. For example, you
# may have an existing version of some packages that you do not want to upgrade,
# or you want to try Python 3 instead of Python 2. If you decide to use a
# virtualenv, you can create one by running
# $ virtualenv vggish # For Python 2
# or
# $ python3 -m venv vggish # For Python 3
# and then enter the virtual environment by running
# $ source vggish/bin/activate # Assuming you use bash
# Leave the virtual environment at the end of the session by running
# $ deactivate
# Upgrade pip first.
$ python -m pip install --upgrade pip
# Install dependences. Resampy needs to be installed after NumPy and SciPy
# are already installed.
$ pip install numpy scipy
$ pip install resampy tensorflow six
# Clone TensorFlow models repo into a 'models' directory.
$ git clone https://github.com/tensorflow/models.git
$ cd models/audioset
# Download data files into same directory as code.
$ curl -O https://storage.googleapis.com/audioset/vggish_model.ckpt
$ curl -O https://storage.googleapis.com/audioset/vggish_pca_params.npz
# Installation ready, let's test it.
$ python vggish_smoke_test.py
# If we see "Looks Good To Me", then we're all set.
```
### Usage
VGGish can be used in two ways:
* *As a feature extractor*: VGGish converts audio input features into a
semantically meaningful, high-level 128-D embedding which can be fed as input
to a downstream classification model. The downstream model can be shallower
than usual because the VGGish embedding is more semantically compact than raw
audio features.
So, for example, you could train a classifier for 10 of the AudioSet classes
by using the released embeddings as features. Then, you could use that
trained classifier with any arbitrary audio input by running the audio through
the audio feature extractor and VGGish model provided here, passing the
resulting embedding features as input to your trained model.
`vggish_inference_demo.py` shows how to produce VGGish embeddings from
arbitrary audio.
* *As part of a larger model*: Here, we treat VGGish as a "warm start" for the
lower layers of a model that takes audio features as input and adds more
layers on top of the VGGish embedding. This can be used to fine-tune VGGish
(or parts thereof) if you have large datasets that might be very different
from the typical YouTube video clip. `vggish_train_demo.py` shows how to add
layers on top of VGGish and train the whole model.
### About the Model
The VGGish code layout is as follows:
* `vggish_slim.py`: Model definition in TensorFlow Slim notation.
* `vggish_params.py`: Hyperparameters.
* `vggish_input.py`: Converter from audio waveform into input examples.
* `mel_features.py`: Audio feature extraction helpers.
* `vggish_postprocess.py`: Embedding postprocessing.
* `vggish_inference_demo.py`: Demo of VGGish in inference mode.
* `vggish_train_demo.py`: Demo of VGGish in training mode.
* `vggish_smoke_test.py`: Simple test of a VGGish installation
#### Architecture
See `vggish_slim.py` and `vggish_params.py`.
VGGish is a variant of the [VGG](https://arxiv.org/abs/1409.1556) model, in
particular Configuration A with 11 weight layers. Specifically, here are the
changes we made:
* The input size was changed to 96x64 for log mel spectrogram audio inputs.
* We drop the last group of convolutional and maxpool layers, so we now have
only four groups of convolution/maxpool layers instead of five.
* Instead of a 1000-wide fully connected layer at the end, we use a 128-wide
fully connected layer. This acts as a compact embedding layer.
The model definition provided here defines layers up to and including the
128-wide embedding layer.
#### Input: Audio Features
See `vggish_input.py` and `mel_features.py`.
VGGish was trained with audio features computed as follows:
* All audio is resampled to 16 kHz mono.
* A spectrogram is computed using magnitudes of the Short-Time Fourier Transform
with a window size of 25 ms, a window hop of 10 ms, and a periodic Hann
window.
* A mel spectrogram is computed by mapping the spectrogram to 64 mel bins
covering the range 125-7500 Hz.
* A stabilized log mel spectrogram is computed by applying
log(mel-spectrum + 0.01) where the offset is used to avoid taking a logarithm
of zero.
* These features are then framed into non-overlapping examples of 0.96 seconds,
where each example covers 64 mel bands and 96 frames of 10 ms each.
We provide our own NumPy implementation that produces features that are very
similar to those produced by our internal production code. This results in
embedding outputs that are closely match the embeddings that we have already
released. Note that these embeddings will *not* be bit-for-bit identical to the
released embeddings due to small differences between the feature computation
code paths, and even between two different installations of VGGish with
different underlying libraries and hardware. However, we expect that the
embeddings will be equivalent in the context of a downstream classification
task.
#### Output: Embeddings
See `vggish_postprocess.py`.
The released AudioSet embeddings were postprocessed before release by applying a
PCA transformation (which performs both PCA and whitening) as well as
quantization to 8 bits per embedding element. This was done to be compatible
with the [YouTube-8M](https://research.google.com/youtube8m) project which has
released visual and audio embeddings for millions of YouTube videos in the same
PCA/whitened/quantized format.
We provide a Python implementation of the postprocessing which can be applied to
batches of embeddings produced by VGGish. `vggish_inference_demo.py` shows how
the postprocessor can be run after inference.
If you don't need to use the released embeddings or YouTube-8M, then you could
skip postprocessing and use raw embeddings.
### Future Work
Below are some of the things we would like to add to this repository. We
welcome pull requests for these or other enhancements, but please consider
sending an email to the mailing list (see the Contact section) describing what
you plan to do before you invest a lot of time, to get feedback from us and the
rest of the community.
* An AudioSet classifier trained on top of the VGGish embeddings to predict all
the AudioSet labels. This can act as a baseline for audio research using
AudioSet.
* Feature extraction implemented within TensorFlow using the upcoming
[tf.contrib.signal](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/docs_src/api_guides/python/contrib.signal.md)
ops.
* A Keras version of the VGGish model definition and checkpoint.
* Jupyter notebook demonstrating audio feature extraction and model performance.
## Contact
For general questions about AudioSet and VGGish, please use the
[audioset-users@googlegroups.com](https://groups.google.com/forum/#!forum/audioset-users)
mailing list.
For technical problems with the released model and code, please open an issue on
the [tensorflow/models issue tracker](https://github.com/tensorflow/models/issues)
and __*assign to @plakal and @dpwe*__. Please note that because the issue tracker
is shared across all models released by Google, we won't be notified about an
issue unless you explicitly @-mention us (@plakal and @dpwe) or assign the issue
to us.
## Credits
Original authors and reviewers of the code in this package include (in
alphabetical order):
* DAn Ellis
* Shawn Hershey
* Aren Jansen
* Manoj Plakal
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines routines to compute mel spectrogram features from audio waveform."""
import numpy as np
def frame(data, window_length, hop_length):
"""Convert array into a sequence of successive possibly overlapping frames.
An n-dimensional array of shape (num_samples, ...) is converted into an
(n+1)-D array of shape (num_frames, window_length, ...), where each frame
starts hop_length points after the preceding one.
This is accomplished using stride_tricks, so the original data is not
copied. However, there is no zero-padding, so any incomplete frames at the
end are not included.
Args:
data: np.array of dimension N >= 1.
window_length: Number of samples in each frame.
hop_length: Advance (in samples) between each window.
Returns:
(N+1)-D np.array with as many rows as there are complete frames that can be
extracted.
"""
num_samples = data.shape[0]
num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length))
shape = (num_frames, window_length) + data.shape[1:]
strides = (data.strides[0] * hop_length,) + data.strides
return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides)
def periodic_hann(window_length):
"""Calculate a "periodic" Hann window.
The classic Hann window is defined as a raised cosine that starts and
ends on zero, and where every value appears twice, except the middle
point for an odd-length window. Matlab calls this a "symmetric" window
and np.hanning() returns it. However, for Fourier analysis, this
actually represents just over one cycle of a period N-1 cosine, and
thus is not compactly expressed on a length-N Fourier basis. Instead,
it's better to use a raised cosine that ends just before the final
zero value - i.e. a complete cycle of a period-N cosine. Matlab
calls this a "periodic" window. This routine calculates it.
Args:
window_length: The number of points in the returned window.
Returns:
A 1D np.array containing the periodic hann window.
"""
return 0.5 - (0.5 * np.cos(2 * np.pi / window_length *
np.arange(window_length)))
def stft_magnitude(signal, fft_length,
hop_length=None,
window_length=None):
"""Calculate the short-time Fourier transform magnitude.
Args:
signal: 1D np.array of the input time-domain signal.
fft_length: Size of the FFT to apply.
hop_length: Advance (in samples) between each frame passed to FFT.
window_length: Length of each block of samples to pass to FFT.
Returns:
2D np.array where each row contains the magnitudes of the fft_length/2+1
unique values of the FFT for the corresponding frame of input samples.
"""
frames = frame(signal, window_length, hop_length)
# Apply frame window to each frame. We use a periodic Hann (cosine of period
# window_length) instead of the symmetric Hann of np.hanning (period
# window_length-1).
window = periodic_hann(window_length)
windowed_frames = frames * window
return np.abs(np.fft.rfft(windowed_frames, int(fft_length)))
# Mel spectrum constants and functions.
_MEL_BREAK_FREQUENCY_HERTZ = 700.0
_MEL_HIGH_FREQUENCY_Q = 1127.0
def hertz_to_mel(frequencies_hertz):
"""Convert frequencies to mel scale using HTK formula.
Args:
frequencies_hertz: Scalar or np.array of frequencies in hertz.
Returns:
Object of same size as frequencies_hertz containing corresponding values
on the mel scale.
"""
return _MEL_HIGH_FREQUENCY_Q * np.log(
1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
def spectrogram_to_mel_matrix(num_mel_bins=20,
num_spectrogram_bins=129,
audio_sample_rate=8000,
lower_edge_hertz=125.0,
upper_edge_hertz=3800.0):
"""Return a matrix that can post-multiply spectrogram rows to make mel.
Returns a np.array matrix A that can be used to post-multiply a matrix S of
spectrogram values (STFT magnitudes) arranged as frames x bins to generate a
"mel spectrogram" M of frames x num_mel_bins. M = S A.
The classic HTK algorithm exploits the complementarity of adjacent mel bands
to multiply each FFT bin by only one mel weight, then add it, with positive
and negative signs, to the two adjacent mel bands to which that bin
contributes. Here, by expressing this operation as a matrix multiply, we go
from num_fft multiplies per frame (plus around 2*num_fft adds) to around
num_fft^2 multiplies and adds. However, because these are all presumably
accomplished in a single call to np.dot(), it's not clear which approach is
faster in Python. The matrix multiplication has the attraction of being more
general and flexible, and much easier to read.
Args:
num_mel_bins: How many bands in the resulting mel spectrum. This is
the number of columns in the output matrix.
num_spectrogram_bins: How many bins there are in the source spectrogram
data, which is understood to be fft_size/2 + 1, i.e. the spectrogram
only contains the nonredundant FFT bins.
audio_sample_rate: Samples per second of the audio at the input to the
spectrogram. We need this to figure out the actual frequencies for
each spectrogram bin, which dictates how they are mapped into mel.
lower_edge_hertz: Lower bound on the frequencies to be included in the mel
spectrum. This corresponds to the lower edge of the lowest triangular
band.
upper_edge_hertz: The desired top edge of the highest frequency band.
Returns:
An np.array with shape (num_spectrogram_bins, num_mel_bins).
Raises:
ValueError: if frequency edges are incorrectly ordered.
"""
nyquist_hertz = audio_sample_rate / 2.
if lower_edge_hertz >= upper_edge_hertz:
raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" %
(lower_edge_hertz, upper_edge_hertz))
spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins)
spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz)
# The i'th mel band (starting from i=1) has center frequency
# band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge
# band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in
# the band_edges_mel arrays.
band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz),
hertz_to_mel(upper_edge_hertz), num_mel_bins + 2)
# Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins
# of spectrogram values.
mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins))
for i in range(num_mel_bins):
lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3]
# Calculate lower and upper slopes for every spectrogram bin.
# Line segments are linear in the *mel* domain, not hertz.
lower_slope = ((spectrogram_bins_mel - lower_edge_mel) /
(center_mel - lower_edge_mel))
upper_slope = ((upper_edge_mel - spectrogram_bins_mel) /
(upper_edge_mel - center_mel))
# .. then intersect them with each other and zero.
mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope,
upper_slope))
# HTK excludes the spectrogram DC bin; make sure it always gets a zero
# coefficient.
mel_weights_matrix[0, :] = 0.0
return mel_weights_matrix
def log_mel_spectrogram(data,
audio_sample_rate=8000,
log_offset=0.0,
window_length_secs=0.025,
hop_length_secs=0.010,
**kwargs):
"""Convert waveform to a log magnitude mel-frequency spectrogram.
Args:
data: 1D np.array of waveform data.
audio_sample_rate: The sampling rate of data.
log_offset: Add this to values when taking log to avoid -Infs.
window_length_secs: Duration of each window to analyze.
hop_length_secs: Advance between successive analysis windows.
**kwargs: Additional arguments to pass to spectrogram_to_mel_matrix.
Returns:
2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank
magnitudes for successive frames.
"""
window_length_samples = int(round(audio_sample_rate * window_length_secs))
hop_length_samples = int(round(audio_sample_rate * hop_length_secs))
fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
spectrogram = stft_magnitude(
data,
fft_length=fft_length,
hop_length=hop_length_samples,
window_length=window_length_samples)
mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix(
num_spectrogram_bins=spectrogram.shape[1],
audio_sample_rate=audio_sample_rate, **kwargs))
return np.log(mel_spectrogram + log_offset)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""A simple demonstration of running VGGish in inference mode.
This is intended as a toy example that demonstrates how the various building
blocks (feature extraction, model definition and loading, postprocessing) work
together in an inference context.
A WAV file (assumed to contain signed 16-bit PCM samples) is read in, converted
into log mel spectrogram examples, fed into VGGish, the raw embedding output is
whitened and quantized, and the postprocessed embeddings are optionally written
in a SequenceExample to a TFRecord file (using the same format as the embedding
features released in AudioSet).
Usage:
# Run a WAV file through the model and print the embeddings. The model
# checkpoint is loaded from vggish_model.ckpt and the PCA parameters are
# loaded from vggish_pca_params.npz in the current directory.
$ python vggish_inference_demo.py --wav_file /path/to/a/wav/file
# Run a WAV file through the model and also write the embeddings to
# a TFRecord file. The model checkpoint and PCA parameters are explicitly
# passed in as well.
$ python vggish_inference_demo.py --wav_file /path/to/a/wav/file \
--tfrecord_file /path/to/tfrecord/file \
--checkpoint /path/to/model/checkpoint \
--pca_params /path/to/pca/params
# Run a built-in input (a sine wav) through the model and print the
# embeddings. Associated model files are read from the current directory.
$ python vggish_inference_demo.py
"""
from __future__ import print_function
import numpy as np
from scipy.io import wavfile
import six
import tensorflow as tf
import vggish_input
import vggish_params
import vggish_postprocess
import vggish_slim
flags = tf.app.flags
flags.DEFINE_string(
'wav_file', None,
'Path to a wav file. Should contain signed 16-bit PCM samples. '
'If none is provided, a synthetic sound is used.')
flags.DEFINE_string(
'checkpoint', 'vggish_model.ckpt',
'Path to the VGGish checkpoint file.')
flags.DEFINE_string(
'pca_params', 'vggish_pca_params.npz',
'Path to the VGGish PCA parameters file.')
flags.DEFINE_string(
'tfrecord_file', None,
'Path to a TFRecord file where embeddings will be written.')
FLAGS = flags.FLAGS
def main(_):
# In this simple example, we run the examples from a single audio file through
# the model. If none is provided, we generate a synthetic input.
if FLAGS.wav_file:
wav_file = FLAGS.wav_file
else:
# Write a WAV of a sine wav into an in-memory file object.
num_secs = 5
freq = 1000
sr = 44100
t = np.linspace(0, num_secs, int(num_secs * sr))
x = np.sin(2 * np.pi * freq * t)
# Convert to signed 16-bit samples.
samples = np.clip(x * 32768, -32768, 32767).astype(np.int16)
wav_file = six.BytesIO()
wavfile.write(wav_file, sr, samples)
wav_file.seek(0)
examples_batch = vggish_input.wavfile_to_examples(wav_file)
print(examples_batch)
# Prepare a postprocessor to munge the model embeddings.
pproc = vggish_postprocess.Postprocessor(FLAGS.pca_params)
# If needed, prepare a record writer to store the postprocessed embeddings.
writer = tf.python_io.TFRecordWriter(
FLAGS.tfrecord_file) if FLAGS.tfrecord_file else None
with tf.Graph().as_default(), tf.Session() as sess:
# Define the model in inference mode, load the checkpoint, and
# locate input and output tensors.
vggish_slim.define_vggish_slim(training=False)
vggish_slim.load_vggish_slim_checkpoint(sess, FLAGS.checkpoint)
features_tensor = sess.graph.get_tensor_by_name(
vggish_params.INPUT_TENSOR_NAME)
embedding_tensor = sess.graph.get_tensor_by_name(
vggish_params.OUTPUT_TENSOR_NAME)
# Run inference and postprocessing.
[embedding_batch] = sess.run([embedding_tensor],
feed_dict={features_tensor: examples_batch})
print(embedding_batch)
postprocessed_batch = pproc.postprocess(embedding_batch)
print(postprocessed_batch)
# Write the postprocessed embeddings as a SequenceExample, in a similar
# format as the features released in AudioSet. Each row of the batch of
# embeddings corresponds to roughly a second of audio (96 10ms frames), and
# the rows are written as a sequence of bytes-valued features, where each
# feature value contains the 128 bytes of the whitened quantized embedding.
seq_example = tf.train.SequenceExample(
feature_lists=tf.train.FeatureLists(
feature_list={
vggish_params.AUDIO_EMBEDDING_FEATURE_NAME:
tf.train.FeatureList(
feature=[
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[embedding.tobytes()]))
for embedding in postprocessed_batch
]
)
}
)
)
print(seq_example)
if writer:
writer.write(seq_example.SerializeToString())
if writer:
writer.close()
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Compute input examples for VGGish from audio waveform."""
import numpy as np
import resampy
from scipy.io import wavfile
import mel_features
import vggish_params
def waveform_to_examples(data, sample_rate):
"""Converts audio waveform into an array of examples for VGGish.
Args:
data: np.array of either one dimension (mono) or two dimensions
(multi-channel, with the outer dimension representing channels).
Each sample is generally expected to lie in the range [-1.0, +1.0],
although this is not required.
sample_rate: Sample rate of data.
Returns:
3-D np.array of shape [num_examples, num_frames, num_bands] which represents
a sequence of examples, each of which contains a patch of log mel
spectrogram, covering num_frames frames of audio and num_bands mel frequency
bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS.
"""
# Convert to mono.
if len(data.shape) > 1:
data = np.mean(data, axis=1)
# Resample to the rate assumed by VGGish.
if sample_rate != vggish_params.SAMPLE_RATE:
data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE)
# Compute log mel spectrogram features.
log_mel = mel_features.log_mel_spectrogram(
data,
audio_sample_rate=vggish_params.SAMPLE_RATE,
log_offset=vggish_params.LOG_OFFSET,
window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS,
hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS,
num_mel_bins=vggish_params.NUM_MEL_BINS,
lower_edge_hertz=vggish_params.MEL_MIN_HZ,
upper_edge_hertz=vggish_params.MEL_MAX_HZ)
# Frame features into examples.
features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS
example_window_length = int(round(
vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate))
example_hop_length = int(round(
vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate))
log_mel_examples = mel_features.frame(
log_mel,
window_length=example_window_length,
hop_length=example_hop_length)
return log_mel_examples
def wavfile_to_examples(wav_file):
"""Convenience wrapper around waveform_to_examples() for a common WAV format.
Args:
wav_file: String path to a file, or a file-like object. The file
is assumed to contain WAV audio data with signed 16-bit PCM samples.
Returns:
See waveform_to_examples.
"""
sr, wav_data = wavfile.read(wav_file)
assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
samples = wav_data / 32768.0 # Convert to [-1.0, +1.0]
return waveform_to_examples(samples, sr)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Global parameters for the VGGish model.
See vggish_slim.py for more information.
"""
# Architectural constants.
NUM_FRAMES = 96 # Frames in input mel-spectrogram patch.
NUM_BANDS = 64 # Frequency bands in input mel-spectrogram patch.
EMBEDDING_SIZE = 128 # Size of embedding layer.
# Hyperparameters used in feature and example generation.
SAMPLE_RATE = 16000
STFT_WINDOW_LENGTH_SECONDS = 0.025
STFT_HOP_LENGTH_SECONDS = 0.010
NUM_MEL_BINS = NUM_BANDS
MEL_MIN_HZ = 125
MEL_MAX_HZ = 7500
LOG_OFFSET = 0.01 # Offset used for stabilized log of input mel-spectrogram.
EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames
EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap.
# Parameters used for embedding postprocessing.
PCA_EIGEN_VECTORS_NAME = 'pca_eigen_vectors'
PCA_MEANS_NAME = 'pca_means'
QUANTIZE_MIN_VAL = -2.0
QUANTIZE_MAX_VAL = +2.0
# Hyperparameters used in training.
INIT_STDDEV = 0.01 # Standard deviation used to initialize weights.
LEARNING_RATE = 1e-4 # Learning rate for the Adam optimizer.
ADAM_EPSILON = 1e-8 # Epsilon for the Adam optimizer.
# Names of ops, tensors, and features.
INPUT_OP_NAME = 'vggish/input_features'
INPUT_TENSOR_NAME = INPUT_OP_NAME + ':0'
OUTPUT_OP_NAME = 'vggish/embedding'
OUTPUT_TENSOR_NAME = OUTPUT_OP_NAME + ':0'
AUDIO_EMBEDDING_FEATURE_NAME = 'audio_embedding'
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Post-process embeddings from VGGish."""
import numpy as np
import vggish_params
class Postprocessor(object):
"""Post-processes VGGish embeddings.
The initial release of AudioSet included 128-D VGGish embeddings for each
segment of AudioSet. These released embeddings were produced by applying
a PCA transformation (technically, a whitening transform is included as well)
and 8-bit quantization to the raw embedding output from VGGish, in order to
stay compatible with the YouTube-8M project which provides visual embeddings
in the same format for a large set of YouTube videos. This class implements
the same PCA (with whitening) and quantization transformations.
"""
def __init__(self, pca_params_npz_path):
"""Constructs a postprocessor.
Args:
pca_params_npz_path: Path to a NumPy-format .npz file that
contains the PCA parameters used in postprocessing.
"""
params = np.load(pca_params_npz_path)
self._pca_matrix = params[vggish_params.PCA_EIGEN_VECTORS_NAME]
# Load means into a column vector for easier broadcasting later.
self._pca_means = params[vggish_params.PCA_MEANS_NAME].reshape(-1, 1)
assert self._pca_matrix.shape == (
vggish_params.EMBEDDING_SIZE, vggish_params.EMBEDDING_SIZE), (
'Bad PCA matrix shape: %r' % (self._pca_matrix.shape,))
assert self._pca_means.shape == (vggish_params.EMBEDDING_SIZE, 1), (
'Bad PCA means shape: %r' % (self._pca_means.shape,))
def postprocess(self, embeddings_batch):
"""Applies postprocessing to a batch of embeddings.
Args:
embeddings_batch: An nparray of shape [batch_size, embedding_size]
containing output from the embedding layer of VGGish.
Returns:
An nparray of the same shape as the input but of type uint8,
containing the PCA-transformed and quantized version of the input.
"""
assert len(embeddings_batch.shape) == 2, (
'Expected 2-d batch, got %r' % (embeddings_batch.shape,))
assert embeddings_batch.shape[1] == vggish_params.EMBEDDING_SIZE, (
'Bad batch shape: %r' % (embeddings_batch.shape,))
# Apply PCA.
# - Embeddings come in as [batch_size, embedding_size].
# - Transpose to [embedding_size, batch_size].
# - Subtract pca_means column vector from each column.
# - Premultiply by PCA matrix of shape [output_dims, input_dims]
# where both are are equal to embedding_size in our case.
# - Transpose result back to [batch_size, embedding_size].
pca_applied = np.dot(self._pca_matrix,
(embeddings_batch.T - self._pca_means)).T
# Quantize by:
# - clipping to [min, max] range
clipped_embeddings = np.clip(
pca_applied, vggish_params.QUANTIZE_MIN_VAL,
vggish_params.QUANTIZE_MAX_VAL)
# - convert to 8-bit in range [0.0, 255.0]
quantized_embeddings = (
(clipped_embeddings - vggish_params.QUANTIZE_MIN_VAL) *
(255.0 /
(vggish_params.QUANTIZE_MAX_VAL - vggish_params.QUANTIZE_MIN_VAL)))
# - cast 8-bit float to uint8
quantized_embeddings = quantized_embeddings.astype(np.uint8)
return quantized_embeddings
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines the 'VGGish' model used to generate AudioSet embedding features.
The public AudioSet release (https://research.google.com/audioset/download.html)
includes 128-D features extracted from the embedding layer of a VGG-like model
that was trained on a large Google-internal YouTube dataset. Here we provide
a TF-Slim definition of the same model, without any dependences on libraries
internal to Google. We call it 'VGGish'.
Note that we only define the model up to the embedding layer, which is the
penultimate layer before the final classifier layer. We also provide various
hyperparameter values (in vggish_params.py) that were used to train this model
internally.
For comparison, here is TF-Slim's VGG definition:
https://github.com/tensorflow/models/blob/master/slim/nets/vgg.py
"""
import tensorflow as tf
import vggish_params as params
slim = tf.contrib.slim
def define_vggish_slim(training=False):
"""Defines the VGGish TensorFlow model.
All ops are created in the current default graph, under the scope 'vggish/'.
The input is a placeholder named 'vggish/input_features' of type float32 and
shape [batch_size, num_frames, num_bands] where batch_size is variable and
num_frames and num_bands are constants, and [num_frames, num_bands] represents
a log-mel-scale spectrogram patch covering num_bands frequency bands and
num_frames time frames (where each frame step is usually 10ms). This is
produced by computing the stabilized log(mel-spectrogram + params.LOG_OFFSET).
The output is an op named 'vggish/embedding' which produces the activations of
a 128-D embedding layer, which is usually the penultimate layer when used as
part of a full model with a final classifier layer.
Args:
training: If true, all parameters are marked trainable.
Returns:
The op 'vggish/embeddings'.
"""
# Defaults:
# - All weights are initialized to N(0, INIT_STDDEV).
# - All biases are initialized to 0.
# - All activations are ReLU.
# - All convolutions are 3x3 with stride 1 and SAME padding.
# - All max-pools are 2x2 with stride 2 and SAME padding.
with slim.arg_scope([slim.conv2d, slim.fully_connected],
weights_initializer=tf.truncated_normal_initializer(
stddev=params.INIT_STDDEV),
biases_initializer=tf.zeros_initializer(),
activation_fn=tf.nn.relu,
trainable=training), \
slim.arg_scope([slim.conv2d],
kernel_size=[3, 3], stride=1, padding='SAME'), \
slim.arg_scope([slim.max_pool2d],
kernel_size=[2, 2], stride=2, padding='SAME'), \
tf.variable_scope('vggish'):
# Input: a batch of 2-D log-mel-spectrogram patches.
features = tf.placeholder(
tf.float32, shape=(None, params.NUM_FRAMES, params.NUM_BANDS),
name='input_features')
# Reshape to 4-D so that we can convolve a batch with conv2d().
net = tf.reshape(features, [-1, params.NUM_FRAMES, params.NUM_BANDS, 1])
# The VGG stack of alternating convolutions and max-pools.
net = slim.conv2d(net, 64, scope='conv1')
net = slim.max_pool2d(net, scope='pool1')
net = slim.conv2d(net, 128, scope='conv2')
net = slim.max_pool2d(net, scope='pool2')
net = slim.repeat(net, 2, slim.conv2d, 256, scope='conv3')
net = slim.max_pool2d(net, scope='pool3')
net = slim.repeat(net, 2, slim.conv2d, 512, scope='conv4')
net = slim.max_pool2d(net, scope='pool4')
# Flatten before entering fully-connected layers
net = slim.flatten(net)
net = slim.repeat(net, 2, slim.fully_connected, 4096, scope='fc1')
# The embedding layer.
net = slim.fully_connected(net, params.EMBEDDING_SIZE, scope='fc2')
return tf.identity(net, name='embedding')
def load_vggish_slim_checkpoint(session, checkpoint_path):
"""Loads a pre-trained VGGish-compatible checkpoint.
This function can be used as an initialization function (referred to as
init_fn in TensorFlow documentation) which is called in a Session after
initializating all variables. When used as an init_fn, this will load
a pre-trained checkpoint that is compatible with the VGGish model
definition. Only variables defined by VGGish will be loaded.
Args:
session: an active TensorFlow session.
checkpoint_path: path to a file containing a checkpoint that is
compatible with the VGGish model definition.
"""
# Get the list of names of all VGGish variables that exist in
# the checkpoint (i.e., all inference-mode VGGish variables).
with tf.Graph().as_default():
define_vggish_slim(training=False)
vggish_var_names = [v.name for v in tf.global_variables()]
# Get the list of all currently existing variables that match
# the list of variable names we just computed.
vggish_vars = [v for v in tf.global_variables() if v.name in vggish_var_names]
# Use a Saver to restore just the variables selected above.
saver = tf.train.Saver(vggish_vars, name='vggish_load_pretrained')
saver.restore(session, checkpoint_path)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A smoke test for VGGish.
This is a simple smoke test of a local install of VGGish and its associated
downloaded files. We create a synthetic sound, extract log mel spectrogram
features, run them through VGGish, post-process the embedding ouputs, and
check some simple statistics of the results, allowing for variations that
might occur due to platform/version differences in the libraries we use.
Usage:
- Download the VGGish checkpoint and PCA parameters into the same directory as
the VGGish source code. If you keep them elsewhere, update the checkpoint_path
and pca_params_path variables below.
- Run:
$ python vggish_smoke_test.py
"""
from __future__ import print_function
import numpy as np
import tensorflow as tf
import vggish_input
import vggish_params
import vggish_postprocess
import vggish_slim
print('\nTesting your install of VGGish\n')
# Paths to downloaded VGGish files.
checkpoint_path = 'vggish_model.ckpt'
pca_params_path = 'vggish_pca_params.npz'
# Relative tolerance of errors in mean and standard deviation of embeddings.
rel_error = 0.1 # Up to 10%
# Generate a 1 kHz sine wave at 44.1 kHz (we use a high sampling rate
# to test resampling to 16 kHz during feature extraction).
num_secs = 3
freq = 1000
sr = 44100
t = np.linspace(0, num_secs, int(num_secs * sr))
x = np.sin(2 * np.pi * freq * t)
# Produce a batch of log mel spectrogram examples.
input_batch = vggish_input.waveform_to_examples(x, sr)
print('Log Mel Spectrogram example: ', input_batch[0])
np.testing.assert_equal(
input_batch.shape,
[num_secs, vggish_params.NUM_FRAMES, vggish_params.NUM_BANDS])
# Define VGGish, load the checkpoint, and run the batch through the model to
# produce embeddings.
with tf.Graph().as_default(), tf.Session() as sess:
vggish_slim.define_vggish_slim()
vggish_slim.load_vggish_slim_checkpoint(sess, checkpoint_path)
features_tensor = sess.graph.get_tensor_by_name(
vggish_params.INPUT_TENSOR_NAME)
embedding_tensor = sess.graph.get_tensor_by_name(
vggish_params.OUTPUT_TENSOR_NAME)
[embedding_batch] = sess.run([embedding_tensor],
feed_dict={features_tensor: input_batch})
print('VGGish embedding: ', embedding_batch[0])
expected_embedding_mean = 0.131
expected_embedding_std = 0.238
np.testing.assert_allclose(
[np.mean(embedding_batch), np.std(embedding_batch)],
[expected_embedding_mean, expected_embedding_std],
rtol=rel_error)
# Postprocess the results to produce whitened quantized embeddings.
pproc = vggish_postprocess.Postprocessor(pca_params_path)
postprocessed_batch = pproc.postprocess(embedding_batch)
print('Postprocessed VGGish embedding: ', postprocessed_batch[0])
expected_postprocessed_mean = 123.0
expected_postprocessed_std = 75.0
np.testing.assert_allclose(
[np.mean(postprocessed_batch), np.std(postprocessed_batch)],
[expected_postprocessed_mean, expected_postprocessed_std],
rtol=rel_error)
print('\nLooks Good To Me!\n')
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""A simple demonstration of running VGGish in training mode.
This is intended as a toy example that demonstrates how to use the VGGish model
definition within a larger model that adds more layers on top, and then train
the larger model. If you let VGGish train as well, then this allows you to
fine-tune the VGGish model parameters for your application. If you don't let
VGGish train, then you use VGGish as a feature extractor for the layers above
it.
For this toy task, we are training a classifier to distinguish between three
classes: sine waves, constant signals, and white noise. We generate synthetic
waveforms from each of these classes, convert into shuffled batches of log mel
spectrogram examples with associated labels, and feed the batches into a model
that includes VGGish at the bottom and a couple of additional layers on top. We
also plumb in labels that are associated with the examples, which feed a label
loss used for training.
Usage:
# Run training for 100 steps using a model checkpoint in the default
# location (vggish_model.ckpt in the current directory). Allow VGGish
# to get fine-tuned.
$ python vggish_train_demo.py --num_batches 100
# Same as before but run for fewer steps and don't change VGGish parameters
# and use a checkpoint in a different location
$ python vggish_train_demo.py --num_batches 50 \
--train_vggish=False \
--checkpoint /path/to/model/checkpoint
"""
from __future__ import print_function
from random import shuffle
import numpy as np
import tensorflow as tf
import vggish_input
import vggish_params
import vggish_slim
flags = tf.app.flags
slim = tf.contrib.slim
flags.DEFINE_integer(
'num_batches', 30,
'Number of batches of examples to feed into the model. Each batch is of '
'variable size and contains shuffled examples of each class of audio.')
flags.DEFINE_boolean(
'train_vggish', True,
'If Frue, allow VGGish parameters to change during training, thus '
'fine-tuning VGGish. If False, VGGish parameters are fixed, thus using '
'VGGish as a fixed feature extractor.')
flags.DEFINE_string(
'checkpoint', 'vggish_model.ckpt',
'Path to the VGGish checkpoint file.')
FLAGS = flags.FLAGS
_NUM_CLASSES = 3
def _get_examples_batch():
"""Returns a shuffled batch of examples of all audio classes.
Note that this is just a toy function because this is a simple demo intended
to illustrate how the training code might work.
Returns:
a tuple (features, labels) where features is a NumPy array of shape
[batch_size, num_frames, num_bands] where the batch_size is variable and
each row is a log mel spectrogram patch of shape [num_frames, num_bands]
suitable for feeding VGGish, while labels is a NumPy array of shape
[batch_size, num_classes] where each row is a multi-hot label vector that
provides the labels for corresponding rows in features.
"""
# Make a waveform for each class.
num_seconds = 5
sr = 44100 # Sampling rate.
t = np.linspace(0, num_seconds, int(num_seconds * sr)) # Time axis.
# Random sine wave.
freq = np.random.uniform(100, 1000)
sine = np.sin(2 * np.pi * freq * t)
# Random constant signal.
magnitude = np.random.uniform(-1, 1)
const = magnitude * t
# White noise.
noise = np.random.normal(-1, 1, size=t.shape)
# Make examples of each signal and corresponding labels.
# Sine is class index 0, Const class index 1, Noise class index 2.
sine_examples = vggish_input.waveform_to_examples(sine, sr)
sine_labels = np.array([[1, 0, 0]] * sine_examples.shape[0])
const_examples = vggish_input.waveform_to_examples(const, sr)
const_labels = np.array([[0, 1, 0]] * const_examples.shape[0])
noise_examples = vggish_input.waveform_to_examples(noise, sr)
noise_labels = np.array([[0, 0, 1]] * noise_examples.shape[0])
# Shuffle (example, label) pairs across all classes.
all_examples = sine_examples + const_examples + noise_examples
all_labels = sine_labels + const_labels + noise_labels
labeled_examples = list(zip(all_examples, all_labels))
shuffle(labeled_examples)
# Separate and return the features and labels.
features = [example for (example, _) in labeled_examples]
labels = [label for (_, label) in labeled_examples]
return (features, labels)
def main(_):
with tf.Graph().as_default(), tf.Session() as sess:
# Define VGGish.
embeddings = vggish_slim.define_vggish_slim(FLAGS.train_vggish)
# Define a shallow classification model and associated training ops on top
# of VGGish.
with tf.variable_scope('mymodel'):
# Add a fully connected layer with 100 units.
num_units = 100
fc = slim.fully_connected(embeddings, num_units)
# Add a classifier layer at the end, consisting of parallel logistic
# classifiers, one per class. This allows for multi-class tasks.
logits = slim.fully_connected(
fc, _NUM_CLASSES, activation_fn=None, scope='logits')
tf.sigmoid(logits, name='prediction')
# Add training ops.
with tf.variable_scope('train'):
global_step = tf.Variable(
0, name='global_step', trainable=False,
collections=[tf.GraphKeys.GLOBAL_VARIABLES,
tf.GraphKeys.GLOBAL_STEP])
# Labels are assumed to be fed as a batch multi-hot vectors, with
# a 1 in the position of each positive class label, and 0 elsewhere.
labels = tf.placeholder(
tf.float32, shape=(None, _NUM_CLASSES), name='labels')
# Cross-entropy label loss.
xent = tf.nn.sigmoid_cross_entropy_with_logits(
logits=logits, labels=labels, name='xent')
loss = tf.reduce_mean(xent, name='loss_op')
tf.summary.scalar('loss', loss)
# We use the same optimizer and hyperparameters as used to train VGGish.
optimizer = tf.train.AdamOptimizer(
learning_rate=vggish_params.LEARNING_RATE,
epsilon=vggish_params.ADAM_EPSILON)
optimizer.minimize(loss, global_step=global_step, name='train_op')
# Initialize all variables in the model, and then load the pre-trained
# VGGish checkpoint.
sess.run(tf.global_variables_initializer())
vggish_slim.load_vggish_slim_checkpoint(sess, FLAGS.checkpoint)
# Locate all the tensors and ops we need for the training loop.
features_tensor = sess.graph.get_tensor_by_name(
vggish_params.INPUT_TENSOR_NAME)
labels_tensor = sess.graph.get_tensor_by_name('mymodel/train/labels:0')
global_step_tensor = sess.graph.get_tensor_by_name(
'mymodel/train/global_step:0')
loss_tensor = sess.graph.get_tensor_by_name('mymodel/train/loss_op:0')
train_op = sess.graph.get_operation_by_name('mymodel/train/train_op')
# The training loop.
for _ in range(FLAGS.num_batches):
(features, labels) = _get_examples_batch()
[num_steps, loss, _] = sess.run(
[global_step_tensor, loss_tensor, train_op],
feed_dict={features_tensor: features, labels_tensor: labels})
print('Step %d: loss %g' % (num_steps, loss))
if __name__ == '__main__':
tf.app.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment