Commit dfd0c5fd authored by Jeff Hwang's avatar Jeff Hwang Committed by Facebook GitHub Bot
Browse files

Introduce chroma filter bank function (#3395)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3395

Adds chroma filter bank function `chroma_filterbank` to `torchaudio.prototype.functional`.

Reviewed By: mthrok

Differential Revision: D46307672

fbshipit-source-id: c5d8104a8bb03da70d0629b5cc224e0d897148d5
parent 25e96f42
......@@ -4,10 +4,15 @@ torchaudio.prototype.functional
.. py:module:: torchaudio.prototype.functional
.. currentmodule:: torchaudio.prototype.functional
barkscale_fbanks
~~~~~~~~~~~~~~~~
Utility
~~~~~~~
.. autofunction:: barkscale_fbanks
.. autosummary::
:toctree: generated
:nosignatures:
barkscale_fbanks
chroma_filterbank
DSP
~~~
......
from torchaudio_unittest.common_utils import PytorchTestCase
from .librosa_compatibility_test_impl import Functional
class TestFunctionalCPU(Functional, PytorchTestCase):
device = "cpu"
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .librosa_compatibility_test_impl import Functional
@skipIfNoCuda
class TestFunctionalCUDA(Functional, PytorchTestCase):
device = "cuda"
import unittest
import torch
import torchaudio.prototype.functional as F
from torchaudio._internal.module_utils import is_module_available
LIBROSA_AVAILABLE = is_module_available("librosa")
if LIBROSA_AVAILABLE:
import librosa
import numpy as np
from torchaudio_unittest.common_utils import TestBaseMixin
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class Functional(TestBaseMixin):
"""Test suite for functions in `functional` module."""
dtype = torch.float64
def test_chroma_filterbank(self):
sample_rate = 16_000
n_stft = 400
n_chroma = 12
tuning = 0.0
ctroct = 5.0
octwidth = 2.0
norm = 2
base_c = True
# NOTE: difference in convention with librosa.
# Whereas librosa expects users to supply the full count of FFT frequency bins,
# TorchAudio expects users to supply the count with redundant bins, i.e. those in the upper half of the
# frequency range, removed. This is consistent with other TorchAudio filter bank functions.
n_freqs = n_stft // 2 + 1
torchaudio_fbank = F.chroma_filterbank(
sample_rate=sample_rate,
n_freqs=n_freqs,
n_chroma=n_chroma,
tuning=tuning,
ctroct=ctroct,
octwidth=octwidth,
norm=norm,
base_c=base_c,
)
librosa_fbank = librosa.filters.chroma(
sr=sample_rate,
n_fft=n_stft,
n_chroma=n_chroma,
tuning=tuning,
ctroct=ctroct,
octwidth=octwidth,
norm=norm,
base_c=True,
dtype=np.float32,
)
self.assertEqual(torchaudio_fbank, librosa_fbank.T)
......@@ -8,13 +8,14 @@ from ._dsp import (
sinc_impulse_response,
)
from ._rir import simulate_rir_ism
from .functional import barkscale_fbanks
from .functional import barkscale_fbanks, chroma_filterbank
__all__ = [
"adsr_envelope",
"exp_sigmoid",
"barkscale_fbanks",
"chroma_filterbank",
"extend_pitch",
"filter_waveform",
"frequency_impulse_response",
......
import math
import warnings
from typing import Optional
import torch
from torchaudio.functional.functional import _create_triangular_filterbank
......@@ -66,6 +67,11 @@ def _bark_to_hz(barks: torch.Tensor, bark_scale: str = "traunmuller") -> torch.T
return freqs
def _hz_to_octs(freqs, tuning=0.0, bins_per_octave=12):
a440 = 440.0 * 2.0 ** (tuning / bins_per_octave)
return torch.log2(freqs / (a440 / 16))
def barkscale_fbanks(
n_freqs: int,
f_min: float,
......@@ -121,3 +127,64 @@ def barkscale_fbanks(
)
return fb
def chroma_filterbank(
sample_rate: int,
n_freqs: int,
n_chroma: int,
*,
tuning: float = 0.0,
ctroct: float = 5.0,
octwidth: Optional[float] = 2.0,
norm: int = 2,
base_c: bool = True,
):
"""Create a frequency-to-chroma conversion matrix. Implementation adapted from librosa.
Args:
sample_rate (int): Sample rate.
n_freqs (int): Number of input frequencies.
n_chroma (int): Number of output chroma.
tuning (float, optional): Tuning deviation from A440 in fractions of a chroma bin. (Default: 0.0)
ctroct (float, optional): Center of Gaussian dominance window to weight filters by, in octaves. (Default: 5.0)
octwidth (float or None, optional): Width of Gaussian dominance window to weight filters by, in octaves.
If ``None``, then disable weighting altogether. (Default: 2.0)
norm (int, optional): order of norm to normalize filter bank by. (Default: 2)
base_c (bool, optional): If True, then start filter bank at C. Otherwise, start at A. (Default: True)
Returns:
torch.Tensor: Chroma filter bank, with shape `(n_freqs, n_chroma)`.
"""
# Skip redundant upper half of frequency range.
freqs = torch.linspace(0, sample_rate // 2, n_freqs)[1:]
freq_bins = n_chroma * _hz_to_octs(freqs, bins_per_octave=n_chroma, tuning=tuning)
freq_bins = torch.cat((torch.tensor([freq_bins[0] - 1.5 * n_chroma]), freq_bins))
freq_bin_widths = torch.cat(
(
torch.maximum(freq_bins[1:] - freq_bins[:-1], torch.tensor(1.0)),
torch.tensor([1]),
)
)
# (n_freqs, n_chroma)
D = freq_bins.unsqueeze(1) - torch.arange(0, n_chroma)
n_chroma2 = round(n_chroma / 2)
# Project to range [-n_chroma/2, n_chroma/2 - 1]
D = torch.remainder(D + n_chroma2, n_chroma) - n_chroma2
fb = torch.exp(-0.5 * (2 * D / torch.tile(freq_bin_widths.unsqueeze(1), (1, n_chroma))) ** 2)
fb = torch.nn.functional.normalize(fb, p=norm, dim=1)
if octwidth is not None:
fb *= torch.tile(
torch.exp(-0.5 * (((freq_bins.unsqueeze(1) / n_chroma - ctroct) / octwidth) ** 2)),
(1, n_chroma),
)
if base_c:
fb = torch.roll(fb, -3 * (n_chroma // 12), dims=1)
return fb
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment