Unverified Commit 0f7616bd authored by Manoj Plakal's avatar Manoj Plakal Committed by GitHub
Browse files

VGGish embeddings are pre-activation, not post-activation. (#9080)

Fixed a long-standing bug where the released VGGish model used
post-activation embedding output while the released embeddings
were pre-activation. There are still discrepancies due to
other reasons: differences in choice of YouTube transcode,
repeated resamplings with different resamplers, slight differences
in feature computation, etc.
parent cb487314
......@@ -126,7 +126,10 @@ changes we made:
fully connected layer. This acts as a compact embedding layer.
The model definition provided here defines layers up to and including the
128-wide embedding layer.
128-wide embedding layer. Note that the embedding layer does not include
a final non-linear activation, so the embedding value is pre-activation.
When training a model stacked on top of VGGish, you should send the
embedding through a non-linearity of your choice before adding more layers.
### Input: Audio Features
......@@ -147,14 +150,7 @@ VGGish was trained with audio features computed as follows:
where each example covers 64 mel bands and 96 frames of 10 ms each.
We provide our own NumPy implementation that produces features that are very
similar to those produced by our internal production code. This results in
embedding outputs that are closely match the embeddings that we have already
released. Note that these embeddings will *not* be bit-for-bit identical to the
released embeddings due to small differences between the feature computation
code paths, and even between two different installations of VGGish with
different underlying libraries and hardware. However, we expect that the
embeddings will be equivalent in the context of a downstream classification
task.
similar to those produced by our internal production code.
### Output: Embeddings
......
......@@ -49,9 +49,9 @@ def define_vggish_slim(features_tensor=None, training=False):
patch covering num_bands frequency bands and num_frames time frames (where
each frame step is usually 10ms). This is produced by computing the stabilized
log(mel-spectrogram + params.LOG_OFFSET). The output is a tensor named
'vggish/embedding' which produces the activations of a 128-D embedding layer,
which is usually the penultimate layer when used as part of a full model with
a final classifier layer.
'vggish/embedding' which produces the pre-activation values of a 128-D
embedding layer, which is usually the penultimate layer when used as part of a
full model with a final classifier layer.
Args:
features_tensor: If not None, the tensor containing the input features.
......@@ -101,7 +101,8 @@ def define_vggish_slim(features_tensor=None, training=False):
net = slim.flatten(net)
net = slim.repeat(net, 2, slim.fully_connected, 4096, scope='fc1')
# The embedding layer.
net = slim.fully_connected(net, params.EMBEDDING_SIZE, scope='fc2')
net = slim.fully_connected(net, params.EMBEDDING_SIZE, scope='fc2',
activation_fn=None)
return tf.identity(net, name='embedding')
......
......@@ -76,8 +76,8 @@ with tf.Graph().as_default(), tf.Session() as sess:
[embedding_batch] = sess.run([embedding_tensor],
feed_dict={features_tensor: input_batch})
print('VGGish embedding: ', embedding_batch[0])
expected_embedding_mean = 0.131
expected_embedding_std = 0.238
expected_embedding_mean = -0.0333
expected_embedding_std = 0.380
np.testing.assert_allclose(
[np.mean(embedding_batch), np.std(embedding_batch)],
[expected_embedding_mean, expected_embedding_std],
......@@ -87,8 +87,8 @@ with tf.Graph().as_default(), tf.Session() as sess:
pproc = vggish_postprocess.Postprocessor(pca_params_path)
postprocessed_batch = pproc.postprocess(embedding_batch)
print('Postprocessed VGGish embedding: ', postprocessed_batch[0])
expected_postprocessed_mean = 123.0
expected_postprocessed_std = 75.0
expected_postprocessed_mean = 122.0
expected_postprocessed_std = 93.5
np.testing.assert_allclose(
[np.mean(postprocessed_batch), np.std(postprocessed_batch)],
[expected_postprocessed_mean, expected_postprocessed_std],
......
......@@ -133,9 +133,10 @@ def main(_):
# Define a shallow classification model and associated training ops on top
# of VGGish.
with tf.variable_scope('mymodel'):
# Add a fully connected layer with 100 units.
# Add a fully connected layer with 100 units. Add an activation function
# to the embeddings since they are pre-activation.
num_units = 100
fc = slim.fully_connected(embeddings, num_units)
fc = slim.fully_connected(tf.nn.relu(embeddings), num_units)
# Add a classifier layer at the end, consisting of parallel logistic
# classifiers, one per class. This allows for multi-class tasks.
......@@ -145,19 +146,16 @@ def main(_):
# Add training ops.
with tf.variable_scope('train'):
global_step = tf.Variable(
0, name='global_step', trainable=False,
collections=[tf.GraphKeys.GLOBAL_VARIABLES,
tf.GraphKeys.GLOBAL_STEP])
global_step = tf.train.create_global_step()
# Labels are assumed to be fed as a batch multi-hot vectors, with
# a 1 in the position of each positive class label, and 0 elsewhere.
labels = tf.placeholder(
labels_input = tf.placeholder(
tf.float32, shape=(None, _NUM_CLASSES), name='labels')
# Cross-entropy label loss.
xent = tf.nn.sigmoid_cross_entropy_with_logits(
logits=logits, labels=labels, name='xent')
logits=logits, labels=labels_input, name='xent')
loss = tf.reduce_mean(xent, name='loss_op')
tf.summary.scalar('loss', loss)
......@@ -165,29 +163,22 @@ def main(_):
optimizer = tf.train.AdamOptimizer(
learning_rate=vggish_params.LEARNING_RATE,
epsilon=vggish_params.ADAM_EPSILON)
optimizer.minimize(loss, global_step=global_step, name='train_op')
train_op = optimizer.minimize(loss, global_step=global_step)
# Initialize all variables in the model, and then load the pre-trained
# VGGish checkpoint.
sess.run(tf.global_variables_initializer())
vggish_slim.load_vggish_slim_checkpoint(sess, FLAGS.checkpoint)
# Locate all the tensors and ops we need for the training loop.
features_tensor = sess.graph.get_tensor_by_name(
vggish_params.INPUT_TENSOR_NAME)
labels_tensor = sess.graph.get_tensor_by_name('mymodel/train/labels:0')
global_step_tensor = sess.graph.get_tensor_by_name(
'mymodel/train/global_step:0')
loss_tensor = sess.graph.get_tensor_by_name('mymodel/train/loss_op:0')
train_op = sess.graph.get_operation_by_name('mymodel/train/train_op')
# The training loop.
features_input = sess.graph.get_tensor_by_name(
vggish_params.INPUT_TENSOR_NAME)
for _ in range(FLAGS.num_batches):
(features, labels) = _get_examples_batch()
[num_steps, loss, _] = sess.run(
[global_step_tensor, loss_tensor, train_op],
feed_dict={features_tensor: features, labels_tensor: labels})
print('Step %d: loss %g' % (num_steps, loss))
[num_steps, loss_value, _] = sess.run(
[global_step, loss, train_op],
feed_dict={features_input: features, labels_input: labels})
print('Step %d: loss %g' % (num_steps, loss_value))
if __name__ == '__main__':
tf.app.run()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment