"git@developer.sourcefind.cn:change/sglang.git" did not exist on "adba172fd1554fdb52c06542a2444a81d170f104"
Unverified Commit 557eec27 authored by Manoj Plakal's avatar Manoj Plakal Committed by GitHub
Browse files

Made VGGish and YAMNet work in TF2 without disabling TF2 behavior. (#9077)

* Made VGGish and YAMNet work in TF2 without disabling TF2 behavior.

Allowed TF2 behavior and allowed passing in a features tensor into the
VGGish model definition. Both of these changes are needed for making
TF-Hub exports of these models. Lifted constraints on TF versions since
tf_slim has been updated to work with TF 2.

* Responded to DAn's comments in https://github.com/tensorflow/models/pull/9077

* Fixed typo in comment.
parent d768c91c
...@@ -16,17 +16,14 @@ VGGish depends on the following Python packages: ...@@ -16,17 +16,14 @@ VGGish depends on the following Python packages:
* [`numpy`](http://www.numpy.org/) * [`numpy`](http://www.numpy.org/)
* [`resampy`](http://resampy.readthedocs.io/en/latest/) * [`resampy`](http://resampy.readthedocs.io/en/latest/)
* [`tensorflow`](http://www.tensorflow.org/) (currently, only TF v1.x) * [`tensorflow`](http://www.tensorflow.org/)
* [`tf_slim`](https://github.com/google-research/tf-slim) * [`tf_slim`](https://github.com/google-research/tf-slim)
* [`six`](https://pythonhosted.org/six/) * [`six`](https://pythonhosted.org/six/)
* [`soundfile`](https://pysoundfile.readthedocs.io/) * [`soundfile`](https://pysoundfile.readthedocs.io/)
These are all easily installable via, e.g., `pip install numpy` (as in the These are all easily installable via, e.g., `pip install numpy` (as in the
sample installation session below). sample installation session below). Any reasonably recent version of these
packages shold work.
Any reasonably recent version of these packages shold work. Note that we currently only support
TensorFlow v1.x due to a [`tf_slim` limitation](https://github.com/google-research/tf-slim/pull/1).
TensorFlow v1.15 (the latest version as of Jan 2020) has been tested to work.
VGGish also requires downloading two data files: VGGish also requires downloading two data files:
...@@ -60,7 +57,7 @@ Here's a sample installation and test session: ...@@ -60,7 +57,7 @@ Here's a sample installation and test session:
$ sudo python -m pip install --upgrade pip wheel $ sudo python -m pip install --upgrade pip wheel
# Install all dependences. # Install all dependences.
$ sudo pip install numpy resampy tensorflow==1.15 tf_slim six soundfile $ sudo pip install numpy resampy tensorflow tf_slim six soundfile
# Clone TensorFlow models repo into a 'models' directory. # Clone TensorFlow models repo into a 'models' directory.
$ git clone https://github.com/tensorflow/models.git $ git clone https://github.com/tensorflow/models.git
......
...@@ -50,7 +50,6 @@ import numpy as np ...@@ -50,7 +50,6 @@ import numpy as np
import six import six
import soundfile import soundfile
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import vggish_input import vggish_input
import vggish_params import vggish_params
......
...@@ -31,28 +31,31 @@ https://github.com/tensorflow/models/blob/master/research/slim/nets/vgg.py ...@@ -31,28 +31,31 @@ https://github.com/tensorflow/models/blob/master/research/slim/nets/vgg.py
""" """
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import tf_slim as slim import tf_slim as slim
import vggish_params as params import vggish_params as params
def define_vggish_slim(training=False): def define_vggish_slim(features_tensor=None, training=False):
"""Defines the VGGish TensorFlow model. """Defines the VGGish TensorFlow model.
All ops are created in the current default graph, under the scope 'vggish/'. All ops are created in the current default graph, under the scope 'vggish/'.
The input is a placeholder named 'vggish/input_features' of type float32 and The input is either a tensor passed in via the optional 'features_tensor'
shape [batch_size, num_frames, num_bands] where batch_size is variable and argument or a placeholder created below named 'vggish/input_features'. The
num_frames and num_bands are constants, and [num_frames, num_bands] represents input is expected to have dtype float32 and shape [batch_size, num_frames,
a log-mel-scale spectrogram patch covering num_bands frequency bands and num_bands] where batch_size is variable and num_frames and num_bands are
num_frames time frames (where each frame step is usually 10ms). This is constants, and [num_frames, num_bands] represents a log-mel-scale spectrogram
produced by computing the stabilized log(mel-spectrogram + params.LOG_OFFSET). patch covering num_bands frequency bands and num_frames time frames (where
The output is an op named 'vggish/embedding' which produces the activations of each frame step is usually 10ms). This is produced by computing the stabilized
a 128-D embedding layer, which is usually the penultimate layer when used as log(mel-spectrogram + params.LOG_OFFSET). The output is a tensor named
part of a full model with a final classifier layer. 'vggish/embedding' which produces the activations of a 128-D embedding layer,
which is usually the penultimate layer when used as part of a full model with
a final classifier layer.
Args: Args:
features_tensor: If not None, the tensor containing the input features.
If None, a placeholder input is created.
training: If true, all parameters are marked trainable. training: If true, all parameters are marked trainable.
Returns: Returns:
...@@ -76,11 +79,13 @@ def define_vggish_slim(training=False): ...@@ -76,11 +79,13 @@ def define_vggish_slim(training=False):
kernel_size=[2, 2], stride=2, padding='SAME'), \ kernel_size=[2, 2], stride=2, padding='SAME'), \
tf.variable_scope('vggish'): tf.variable_scope('vggish'):
# Input: a batch of 2-D log-mel-spectrogram patches. # Input: a batch of 2-D log-mel-spectrogram patches.
features = tf.placeholder( if features_tensor is None:
tf.float32, shape=(None, params.NUM_FRAMES, params.NUM_BANDS), features_tensor = tf.placeholder(
name='input_features') tf.float32, shape=(None, params.NUM_FRAMES, params.NUM_BANDS),
name='input_features')
# Reshape to 4-D so that we can convolve a batch with conv2d(). # Reshape to 4-D so that we can convolve a batch with conv2d().
net = tf.reshape(features, [-1, params.NUM_FRAMES, params.NUM_BANDS, 1]) net = tf.reshape(features_tensor,
[-1, params.NUM_FRAMES, params.NUM_BANDS, 1])
# The VGG stack of alternating convolutions and max-pools. # The VGG stack of alternating convolutions and max-pools.
net = slim.conv2d(net, 64, scope='conv1') net = slim.conv2d(net, 64, scope='conv1')
......
...@@ -33,7 +33,6 @@ from __future__ import print_function ...@@ -33,7 +33,6 @@ from __future__ import print_function
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import vggish_input import vggish_input
import vggish_params import vggish_params
......
...@@ -49,7 +49,6 @@ from random import shuffle ...@@ -49,7 +49,6 @@ from random import shuffle
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import tf_slim as slim import tf_slim as slim
import vggish_input import vggish_input
...@@ -129,7 +128,7 @@ def _get_examples_batch(): ...@@ -129,7 +128,7 @@ def _get_examples_batch():
def main(_): def main(_):
with tf.Graph().as_default(), tf.Session() as sess: with tf.Graph().as_default(), tf.Session() as sess:
# Define VGGish. # Define VGGish.
embeddings = vggish_slim.define_vggish_slim(FLAGS.train_vggish) embeddings = vggish_slim.define_vggish_slim(training=FLAGS.train_vggish)
# Define a shallow classification model and associated training ops on top # Define a shallow classification model and associated training ops on top
# of VGGish. # of VGGish.
......
...@@ -18,12 +18,8 @@ YAMNet depends on the following Python packages: ...@@ -18,12 +18,8 @@ YAMNet depends on the following Python packages:
* [`pysoundfile`](https://pysoundfile.readthedocs.io/) * [`pysoundfile`](https://pysoundfile.readthedocs.io/)
These are all easily installable via, e.g., `pip install numpy` (as in the These are all easily installable via, e.g., `pip install numpy` (as in the
example command sequence below). example command sequence below). Any reasonably recent version of these
packages should work.
Any reasonably recent version of these packages should work. TensorFlow should
be at least version 1.8 to ensure Keras support is included. Note that while
the code works fine with TensorFlow v1.x or v2.x, we explicitly enable v1.x
behavior.
YAMNet also requires downloading the following data file: YAMNet also requires downloading the following data file:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment