Unverified Commit 9b179e8e authored by Manoj Plakal's avatar Manoj Plakal Committed by GitHub
Browse files

Input/Output tweaks for YAMNet and VGGish. (#9092)

* Input/Output tweaks for YAMNet and VGGish.

- Waveform input for YAMNet is now padded so that we get at least
  one patch of log mel spectrogram. The VGGish TF-Hub exporter
  uses YAMNet's feature computation so the VGGish export will
  also pad waveform input similarly.
- Added a 1024-D embedding output to YAMNet so we now produce
  predicted scores, log mel spectrogram features, and embeddings,
  to satisfy a variety of uses: class prediction, acoustic
  feature visualization, semantic feature extraction.
- Simplified usage of YAMNet in inference mode. Instead of trying
  to work around implicit batch size issues in the Model.predict()
  API, we simply __call__() the Model.
- Switched inference.py to TF 2 and Eager execution.
- Updated the visualization notebook: now uses TF2/Eager and
  can be loaded and run in Google Colab.

* Responded to DAn's comments in https://github.com/tensorflow/models/pull/9092

- Merged spectrogram computation and framing into a single function
  that returns both spectrogram and framed features.
- Extended waveform padding to pad up to an integral number of hops
  in addition to the final STFT analysis window.
parent b7e9ad13
......@@ -19,8 +19,8 @@ import numpy as np
import tensorflow as tf
def waveform_to_log_mel_spectrogram(waveform, params):
"""Compute log mel spectrogram of a 1-D waveform."""
def waveform_to_log_mel_spectrogram_patches(waveform, params):
"""Compute log mel spectrogram patches of a 1-D waveform."""
with tf.name_scope('log_mel_features'):
# waveform has shape [<# samples>]
......@@ -51,29 +51,50 @@ def waveform_to_log_mel_spectrogram(waveform, params):
log_mel_spectrogram = tf.math.log(mel_spectrogram + params.LOG_OFFSET)
# log_mel_spectrogram has shape [<# STFT frames>, MEL_BANDS]
return log_mel_spectrogram
def spectrogram_to_patches(spectrogram, params):
"""Break up a spectrogram into a stack of fixed-size patches."""
with tf.name_scope('feature_patches'):
# Frame spectrogram (shape [<# STFT frames>, MEL_BANDS]) into patches
# (the input examples).
# Only complete frames are emitted, so if there is less than
# PATCH_WINDOW_SECONDS of waveform then nothing is emitted
# (to avoid this, zero-pad before processing).
hop_length_samples = int(
# Frame spectrogram (shape [<# STFT frames>, MEL_BANDS]) into patches (the
# input examples). Only complete frames are emitted, so if there is less
# than PATCH_WINDOW_SECONDS of waveform then nothing is emitted (to avoid
# this, zero-pad before processing).
spectrogram_hop_length_samples = int(
round(params.SAMPLE_RATE * params.STFT_HOP_SECONDS))
spectrogram_sr = params.SAMPLE_RATE / hop_length_samples
spectrogram_sample_rate = params.SAMPLE_RATE / spectrogram_hop_length_samples
patch_window_length_samples = int(
round(spectrogram_sr * params.PATCH_WINDOW_SECONDS))
round(spectrogram_sample_rate * params.PATCH_WINDOW_SECONDS))
patch_hop_length_samples = int(
round(spectrogram_sr * params.PATCH_HOP_SECONDS))
round(spectrogram_sample_rate * params.PATCH_HOP_SECONDS))
features = tf.signal.frame(
signal=spectrogram,
signal=log_mel_spectrogram,
frame_length=patch_window_length_samples,
frame_step=patch_hop_length_samples,
axis=0)
# features has shape [<# patches>, <# STFT frames in an patch>, MEL_BANDS]
return features
return log_mel_spectrogram, features
def pad_waveform(waveform, params):
"""Pads waveform with silence if needed to get an integral number of patches."""
# In order to produce one patch of log mel spectrogram input to YAMNet, we
# need at least one patch window length of waveform plus enough extra samples
# to complete the final STFT analysis window.
min_waveform_seconds = (
params.PATCH_WINDOW_SECONDS +
params.STFT_WINDOW_SECONDS - params.STFT_HOP_SECONDS)
min_num_samples = tf.cast(min_waveform_seconds * params.SAMPLE_RATE, tf.int32)
num_samples = tf.size(waveform)
num_padding_samples = tf.maximum(0, min_num_samples - num_samples)
# In addition, there might be enough waveform for one or more additional
# patches formed by hopping forward. If there are more samples than one patch,
# round up to an integral number of hops.
num_samples = tf.maximum(num_samples, min_num_samples)
num_samples_after_first_patch = num_samples - min_num_samples
hop_samples = tf.cast(params.PATCH_HOP_SECONDS * params.SAMPLE_RATE, tf.int32)
num_hops_after_first_patch = tf.cast(tf.math.ceil(
tf.math.divide(num_samples_after_first_patch, hop_samples)), tf.int32)
num_padding_samples += (
hop_samples * num_hops_after_first_patch - num_samples_after_first_patch)
padded_waveform = tf.pad(waveform, [[0, num_padding_samples]],
mode='CONSTANT', constant_values=0.0)
return padded_waveform
......@@ -28,12 +28,10 @@ import yamnet as yamnet_model
def main(argv):
assert argv
assert argv, 'Usage: inference.py <wav file> <wav file> ...'
graph = tf.Graph()
with graph.as_default():
yamnet = yamnet_model.yamnet_frames_model(params)
yamnet.load_weights('yamnet.h5')
yamnet = yamnet_model.yamnet_frames_model(params)
yamnet.load_weights('yamnet.h5')
yamnet_classes = yamnet_model.class_names('yamnet_class_map.csv')
for file_name in argv:
......@@ -41,6 +39,7 @@ def main(argv):
wav_data, sr = sf.read(file_name, dtype=np.int16)
assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
waveform = wav_data / 32768.0 # Convert to [-1.0, +1.0]
waveform = waveform.astype('float32')
# Convert to mono and the sample rate expected by YAMNet.
if len(waveform.shape) > 1:
......@@ -49,16 +48,13 @@ def main(argv):
waveform = resampy.resample(waveform, sr, params.SAMPLE_RATE)
# Predict YAMNet classes.
# Second output is log-mel-spectrogram array (used for visualizations).
# (steps=1 is a work around for Keras batching limitations.)
with graph.as_default():
scores, _ = yamnet.predict(np.reshape(waveform, [1, -1]), steps=1)
scores, embeddings, spectrogram = yamnet(waveform)
# Scores is a matrix of (time_frames, num_classes) classifier scores.
# Average them along time to get an overall classifier output for the clip.
prediction = np.mean(scores, axis=0)
# Report the highest-scoring classes and their scores.
top5_i = np.argsort(prediction)[::-1][:5]
print(file_name, ':\n' +
print(file_name, ':\n' +
'\n'.join(' {:12s}: {:.3f}'.format(yamnet_classes[i], prediction[i])
for i in top5_i))
......
......@@ -37,6 +37,3 @@ BATCHNORM_CENTER = True
BATCHNORM_SCALE = False
BATCHNORM_EPSILON = 1e-4
CLASSIFIER_ACTIVATION = 'sigmoid'
FEATURES_LAYER_NAME = 'features'
EXAMPLE_PREDICTIONS_LAYER_NAME = 'predictions'
......@@ -96,16 +96,14 @@ _YAMNET_LAYER_DEFS = [
def yamnet(features):
"""Define the core YAMNet mode in Keras."""
net = layers.Reshape(
(params.PATCH_FRAMES, params.PATCH_BANDS, 1),
input_shape=(params.PATCH_FRAMES, params.PATCH_BANDS))(features)
(params.PATCH_FRAMES, params.PATCH_BANDS, 1),
input_shape=(params.PATCH_FRAMES, params.PATCH_BANDS))(features)
for (i, (layer_fun, kernel, stride, filters)) in enumerate(_YAMNET_LAYER_DEFS):
net = layer_fun('layer{}'.format(i + 1), kernel, stride, filters)(net)
net = layers.GlobalAveragePooling2D()(net)
logits = layers.Dense(units=params.NUM_CLASSES, use_bias=True)(net)
predictions = layers.Activation(
name=params.EXAMPLE_PREDICTIONS_LAYER_NAME,
activation=params.CLASSIFIER_ACTIVATION)(logits)
return predictions
embeddings = layers.GlobalAveragePooling2D()(net)
logits = layers.Dense(units=params.NUM_CLASSES, use_bias=True)(embeddings)
predictions = layers.Activation(activation=params.CLASSIFIER_ACTIVATION)(logits)
return predictions, embeddings
def yamnet_frames_model(feature_params):
......@@ -116,19 +114,19 @@ def yamnet_frames_model(feature_params):
calculation.
Returns:
A model accepting (1, num_samples) waveform input and emitting a
(num_patches, num_classes) matrix of class scores per time frame as
well as a (num_spectrogram_frames, num_mel_bins) spectrogram feature
matrix.
A model accepting (num_samples,) waveform input and emitting:
- predictions: (num_patches, num_classes) matrix of class scores per time frame
- embeddings: (num_patches, embedding size) matrix of embeddings per time frame
- log_mel_spectrogram: (num_spectrogram_frames, num_mel_bins) spectrogram feature matrix
"""
waveform = layers.Input(batch_shape=(1, None))
# Store the intermediate spectrogram features to use in visualization.
spectrogram = features_lib.waveform_to_log_mel_spectrogram(
tf.squeeze(waveform, axis=0), feature_params)
patches = features_lib.spectrogram_to_patches(spectrogram, feature_params)
predictions = yamnet(patches)
frames_model = Model(name='yamnet_frames',
inputs=waveform, outputs=[predictions, spectrogram])
waveform = layers.Input(batch_shape=(None,), dtype=tf.float32)
waveform_padded = features_lib.pad_waveform(waveform, feature_params)
log_mel_spectrogram, features = features_lib.waveform_to_log_mel_spectrogram_patches(
waveform_padded, feature_params)
predictions, embeddings = yamnet(features)
frames_model = Model(
name='yamnet_frames', inputs=waveform,
outputs=[predictions, embeddings, log_mel_spectrogram])
return frames_model
......
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment