Unverified Commit 557eec27 authored by Manoj Plakal's avatar Manoj Plakal Committed by GitHub
Browse files

Made VGGish and YAMNet work in TF2 without disabling TF2 behavior. (#9077)

* Made VGGish and YAMNet work in TF2 without disabling TF2 behavior.

Allowed TF2 behavior and allowed passing in a features tensor into the
VGGish model definition. Both of these changes are needed for making
TF-Hub exports of these models. Lifted constraints on TF versions since
tf_slim has been updated to work with TF 2.

* Responded to DAn's comments in https://github.com/tensorflow/models/pull/9077

* Fixed typo in comment.
parent d768c91c
......@@ -16,17 +16,14 @@ VGGish depends on the following Python packages:
* [`numpy`](http://www.numpy.org/)
* [`resampy`](http://resampy.readthedocs.io/en/latest/)
* [`tensorflow`](http://www.tensorflow.org/) (currently, only TF v1.x)
* [`tensorflow`](http://www.tensorflow.org/)
* [`tf_slim`](https://github.com/google-research/tf-slim)
* [`six`](https://pythonhosted.org/six/)
* [`soundfile`](https://pysoundfile.readthedocs.io/)
These are all easily installable via, e.g., `pip install numpy` (as in the
sample installation session below).
Any reasonably recent version of these packages shold work. Note that we currently only support
TensorFlow v1.x due to a [`tf_slim` limitation](https://github.com/google-research/tf-slim/pull/1).
TensorFlow v1.15 (the latest version as of Jan 2020) has been tested to work.
sample installation session below). Any reasonably recent version of these
packages shold work.
VGGish also requires downloading two data files:
......@@ -60,7 +57,7 @@ Here's a sample installation and test session:
$ sudo python -m pip install --upgrade pip wheel
# Install all dependences.
$ sudo pip install numpy resampy tensorflow==1.15 tf_slim six soundfile
$ sudo pip install numpy resampy tensorflow tf_slim six soundfile
# Clone TensorFlow models repo into a 'models' directory.
$ git clone https://github.com/tensorflow/models.git
......
......@@ -50,7 +50,6 @@ import numpy as np
import six
import soundfile
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import vggish_input
import vggish_params
......
......@@ -31,28 +31,31 @@ https://github.com/tensorflow/models/blob/master/research/slim/nets/vgg.py
"""
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import tf_slim as slim
import vggish_params as params
def define_vggish_slim(training=False):
def define_vggish_slim(features_tensor=None, training=False):
"""Defines the VGGish TensorFlow model.
All ops are created in the current default graph, under the scope 'vggish/'.
The input is a placeholder named 'vggish/input_features' of type float32 and
shape [batch_size, num_frames, num_bands] where batch_size is variable and
num_frames and num_bands are constants, and [num_frames, num_bands] represents
a log-mel-scale spectrogram patch covering num_bands frequency bands and
num_frames time frames (where each frame step is usually 10ms). This is
produced by computing the stabilized log(mel-spectrogram + params.LOG_OFFSET).
The output is an op named 'vggish/embedding' which produces the activations of
a 128-D embedding layer, which is usually the penultimate layer when used as
part of a full model with a final classifier layer.
The input is either a tensor passed in via the optional 'features_tensor'
argument or a placeholder created below named 'vggish/input_features'. The
input is expected to have dtype float32 and shape [batch_size, num_frames,
num_bands] where batch_size is variable and num_frames and num_bands are
constants, and [num_frames, num_bands] represents a log-mel-scale spectrogram
patch covering num_bands frequency bands and num_frames time frames (where
each frame step is usually 10ms). This is produced by computing the stabilized
log(mel-spectrogram + params.LOG_OFFSET). The output is a tensor named
'vggish/embedding' which produces the activations of a 128-D embedding layer,
which is usually the penultimate layer when used as part of a full model with
a final classifier layer.
Args:
features_tensor: If not None, the tensor containing the input features.
If None, a placeholder input is created.
training: If true, all parameters are marked trainable.
Returns:
......@@ -76,11 +79,13 @@ def define_vggish_slim(training=False):
kernel_size=[2, 2], stride=2, padding='SAME'), \
tf.variable_scope('vggish'):
# Input: a batch of 2-D log-mel-spectrogram patches.
features = tf.placeholder(
tf.float32, shape=(None, params.NUM_FRAMES, params.NUM_BANDS),
name='input_features')
if features_tensor is None:
features_tensor = tf.placeholder(
tf.float32, shape=(None, params.NUM_FRAMES, params.NUM_BANDS),
name='input_features')
# Reshape to 4-D so that we can convolve a batch with conv2d().
net = tf.reshape(features, [-1, params.NUM_FRAMES, params.NUM_BANDS, 1])
net = tf.reshape(features_tensor,
[-1, params.NUM_FRAMES, params.NUM_BANDS, 1])
# The VGG stack of alternating convolutions and max-pools.
net = slim.conv2d(net, 64, scope='conv1')
......
......@@ -33,7 +33,6 @@ from __future__ import print_function
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import vggish_input
import vggish_params
......
......@@ -49,7 +49,6 @@ from random import shuffle
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import tf_slim as slim
import vggish_input
......@@ -129,7 +128,7 @@ def _get_examples_batch():
def main(_):
with tf.Graph().as_default(), tf.Session() as sess:
# Define VGGish.
embeddings = vggish_slim.define_vggish_slim(FLAGS.train_vggish)
embeddings = vggish_slim.define_vggish_slim(training=FLAGS.train_vggish)
# Define a shallow classification model and associated training ops on top
# of VGGish.
......
......@@ -18,12 +18,8 @@ YAMNet depends on the following Python packages:
* [`pysoundfile`](https://pysoundfile.readthedocs.io/)
These are all easily installable via, e.g., `pip install numpy` (as in the
example command sequence below).
Any reasonably recent version of these packages should work. TensorFlow should
be at least version 1.8 to ensure Keras support is included. Note that while
the code works fine with TensorFlow v1.x or v2.x, we explicitly enable v1.x
behavior.
example command sequence below). Any reasonably recent version of these
packages should work.
YAMNet also requires downloading the following data file:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment