"src/libtorio/ffmpeg/pybind/pybind.cpp" did not exist on "4eac61a3a6c908bd85ab45c3cba26217afaab55e"
Commit 70968293 authored by Jeff Hwang's avatar Jeff Hwang Committed by Facebook GitHub Bot
Browse files

Introduce chroma spectrogram transform (#3427)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3427

Adds transform `ChromaSpectrogram` for generating chromagrams from waveforms as well as transform `ChromaScale` for generating chromagrams from linear-frequency spectrograms.

Reviewed By: mthrok

Differential Revision: D46547418

fbshipit-source-id: 250f298b8e11d8cf82f05536c29d51cf8d77a960
parent 627c37a9
......@@ -10,5 +10,7 @@ torchaudio.prototype.transforms
:nosignatures:
BarkScale
InverseBarkScale
BarkSpectrogram
ChromaScale
ChromaSpectrogram
InverseBarkScale
......@@ -44,3 +44,19 @@ class Autograd(TestBaseMixin):
get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2), n_fft=n_fft, power=1
)
self.assert_grad(transform, [spec])
def test_chroma_spectrogram(self):
sample_rate = 8000
transform = T.ChromaSpectrogram(sample_rate=sample_rate, n_fft=400)
waveform = get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform], nondet_tol=1e-10)
def test_chroma_scale(self):
sample_rate = 8000
n_fft = 400
n_chroma = 12
transform = T.ChromaScale(sample_rate=sample_rate, n_freqs=n_fft // 2 + 1, n_chroma=n_chroma)
waveform = get_spectrogram(
get_whitenoise(sample_rate=sample_rate, duration=0.05, n_channels=2), n_fft=n_fft, power=1
)
self.assert_grad(transform, [waveform], nondet_tol=1e-10)
......@@ -39,3 +39,20 @@ class BatchConsistencyTest(TorchaudioTestCase):
# Because InverseBarkScale runs SGD on randomly initialized values so they do not yield
# exactly same result. For this reason, tolerance is very relaxed here.
self.assert_batch_consistency(transform, bark_spec, atol=1.0, rtol=1e-5)
def test_batch_chroma_scale(self):
n_freqs = 201
specgram = torch.randn(3, 2, n_freqs, 256)
atol = 1e-6 if os.name == "nt" else 1e-8
transform = T.ChromaScale(16000, n_freqs, n_chroma=12)
self.assert_batch_consistency(transform, specgram, atol=atol)
def test_batch_chroma_spectrogram(self):
waveform = torch.randn(3, 2, 4000)
atol = 1e-6 if os.name == "nt" else 1e-8
transform = T.ChromaSpectrogram(16000, 512, n_chroma=12)
self.assert_batch_consistency(transform, waveform, atol=atol)
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .librosa_compatibility_test_impl import TransformsTestBase
class TestTransforms(TransformsTestBase, PytorchTestCase):
dtype = torch.float64
device = torch.device("cpu")
import torch
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .librosa_compatibility_test_impl import TransformsTestBase
@skipIfNoCuda
class TestTransforms(TransformsTestBase, PytorchTestCase):
dtype = torch.float64
device = torch.device("cuda")
import unittest
import torch
import torchaudio.prototype.transforms as T
from parameterized import param
from torchaudio._internal.module_utils import is_module_available
from torchaudio_unittest.common_utils import get_sinusoid, nested_params, TestBaseMixin
LIBROSA_AVAILABLE = is_module_available("librosa")
if LIBROSA_AVAILABLE:
import librosa
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TransformsTestBase(TestBaseMixin):
@nested_params(
[
param(n_fft=400, hop_length=200, n_chroma=13),
param(n_fft=600, hop_length=100, n_chroma=24),
param(n_fft=200, hop_length=50, n_chroma=12),
],
)
def test_chroma_spectrogram(self, n_fft, hop_length, n_chroma):
sample_rate = 16000
waveform = get_sinusoid(
sample_rate=sample_rate,
n_channels=1,
).to(self.device, self.dtype)
expected = librosa.feature.chroma_stft(
y=waveform[0].cpu().numpy(),
sr=sample_rate,
n_fft=n_fft,
hop_length=hop_length,
n_chroma=n_chroma,
norm=None,
pad_mode="reflect",
tuning=0.0,
)
result = T.ChromaSpectrogram(
sample_rate=sample_rate,
window_fn=torch.hann_window,
hop_length=hop_length,
n_chroma=n_chroma,
n_fft=n_fft,
tuning=0.0,
).to(self.device, self.dtype)(waveform)[0]
self.assertEqual(result, expected, atol=5e-4, rtol=1e-4)
from ._transforms import BarkScale, BarkSpectrogram, InverseBarkScale
from ._transforms import BarkScale, BarkSpectrogram, ChromaScale, ChromaSpectrogram, InverseBarkScale
__all__ = [
"BarkScale",
"BarkSpectrogram",
"ChromaScale",
"ChromaSpectrogram",
"InverseBarkScale",
]
from typing import Callable, Optional
import torch
from torchaudio.prototype.functional import barkscale_fbanks
from torchaudio.prototype.functional import barkscale_fbanks, chroma_filterbank
from torchaudio.transforms import Spectrogram
......@@ -295,3 +295,162 @@ class BarkSpectrogram(torch.nn.Module):
specgram = self.spectrogram(waveform)
bark_specgram = self.bark_scale(specgram)
return bark_specgram
class ChromaScale(torch.nn.Module):
r"""Converts spectrogram to chromagram.
.. devices:: CPU CUDA
.. properties:: Autograd
Args:
sample_rate (int): Sample rate of audio signal.
n_freqs (int): Number of frequency bins in STFT. See ``n_fft`` in :class:`Spectrogram`.
n_chroma (int, optional): Number of chroma. (Default: ``12``)
tuning (float, optional): Tuning deviation from A440 in fractions of a chroma bin. (Default: 0.0)
ctroct (float, optional): Center of Gaussian dominance window to weight filters by, in octaves. (Default: 5.0)
octwidth (float or None, optional): Width of Gaussian dominance window to weight filters by, in octaves.
If ``None``, then disable weighting altogether. (Default: 2.0)
norm (int, optional): order of norm to normalize filter bank by. (Default: 2)
base_c (bool, optional): If True, then start filter bank at C. Otherwise, start at A. (Default: True)
Example
>>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
>>> spectrogram_transform = transforms.Spectrogram(n_fft=1024)
>>> spectrogram = spectrogram_transform(waveform)
>>> chroma_transform = transforms.ChromaScale(sample_rate=sample_rate, n_freqs=1024 // 2 + 1)
>>> chroma_spectrogram = chroma_transform(spectrogram)
See also:
:py:func:`torchaudio.prototype.functional.chroma_filterbank` — function used to
generate the filter bank.
"""
def __init__(
self,
sample_rate: int,
n_freqs: int,
*,
n_chroma: int = 12,
tuning: float = 0.0,
ctroct: float = 5.0,
octwidth: Optional[float] = 2.0,
norm: int = 2,
base_c: bool = True,
):
super().__init__()
fb = chroma_filterbank(
sample_rate, n_freqs, n_chroma, tuning=tuning, ctroct=ctroct, octwidth=octwidth, norm=norm, base_c=base_c
)
self.register_buffer("fb", fb)
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""
Args:
specgram (torch.Tensor): Spectrogram of dimension (..., ``n_freqs``, time).
Returns:
torch.Tensor: Chroma spectrogram of size (..., ``n_chroma``, time).
"""
return torch.matmul(x.transpose(-1, -2), self.fb).transpose(-1, -2)
class ChromaSpectrogram(torch.nn.Module):
r"""Generates chromagram for audio signal.
.. devices:: CPU CUDA
.. properties:: Autograd
Composes :py:func:`torchaudio.transforms.Spectrogram` and
and :py:func:`torchaudio.prototype.transforms.ChromaScale`.
Args:
sample_rate (int): Sample rate of audio signal.
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins.
win_length (int or None, optional): Window size. (Default: ``n_fft``)
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
pad (int, optional): Two sided padding of signal. (Default: ``0``)
window_fn (Callable[..., torch.Tensor], optional): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: ``torch.hann_window``)
power (float, optional): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc. (Default: ``2``)
normalized (bool, optional): Whether to normalize by magnitude after stft. (Default: ``False``)
wkwargs (Dict[..., ...] or None, optional): Arguments for window function. (Default: ``None``)
center (bool, optional): whether to pad :attr:`waveform` on both sides so
that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
(Default: ``True``)
pad_mode (string, optional): controls the padding method used when
:attr:`center` is ``True``. (Default: ``"reflect"``)
n_chroma (int, optional): Number of chroma. (Default: ``12``)
tuning (float, optional): Tuning deviation from A440 in fractions of a chroma bin. (Default: 0.0)
ctroct (float, optional): Center of Gaussian dominance window to weight filters by, in octaves. (Default: 5.0)
octwidth (float or None, optional): Width of Gaussian dominance window to weight filters by, in octaves.
If ``None``, then disable weighting altogether. (Default: 2.0)
norm (int, optional): order of norm to normalize filter bank by. (Default: 2)
base_c (bool, optional): If True, then start filter bank at C. Otherwise, start at A. (Default: True)
Example
>>> waveform, sample_rate = torchaudio.load("test.wav", normalize=True)
>>> transform = transforms.ChromaSpectrogram(sample_rate=sample_rate, n_fft=400)
>>> chromagram = transform(waveform) # (channel, n_chroma, time)
"""
def __init__(
self,
sample_rate: int,
n_fft: int,
*,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
pad: int = 0,
window_fn: Callable[..., torch.Tensor] = torch.hann_window,
power: float = 2.0,
normalized: bool = False,
wkwargs: Optional[dict] = None,
center: bool = True,
pad_mode: str = "reflect",
n_chroma: int = 12,
tuning: float = 0.0,
ctroct: float = 5.0,
octwidth: Optional[float] = 2.0,
norm: int = 2,
base_c: bool = True,
):
super().__init__()
self.spectrogram = Spectrogram(
n_fft=n_fft,
win_length=win_length,
hop_length=hop_length,
pad=pad,
window_fn=window_fn,
power=power,
normalized=normalized,
wkwargs=wkwargs,
center=center,
pad_mode=pad_mode,
onesided=True,
)
self.chroma_scale = ChromaScale(
sample_rate,
n_fft // 2 + 1,
n_chroma=n_chroma,
tuning=tuning,
base_c=base_c,
ctroct=ctroct,
octwidth=octwidth,
norm=norm,
)
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
r"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
Returns:
Tensor: Chromagram of size (..., ``n_chroma``, time).
"""
spectrogram = self.spectrogram(waveform)
chroma_spectrogram = self.chroma_scale(spectrogram)
return chroma_spectrogram
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment