Unverified Commit 8da48573 authored by Manoj Plakal's avatar Manoj Plakal Committed by GitHub
Browse files

Added TF-Lite-compatible feature extractor and model exporter for YAMNet (#9098)

* Added TF-Lite-compatible feature extractor and model exporter for YAMNet.

- Added a TF-Lite compatible feature extractor. With the latest TF-Lite,
  that involves a DFT-multiplication replacement for tf.abs(tf.signal.stft())
  and not a lot else. Note that TF-Lite now allows variable-length inputs.
- Added a YAMNet exporter that produces TF2 SavedModels, TF-Lite models,
  and TF-JS models.
- Cleanups: switched hyperparameters to a dataclass, got rid of
  some lingering cruft in yamnet_test.

* Responded to DAn's comments in https://github.com/tensorflow/models/pull/9098

- Switched some hparams to float
- Made class map asset available on the exported model, and tested that
  it can be loaded from the various exports.
parent ea5fc64d
...@@ -5,8 +5,10 @@ containing an audio waveform (assumed to be mono 16 kHz samples in the [-1, +1] ...@@ -5,8 +5,10 @@ containing an audio waveform (assumed to be mono 16 kHz samples in the [-1, +1]
range) and returns a 2-d float32 batch of 128-d VGGish embeddings, one per range) and returns a 2-d float32 batch of 128-d VGGish embeddings, one per
0.96s example generated from the waveform. 0.96s example generated from the waveform.
Requires pip-installing tensorflow_hub.
Usage: Usage:
export_tfhub.py <path/to/VGGish/checkpoint> <path/to/tfhub/export> vggish_export_tfhub.py <path/to/VGGish/checkpoint> <path/to/tfhub/export>
""" """
import sys import sys
...@@ -41,19 +43,19 @@ def vggish_definer(variables, checkpoint_path): ...@@ -41,19 +43,19 @@ def vggish_definer(variables, checkpoint_path):
def waveform_to_features(waveform): def waveform_to_features(waveform):
"""Creates VGGish features using the YAMNet feature extractor.""" """Creates VGGish features using the YAMNet feature extractor."""
yamnet_params.SAMPLE_RATE = vggish_params.SAMPLE_RATE params = yamnet_params.Params(
yamnet_params.STFT_WINDOW_SECONDS = vggish_params.STFT_WINDOW_LENGTH_SECONDS sample_rate=vggish_params.SAMPLE_RATE,
yamnet_params.STFT_HOP_SECONDS = vggish_params.STFT_HOP_LENGTH_SECONDS stft_window_seconds=vggish_params.STFT_WINDOW_LENGTH_SECONDS,
yamnet_params.MEL_BANDS = vggish_params.NUM_MEL_BINS stft_hop_seconds=vggish_params.STFT_HOP_LENGTH_SECONDS,
yamnet_params.MEL_MIN_HZ = vggish_params.MEL_MIN_HZ mel_bands=vggish_params.NUM_MEL_BINS,
yamnet_params.MEL_MAX_HZ = vggish_params.MEL_MAX_HZ mel_min_hz=vggish_params.MEL_MIN_HZ,
yamnet_params.LOG_OFFSET = vggish_params.LOG_OFFSET mel_max_hz=vggish_params.MEL_MAX_HZ,
yamnet_params.PATCH_WINDOW_SECONDS = vggish_params.EXAMPLE_WINDOW_SECONDS log_offset=vggish_params.LOG_OFFSET,
yamnet_params.PATCH_HOP_SECONDS = vggish_params.EXAMPLE_HOP_SECONDS patch_window_seconds=vggish_params.EXAMPLE_WINDOW_SECONDS,
log_mel_spectrogram = yamnet_features.waveform_to_log_mel_spectrogram( patch_hop_seconds=vggish_params.EXAMPLE_HOP_SECONDS)
waveform, yamnet_params) log_mel_spectrogram, features = yamnet_features.waveform_to_log_mel_spectrogram_patches(
return yamnet_features.spectrogram_to_patches( waveform, params)
log_mel_spectrogram, yamnet_params) return features
def define_vggish(waveform): def define_vggish(waveform):
with tf.variable_creator_scope(var_tracker): with tf.variable_creator_scope(var_tracker):
......
"""Exports YAMNet as: TF2 SavedModel, TF-Lite model, TF-JS model.
The exported models all accept as input:
- 1-d float32 Tensor of arbitrary shape containing an audio waveform
(assumed to be mono 16 kHz samples in the [-1, +1] range)
and return as output:
- a 2-d float32 Tensor of shape [num_frames, num_classes] containing
predicted class scores for each frame of audio extracted from the input.
- a 2-d float32 Tensor of shape [num_frames, embedding_size] containing
embeddings of each frame of audio.
- a 2-d float32 Tensor of shape [num_spectrogram_frames, num_mel_bins]
containing the log mel spectrogram of the entire waveform.
The SavedModels will also contain (as an asset) a class map CSV file that maps
class indices to AudioSet class names and Freebase MIDs. The path to the class
map is available as the 'class_map_path()' method of the restored model.
Requires pip-installing tensorflow_hub and tensorflowjs.
Usage:
export.py <path/to/YAMNet/weights-hdf-file> <path/to/output/directory>
and the various exports will be created in subdirectories of the output directory.
Assumes that it will be run in the yamnet source directory from where it loads
the class map. Skips an export if the corresponding directory already exists.
"""
import os
import sys
import tempfile
import time
import numpy as np
import tensorflow as tf
assert tf.version.VERSION >= '2.0.0', (
'Need at least TF 2.0, you have TF v{}'.format(tf.version.VERSION))
import tensorflow_hub as tfhub
from tensorflowjs.converters import tf_saved_model_conversion_v2 as tfjs_saved_model_converter
import params as yamnet_params
import yamnet
def log(msg):
print('\n=====\n{} | {}\n=====\n'.format(time.asctime(), msg), flush=True)
class YAMNet(tf.Module):
"''A TF2 Module wrapper around YAMNet."""
def __init__(self, weights_path, params):
super().__init__()
self._yamnet = yamnet.yamnet_frames_model(params)
self._yamnet.load_weights(weights_path)
self._class_map_asset = tf.saved_model.Asset('yamnet_class_map.csv')
@tf.function
def class_map_path(self):
return self._class_map_asset.asset_path
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.float32),))
def __call__(self, waveform):
return self._yamnet(waveform)
def check_model(model_fn, class_map_path, params):
yamnet_classes = yamnet.class_names(class_map_path)
"""Applies yamnet_test's sanity checks to an instance of YAMNet."""
def clip_test(waveform, expected_class_name, top_n=10):
predictions, embeddings, log_mel_spectrogram = model_fn(waveform)
clip_predictions = np.mean(predictions, axis=0)
top_n_indices = np.argsort(clip_predictions)[-top_n:]
top_n_scores = clip_predictions[top_n_indices]
top_n_class_names = yamnet_classes[top_n_indices]
top_n_predictions = list(zip(top_n_class_names, top_n_scores))
assert expected_class_name in top_n_class_names, (
'Did not find expected class {} in top {} predictions: {}'.format(
expected_class_name, top_n, top_n_predictions))
clip_test(
waveform=np.zeros((int(3 * params.sample_rate),), dtype=np.float32),
expected_class_name='Silence')
np.random.seed(51773) # Ensure repeatability.
clip_test(
waveform=np.random.uniform(-1.0, +1.0,
(int(3 * params.sample_rate),)).astype(np.float32),
expected_class_name='White noise')
clip_test(
waveform=np.sin(2 * np.pi * 440 *
np.arange(0, 3, 1 / params.sample_rate), dtype=np.float32),
expected_class_name='Sine wave')
def make_tf2_export(weights_path, export_dir):
if os.path.exists(export_dir):
log('TF2 export already exists in {}, skipping TF2 export'.format(
export_dir))
return
# Create a TF2 Module wrapper around YAMNet.
log('Building and checking TF2 Module ...')
params = yamnet_params.Params()
yamnet = YAMNet(weights_path, params)
check_model(yamnet, yamnet.class_map_path(), params)
log('Done')
# Make TF2 SavedModel export.
log('Making TF2 SavedModel export ...')
tf.saved_model.save(yamnet, export_dir)
log('Done')
# Check export with TF-Hub in TF2.
log('Checking TF2 SavedModel export in TF2 ...')
model = tfhub.load(export_dir)
check_model(model, model.class_map_path(), params)
log('Done')
# Check export with TF-Hub in TF1.
log('Checking TF2 SavedModel export in TF1 ...')
with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as sess:
model = tfhub.load(export_dir)
sess.run(tf.compat.v1.global_variables_initializer())
def run_model(waveform):
return sess.run(model(waveform))
check_model(run_model, model.class_map_path().eval(), params)
log('Done')
def make_tflite_export(weights_path, export_dir):
if os.path.exists(export_dir):
log('TF-Lite export already exists in {}, skipping TF-Lite export'.format(
export_dir))
return
# Create a TF-Lite compatible Module wrapper around YAMNet.
log('Building and checking TF-Lite Module ...')
params = yamnet_params.Params(tflite_compatible=True)
yamnet = YAMNet(weights_path, params)
check_model(yamnet, yamnet.class_map_path(), params)
log('Done')
# Make TF-Lite SavedModel export.
log('Making TF-Lite SavedModel export ...')
saved_model_dir = os.path.join(export_dir, 'saved_model')
os.makedirs(saved_model_dir)
tf.saved_model.save(yamnet, saved_model_dir)
log('Done')
# Check that the export can be loaded and works.
log('Checking TF-Lite SavedModel export in TF2 ...')
model = tf.saved_model.load(saved_model_dir)
check_model(model, model.class_map_path(), params)
log('Done')
# Make a TF-Lite model from the SavedModel.
log('Making TF-Lite model ...')
tflite_converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = tflite_converter.convert()
tflite_model_path = os.path.join(export_dir, 'yamnet.tflite')
with open(tflite_model_path, 'wb') as f:
f.write(tflite_model)
log('Done')
# Check the TF-Lite export.
log('Checking TF-Lite model ...')
interpreter = tf.lite.Interpreter(tflite_model_path)
audio_input_index = interpreter.get_input_details()[0]['index']
scores_output_index = interpreter.get_output_details()[0]['index']
embeddings_output_index = interpreter.get_output_details()[1]['index']
spectrogram_output_index = interpreter.get_output_details()[2]['index']
def run_model(waveform):
interpreter.resize_tensor_input(audio_input_index, [len(waveform)], strict=True)
interpreter.allocate_tensors()
interpreter.set_tensor(audio_input_index, waveform)
interpreter.invoke()
return (interpreter.get_tensor(scores_output_index),
interpreter.get_tensor(embeddings_output_index),
interpreter.get_tensor(spectrogram_output_index))
check_model(run_model, 'yamnet_class_map.csv', params)
log('Done')
return saved_model_dir
def make_tfjs_export(tflite_saved_model_dir, export_dir):
if os.path.exists(export_dir):
log('TF-JS export already exists in {}, skipping TF-JS export'.format(
export_dir))
return
# Make a TF-JS model from the TF-Lite SavedModel export.
log('Making TF-JS model ...')
os.makedirs(export_dir)
tfjs_saved_model_converter.convert_tf_saved_model(
tflite_saved_model_dir, export_dir)
log('Done')
def main(args):
weights_path = args[0]
output_dir = args[1]
tf2_export_dir = os.path.join(output_dir, 'tf2')
make_tf2_export(weights_path, tf2_export_dir)
tflite_export_dir = os.path.join(output_dir, 'tflite')
tflite_saved_model_dir = make_tflite_export(weights_path, tflite_export_dir)
tfjs_export_dir = os.path.join(output_dir, 'tfjs')
make_tfjs_export(tflite_saved_model_dir, tfjs_export_dir)
if __name__ == '__main__':
main(sys.argv[1:])
...@@ -27,47 +27,54 @@ def waveform_to_log_mel_spectrogram_patches(waveform, params): ...@@ -27,47 +27,54 @@ def waveform_to_log_mel_spectrogram_patches(waveform, params):
# Convert waveform into spectrogram using a Short-Time Fourier Transform. # Convert waveform into spectrogram using a Short-Time Fourier Transform.
# Note that tf.signal.stft() uses a periodic Hann window by default. # Note that tf.signal.stft() uses a periodic Hann window by default.
window_length_samples = int( window_length_samples = int(
round(params.SAMPLE_RATE * params.STFT_WINDOW_SECONDS)) round(params.sample_rate * params.stft_window_seconds))
hop_length_samples = int( hop_length_samples = int(
round(params.SAMPLE_RATE * params.STFT_HOP_SECONDS)) round(params.sample_rate * params.stft_hop_seconds))
fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0))) fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
num_spectrogram_bins = fft_length // 2 + 1 num_spectrogram_bins = fft_length // 2 + 1
magnitude_spectrogram = tf.abs(tf.signal.stft( if params.tflite_compatible:
signals=waveform, magnitude_spectrogram = _tflite_stft_magnitude(
frame_length=window_length_samples, signal=waveform,
frame_step=hop_length_samples, frame_length=window_length_samples,
fft_length=fft_length)) frame_step=hop_length_samples,
fft_length=fft_length)
else:
magnitude_spectrogram = tf.abs(tf.signal.stft(
signals=waveform,
frame_length=window_length_samples,
frame_step=hop_length_samples,
fft_length=fft_length))
# magnitude_spectrogram has shape [<# STFT frames>, num_spectrogram_bins] # magnitude_spectrogram has shape [<# STFT frames>, num_spectrogram_bins]
# Convert spectrogram into log mel spectrogram. # Convert spectrogram into log mel spectrogram.
linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix( linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
num_mel_bins=params.MEL_BANDS, num_mel_bins=params.mel_bands,
num_spectrogram_bins=num_spectrogram_bins, num_spectrogram_bins=num_spectrogram_bins,
sample_rate=params.SAMPLE_RATE, sample_rate=params.sample_rate,
lower_edge_hertz=params.MEL_MIN_HZ, lower_edge_hertz=params.mel_min_hz,
upper_edge_hertz=params.MEL_MAX_HZ) upper_edge_hertz=params.mel_max_hz)
mel_spectrogram = tf.matmul( mel_spectrogram = tf.matmul(
magnitude_spectrogram, linear_to_mel_weight_matrix) magnitude_spectrogram, linear_to_mel_weight_matrix)
log_mel_spectrogram = tf.math.log(mel_spectrogram + params.LOG_OFFSET) log_mel_spectrogram = tf.math.log(mel_spectrogram + params.log_offset)
# log_mel_spectrogram has shape [<# STFT frames>, MEL_BANDS] # log_mel_spectrogram has shape [<# STFT frames>, params.mel_bands]
# Frame spectrogram (shape [<# STFT frames>, MEL_BANDS]) into patches (the # Frame spectrogram (shape [<# STFT frames>, params.mel_bands]) into patches
# input examples). Only complete frames are emitted, so if there is less # (the input examples). Only complete frames are emitted, so if there is
# than PATCH_WINDOW_SECONDS of waveform then nothing is emitted (to avoid # less than params.patch_window_seconds of waveform then nothing is emitted
# this, zero-pad before processing). # (to avoid this, zero-pad before processing).
spectrogram_hop_length_samples = int( spectrogram_hop_length_samples = int(
round(params.SAMPLE_RATE * params.STFT_HOP_SECONDS)) round(params.sample_rate * params.stft_hop_seconds))
spectrogram_sample_rate = params.SAMPLE_RATE / spectrogram_hop_length_samples spectrogram_sample_rate = params.sample_rate / spectrogram_hop_length_samples
patch_window_length_samples = int( patch_window_length_samples = int(
round(spectrogram_sample_rate * params.PATCH_WINDOW_SECONDS)) round(spectrogram_sample_rate * params.patch_window_seconds))
patch_hop_length_samples = int( patch_hop_length_samples = int(
round(spectrogram_sample_rate * params.PATCH_HOP_SECONDS)) round(spectrogram_sample_rate * params.patch_hop_seconds))
features = tf.signal.frame( features = tf.signal.frame(
signal=log_mel_spectrogram, signal=log_mel_spectrogram,
frame_length=patch_window_length_samples, frame_length=patch_window_length_samples,
frame_step=patch_hop_length_samples, frame_step=patch_hop_length_samples,
axis=0) axis=0)
# features has shape [<# patches>, <# STFT frames in an patch>, MEL_BANDS] # features has shape [<# patches>, <# STFT frames in an patch>, params.mel_bands]
return log_mel_spectrogram, features return log_mel_spectrogram, features
...@@ -78,10 +85,10 @@ def pad_waveform(waveform, params): ...@@ -78,10 +85,10 @@ def pad_waveform(waveform, params):
# need at least one patch window length of waveform plus enough extra samples # need at least one patch window length of waveform plus enough extra samples
# to complete the final STFT analysis window. # to complete the final STFT analysis window.
min_waveform_seconds = ( min_waveform_seconds = (
params.PATCH_WINDOW_SECONDS + params.patch_window_seconds +
params.STFT_WINDOW_SECONDS - params.STFT_HOP_SECONDS) params.stft_window_seconds - params.stft_hop_seconds)
min_num_samples = tf.cast(min_waveform_seconds * params.SAMPLE_RATE, tf.int32) min_num_samples = tf.cast(min_waveform_seconds * params.sample_rate, tf.int32)
num_samples = tf.size(waveform) num_samples = tf.shape(waveform)[0]
num_padding_samples = tf.maximum(0, min_num_samples - num_samples) num_padding_samples = tf.maximum(0, min_num_samples - num_samples)
# In addition, there might be enough waveform for one or more additional # In addition, there might be enough waveform for one or more additional
...@@ -89,12 +96,70 @@ def pad_waveform(waveform, params): ...@@ -89,12 +96,70 @@ def pad_waveform(waveform, params):
# round up to an integral number of hops. # round up to an integral number of hops.
num_samples = tf.maximum(num_samples, min_num_samples) num_samples = tf.maximum(num_samples, min_num_samples)
num_samples_after_first_patch = num_samples - min_num_samples num_samples_after_first_patch = num_samples - min_num_samples
hop_samples = tf.cast(params.PATCH_HOP_SECONDS * params.SAMPLE_RATE, tf.int32) hop_samples = tf.cast(params.patch_hop_seconds * params.sample_rate, tf.int32)
num_hops_after_first_patch = tf.cast(tf.math.ceil( num_hops_after_first_patch = tf.cast(tf.math.ceil(
tf.math.divide(num_samples_after_first_patch, hop_samples)), tf.int32) tf.cast(num_samples_after_first_patch, tf.float32) /
tf.cast(hop_samples, tf.float32)), tf.int32)
num_padding_samples += ( num_padding_samples += (
hop_samples * num_hops_after_first_patch - num_samples_after_first_patch) hop_samples * num_hops_after_first_patch - num_samples_after_first_patch)
padded_waveform = tf.pad(waveform, [[0, num_padding_samples]], padded_waveform = tf.pad(waveform, [[0, num_padding_samples]],
mode='CONSTANT', constant_values=0.0) mode='CONSTANT', constant_values=0.0)
return padded_waveform return padded_waveform
def _tflite_stft_magnitude(signal, frame_length, frame_step, fft_length):
"""TF-Lite-compatible version of tf.abs(tf.signal.stft())."""
def _hann_window():
return tf.reshape(
tf.constant(
(0.5 - 0.5 * np.cos(2 * np.pi * np.arange(0, 1.0, 1.0 / frame_length))
).astype(np.float32),
name='hann_window'), [1, frame_length])
def _dft_matrix(dft_length):
"""Calculate the full DFT matrix in NumPy."""
# See https://en.wikipedia.org/wiki/DFT_matrix
omega = (0 + 1j) * 2.0 * np.pi / float(dft_length)
# Don't include 1/sqrt(N) scaling, tf.signal.rfft doesn't apply it.
return np.exp(omega * np.outer(np.arange(dft_length), np.arange(dft_length)))
def _rdft(framed_signal, fft_length):
"""Implement real-input Discrete Fourier Transform by matmul."""
# We are right-multiplying by the DFT matrix, and we are keeping only the
# first half ("positive frequencies"). So discard the second half of rows,
# but transpose the array for right-multiplication. The DFT matrix is
# symmetric, so we could have done it more directly, but this reflects our
# intention better.
complex_dft_matrix_kept_values = _dft_matrix(fft_length)[:(
fft_length // 2 + 1), :].transpose()
real_dft_matrix = tf.constant(
np.real(complex_dft_matrix_kept_values).astype(np.float32),
name='real_dft_matrix')
imag_dft_matrix = tf.constant(
np.imag(complex_dft_matrix_kept_values).astype(np.float32),
name='imaginary_dft_matrix')
signal_frame_length = tf.shape(framed_signal)[-1]
half_pad = (fft_length - signal_frame_length) // 2
padded_frames = tf.pad(
framed_signal,
[
# Don't add any padding in the frame dimension.
[0, 0],
# Pad before and after the signal within each frame.
[half_pad, fft_length - signal_frame_length - half_pad]
],
mode='CONSTANT',
constant_values=0.0)
real_stft = tf.matmul(padded_frames, real_dft_matrix)
imag_stft = tf.matmul(padded_frames, imag_dft_matrix)
return real_stft, imag_stft
def _complex_abs(real, imag):
return tf.sqrt(tf.add(real * real, imag * imag))
framed_signal = tf.signal.frame(signal, frame_length, frame_step)
windowed_signal = framed_signal * _hann_window()
real_stft, imag_stft = _rdft(windowed_signal, fft_length)
stft_magnitude = _complex_abs(real_stft, imag_stft)
return stft_magnitude
...@@ -23,13 +23,14 @@ import resampy ...@@ -23,13 +23,14 @@ import resampy
import soundfile as sf import soundfile as sf
import tensorflow as tf import tensorflow as tf
import params import params as yamnet_params
import yamnet as yamnet_model import yamnet as yamnet_model
def main(argv): def main(argv):
assert argv, 'Usage: inference.py <wav file> <wav file> ...' assert argv, 'Usage: inference.py <wav file> <wav file> ...'
params = yamnet_params.Params()
yamnet = yamnet_model.yamnet_frames_model(params) yamnet = yamnet_model.yamnet_frames_model(params)
yamnet.load_weights('yamnet.h5') yamnet.load_weights('yamnet.h5')
yamnet_classes = yamnet_model.class_names('yamnet_class_map.csv') yamnet_classes = yamnet_model.class_names('yamnet_class_map.csv')
...@@ -44,8 +45,8 @@ def main(argv): ...@@ -44,8 +45,8 @@ def main(argv):
# Convert to mono and the sample rate expected by YAMNet. # Convert to mono and the sample rate expected by YAMNet.
if len(waveform.shape) > 1: if len(waveform.shape) > 1:
waveform = np.mean(waveform, axis=1) waveform = np.mean(waveform, axis=1)
if sr != params.SAMPLE_RATE: if sr != params.sample_rate:
waveform = resampy.resample(waveform, sr, params.SAMPLE_RATE) waveform = resampy.resample(waveform, sr, params.sample_rate)
# Predict YAMNet classes. # Predict YAMNet classes.
scores, embeddings, spectrogram = yamnet(waveform) scores, embeddings, spectrogram = yamnet(waveform)
......
...@@ -15,25 +15,37 @@ ...@@ -15,25 +15,37 @@
"""Hyperparameters for YAMNet.""" """Hyperparameters for YAMNet."""
# The following hyperparameters (except PATCH_HOP_SECONDS) were used to train YAMNet, from dataclasses import dataclass
# The following hyperparameters (except patch_hop_seconds) were used to train YAMNet,
# so expect some variability in performance if you change these. The patch hop can # so expect some variability in performance if you change these. The patch hop can
# be changed arbitrarily: a smaller hop should give you more patches from the same # be changed arbitrarily: a smaller hop should give you more patches from the same
# clip and possibly better performance at a larger computational cost. # clip and possibly better performance at a larger computational cost.
SAMPLE_RATE = 16000 @dataclass(frozen=True) # Instances of this class are immutable.
STFT_WINDOW_SECONDS = 0.025 class Params:
STFT_HOP_SECONDS = 0.010 sample_rate: float = 16000.0
MEL_BANDS = 64 stft_window_seconds: float = 0.025
MEL_MIN_HZ = 125 stft_hop_seconds: float = 0.010
MEL_MAX_HZ = 7500 mel_bands: int = 64
LOG_OFFSET = 0.001 mel_min_hz: float = 125.0
PATCH_WINDOW_SECONDS = 0.96 mel_max_hz: float = 7500.0
PATCH_HOP_SECONDS = 0.48 log_offset: float = 0.001
patch_window_seconds: float = 0.96
patch_hop_seconds: float = 0.48
@property
def patch_frames(self):
return int(round(self.patch_window_seconds / self.stft_hop_seconds))
@property
def patch_bands(self):
return self.mel_bands
num_classes: int = 521
conv_padding: str = 'same'
batchnorm_center: bool = True
batchnorm_scale: bool = False
batchnorm_epsilon: float = 1e-4
classifier_activation: str = 'sigmoid'
PATCH_FRAMES = int(round(PATCH_WINDOW_SECONDS / STFT_HOP_SECONDS)) tflite_compatible: bool = False
PATCH_BANDS = MEL_BANDS
NUM_CLASSES = 521
CONV_PADDING = 'same'
BATCHNORM_CENTER = True
BATCHNORM_SCALE = False
BATCHNORM_EPSILON = 1e-4
CLASSIFIER_ACTIVATION = 'sigmoid'
...@@ -22,53 +22,52 @@ import tensorflow as tf ...@@ -22,53 +22,52 @@ import tensorflow as tf
from tensorflow.keras import Model, layers from tensorflow.keras import Model, layers
import features as features_lib import features as features_lib
import params
def _batch_norm(name): def _batch_norm(name, params):
def _bn_layer(layer_input): def _bn_layer(layer_input):
return layers.BatchNormalization( return layers.BatchNormalization(
name=name, name=name,
center=params.BATCHNORM_CENTER, center=params.batchnorm_center,
scale=params.BATCHNORM_SCALE, scale=params.batchnorm_scale,
epsilon=params.BATCHNORM_EPSILON)(layer_input) epsilon=params.batchnorm_epsilon)(layer_input)
return _bn_layer return _bn_layer
def _conv(name, kernel, stride, filters): def _conv(name, kernel, stride, filters, params):
def _conv_layer(layer_input): def _conv_layer(layer_input):
output = layers.Conv2D(name='{}/conv'.format(name), output = layers.Conv2D(name='{}/conv'.format(name),
filters=filters, filters=filters,
kernel_size=kernel, kernel_size=kernel,
strides=stride, strides=stride,
padding=params.CONV_PADDING, padding=params.conv_padding,
use_bias=False, use_bias=False,
activation=None)(layer_input) activation=None)(layer_input)
output = _batch_norm(name='{}/conv/bn'.format(name))(output) output = _batch_norm('{}/conv/bn'.format(name), params)(output)
output = layers.ReLU(name='{}/relu'.format(name))(output) output = layers.ReLU(name='{}/relu'.format(name))(output)
return output return output
return _conv_layer return _conv_layer
def _separable_conv(name, kernel, stride, filters): def _separable_conv(name, kernel, stride, filters, params):
def _separable_conv_layer(layer_input): def _separable_conv_layer(layer_input):
output = layers.DepthwiseConv2D(name='{}/depthwise_conv'.format(name), output = layers.DepthwiseConv2D(name='{}/depthwise_conv'.format(name),
kernel_size=kernel, kernel_size=kernel,
strides=stride, strides=stride,
depth_multiplier=1, depth_multiplier=1,
padding=params.CONV_PADDING, padding=params.conv_padding,
use_bias=False, use_bias=False,
activation=None)(layer_input) activation=None)(layer_input)
output = _batch_norm(name='{}/depthwise_conv/bn'.format(name))(output) output = _batch_norm('{}/depthwise_conv/bn'.format(name), params)(output)
output = layers.ReLU(name='{}/depthwise_conv/relu'.format(name))(output) output = layers.ReLU(name='{}/depthwise_conv/relu'.format(name))(output)
output = layers.Conv2D(name='{}/pointwise_conv'.format(name), output = layers.Conv2D(name='{}/pointwise_conv'.format(name),
filters=filters, filters=filters,
kernel_size=(1, 1), kernel_size=(1, 1),
strides=1, strides=1,
padding=params.CONV_PADDING, padding=params.conv_padding,
use_bias=False, use_bias=False,
activation=None)(output) activation=None)(output)
output = _batch_norm(name='{}/pointwise_conv/bn'.format(name))(output) output = _batch_norm('{}/pointwise_conv/bn'.format(name), params)(output)
output = layers.ReLU(name='{}/pointwise_conv/relu'.format(name))(output) output = layers.ReLU(name='{}/pointwise_conv/relu'.format(name))(output)
return output return output
return _separable_conv_layer return _separable_conv_layer
...@@ -93,25 +92,24 @@ _YAMNET_LAYER_DEFS = [ ...@@ -93,25 +92,24 @@ _YAMNET_LAYER_DEFS = [
] ]
def yamnet(features): def yamnet(features, params):
"""Define the core YAMNet mode in Keras.""" """Define the core YAMNet mode in Keras."""
net = layers.Reshape( net = layers.Reshape(
(params.PATCH_FRAMES, params.PATCH_BANDS, 1), (params.patch_frames, params.patch_bands, 1),
input_shape=(params.PATCH_FRAMES, params.PATCH_BANDS))(features) input_shape=(params.patch_frames, params.patch_bands))(features)
for (i, (layer_fun, kernel, stride, filters)) in enumerate(_YAMNET_LAYER_DEFS): for (i, (layer_fun, kernel, stride, filters)) in enumerate(_YAMNET_LAYER_DEFS):
net = layer_fun('layer{}'.format(i + 1), kernel, stride, filters)(net) net = layer_fun('layer{}'.format(i + 1), kernel, stride, filters, params)(net)
embeddings = layers.GlobalAveragePooling2D()(net) embeddings = layers.GlobalAveragePooling2D()(net)
logits = layers.Dense(units=params.NUM_CLASSES, use_bias=True)(embeddings) logits = layers.Dense(units=params.num_classes, use_bias=True)(embeddings)
predictions = layers.Activation(activation=params.CLASSIFIER_ACTIVATION)(logits) predictions = layers.Activation(activation=params.classifier_activation)(logits)
return predictions, embeddings return predictions, embeddings
def yamnet_frames_model(feature_params): def yamnet_frames_model(params):
"""Defines the YAMNet waveform-to-class-scores model. """Defines the YAMNet waveform-to-class-scores model.
Args: Args:
feature_params: An object with parameter fields to control the feature params: An instance of Params containing hyperparameters.
calculation.
Returns: Returns:
A model accepting (num_samples,) waveform input and emitting: A model accepting (num_samples,) waveform input and emitting:
...@@ -120,10 +118,10 @@ def yamnet_frames_model(feature_params): ...@@ -120,10 +118,10 @@ def yamnet_frames_model(feature_params):
- log_mel_spectrogram: (num_spectrogram_frames, num_mel_bins) spectrogram feature matrix - log_mel_spectrogram: (num_spectrogram_frames, num_mel_bins) spectrogram feature matrix
""" """
waveform = layers.Input(batch_shape=(None,), dtype=tf.float32) waveform = layers.Input(batch_shape=(None,), dtype=tf.float32)
waveform_padded = features_lib.pad_waveform(waveform, feature_params) waveform_padded = features_lib.pad_waveform(waveform, params)
log_mel_spectrogram, features = features_lib.waveform_to_log_mel_spectrogram_patches( log_mel_spectrogram, features = features_lib.waveform_to_log_mel_spectrogram_patches(
waveform_padded, feature_params) waveform_padded, params)
predictions, embeddings = yamnet(features) predictions, embeddings = yamnet(features, params)
frames_model = Model( frames_model = Model(
name='yamnet_frames', inputs=waveform, name='yamnet_frames', inputs=waveform,
outputs=[predictions, embeddings, log_mel_spectrogram]) outputs=[predictions, embeddings, log_mel_spectrogram])
...@@ -132,6 +130,8 @@ def yamnet_frames_model(feature_params): ...@@ -132,6 +130,8 @@ def yamnet_frames_model(feature_params):
def class_names(class_map_csv): def class_names(class_map_csv):
"""Read the class name definition file and return a list of strings.""" """Read the class name definition file and return a list of strings."""
if tf.is_tensor(class_map_csv):
class_map_csv = class_map_csv.numpy()
with open(class_map_csv) as csv_file: with open(class_map_csv) as csv_file:
reader = csv.reader(csv_file) reader = csv.reader(csv_file)
next(reader) # Skip header next(reader) # Skip header
......
...@@ -23,46 +23,46 @@ import yamnet ...@@ -23,46 +23,46 @@ import yamnet
class YAMNetTest(tf.test.TestCase): class YAMNetTest(tf.test.TestCase):
_yamnet_graph = None _params = None
_yamnet = None _yamnet = None
_yamnet_classes = None _yamnet_classes = None
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super(YAMNetTest, cls).setUpClass() super().setUpClass()
cls._yamnet_graph = tf.Graph() cls._params = params.Params()
with cls._yamnet_graph.as_default(): cls._yamnet = yamnet.yamnet_frames_model(cls._params)
cls._yamnet = yamnet.yamnet_frames_model(params) cls._yamnet.load_weights('yamnet.h5')
cls._yamnet.load_weights('yamnet.h5') cls._yamnet_classes = yamnet.class_names('yamnet_class_map.csv')
cls._yamnet_classes = yamnet.class_names('yamnet_class_map.csv')
def clip_test(self, waveform, expected_class_name, top_n=10): def clip_test(self, waveform, expected_class_name, top_n=10):
"""Run the model on the waveform, check that expected class is in top-n.""" """Run the model on the waveform, check that expected class is in top-n."""
with YAMNetTest._yamnet_graph.as_default(): predictions, embeddings, log_mel_spectrogram = YAMNetTest._yamnet(waveform)
prediction = np.mean(YAMNetTest._yamnet.predict( clip_predictions = np.mean(predictions, axis=0)
np.reshape(waveform, [1, -1]), steps=1)[0], axis=0) top_n_indices = np.argsort(clip_predictions)[-top_n:]
top_n_class_names = YAMNetTest._yamnet_classes[ top_n_scores = clip_predictions[top_n_indices]
np.argsort(prediction)[-top_n:]] top_n_class_names = YAMNetTest._yamnet_classes[top_n_indices]
self.assertIn(expected_class_name, top_n_class_names) top_n_predictions = list(zip(top_n_class_names, top_n_scores))
self.assertIn(expected_class_name, top_n_class_names,
'Did not find expected class {} in top {} predictions: {}'.format(
expected_class_name, top_n, top_n_predictions))
def testZeros(self): def testZeros(self):
self.clip_test( self.clip_test(
waveform=np.zeros((1, int(3 * params.SAMPLE_RATE))), waveform=np.zeros((int(3 * YAMNetTest._params.sample_rate),)),
expected_class_name='Silence') expected_class_name='Silence')
def testRandom(self): def testRandom(self):
np.random.seed(51773) # Ensure repeatability. np.random.seed(51773) # Ensure repeatability.
self.clip_test( self.clip_test(
waveform=np.random.uniform(-1.0, +1.0, waveform=np.random.uniform(-1.0, +1.0,
(1, int(3 * params.SAMPLE_RATE))), (int(3 * YAMNetTest._params.sample_rate),)),
expected_class_name='White noise') expected_class_name='White noise')
def testSine(self): def testSine(self):
self.clip_test( self.clip_test(
waveform=np.reshape( waveform=np.sin(2 * np.pi * 440 *
np.sin(2 * np.pi * 440 * np.linspace( np.arange(0, 3, 1 / YAMNetTest._params.sample_rate)),
0, 3, int(3 *params.SAMPLE_RATE))),
[1, -1]),
expected_class_name='Sine wave') expected_class_name='Sine wave')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment