Commit cbfde17b authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Add pretrained VGGish inference pipeline (#3491)

Summary:
Adds pre-trained VGGish inference pipeline ported from https://github.com/harritaylor/torchvggish and https://github.com/tensorflow/models/tree/master/research/audioset.

Pull Request resolved: https://github.com/pytorch/audio/pull/3491

Reviewed By: mthrok

Differential Revision: D47738130

Pulled By: hwangjeff

fbshipit-source-id: 859c1ff1ec1b09dae4e26586169544571657cc67
parent 8a2e12d3
......@@ -13,6 +13,12 @@
{%- elif name == "Tacotron2TTSBundle.Vocoder" %}
{%- set attributes=["sample_rate"] %}
{%- set methods = ["__call__"] %}
{%- elif name == "VGGishBundle.VGGish" %}
{%- set attributes = [] %}
{%- set methods = ["forward"] %}
{%- elif name == "VGGishBundle.VGGishInputProcessor" %}
{%- set attributes = [] %}
{%- set methods = ["__call__"] %}
{% endif %}
..
......
......@@ -45,3 +45,28 @@ Pretrained Models
:template: autosummary/bundle_data.rst
HIFIGAN_VOCODER_V3_LJSPEECH
VGGish
------
Interface
~~~~~~~~~
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_class.rst
VGGishBundle
VGGishBundle.VGGish
VGGishBundle.VGGishInputProcessor
Pretrained Models
~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated
:nosignatures:
:template: autosummary/bundle_data.rst
VGGISH
......@@ -562,3 +562,11 @@ abstract = {End-to-end spoken language translation (SLT) has recently gained pop
journal={arXiv preprint arXiv:2304.01448},
year={2023}
}
@incollection{45611,
title = {CNN Architectures for Large-Scale Audio Classification},
author = {Shawn Hershey and Sourish Chaudhuri and Daniel P. W. Ellis and Jort F. Gemmeke and Aren Jansen and Channing Moore and Manoj Plakal and Devin Platt and Rif A. Saurous and Bryan Seybold and Malcolm Slaney and Ron Weiss and Kevin Wilson},
year = {2017},
URL = {https://arxiv.org/abs/1609.09430},
booktitle = {International Conference on Acoustics, Speech and Signal Processing (ICASSP)}
}
import torchaudio
from torchaudio.prototype.pipelines import VGGISH
def test_vggish():
input_sr = VGGISH.sample_rate
input_proc = VGGISH.get_input_processor()
model = VGGISH.get_model()
path = torchaudio.utils.download_asset("test-assets/Chopin_Ballade_-1_In_G_Minor,_Op._23_excerpt.mp3")
waveform, sr = torchaudio.load(path, backend="ffmpeg")
waveform = waveform.mean(axis=0)
waveform = torchaudio.functional.resample(waveform, sr, input_sr)
batch = input_proc(waveform)
assert batch.shape == (62, 1, 96, 64)
output = model(batch)
assert output.shape == (62, 128)
from ._vggish import VGGISH, VGGishBundle
from .hifigan_pipeline import HIFIGAN_VOCODER_V3_LJSPEECH, HiFiGANVocoderBundle
from .rnnt_pipeline import EMFORMER_RNNT_BASE_MUSTC, EMFORMER_RNNT_BASE_TEDLIUM3
......@@ -6,4 +7,6 @@ __all__ = [
"EMFORMER_RNNT_BASE_TEDLIUM3",
"HIFIGAN_VOCODER_V3_LJSPEECH",
"HiFiGANVocoderBundle",
"VGGISH",
"VGGishBundle",
]
from ._vggish_pipeline import VGGISH, VGGishBundle
__all__ = ["VGGISH", "VGGishBundle"]
# Derived from torchvggish (https://github.com/harritaylor/torchvggish).
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import math
import torch
_MEL_BREAK_FREQUENCY_HERTZ = 700.0
_MEL_HIGH_FREQUENCY_Q = 1127.0
_SAMPLE_RATE = 16000
_STFT_WINDOW_LENGTH_SECONDS = 0.025
_STFT_HOP_LENGTH_SECONDS = 0.010
_MEL_MIN_HZ = 125
_MEL_MAX_HZ = 7500
_NUM_BANDS = 64
_LOG_OFFSET = 0.01
_EXAMPLE_WINDOW_SECONDS = 0.96 # Each example contains 96 10ms frames
_EXAMPLE_HOP_SECONDS = 0.96 # with zero overlap.
def _build_features_network():
layers = []
for input_dim, output_dim in [(1, 64), (64, 128)]:
layers += [
torch.nn.Conv2d(input_dim, output_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
torch.nn.ReLU(inplace=True),
torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
]
for input_dim, output_dim in [(128, 256), (256, 512)]:
layers += [
torch.nn.Conv2d(input_dim, output_dim, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(
output_dim,
output_dim,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
),
torch.nn.ReLU(inplace=True),
torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False),
]
return torch.nn.Sequential(*layers)
def _build_embedding_network():
return torch.nn.Sequential(
torch.nn.Linear(512 * 4 * 6, 4096),
torch.nn.ReLU(True),
torch.nn.Linear(4096, 4096),
torch.nn.ReLU(True),
torch.nn.Linear(4096, 128),
torch.nn.ReLU(True),
)
def _frame(data, window_length, hop_length):
num_samples = data.shape[0]
num_frames = 1 + int(math.floor((num_samples - window_length) / hop_length))
shape = (num_frames, window_length) + data.shape[1:]
strides = (data.stride()[0] * hop_length,) + data.stride()
return torch.as_strided(data, shape, strides)
def _stft_magnitude(signal, fft_length, hop_length=None, window_length=None):
frames = _frame(signal, window_length, hop_length)
window = torch.hann_window(window_length, periodic=True).to(signal.device)
windowed_frames = frames * window
return torch.abs(torch.fft.rfft(windowed_frames, int(fft_length)))
def _hertz_to_mel(frequencies_hertz):
return _MEL_HIGH_FREQUENCY_Q * torch.log(1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ))
def _spectrogram_to_mel_matrix(
num_mel_bins=20,
num_spectrogram_bins=129,
audio_sample_rate=8000,
lower_edge_hertz=125.0,
upper_edge_hertz=3800.0,
):
nyquist_hertz = audio_sample_rate / 2.0
if lower_edge_hertz < 0.0:
raise ValueError("lower_edge_hertz %.1f must be >= 0" % lower_edge_hertz)
if lower_edge_hertz >= upper_edge_hertz:
raise ValueError("lower_edge_hertz %.1f >= upper_edge_hertz %.1f" % (lower_edge_hertz, upper_edge_hertz))
if upper_edge_hertz > nyquist_hertz:
raise ValueError("upper_edge_hertz %.1f is greater than Nyquist %.1f" % (upper_edge_hertz, nyquist_hertz))
spectrogram_bins_hertz = torch.linspace(0.0, nyquist_hertz, num_spectrogram_bins)
spectrogram_bins_mel = _hertz_to_mel(spectrogram_bins_hertz)
# The i'th mel band (starting from i=1) has center frequency
# band_edges_mel[i], lower edge band_edges_mel[i-1], and higher edge
# band_edges_mel[i+1]. Thus, we need num_mel_bins + 2 values in
# the band_edges_mel arrays.
band_edges_mel = torch.linspace(
_hertz_to_mel(torch.tensor(lower_edge_hertz)),
_hertz_to_mel(torch.tensor(upper_edge_hertz)),
num_mel_bins + 2,
)
# Matrix to post-multiply feature arrays whose rows are num_spectrogram_bins
# of spectrogram values.
mel_weights_matrix = torch.empty((num_spectrogram_bins, num_mel_bins))
for i in range(num_mel_bins):
lower_edge_mel, center_mel, upper_edge_mel = band_edges_mel[i : i + 3]
# Calculate lower and upper slopes for every spectrogram bin.
# Line segments are linear in the *mel* domain, not hertz.
lower_slope = (spectrogram_bins_mel - lower_edge_mel) / (center_mel - lower_edge_mel)
upper_slope = (upper_edge_mel - spectrogram_bins_mel) / (upper_edge_mel - center_mel)
# .. then intersect them with each other and zero.
mel_weights_matrix[:, i] = torch.maximum(torch.tensor(0.0), torch.minimum(lower_slope, upper_slope))
# HTK excludes the spectrogram DC bin; make sure it always gets a zero
# coefficient.
mel_weights_matrix[0, :] = 0.0
return mel_weights_matrix
def _log_mel_spectrogram(
data,
audio_sample_rate=8000,
log_offset=0.0,
window_length_secs=0.025,
hop_length_secs=0.010,
**kwargs,
):
window_length_samples = int(round(audio_sample_rate * window_length_secs))
hop_length_samples = int(round(audio_sample_rate * hop_length_secs))
fft_length = 2 ** int(math.ceil(math.log(window_length_samples) / math.log(2.0)))
spectrogram = _stft_magnitude(
data,
fft_length=fft_length,
hop_length=hop_length_samples,
window_length=window_length_samples,
)
mel_spectrogram = torch.matmul(
spectrogram,
_spectrogram_to_mel_matrix(
num_spectrogram_bins=spectrogram.shape[1],
audio_sample_rate=audio_sample_rate,
**kwargs,
).to(spectrogram),
)
return torch.log(mel_spectrogram + log_offset)
def _waveform_to_examples(data):
# Compute log mel spectrogram features, with shape (n_frame, n_mel)
log_mel = _log_mel_spectrogram(
data,
audio_sample_rate=_SAMPLE_RATE,
log_offset=_LOG_OFFSET,
window_length_secs=_STFT_WINDOW_LENGTH_SECONDS,
hop_length_secs=_STFT_HOP_LENGTH_SECONDS,
num_mel_bins=_NUM_BANDS,
lower_edge_hertz=_MEL_MIN_HZ,
upper_edge_hertz=_MEL_MAX_HZ,
)
# Frame features into examples, with shape (n_example, n_frame, n_mel)
features_sample_rate = 1.0 / _STFT_HOP_LENGTH_SECONDS
example_window_length = int(round(_EXAMPLE_WINDOW_SECONDS * features_sample_rate))
example_hop_length = int(round(_EXAMPLE_HOP_SECONDS * features_sample_rate))
log_mel_examples = _frame(log_mel, window_length=example_window_length, hop_length=example_hop_length)
# (n_example, 1, n_frame, n_mel)
return log_mel_examples.unsqueeze(1)
class VGGish(torch.nn.Module):
"""Implementation of VGGish model :cite:`45611`."""
def __init__(self):
super().__init__()
self.features_network = _build_features_network()
self.embedding_network = _build_embedding_network()
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Args:
input (torch.Tensor): batch of spectrograms, with shape `(n_example, 1, n_frame, 64)`.
Returns:
torch.Tensor: model output, with shape `(n_example, 128)`.
"""
x = self.features_network(input)
x = x.permute(0, 2, 3, 1)
x = x.reshape(x.size(0), -1)
return self.embedding_network(x)
class VGGishInputProcessor:
"""Converts raw waveforms to batches of examples to use as inputs to VGGish."""
def __call__(self, input: torch.Tensor) -> torch.Tensor:
"""
Args:
input (torch.Tensor): waveform, with shape `(T,)`.
sample_rate (int): sample rate of waveform in hertz.
Returns:
torch.Tensor: batch of examples to pass to VGGish, with shape `(n_example, 1, n_frame, 64)`.
"""
if len(input.shape) != 1:
raise ValueError("input waveform must have dimension of 1.")
return _waveform_to_examples(input)
from dataclasses import dataclass
import torch
import torchaudio
from ._vggish_impl import _SAMPLE_RATE, VGGish as _VGGish, VGGishInputProcessor as _VGGishInputProcessor
@dataclass
class VGGishBundle:
"""VGGish :cite:`45611` inference pipeline ported from
`torchvggish <https://github.com/harritaylor/torchvggish>`__
and `tensorflow-models <https://github.com/tensorflow/models/tree/master/research/audioset>`__.
Example:
>>> import torchaudio
>>> from torchaudio.prototype.pipelines import VGGISH
>>>
>>> input_sr = VGGISH.sample_rate
>>> input_proc = VGGISH.get_input_processor()
>>> model = VGGISH.get_model()
>>>
>>> waveform, sr = torchaudio.load(
>>> "Chopin_Ballade_-1_In_G_Minor,_Op._23.mp3",
>>> )
>>> waveform = waveform.squeeze(0)
>>> waveform = torchaudio.functional.resample(waveform, sr, input_sr)
>>> mono_output = model(input_proc(waveform))
"""
class VGGish(_VGGish):
__doc__ = _VGGish.__doc__
class VGGishInputProcessor(_VGGishInputProcessor):
__doc__ = _VGGishInputProcessor.__doc__
_weights_path: str
@property
def sample_rate(self) -> int:
"""Sample rate of input waveform expected by input processor and model.
:type: int
"""
return _SAMPLE_RATE
def get_model(self) -> VGGish:
"""Constructs pre-trained VGGish model. Downloads and caches weights as necessary.
Returns:
VGGish: VGGish model with pre-trained weights loaded.
"""
model = self.VGGish()
path = torchaudio.utils.download_asset(self._weights_path)
state_dict = torch.load(path)
model.load_state_dict(state_dict)
model.eval()
return model
def get_input_processor(self) -> VGGishInputProcessor:
"""Constructs input processor for VGGish.
Returns:
VGGishInputProcessor: input processor for VGGish.
"""
return self.VGGishInputProcessor()
VGGISH = VGGishBundle("models/vggish.pt")
VGGISH.__doc__ = """Pre-trained VGGish :cite:`45611` inference pipeline ported from
`torchvggish <https://github.com/harritaylor/torchvggish>`__
and `tensorflow-models <https://github.com/tensorflow/models/tree/master/research/audioset>`__.
Per the `documentation <https://github.com/tensorflow/models/tree/master/research/audioset/vggish>`__
for the original model, the model is "trained on a large YouTube dataset (a preliminary version of
what later became YouTube-8M)".
"""
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment