Commit f5fc733a authored by Byzantine's avatar Byzantine
Browse files

Removing research/community models

parent 09bc9f54
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Script to train the Attention OCR model.
A simple usage example:
python train.py
"""
import collections
import logging
import tensorflow as tf
from tensorflow.contrib import slim
from tensorflow import app
from tensorflow.python.platform import flags
from tensorflow.contrib.tfprof import model_analyzer
import data_provider
import common_flags
FLAGS = flags.FLAGS
common_flags.define()
# yapf: disable
flags.DEFINE_integer('task', 0,
'The Task ID. This value is used when training with '
'multiple workers to identify each worker.')
flags.DEFINE_integer('ps_tasks', 0,
'The number of parameter servers. If the value is 0, then'
' the parameters are handled locally by the worker.')
flags.DEFINE_integer('save_summaries_secs', 60,
'The frequency with which summaries are saved, in '
'seconds.')
flags.DEFINE_integer('save_interval_secs', 600,
'Frequency in seconds of saving the model.')
flags.DEFINE_integer('max_number_of_steps', int(1e10),
'The maximum number of gradient steps.')
flags.DEFINE_string('checkpoint_inception', '',
'Checkpoint to recover inception weights from.')
flags.DEFINE_float('clip_gradient_norm', 2.0,
'If greater than 0 then the gradients would be clipped by '
'it.')
flags.DEFINE_bool('sync_replicas', False,
'If True will synchronize replicas during training.')
flags.DEFINE_integer('replicas_to_aggregate', 1,
'The number of gradients updates before updating params.')
flags.DEFINE_integer('total_num_replicas', 1,
'Total number of worker replicas.')
flags.DEFINE_integer('startup_delay_steps', 15,
'Number of training steps between replicas startup.')
flags.DEFINE_boolean('reset_train_dir', False,
'If true will delete all files in the train_log_dir')
flags.DEFINE_boolean('show_graph_stats', False,
'Output model size stats to stderr.')
# yapf: enable
TrainingHParams = collections.namedtuple('TrainingHParams', [
'learning_rate',
'optimizer',
'momentum',
'use_augment_input',
])
def get_training_hparams():
return TrainingHParams(
learning_rate=FLAGS.learning_rate,
optimizer=FLAGS.optimizer,
momentum=FLAGS.momentum,
use_augment_input=FLAGS.use_augment_input)
def create_optimizer(hparams):
"""Creates optimized based on the specified flags."""
if hparams.optimizer == 'momentum':
optimizer = tf.train.MomentumOptimizer(
hparams.learning_rate, momentum=hparams.momentum)
elif hparams.optimizer == 'adam':
optimizer = tf.train.AdamOptimizer(hparams.learning_rate)
elif hparams.optimizer == 'adadelta':
optimizer = tf.train.AdadeltaOptimizer(hparams.learning_rate)
elif hparams.optimizer == 'adagrad':
optimizer = tf.train.AdagradOptimizer(hparams.learning_rate)
elif hparams.optimizer == 'rmsprop':
optimizer = tf.train.RMSPropOptimizer(
hparams.learning_rate, momentum=hparams.momentum)
return optimizer
def train(loss, init_fn, hparams):
"""Wraps slim.learning.train to run a training loop.
Args:
loss: a loss tensor
init_fn: A callable to be executed after all other initialization is done.
hparams: a model hyper parameters
"""
optimizer = create_optimizer(hparams)
if FLAGS.sync_replicas:
replica_id = tf.constant(FLAGS.task, tf.int32, shape=())
optimizer = tf.LegacySyncReplicasOptimizer(
opt=optimizer,
replicas_to_aggregate=FLAGS.replicas_to_aggregate,
replica_id=replica_id,
total_num_replicas=FLAGS.total_num_replicas)
sync_optimizer = optimizer
startup_delay_steps = 0
else:
startup_delay_steps = 0
sync_optimizer = None
train_op = slim.learning.create_train_op(
loss,
optimizer,
summarize_gradients=True,
clip_gradient_norm=FLAGS.clip_gradient_norm)
slim.learning.train(
train_op=train_op,
logdir=FLAGS.train_log_dir,
graph=loss.graph,
master=FLAGS.master,
is_chief=(FLAGS.task == 0),
number_of_steps=FLAGS.max_number_of_steps,
save_summaries_secs=FLAGS.save_summaries_secs,
save_interval_secs=FLAGS.save_interval_secs,
startup_delay_steps=startup_delay_steps,
sync_optimizer=sync_optimizer,
init_fn=init_fn)
def prepare_training_dir():
if not tf.gfile.Exists(FLAGS.train_log_dir):
logging.info('Create a new training directory %s', FLAGS.train_log_dir)
tf.gfile.MakeDirs(FLAGS.train_log_dir)
else:
if FLAGS.reset_train_dir:
logging.info('Reset the training directory %s', FLAGS.train_log_dir)
tf.gfile.DeleteRecursively(FLAGS.train_log_dir)
tf.gfile.MakeDirs(FLAGS.train_log_dir)
else:
logging.info('Use already existing training directory %s',
FLAGS.train_log_dir)
def calculate_graph_metrics():
param_stats = model_analyzer.print_model_analysis(
tf.get_default_graph(),
tfprof_options=model_analyzer.TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
return param_stats.total_parameters
def main(_):
prepare_training_dir()
dataset = common_flags.create_dataset(split_name=FLAGS.split_name)
model = common_flags.create_model(dataset.num_char_classes,
dataset.max_sequence_length,
dataset.num_of_views, dataset.null_code)
hparams = get_training_hparams()
# If ps_tasks is zero, the local device is used. When using multiple
# (non-local) replicas, the ReplicaDeviceSetter distributes the variables
# across the different devices.
device_setter = tf.train.replica_device_setter(
FLAGS.ps_tasks, merge_devices=True)
with tf.device(device_setter):
data = data_provider.get_data(
dataset,
FLAGS.batch_size,
augment=hparams.use_augment_input,
central_crop_size=common_flags.get_crop_size())
endpoints = model.create_base(data.images, data.labels_one_hot)
total_loss = model.create_loss(data, endpoints)
model.create_summaries(data, endpoints, dataset.charset, is_training=True)
init_fn = model.create_init_fn_to_restore(FLAGS.checkpoint,
FLAGS.checkpoint_inception)
if FLAGS.show_graph_stats:
logging.info('Total number of weights in the graph: %s',
calculate_graph_metrics())
train(total_loss, init_fn, hparams)
if __name__ == '__main__':
app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions to support building models for StreetView text transcription."""
import tensorflow as tf
from tensorflow.contrib import slim
def logits_to_log_prob(logits):
"""Computes log probabilities using numerically stable trick.
This uses two numerical stability tricks:
1) softmax(x) = softmax(x - c) where c is a constant applied to all
arguments. If we set c = max(x) then the softmax is more numerically
stable.
2) log softmax(x) is not numerically stable, but we can stabilize it
by using the identity log softmax(x) = x - log sum exp(x)
Args:
logits: Tensor of arbitrary shape whose last dimension contains logits.
Returns:
A tensor of the same shape as the input, but with corresponding log
probabilities.
"""
with tf.variable_scope('log_probabilities'):
reduction_indices = len(logits.shape.as_list()) - 1
max_logits = tf.reduce_max(
logits, reduction_indices=reduction_indices, keep_dims=True)
safe_logits = tf.subtract(logits, max_logits)
sum_exp = tf.reduce_sum(
tf.exp(safe_logits),
reduction_indices=reduction_indices,
keep_dims=True)
log_probs = tf.subtract(safe_logits, tf.log(sum_exp))
return log_probs
def variables_to_restore(scope=None, strip_scope=False):
"""Returns a list of variables to restore for the specified list of methods.
It is supposed that variable name starts with the method's scope (a prefix
returned by _method_scope function).
Args:
methods_names: a list of names of configurable methods.
strip_scope: if True will return variable names without method's scope.
If methods_names is None will return names unchanged.
model_scope: a scope for a whole model.
Returns:
a dictionary mapping variable names to variables for restore.
"""
if scope:
variable_map = {}
method_variables = slim.get_variables_to_restore(include=[scope])
for var in method_variables:
if strip_scope:
var_name = var.op.name[len(scope) + 1:]
else:
var_name = var.op.name
variable_map[var_name] = var
return variable_map
else:
return {v.op.name: v for v in slim.get_variables_to_restore()}
![TensorFlow Requirement: 1.x](https://img.shields.io/badge/TensorFlow%20Requirement-1.x-brightgreen)
![TensorFlow 2 Not Supported](https://img.shields.io/badge/TensorFlow%202%20Not%20Supported-%E2%9C%95-red.svg)
# Models for AudioSet: A Large Scale Dataset of Audio Events
This repository provides models and supporting code associated with
[AudioSet](http://g.co/audioset), a dataset of over 2 million human-labeled
10-second YouTube video soundtracks, with labels taken from an ontology of
more than 600 audio event classes.
AudioSet was
[released](https://research.googleblog.com/2017/03/announcing-audioset-dataset-for-audio.html)
in March 2017 by Google's Sound Understanding team to provide a common
large-scale evaluation task for audio event detection as well as a starting
point for a comprehensive vocabulary of sound events.
For more details about AudioSet and the various models we have trained, please
visit the [AudioSet website](http://g.co/audioset) and read our papers:
* Gemmeke, J. et. al.,
[AudioSet: An ontology and human-labelled dataset for audio events](https://research.google.com/pubs/pub45857.html),
ICASSP 2017
* Hershey, S. et. al.,
[CNN Architectures for Large-Scale Audio Classification](https://research.google.com/pubs/pub45611.html),
ICASSP 2017
If you use any of our pre-trained models in your published research, we ask that
you cite [CNN Architectures for Large-Scale Audio Classification](https://research.google.com/pubs/pub45611.html).
If you use the AudioSet dataset or the released embeddings of AudioSet segments,
please cite
[AudioSet: An ontology and human-labelled dataset for audio events](https://research.google.com/pubs/pub45857.html).
## Contact
For general questions about AudioSet and these models, please use the
[audioset-users@googlegroups.com](https://groups.google.com/forum/#!forum/audioset-users)
mailing list.
For technical problems with the released model and code, please open an issue on
the [tensorflow/models issue tracker](https://github.com/tensorflow/models/issues)
and __*assign to @plakal and @dpwe*__. Please note that because the issue tracker
is shared across all models released by Google, we won't be notified about an
issue unless you explicitly @-mention us (@plakal and @dpwe) or assign the issue
to us.
## Credits
Original authors and reviewers of the code in this package include (in
alphabetical order):
* DAn Ellis
* Shawn Hershey
* Aren Jansen
* Manoj Plakal
# VGGish
The initial AudioSet release included 128-dimensional embeddings of each
AudioSet segment produced from a VGG-like audio classification model that was
trained on a large YouTube dataset (a preliminary version of what later became
[YouTube-8M](https://research.google.com/youtube8m)).
We provide a TensorFlow definition of this model, which we call __*VGGish*__, as
well as supporting code to extract input features for the model from audio
waveforms and to post-process the model embedding output into the same format as
the released embedding features.
## Installation
VGGish depends on the following Python packages:
* [`numpy`](http://www.numpy.org/)
* [`resampy`](http://resampy.readthedocs.io/en/latest/)
* [`tensorflow`](http://www.tensorflow.org/) (currently, only TF v1.x)
* [`tf_slim`](https://github.com/google-research/tf-slim)
* [`six`](https://pythonhosted.org/six/)
* [`soundfile`](https://pysoundfile.readthedocs.io/)
These are all easily installable via, e.g., `pip install numpy` (as in the
sample installation session below).
Any reasonably recent version of these packages shold work. Note that we currently only support
TensorFlow v1.x due to a [`tf_slim` limitation](https://github.com/google-research/tf-slim/pull/1).
TensorFlow v1.15 (the latest version as of Jan 2020) has been tested to work.
VGGish also requires downloading two data files:
* [VGGish model checkpoint](https://storage.googleapis.com/audioset/vggish_model.ckpt),
in TensorFlow checkpoint format.
* [Embedding PCA parameters](https://storage.googleapis.com/audioset/vggish_pca_params.npz),
in NumPy compressed archive format.
After downloading these files into the same directory as this README, the
installation can be tested by running `python vggish_smoke_test.py` which
runs a known signal through the model and checks the output.
Here's a sample installation and test session:
```shell
# You can optionally install and test VGGish within a Python virtualenv, which
# is useful for isolating changes from the rest of your system. For example, you
# may have an existing version of some packages that you do not want to upgrade,
# or you want to try Python 3 instead of Python 2. If you decide to use a
# virtualenv, you can create one by running
# $ virtualenv vggish # For Python 2
# or
# $ python3 -m venv vggish # For Python 3
# and then enter the virtual environment by running
# $ source vggish/bin/activate # Assuming you use bash
# Leave the virtual environment at the end of the session by running
# $ deactivate
# Within the virtual environment, do not use 'sudo'.
# Upgrade pip first. Also make sure wheel is installed.
$ sudo python -m pip install --upgrade pip wheel
# Install all dependences.
$ sudo pip install numpy resampy tensorflow==1.15 tf_slim six soundfile
# Clone TensorFlow models repo into a 'models' directory.
$ git clone https://github.com/tensorflow/models.git
$ cd models/research/audioset/vggish
# Download data files into same directory as code.
$ curl -O https://storage.googleapis.com/audioset/vggish_model.ckpt
$ curl -O https://storage.googleapis.com/audioset/vggish_pca_params.npz
# Installation ready, let's test it.
$ python vggish_smoke_test.py
# If we see "Looks Good To Me", then we're all set.
```
## Usage
VGGish can be used in two ways:
* *As a feature extractor*: VGGish converts audio input features into a
semantically meaningful, high-level 128-D embedding which can be fed as input
to a downstream classification model. The downstream model can be shallower
than usual because the VGGish embedding is more semantically compact than raw
audio features.
So, for example, you could train a classifier for 10 of the AudioSet classes
by using the released embeddings as features. Then, you could use that
trained classifier with any arbitrary audio input by running the audio through
the audio feature extractor and VGGish model provided here, passing the
resulting embedding features as input to your trained model.
`vggish_inference_demo.py` shows how to produce VGGish embeddings from
arbitrary audio.
* *As part of a larger model*: Here, we treat VGGish as a "warm start" for the
lower layers of a model that takes audio features as input and adds more
layers on top of the VGGish embedding. This can be used to fine-tune VGGish
(or parts thereof) if you have large datasets that might be very different
from the typical YouTube video clip. `vggish_train_demo.py` shows how to add
layers on top of VGGish and train the whole model.
## About the Model
The VGGish code layout is as follows:
* `vggish_slim.py`: Model definition in TensorFlow Slim notation.
* `vggish_params.py`: Hyperparameters.
* `vggish_input.py`: Converter from audio waveform into input examples.
* `mel_features.py`: Audio feature extraction helpers.
* `vggish_postprocess.py`: Embedding postprocessing.
* `vggish_inference_demo.py`: Demo of VGGish in inference mode.
* `vggish_train_demo.py`: Demo of VGGish in training mode.
* `vggish_smoke_test.py`: Simple test of a VGGish installation
### Architecture
See `vggish_slim.py` and `vggish_params.py`.
VGGish is a variant of the [VGG](https://arxiv.org/abs/1409.1556) model, in
particular Configuration A with 11 weight layers. Specifically, here are the
changes we made:
* The input size was changed to 96x64 for log mel spectrogram audio inputs.
* We drop the last group of convolutional and maxpool layers, so we now have
only four groups of convolution/maxpool layers instead of five.
* Instead of a 1000-wide fully connected layer at the end, we use a 128-wide
fully connected layer. This acts as a compact embedding layer.
The model definition provided here defines layers up to and including the
128-wide embedding layer.
### Input: Audio Features
See `vggish_input.py` and `mel_features.py`.
VGGish was trained with audio features computed as follows:
* All audio is resampled to 16 kHz mono.
* A spectrogram is computed using magnitudes of the Short-Time Fourier Transform
with a window size of 25 ms, a window hop of 10 ms, and a periodic Hann
window.
* A mel spectrogram is computed by mapping the spectrogram to 64 mel bins
covering the range 125-7500 Hz.
* A stabilized log mel spectrogram is computed by applying
log(mel-spectrum + 0.01) where the offset is used to avoid taking a logarithm
of zero.
* These features are then framed into non-overlapping examples of 0.96 seconds,
where each example covers 64 mel bands and 96 frames of 10 ms each.
We provide our own NumPy implementation that produces features that are very
similar to those produced by our internal production code. This results in
embedding outputs that are closely match the embeddings that we have already
released. Note that these embeddings will *not* be bit-for-bit identical to the
released embeddings due to small differences between the feature computation
code paths, and even between two different installations of VGGish with
different underlying libraries and hardware. However, we expect that the
embeddings will be equivalent in the context of a downstream classification
task.
### Output: Embeddings
See `vggish_postprocess.py`.
The released AudioSet embeddings were postprocessed before release by applying a
PCA transformation (which performs both PCA and whitening) as well as
quantization to 8 bits per embedding element. This was done to be compatible
with the [YouTube-8M](https://research.google.com/youtube8m) project which has
released visual and audio embeddings for millions of YouTube videos in the same
PCA/whitened/quantized format.
We provide a Python implementation of the postprocessing which can be applied to
batches of embeddings produced by VGGish. `vggish_inference_demo.py` shows how
the postprocessor can be run after inference.
If you don't need to use the released embeddings or YouTube-8M, then you could
skip postprocessing and use raw embeddings.
A [Colab](https://colab.research.google.com/)
showing how to download the model and calculate the embeddings on your
own sound data is available here:
[AudioSet Embedding Colab](https://colab.research.google.com/drive/1TbX92UL9sYWbdwdGE0rJ9owmezB-Rl1C).
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines routines to compute mel spectrogram features from audio waveform."""
import numpy as np
def frame(data, window_length, hop_length):
"""Convert array into a sequence of successive possibly overlapping frames.
An n-dimensional array of shape (num_samples, ...) is converted into an
(n+1)-D array of shape (num_frames, window_length, ...), where each frame
starts hop_length points after the preceding one.
This is accomplished using stride_tricks, so the original data is not
copied. However, there is no zero-padding, so any incomplete frames at the
end are not included.
Args:
data: np.array of dimension N >= 1.
window_length: Number of samples in each frame.
hop_length: Advance (in samples) between each window.
Returns:
(N+1)-D np.array with as many rows as there are complete frames that can be
extracted.
"""
num_samples = data.shape[0]
num_frames = 1 + int(np.floor((num_samples - window_length) / hop_length))
shape = (num_frames, window_length) + data.shape[1:]
strides = (data.strides[0] * hop_length,) + data.strides
return np.lib.stride_tricks.as_strided(data, shape=shape, strides=strides)
def periodic_hann(window_length):
"""Calculate a "periodic" Hann window.
The classic Hann window is defined as a raised cosine that starts and
ends on zero, and where every value appears twice, except the middle
point for an odd-length window. Matlab calls this a "symmetric" window
and np.hanning() returns it. However, for Fourier analysis, this
actually represents just over one cycle of a period N-1 cosine, and
thus is not compactly expressed on a length-N Fourier basis. Instead,
it's better to use a raised cosine that ends just before the final
zero value - i.e. a complete cycle of a period-N cosine. Matlab
calls this a "periodic" window. This routine calculates it.
Args:
window_length: The number of points in the returned window.
Returns:
A 1D np.array containing the periodic hann window.
"""
return 0.5 - (0.5 * np.cos(2 * np.pi / window_length *
np.arange(window_length)))
def stft_magnitude(signal, fft_length,
hop_length=None,
window_length=None):
"""Calculate the short-time Fourier transform magnitude.
Args:
signal: 1D np.array of the input time-domain signal.
fft_length: Size of the FFT to apply.
hop_length: Advance (in samples) between each frame passed to FFT.
window_length: Length of each block of samples to pass to FFT.
Returns:
2D np.array where each row contains the magnitudes of the fft_length/2+1
unique values of the FFT for the corresponding frame of input samples.
"""
frames = frame(signal, window_length, hop_length)
# Apply frame window to each frame. We use a periodic Hann (cosine of period
# window_length) instead of the symmetric Hann of np.hanning (period
# window_length-1).
window = periodic_hann(window_length)
windowed_frames = frames * window
return np.abs(np.fft.rfft(windowed_frames, int(fft_length)))
# Mel spectrum constants and functions.
_MEL_BREAK_FREQUENCY_HERTZ = 700.0
_MEL_HIGH_FREQUENCY_Q = 1127.0
def hertz_to_mel(frequencies_hertz):
"""Convert frequencies to mel scale using HTK formula.
Args:
frequencies_hertz: Scalar or np.array of frequencies in hertz.
Returns:
Object of same size as frequencies_hertz containing corresponding values
on the mel scale.
"""
return _MEL_HIGH_FREQUENCY_Q * np.log(
1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
def spectrogram_to_mel_matrix(num_mel_bins=20,
num_spectrogram_bins=129,
audio_sample_rate=8000,
lower_edge_hertz=125.0,
upper_edge_hertz=3800.0):
"""Return a matrix that can post-multiply spectrogram rows to make mel.
Returns a np.array matrix A that can be used to post-multiply a matrix S of
spectrogram values (STFT magnitudes) arranged as frames x bins to generate a
"mel spectrogram" M of frames x num_mel_bins. M = S A.
The classic HTK algorithm exploits the complementarity of adjacent mel bands
to multiply each FFT bin by only one mel weight, then add it, with positive
and negative signs, to the two adjacent mel bands to which that bin
contributes. Here, by expressing this operation as a matrix multiply, we go
from num_fft multiplies per frame (plus around 2*num_fft adds) to around
num_fft^2 multiplies and adds. However, because these are all presumably
accomplished in a single call to np.dot(), it's not clear which approach is
faster in Python. The matrix multiplication has the attraction of being more
general and flexible, and much easier to read.
Args:
num_mel_bins: How many bands in the resulting mel spectrum. This is
the number of columns in the output matrix.
num_spectrogram_bins: How many bins there are in the source spectrogram
data, which is understood to be fft_size/2 + 1, i.e. the spectrogram
only contains the nonredundant FFT bins.
audio_sample_rate: Samples per second of the audio at the input to the
spectrogram. We need this to figure out the actual frequencies for
each spectrogram bin, which dictates how they are mapped into mel.
lower_edge_hertz: Lower bound on the frequencies to be included in the mel
spectrum. This corresponds to the lower edge of the lowest triangular
band.
upper_edge_hertz: The desired top edge of the highest frequency band.
Returns:
An np.array with shape (num_spectrogram_bins, num_mel_bins).
Raises:
ValueError: if frequency edges are incorrectly ordered or out of range.
"""
nyquist_hertz = audio_sample_rate / 2.
if lower_edge_hertz < 0.0:
raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz)
if lower_edge_hertz >= upper_edge_hertz:
raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" %
(lower_edge_hertz, upper_edge_hertz))
if upper_edge_hertz > nyquist_hertz:
raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" %
(upper_edge_hertz, nyquist_hertz))
spectrogram_bins_hertz = np.linspace(0.0, nyquist_hertz, num_spectrogram_bins)
spectrogram_bins_mel = hertz_to_mel(spectrogram_bins_hertz)
# The i'th mel band (starting from i=1) has center frequency
# band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge
# band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in
# the band_edges_mel arrays.
band_edges_mel = np.linspace(hertz_to_mel(lower_edge_hertz),
hertz_to_mel(upper_edge_hertz), num_mel_bins + 2)
# Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins
# of spectrogram values.
mel_weights_matrix = np.empty((num_spectrogram_bins, num_mel_bins))
for i in range(num_mel_bins):
lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i:i + 3]
# Calculate lower and upper slopes for every spectrogram bin.
# Line segments are linear in the *mel* domain, not hertz.
lower_slope = ((spectrogram_bins_mel - lower_edge_mel) /
(center_mel - lower_edge_mel))
upper_slope = ((upper_edge_mel - spectrogram_bins_mel) /
(upper_edge_mel - center_mel))
# .. then intersect them with each other and zero.
mel_weights_matrix[:, i] = np.maximum(0.0, np.minimum(lower_slope,
upper_slope))
# HTK excludes the spectrogram DC bin; make sure it always gets a zero
# coefficient.
mel_weights_matrix[0, :] = 0.0
return mel_weights_matrix
def log_mel_spectrogram(data,
audio_sample_rate=8000,
log_offset=0.0,
window_length_secs=0.025,
hop_length_secs=0.010,
**kwargs):
"""Convert waveform to a log magnitude mel-frequency spectrogram.
Args:
data: 1D np.array of waveform data.
audio_sample_rate: The sampling rate of data.
log_offset: Add this to values when taking log to avoid -Infs.
window_length_secs: Duration of each window to analyze.
hop_length_secs: Advance between successive analysis windows.
**kwargs: Additional arguments to pass to spectrogram_to_mel_matrix.
Returns:
2D np.array of (num_frames, num_mel_bins) consisting of log mel filterbank
magnitudes for successive frames.
"""
window_length_samples = int(round(audio_sample_rate * window_length_secs))
hop_length_samples = int(round(audio_sample_rate * hop_length_secs))
fft_length = 2 ** int(np.ceil(np.log(window_length_samples) / np.log(2.0)))
spectrogram = stft_magnitude(
data,
fft_length=fft_length,
hop_length=hop_length_samples,
window_length=window_length_samples)
mel_spectrogram = np.dot(spectrogram, spectrogram_to_mel_matrix(
num_spectrogram_bins=spectrogram.shape[1],
audio_sample_rate=audio_sample_rate, **kwargs))
return np.log(mel_spectrogram + log_offset)
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""A simple demonstration of running VGGish in inference mode.
This is intended as a toy example that demonstrates how the various building
blocks (feature extraction, model definition and loading, postprocessing) work
together in an inference context.
A WAV file (assumed to contain signed 16-bit PCM samples) is read in, converted
into log mel spectrogram examples, fed into VGGish, the raw embedding output is
whitened and quantized, and the postprocessed embeddings are optionally written
in a SequenceExample to a TFRecord file (using the same format as the embedding
features released in AudioSet).
Usage:
# Run a WAV file through the model and print the embeddings. The model
# checkpoint is loaded from vggish_model.ckpt and the PCA parameters are
# loaded from vggish_pca_params.npz in the current directory.
$ python vggish_inference_demo.py --wav_file /path/to/a/wav/file
# Run a WAV file through the model and also write the embeddings to
# a TFRecord file. The model checkpoint and PCA parameters are explicitly
# passed in as well.
$ python vggish_inference_demo.py --wav_file /path/to/a/wav/file \
--tfrecord_file /path/to/tfrecord/file \
--checkpoint /path/to/model/checkpoint \
--pca_params /path/to/pca/params
# Run a built-in input (a sine wav) through the model and print the
# embeddings. Associated model files are read from the current directory.
$ python vggish_inference_demo.py
"""
from __future__ import print_function
import numpy as np
import six
import soundfile
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import vggish_input
import vggish_params
import vggish_postprocess
import vggish_slim
flags = tf.app.flags
flags.DEFINE_string(
'wav_file', None,
'Path to a wav file. Should contain signed 16-bit PCM samples. '
'If none is provided, a synthetic sound is used.')
flags.DEFINE_string(
'checkpoint', 'vggish_model.ckpt',
'Path to the VGGish checkpoint file.')
flags.DEFINE_string(
'pca_params', 'vggish_pca_params.npz',
'Path to the VGGish PCA parameters file.')
flags.DEFINE_string(
'tfrecord_file', None,
'Path to a TFRecord file where embeddings will be written.')
FLAGS = flags.FLAGS
def main(_):
# In this simple example, we run the examples from a single audio file through
# the model. If none is provided, we generate a synthetic input.
if FLAGS.wav_file:
wav_file = FLAGS.wav_file
else:
# Write a WAV of a sine wav into an in-memory file object.
num_secs = 5
freq = 1000
sr = 44100
t = np.linspace(0, num_secs, int(num_secs * sr))
x = np.sin(2 * np.pi * freq * t)
# Convert to signed 16-bit samples.
samples = np.clip(x * 32768, -32768, 32767).astype(np.int16)
wav_file = six.BytesIO()
soundfile.write(wav_file, samples, sr, format='WAV', subtype='PCM_16')
wav_file.seek(0)
examples_batch = vggish_input.wavfile_to_examples(wav_file)
print(examples_batch)
# Prepare a postprocessor to munge the model embeddings.
pproc = vggish_postprocess.Postprocessor(FLAGS.pca_params)
# If needed, prepare a record writer to store the postprocessed embeddings.
writer = tf.python_io.TFRecordWriter(
FLAGS.tfrecord_file) if FLAGS.tfrecord_file else None
with tf.Graph().as_default(), tf.Session() as sess:
# Define the model in inference mode, load the checkpoint, and
# locate input and output tensors.
vggish_slim.define_vggish_slim(training=False)
vggish_slim.load_vggish_slim_checkpoint(sess, FLAGS.checkpoint)
features_tensor = sess.graph.get_tensor_by_name(
vggish_params.INPUT_TENSOR_NAME)
embedding_tensor = sess.graph.get_tensor_by_name(
vggish_params.OUTPUT_TENSOR_NAME)
# Run inference and postprocessing.
[embedding_batch] = sess.run([embedding_tensor],
feed_dict={features_tensor: examples_batch})
print(embedding_batch)
postprocessed_batch = pproc.postprocess(embedding_batch)
print(postprocessed_batch)
# Write the postprocessed embeddings as a SequenceExample, in a similar
# format as the features released in AudioSet. Each row of the batch of
# embeddings corresponds to roughly a second of audio (96 10ms frames), and
# the rows are written as a sequence of bytes-valued features, where each
# feature value contains the 128 bytes of the whitened quantized embedding.
seq_example = tf.train.SequenceExample(
feature_lists=tf.train.FeatureLists(
feature_list={
vggish_params.AUDIO_EMBEDDING_FEATURE_NAME:
tf.train.FeatureList(
feature=[
tf.train.Feature(
bytes_list=tf.train.BytesList(
value=[embedding.tobytes()]))
for embedding in postprocessed_batch
]
)
}
)
)
print(seq_example)
if writer:
writer.write(seq_example.SerializeToString())
if writer:
writer.close()
if __name__ == '__main__':
tf.app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Compute input examples for VGGish from audio waveform."""
import numpy as np
import resampy
import mel_features
import vggish_params
try:
import soundfile as sf
def wav_read(wav_file):
wav_data, sr = sf.read(wav_file, dtype='int16')
return wav_data, sr
except ImportError:
def wav_read(wav_file):
raise NotImplementedError('WAV file reading requires soundfile package.')
def waveform_to_examples(data, sample_rate):
"""Converts audio waveform into an array of examples for VGGish.
Args:
data: np.array of either one dimension (mono) or two dimensions
(multi-channel, with the outer dimension representing channels).
Each sample is generally expected to lie in the range [-1.0, +1.0],
although this is not required.
sample_rate: Sample rate of data.
Returns:
3-D np.array of shape [num_examples, num_frames, num_bands] which represents
a sequence of examples, each of which contains a patch of log mel
spectrogram, covering num_frames frames of audio and num_bands mel frequency
bands, where the frame length is vggish_params.STFT_HOP_LENGTH_SECONDS.
"""
# Convert to mono.
if len(data.shape) > 1:
data = np.mean(data, axis=1)
# Resample to the rate assumed by VGGish.
if sample_rate != vggish_params.SAMPLE_RATE:
data = resampy.resample(data, sample_rate, vggish_params.SAMPLE_RATE)
# Compute log mel spectrogram features.
log_mel = mel_features.log_mel_spectrogram(
data,
audio_sample_rate=vggish_params.SAMPLE_RATE,
log_offset=vggish_params.LOG_OFFSET,
window_length_secs=vggish_params.STFT_WINDOW_LENGTH_SECONDS,
hop_length_secs=vggish_params.STFT_HOP_LENGTH_SECONDS,
num_mel_bins=vggish_params.NUM_MEL_BINS,
lower_edge_hertz=vggish_params.MEL_MIN_HZ,
upper_edge_hertz=vggish_params.MEL_MAX_HZ)
# Frame features into examples.
features_sample_rate = 1.0 / vggish_params.STFT_HOP_LENGTH_SECONDS
example_window_length = int(round(
vggish_params.EXAMPLE_WINDOW_SECONDS * features_sample_rate))
example_hop_length = int(round(
vggish_params.EXAMPLE_HOP_SECONDS * features_sample_rate))
log_mel_examples = mel_features.frame(
log_mel,
window_length=example_window_length,
hop_length=example_hop_length)
return log_mel_examples
def wavfile_to_examples(wav_file):
"""Convenience wrapper around waveform_to_examples() for a common WAV format.
Args:
wav_file: String path to a file, or a file-like object. The file
is assumed to contain WAV audio data with signed 16-bit PCM samples.
Returns:
See waveform_to_examples.
"""
wav_data, sr = wav_read(wav_file)
assert wav_data.dtype == np.int16, 'Bad sample type: %r' % wav_data.dtype
samples = wav_data / 32768.0 # Convert to [-1.0, +1.0]
return waveform_to_examples(samples, sr)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment