Unverified Commit 27031755 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add SpectralCentroid transform (#1167)

parent 5547f204
...@@ -283,3 +283,14 @@ class TestTransforms(common_utils.TorchaudioTestCase): ...@@ -283,3 +283,14 @@ class TestTransforms(common_utils.TorchaudioTestCase):
# Batch then transform # Batch then transform
computed = torchaudio.transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1)) computed = torchaudio.transforms.Vol(gain=1.1)(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected) self.assertEqual(computed, expected)
def test_batch_spectral_centroid(self):
sample_rate = 44100
waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
# Single then transform then batch
expected = torchaudio.transforms.SpectralCentroid(sample_rate)(waveform).repeat(3, 1, 1)
# Batch then transform
computed = torchaudio.transforms.SpectralCentroid(sample_rate)(waveform.repeat(3, 1, 1))
self.assertEqual(computed, expected)
...@@ -231,6 +231,19 @@ class TestTransforms(common_utils.TorchaudioTestCase): ...@@ -231,6 +231,19 @@ class TestTransforms(common_utils.TorchaudioTestCase):
self.assertEqual( self.assertEqual(
torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3, rtol=1e-5) torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3, rtol=1e-5)
self.assert_compatibilities_spectral_centroid(sample_rate, n_fft, hop_length, sound, sound_librosa)
def assert_compatibilities_spectral_centroid(self, sample_rate, n_fft, hop_length, sound, sound_librosa):
spect_centroid = torchaudio.transforms.SpectralCentroid(
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length)
out_torch = spect_centroid(sound).squeeze().cpu()
out_librosa = librosa.feature.spectral_centroid(
y=sound_librosa, sr=sample_rate, n_fft=n_fft, hop_length=hop_length)
out_librosa = torch.from_numpy(out_librosa)[0]
self.assertEqual(out_torch.type(out_librosa.dtype), out_librosa, atol=1e-5, rtol=1e-5)
def test_basics1(self): def test_basics1(self):
kwargs = { kwargs = {
'n_fft': 400, 'n_fft': 400,
......
...@@ -535,6 +535,20 @@ class Functional(common_utils.TestBaseMixin): ...@@ -535,6 +535,20 @@ class Functional(common_utils.TestBaseMixin):
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_spectral_centroid(self):
def func(tensor):
sample_rate = 44100
n_fft = 400
ws = 400
hop = 200
pad = 0
window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
return F.spectral_centroid(tensor, sample_rate, pad, window, n_fft, hop, ws)
tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor)
class Transforms(common_utils.TestBaseMixin): class Transforms(common_utils.TestBaseMixin):
"""Implements test for Transforms that are performed for different devices""" """Implements test for Transforms that are performed for different devices"""
...@@ -624,3 +638,8 @@ class Transforms(common_utils.TestBaseMixin): ...@@ -624,3 +638,8 @@ class Transforms(common_utils.TestBaseMixin):
filepath = common_utils.get_asset_path("vad-go-mono-32000.wav") filepath = common_utils.get_asset_path("vad-go-mono-32000.wav")
waveform, sample_rate = common_utils.load_wav(filepath) waveform, sample_rate = common_utils.load_wav(filepath)
self._assert_consistency(T.Vad(sample_rate=sample_rate), waveform) self._assert_consistency(T.Vad(sample_rate=sample_rate), waveform)
def test_SpectralCentroid(self):
sample_rate = 44100
waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
self._assert_consistency(T.SpectralCentroid(sample_rate=sample_rate), waveform)
...@@ -16,6 +16,7 @@ from .functional import ( ...@@ -16,6 +16,7 @@ from .functional import (
phase_vocoder, phase_vocoder,
sliding_window_cmn, sliding_window_cmn,
spectrogram, spectrogram,
spectral_centroid,
) )
from .filtering import ( from .filtering import (
allpass_biquad, allpass_biquad,
......
...@@ -27,6 +27,7 @@ __all__ = [ ...@@ -27,6 +27,7 @@ __all__ = [
'mask_along_axis', 'mask_along_axis',
'mask_along_axis_iid', 'mask_along_axis_iid',
'sliding_window_cmn', 'sliding_window_cmn',
"spectral_centroid",
] ]
...@@ -935,3 +936,38 @@ def sliding_window_cmn( ...@@ -935,3 +936,38 @@ def sliding_window_cmn(
if len(input_shape) == 2: if len(input_shape) == 2:
cmn_waveform = cmn_waveform.squeeze(0) cmn_waveform = cmn_waveform.squeeze(0)
return cmn_waveform return cmn_waveform
def spectral_centroid(
waveform: Tensor,
sample_rate: int,
pad: int,
window: Tensor,
n_fft: int,
hop_length: int,
win_length: int,
) -> Tensor:
r"""
Compute the spectral centroid for each channel along the time axis.
The spectral centroid is defined as the weighted average of the
frequency values, weighted by their magnitude.
Args:
waveform (Tensor): Tensor of audio of dimension (..., time)
sample_rate (int): Sample rate of the audio waveform
pad (int): Two sided padding of signal
window (Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of FFT
hop_length (int): Length of hop between STFT windows
win_length (int): Window size
Returns:
Tensor: Dimension (..., time)
"""
specgram = spectrogram(waveform, pad=pad, window=window, n_fft=n_fft, hop_length=hop_length,
win_length=win_length, power=1., normalized=False)
freqs = torch.linspace(0, sample_rate // 2, steps=1 + n_fft // 2,
device=specgram.device).reshape((-1, 1))
freq_dim = -2
return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim)
...@@ -28,6 +28,7 @@ __all__ = [ ...@@ -28,6 +28,7 @@ __all__ = [
'TimeMasking', 'TimeMasking',
'SlidingWindowCmn', 'SlidingWindowCmn',
'Vad', 'Vad',
'SpectralCentroid',
] ]
...@@ -1037,3 +1038,54 @@ class Vad(torch.nn.Module): ...@@ -1037,3 +1038,54 @@ class Vad(torch.nn.Module):
hp_lifter_freq=self.hp_lifter_freq, hp_lifter_freq=self.hp_lifter_freq,
lp_lifter_freq=self.lp_lifter_freq, lp_lifter_freq=self.lp_lifter_freq,
) )
class SpectralCentroid(torch.nn.Module):
r"""Compute the spectral centroid for each channel along the time axis.
The spectral centroid is defined as the weighted average of the
frequency values, weighted by their magnitude.
Args:
sample_rate (int): Sample rate of audio signal.
n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins. (Default: ``400``)
win_length (int or None, optional): Window size. (Default: ``n_fft``)
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
pad (int, optional): Two sided padding of signal. (Default: ``0``)
window(Tensor, optional): A window tensor that is applied/multiplied to each frame.
(Default: ``torch.hann_window(win_length)``)
Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
>>> spectral_centroid = transforms.SpectralCentroid(sample_rate)(waveform) # (channel, time)
"""
__constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad']
def __init__(self,
sample_rate: int,
n_fft: int = 400,
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
pad: int = 0,
window: Optional[Tensor] = None) -> None:
super(SpectralCentroid, self).__init__()
self.sample_rate = sample_rate
self.n_fft = n_fft
self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 2
if window is None:
window = torch.hann_window(self.win_length)
self.register_buffer('window', window)
self.pad = pad
def forward(self, waveform: Tensor) -> Tensor:
r"""
Args:
waveform (Tensor): Tensor of audio of dimension (..., time).
Returns:
Tensor: Spectral Centroid of size (..., time).
"""
return F.spectral_centroid(waveform, self.sample_rate, self.pad, self.window, self.n_fft, self.hop_length,
self.win_length)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment