Unverified Commit 9b179e8e authored by Manoj Plakal's avatar Manoj Plakal Committed by GitHub
Browse files

Input/Output tweaks for YAMNet and VGGish. (#9092)

* Input/Output tweaks for YAMNet and VGGish.

- Waveform input for YAMNet is now padded so that we get at least
  one patch of log mel spectrogram. The VGGish TF-Hub exporter
  uses YAMNet's feature computation so the VGGish export will
  also pad waveform input similarly.
- Added a 1024-D embedding output to YAMNet so we now produce
  predicted scores, log mel spectrogram features, and embeddings,
  to satisfy a variety of uses: class prediction, acoustic
  feature visualization, semantic feature extraction.
- Simplified usage of YAMNet in inference mode. Instead of trying
  to work around implicit batch size issues in the Model.predict()
  API, we simply __call__() the Model.
- Switched inference.py to TF 2 and Eager execution.
- Updated the visualization notebook: now uses TF2/Eager and
  can be loaded and run in Google Colab.

* Responded to DAn's comments in https://github.com/tensorflow/models/pull/9092

- Merged spectrogram computation and framing into a single function
  that returns both spectrogram and framed features.
- Extended waveform padding to pad up to an integral number of hops
  in addition to the final STFT analysis window.
parent b7e9ad13
...@@ -19,8 +19,8 @@ import numpy as np ...@@ -19,8 +19,8 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
def waveform_to_log_mel_spectrogram(waveform, params): def waveform_to_log_mel_spectrogram_patches(waveform, params):
"""Compute log mel spectrogram of a 1-D waveform.""" """Compute log mel spectrogram patches of a 1-D waveform."""
with tf.name_scope('log_mel_features'): with tf.name_scope('log_mel_features'):
# waveform has shape [<# samples>] # waveform has shape [<# samples>]
...@@ -51,29 +51,50 @@ def waveform_to_log_mel_spectrogram(waveform, params): ...@@ -51,29 +51,50 @@ def waveform_to_log_mel_spectrogram(waveform, params):
log_mel_spectrogram = tf.math.log(mel_spectrogram + params.LOG_OFFSET) log_mel_spectrogram = tf.math.log(mel_spectrogram + params.LOG_OFFSET)
# log_mel_spectrogram has shape [<# STFT frames>, MEL_BANDS] # log_mel_spectrogram has shape [<# STFT frames>, MEL_BANDS]
return log_mel_spectrogram # Frame spectrogram (shape [<# STFT frames>, MEL_BANDS]) into patches (the
# input examples). Only complete frames are emitted, so if there is less
# than PATCH_WINDOW_SECONDS of waveform then nothing is emitted (to avoid
def spectrogram_to_patches(spectrogram, params): # this, zero-pad before processing).
"""Break up a spectrogram into a stack of fixed-size patches.""" spectrogram_hop_length_samples = int(
with tf.name_scope('feature_patches'):
# Frame spectrogram (shape [<# STFT frames>, MEL_BANDS]) into patches
# (the input examples).
# Only complete frames are emitted, so if there is less than
# PATCH_WINDOW_SECONDS of waveform then nothing is emitted
# (to avoid this, zero-pad before processing).
hop_length_samples = int(
round(params.SAMPLE_RATE * params.STFT_HOP_SECONDS)) round(params.SAMPLE_RATE * params.STFT_HOP_SECONDS))
spectrogram_sr = params.SAMPLE_RATE / hop_length_samples spectrogram_sample_rate = params.SAMPLE_RATE / spectrogram_hop_length_samples
patch_window_length_samples = int( patch_window_length_samples = int(
round(spectrogram_sr * params.PATCH_WINDOW_SECONDS)) round(spectrogram_sample_rate * params.PATCH_WINDOW_SECONDS))
patch_hop_length_samples = int( patch_hop_length_samples = int(
round(spectrogram_sr * params.PATCH_HOP_SECONDS)) round(spectrogram_sample_rate * params.PATCH_HOP_SECONDS))
features = tf.signal.frame( features = tf.signal.frame(
signal=spectrogram, signal=log_mel_spectrogram,
frame_length=patch_window_length_samples, frame_length=patch_window_length_samples,
frame_step=patch_hop_length_samples, frame_step=patch_hop_length_samples,
axis=0) axis=0)
# features has shape [<# patches>, <# STFT frames in an patch>, MEL_BANDS] # features has shape [<# patches>, <# STFT frames in an patch>, MEL_BANDS]
return features return log_mel_spectrogram, features
def pad_waveform(waveform, params):
"""Pads waveform with silence if needed to get an integral number of patches."""
# In order to produce one patch of log mel spectrogram input to YAMNet, we
# need at least one patch window length of waveform plus enough extra samples
# to complete the final STFT analysis window.
min_waveform_seconds = (
params.PATCH_WINDOW_SECONDS +
params.STFT_WINDOW_SECONDS - params.STFT_HOP_SECONDS)
min_num_samples = tf.cast(min_waveform_seconds * params.SAMPLE_RATE, tf.int32)
num_samples = tf.size(waveform)
num_padding_samples = tf.maximum(0, min_num_samples - num_samples)
# In addition, there might be enough waveform for one or more additional
# patches formed by hopping forward. If there are more samples than one patch,
# round up to an integral number of hops.
num_samples = tf.maximum(num_samples, min_num_samples)
num_samples_after_first_patch = num_samples - min_num_samples
hop_samples = tf.cast(params.PATCH_HOP_SECONDS * params.SAMPLE_RATE, tf.int32)
num_hops_after_first_patch = tf.cast(tf.math.ceil(
tf.math.divide(num_samples_after_first_patch, hop_samples)), tf.int32)
num_padding_samples += (
hop_samples * num_hops_after_first_patch - num_samples_after_first_patch)
padded_waveform = tf.pad(waveform, [[0, num_padding_samples]],
mode='CONSTANT', constant_values=0.0)
return padded_waveform
...@@ -28,12 +28,10 @@ import yamnet as yamnet_model ...@@ -28,12 +28,10 @@ import yamnet as yamnet_model
def main(argv): def main(argv):
assert argv assert argv, 'Usage: inference.py <wav file> <wav file> ...'
graph = tf.Graph() yamnet = yamnet_model.yamnet_frames_model(params)
with graph.as_default(): yamnet.load_weights('yamnet.h5')
yamnet = yamnet_model.yamnet_frames_model(params)
yamnet.load_weights('yamnet.h5')
yamnet_classes = yamnet_model.class_names('yamnet_class_map.csv') yamnet_classes = yamnet_model.class_names('yamnet_class_map.csv')
for file_name in argv: for file_name in argv:
...@@ -41,6 +39,7 @@ def main(argv): ...@@ -41,6 +39,7 @@ def main(argv):
wav_data, sr = sf.read(file_name, dtype=np.int16) wav_data, sr = sf.read(file_name, dtype=np.int16)
assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
waveform = wav_data / 32768.0 # Convert to [-1.0, +1.0] waveform = wav_data / 32768.0 # Convert to [-1.0, +1.0]
waveform = waveform.astype('float32')
# Convert to mono and the sample rate expected by YAMNet. # Convert to mono and the sample rate expected by YAMNet.
if len(waveform.shape) > 1: if len(waveform.shape) > 1:
...@@ -49,16 +48,13 @@ def main(argv): ...@@ -49,16 +48,13 @@ def main(argv):
waveform = resampy.resample(waveform, sr, params.SAMPLE_RATE) waveform = resampy.resample(waveform, sr, params.SAMPLE_RATE)
# Predict YAMNet classes. # Predict YAMNet classes.
# Second output is log-mel-spectrogram array (used for visualizations). scores, embeddings, spectrogram = yamnet(waveform)
# (steps=1 is a work around for Keras batching limitations.)
with graph.as_default():
scores, _ = yamnet.predict(np.reshape(waveform, [1, -1]), steps=1)
# Scores is a matrix of (time_frames, num_classes) classifier scores. # Scores is a matrix of (time_frames, num_classes) classifier scores.
# Average them along time to get an overall classifier output for the clip. # Average them along time to get an overall classifier output for the clip.
prediction = np.mean(scores, axis=0) prediction = np.mean(scores, axis=0)
# Report the highest-scoring classes and their scores. # Report the highest-scoring classes and their scores.
top5_i = np.argsort(prediction)[::-1][:5] top5_i = np.argsort(prediction)[::-1][:5]
print(file_name, ':\n' + print(file_name, ':\n' +
'\n'.join(' {:12s}: {:.3f}'.format(yamnet_classes[i], prediction[i]) '\n'.join(' {:12s}: {:.3f}'.format(yamnet_classes[i], prediction[i])
for i in top5_i)) for i in top5_i))
......
...@@ -37,6 +37,3 @@ BATCHNORM_CENTER = True ...@@ -37,6 +37,3 @@ BATCHNORM_CENTER = True
BATCHNORM_SCALE = False BATCHNORM_SCALE = False
BATCHNORM_EPSILON = 1e-4 BATCHNORM_EPSILON = 1e-4
CLASSIFIER_ACTIVATION = 'sigmoid' CLASSIFIER_ACTIVATION = 'sigmoid'
FEATURES_LAYER_NAME = 'features'
EXAMPLE_PREDICTIONS_LAYER_NAME = 'predictions'
...@@ -96,16 +96,14 @@ _YAMNET_LAYER_DEFS = [ ...@@ -96,16 +96,14 @@ _YAMNET_LAYER_DEFS = [
def yamnet(features): def yamnet(features):
"""Define the core YAMNet mode in Keras.""" """Define the core YAMNet mode in Keras."""
net = layers.Reshape( net = layers.Reshape(
(params.PATCH_FRAMES, params.PATCH_BANDS, 1), (params.PATCH_FRAMES, params.PATCH_BANDS, 1),
input_shape=(params.PATCH_FRAMES, params.PATCH_BANDS))(features) input_shape=(params.PATCH_FRAMES, params.PATCH_BANDS))(features)
for (i, (layer_fun, kernel, stride, filters)) in enumerate(_YAMNET_LAYER_DEFS): for (i, (layer_fun, kernel, stride, filters)) in enumerate(_YAMNET_LAYER_DEFS):
net = layer_fun('layer{}'.format(i + 1), kernel, stride, filters)(net) net = layer_fun('layer{}'.format(i + 1), kernel, stride, filters)(net)
net = layers.GlobalAveragePooling2D()(net) embeddings = layers.GlobalAveragePooling2D()(net)
logits = layers.Dense(units=params.NUM_CLASSES, use_bias=True)(net) logits = layers.Dense(units=params.NUM_CLASSES, use_bias=True)(embeddings)
predictions = layers.Activation( predictions = layers.Activation(activation=params.CLASSIFIER_ACTIVATION)(logits)
name=params.EXAMPLE_PREDICTIONS_LAYER_NAME, return predictions, embeddings
activation=params.CLASSIFIER_ACTIVATION)(logits)
return predictions
def yamnet_frames_model(feature_params): def yamnet_frames_model(feature_params):
...@@ -116,19 +114,19 @@ def yamnet_frames_model(feature_params): ...@@ -116,19 +114,19 @@ def yamnet_frames_model(feature_params):
calculation. calculation.
Returns: Returns:
A model accepting (1, num_samples) waveform input and emitting a A model accepting (num_samples,) waveform input and emitting:
(num_patches, num_classes) matrix of class scores per time frame as - predictions: (num_patches, num_classes) matrix of class scores per time frame
well as a (num_spectrogram_frames, num_mel_bins) spectrogram feature - embeddings: (num_patches, embedding size) matrix of embeddings per time frame
matrix. - log_mel_spectrogram: (num_spectrogram_frames, num_mel_bins) spectrogram feature matrix
""" """
waveform = layers.Input(batch_shape=(1, None)) waveform = layers.Input(batch_shape=(None,), dtype=tf.float32)
# Store the intermediate spectrogram features to use in visualization. waveform_padded = features_lib.pad_waveform(waveform, feature_params)
spectrogram = features_lib.waveform_to_log_mel_spectrogram( log_mel_spectrogram, features = features_lib.waveform_to_log_mel_spectrogram_patches(
tf.squeeze(waveform, axis=0), feature_params) waveform_padded, feature_params)
patches = features_lib.spectrogram_to_patches(spectrogram, feature_params) predictions, embeddings = yamnet(features)
predictions = yamnet(patches) frames_model = Model(
frames_model = Model(name='yamnet_frames', name='yamnet_frames', inputs=waveform,
inputs=waveform, outputs=[predictions, spectrogram]) outputs=[predictions, embeddings, log_mel_spectrogram])
return frames_model return frames_model
......
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment