Commit b29a4639 authored by jamarshon's avatar jamarshon Committed by cpuhrsch
Browse files

[BC] Standardization of Transforms/Functionals (#152)

parent 2271a7ae
......@@ -2,6 +2,8 @@ import math
import torch
import torchaudio
import torchaudio.functional as F
import pytest
import unittest
import test.common_utils
......@@ -11,10 +13,6 @@ if IMPORT_LIBROSA:
import numpy as np
import librosa
import pytest
import torchaudio.functional as F
xfail = pytest.mark.xfail
class TestFunctional(unittest.TestCase):
data_sizes = [(2, 20), (3, 15), (4, 10)]
......@@ -197,54 +195,6 @@ def _num_stft_bins(signal_len, fft_len, hop_length, pad):
return (signal_len + 2 * pad - fft_len + hop_length) // hop_length
@pytest.mark.parametrize('fft_length', [512])
@pytest.mark.parametrize('hop_length', [256])
@pytest.mark.parametrize('waveform', [
(torch.randn(1, 100000)),
(torch.randn(1, 2, 100000)),
pytest.param(torch.randn(1, 100), marks=xfail(raises=RuntimeError)),
])
@pytest.mark.parametrize('pad_mode', [
# 'constant',
'reflect',
])
@unittest.skipIf(not IMPORT_LIBROSA, 'Librosa is not available')
def test_stft(waveform, fft_length, hop_length, pad_mode):
"""
Test STFT for multi-channel signals.
Padding: Value in having padding outside of torch.stft?
"""
pad = fft_length // 2
window = torch.hann_window(fft_length)
complex_spec = F.stft(waveform,
fft_length=fft_length,
hop_length=hop_length,
window=window,
pad_mode=pad_mode)
mag_spec, phase_spec = F.magphase(complex_spec)
# == Test shape
expected_size = list(waveform.size()[:-1])
expected_size += [fft_length // 2 + 1, _num_stft_bins(
waveform.size(-1), fft_length, hop_length, pad), 2]
assert complex_spec.dim() == waveform.dim() + 2
assert complex_spec.size() == torch.Size(expected_size)
# == Test values
fft_config = dict(n_fft=fft_length, hop_length=hop_length, pad_mode=pad_mode)
# note that librosa *automatically* pad with fft_length // 2.
expected_complex_spec = np.apply_along_axis(librosa.stft, -1,
waveform.numpy(), **fft_config)
expected_mag_spec, _ = librosa.magphase(expected_complex_spec)
# Convert torch to np.complex
complex_spec = complex_spec.numpy()
complex_spec = complex_spec[..., 0] + 1j * complex_spec[..., 1]
assert np.allclose(complex_spec, expected_complex_spec, atol=1e-5)
assert np.allclose(mag_spec.numpy(), expected_mag_spec, atol=1e-5)
@pytest.mark.parametrize('rate', [0.5, 1.01, 1.3])
@pytest.mark.parametrize('complex_specgrams', [
torch.randn(1, 2, 1025, 400, 2),
......
......@@ -30,40 +30,18 @@ class Test_JIT(unittest.TestCase):
self.assertTrue(torch.allclose(jit_out, py_out))
def test_torchscript_scale(self):
@torch.jit.script
def jit_method(tensor, factor):
# type: (Tensor, int) -> Tensor
return F.scale(tensor, factor)
tensor = torch.rand((10, 1))
factor = 2
jit_out = jit_method(tensor, factor)
py_out = F.scale(tensor, factor)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_scale(self):
tensor = torch.rand((10, 1), device="cuda")
self._test_script_module(tensor, transforms.Scale)
def test_torchscript_pad_trim(self):
@torch.jit.script
def jit_method(tensor, ch_dim, max_len, len_dim, fill_value):
# type: (Tensor, int, int, int, float) -> Tensor
return F.pad_trim(tensor, ch_dim, max_len, len_dim, fill_value)
def jit_method(tensor, max_len, fill_value):
# type: (Tensor, int, float) -> Tensor
return F.pad_trim(tensor, max_len, fill_value)
tensor = torch.rand((10, 1))
ch_dim = 1
tensor = torch.rand((1, 10))
max_len = 5
len_dim = 0
fill_value = 3.
jit_out = jit_method(tensor, ch_dim, max_len, len_dim, fill_value)
py_out = F.pad_trim(tensor, ch_dim, max_len, len_dim, fill_value)
jit_out = jit_method(tensor, max_len, fill_value)
py_out = F.pad_trim(tensor, max_len, fill_value)
self.assertTrue(torch.allclose(jit_out, py_out))
......@@ -74,45 +52,6 @@ class Test_JIT(unittest.TestCase):
self._test_script_module(tensor, transforms.PadTrim, max_len)
def test_torchscript_downmix_mono(self):
@torch.jit.script
def jit_method(tensor, ch_dim):
# type: (Tensor, int) -> Tensor
return F.downmix_mono(tensor, ch_dim)
tensor = torch.rand((10, 1))
ch_dim = 1
jit_out = jit_method(tensor, ch_dim)
py_out = F.downmix_mono(tensor, ch_dim)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_downmix_mono(self):
tensor = torch.rand((1, 10), device="cuda")
self._test_script_module(tensor, transforms.DownmixMono)
def test_torchscript_LC2CL(self):
@torch.jit.script
def jit_method(tensor):
# type: (Tensor) -> Tensor
return F.LC2CL(tensor)
tensor = torch.rand((10, 1))
jit_out = jit_method(tensor)
py_out = F.LC2CL(tensor)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_LC2CL(self):
tensor = torch.rand((10, 1), device="cuda")
self._test_script_module(tensor, transforms.LC2CL)
def test_torchscript_spectrogram(self):
@torch.jit.script
def jit_method(sig, pad, window, n_fft, hop, ws, power, normalize):
......@@ -167,7 +106,7 @@ class Test_JIT(unittest.TestCase):
# type: (Tensor, float, float, float, Optional[float]) -> Tensor
return F.spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db)
spec = torch.rand((10, 1))
spec = torch.rand((6, 201))
multiplier = 10.
amin = 1e-10
db_multiplier = 0.
......@@ -180,7 +119,7 @@ class Test_JIT(unittest.TestCase):
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_SpectrogramToDB(self):
spec = torch.rand((10, 1), device="cuda")
spec = torch.rand((6, 201), device="cuda")
self._test_script_module(spec, transforms.SpectrogramToDB)
......@@ -211,32 +150,13 @@ class Test_JIT(unittest.TestCase):
self._test_script_module(tensor, transforms.MelSpectrogram)
def test_torchscript_BLC2CBL(self):
@torch.jit.script
def jit_method(tensor):
# type: (Tensor) -> Tensor
return F.BLC2CBL(tensor)
tensor = torch.rand((10, 1000, 1))
jit_out = jit_method(tensor)
py_out = F.BLC2CBL(tensor)
self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_BLC2CBL(self):
tensor = torch.rand((10, 1000, 1), device="cuda")
self._test_script_module(tensor, transforms.BLC2CBL)
def test_torchscript_mu_law_encoding(self):
@torch.jit.script
def jit_method(tensor, qc):
# type: (Tensor, int) -> Tensor
return F.mu_law_encoding(tensor, qc)
tensor = torch.rand((10, 1))
tensor = torch.rand((1, 10))
qc = 256
jit_out = jit_method(tensor, qc)
......@@ -246,7 +166,7 @@ class Test_JIT(unittest.TestCase):
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_MuLawEncoding(self):
tensor = torch.rand((10, 1), device="cuda")
tensor = torch.rand((1, 10), device="cuda")
self._test_script_module(tensor, transforms.MuLawEncoding)
......@@ -256,7 +176,7 @@ class Test_JIT(unittest.TestCase):
# type: (Tensor, int) -> Tensor
return F.mu_law_expanding(tensor, qc)
tensor = torch.rand((10, 1))
tensor = torch.rand((1, 10))
qc = 256
jit_out = jit_method(tensor, qc)
......@@ -266,7 +186,7 @@ class Test_JIT(unittest.TestCase):
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_MuLawExpanding(self):
tensor = torch.rand((10, 1), device="cuda")
tensor = torch.rand((1, 10), device="cuda")
self._test_script_module(tensor, transforms.MuLawExpanding)
......
......@@ -19,191 +19,123 @@ if IMPORT_SCIPY:
class Tester(unittest.TestCase):
# create a sinewave signal for testing
sr = 16000
sample_rate = 16000
freq = 440
volume = .3
sig = (torch.cos(2 * math.pi * torch.arange(0, 4 * sr).float() * freq / sr))
sig.unsqueeze_(1) # (64000, 1)
sig = (sig * volume * 2**31).long()
waveform = (torch.cos(2 * math.pi * torch.arange(0, 4 * sample_rate).float() * freq / sample_rate))
waveform.unsqueeze_(0) # (1, 64000)
waveform = (waveform * volume * 2**31).long()
# file for stereo stft test
test_dirpath, test_dir = test.common_utils.create_temp_assets_dir()
test_filepath = os.path.join(test_dirpath, "assets",
"steam-train-whistle-daniel_simon.mp3")
test_filepath = os.path.join(test_dirpath, 'assets',
'steam-train-whistle-daniel_simon.mp3')
def test_scale(self):
audio_orig = self.sig.clone()
result = transforms.Scale()(audio_orig)
self.assertTrue(result.min() >= -1. and result.max() <= 1.)
maxminmax = max(abs(audio_orig.min()), abs(audio_orig.max())).item()
result = transforms.Scale(factor=maxminmax)(audio_orig)
self.assertTrue((result.min() == -1. or result.max() == 1.) and
result.min() >= -1. and result.max() <= 1.)
repr_test = transforms.Scale()
self.assertTrue(repr_test.__repr__())
def scale(self, waveform, factor=float(2**31)):
# scales a waveform by a factor
if not waveform.is_floating_point():
waveform = waveform.to(torch.get_default_dtype())
return waveform / factor
def test_pad_trim(self):
audio_orig = self.sig.clone()
length_orig = audio_orig.size(0)
waveform = self.waveform.clone()
length_orig = waveform.size(1)
length_new = int(length_orig * 1.2)
result = transforms.PadTrim(max_len=length_new, channels_first=False)(audio_orig)
self.assertEqual(result.size(0), length_new)
result = transforms.PadTrim(max_len=length_new, channels_first=True)(audio_orig.transpose(0, 1))
result = transforms.PadTrim(max_len=length_new)(waveform)
self.assertEqual(result.size(1), length_new)
audio_orig = self.sig.clone()
length_orig = audio_orig.size(0)
length_new = int(length_orig * 0.8)
result = transforms.PadTrim(max_len=length_new, channels_first=False)(audio_orig)
self.assertEqual(result.size(0), length_new)
repr_test = transforms.PadTrim(max_len=length_new, channels_first=False)
self.assertTrue(repr_test.__repr__())
def test_downmix_mono(self):
audio_L = self.sig.clone()
audio_R = self.sig.clone()
R_idx = int(audio_R.size(0) * 0.1)
audio_R = torch.cat((audio_R[R_idx:], audio_R[:R_idx]))
audio_Stereo = torch.cat((audio_L, audio_R), dim=1)
self.assertTrue(audio_Stereo.size(1) == 2)
result = transforms.DownmixMono(channels_first=False)(audio_Stereo)
self.assertTrue(result.size(1) == 1)
repr_test = transforms.DownmixMono(channels_first=False)
self.assertTrue(repr_test.__repr__())
def test_lc2cl(self):
audio = self.sig.clone()
result = transforms.LC2CL()(audio)
self.assertTrue(result.size()[::-1] == audio.size())
repr_test = transforms.LC2CL()
self.assertTrue(repr_test.__repr__())
def test_compose(self):
audio_orig = self.sig.clone()
length_orig = audio_orig.size(0)
length_new = int(length_orig * 1.2)
maxminmax = max(abs(audio_orig.min()), abs(audio_orig.max())).item()
tset = (transforms.Scale(factor=maxminmax),
transforms.PadTrim(max_len=length_new, channels_first=False))
result = transforms.Compose(tset)(audio_orig)
self.assertTrue(max(abs(result.min()), abs(result.max())) == 1.)
self.assertTrue(result.size(0) == length_new)
repr_test = transforms.Compose(tset)
self.assertTrue(repr_test.__repr__())
result = transforms.PadTrim(max_len=length_new)(waveform)
self.assertEqual(result.size(1), length_new)
def test_mu_law_companding(self):
quantization_channels = 256
sig = self.sig.clone()
sig = sig / torch.abs(sig).max()
self.assertTrue(sig.min() >= -1. and sig.max() <= 1.)
sig_mu = transforms.MuLawEncoding(quantization_channels)(sig)
self.assertTrue(sig_mu.min() >= 0. and sig.max() <= quantization_channels)
waveform = self.waveform.clone()
waveform /= torch.abs(waveform).max()
self.assertTrue(waveform.min() >= -1. and waveform.max() <= 1.)
sig_exp = transforms.MuLawExpanding(quantization_channels)(sig_mu)
self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.)
waveform_mu = transforms.MuLawEncoding(quantization_channels)(waveform)
self.assertTrue(waveform_mu.min() >= 0. and waveform_mu.max() <= quantization_channels)
repr_test = transforms.MuLawEncoding(quantization_channels)
self.assertTrue(repr_test.__repr__())
repr_test = transforms.MuLawExpanding(quantization_channels)
self.assertTrue(repr_test.__repr__())
waveform_exp = transforms.MuLawExpanding(quantization_channels)(waveform_mu)
self.assertTrue(waveform_exp.min() >= -1. and waveform_exp.max() <= 1.)
def test_mel2(self):
top_db = 80.
s2db = transforms.SpectrogramToDB("power", top_db)
s2db = transforms.SpectrogramToDB('power', top_db)
audio_orig = self.sig.clone() # (16000, 1)
audio_scaled = transforms.Scale()(audio_orig) # (16000, 1)
audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000)
waveform = self.waveform.clone() # (1, 16000)
waveform_scaled = self.scale(waveform) # (1, 16000)
mel_transform = transforms.MelSpectrogram()
# check defaults
spectrogram_torch = s2db(mel_transform(audio_scaled)) # (1, 319, 40)
spectrogram_torch = s2db(mel_transform(waveform_scaled)) # (1, 128, 321)
self.assertTrue(spectrogram_torch.dim() == 3)
self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
self.assertEqual(spectrogram_torch.size(-1), mel_transform.n_mels)
self.assertEqual(spectrogram_torch.size(1), mel_transform.n_mels)
# check correctness of filterbank conversion matrix
self.assertTrue(mel_transform.fm.fb.sum(1).le(1.).all())
self.assertTrue(mel_transform.fm.fb.sum(1).ge(0.).all())
self.assertTrue(mel_transform.mel_scale.fb.sum(1).le(1.).all())
self.assertTrue(mel_transform.mel_scale.fb.sum(1).ge(0.).all())
# check options
kwargs = {"window": torch.hamming_window, "pad": 10, "ws": 500, "hop": 125, "n_fft": 800, "n_mels": 50}
kwargs = {'window_fn': torch.hamming_window, 'pad': 10, 'win_length': 500,
'hop_length': 125, 'n_fft': 800, 'n_mels': 50}
mel_transform2 = transforms.MelSpectrogram(**kwargs)
spectrogram2_torch = s2db(mel_transform2(audio_scaled)) # (1, 506, 50)
spectrogram2_torch = s2db(mel_transform2(waveform_scaled)) # (1, 50, 513)
self.assertTrue(spectrogram2_torch.dim() == 3)
self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
self.assertEqual(spectrogram2_torch.size(-1), mel_transform2.n_mels)
self.assertTrue(mel_transform2.fm.fb.sum(1).le(1.).all())
self.assertTrue(mel_transform2.fm.fb.sum(1).ge(0.).all())
self.assertEqual(spectrogram2_torch.size(1), mel_transform2.n_mels)
self.assertTrue(mel_transform2.mel_scale.fb.sum(1).le(1.).all())
self.assertTrue(mel_transform2.mel_scale.fb.sum(1).ge(0.).all())
# check on multi-channel audio
x_stereo, sr_stereo = torchaudio.load(self.test_filepath)
spectrogram_stereo = s2db(mel_transform(x_stereo))
x_stereo, sr_stereo = torchaudio.load(self.test_filepath) # (2, 278756), 44100
spectrogram_stereo = s2db(mel_transform(x_stereo)) # (2, 128, 1394)
self.assertTrue(spectrogram_stereo.dim() == 3)
self.assertTrue(spectrogram_stereo.size(0) == 2)
self.assertTrue(spectrogram_torch.ge(spectrogram_torch.max() - top_db).all())
self.assertEqual(spectrogram_stereo.size(-1), mel_transform.n_mels)
self.assertEqual(spectrogram_stereo.size(1), mel_transform.n_mels)
# check filterbank matrix creation
fb_matrix_transform = transforms.MelScale(n_mels=100, sr=16000, f_max=None, f_min=0., n_stft=400)
fb_matrix_transform = transforms.MelScale(
n_mels=100, sample_rate=16000, f_min=0., f_max=None, n_stft=400)
self.assertTrue(fb_matrix_transform.fb.sum(1).le(1.).all())
self.assertTrue(fb_matrix_transform.fb.sum(1).ge(0.).all())
self.assertEqual(fb_matrix_transform.fb.size(), (400, 100))
def test_mfcc(self):
audio_orig = self.sig.clone()
audio_scaled = transforms.Scale()(audio_orig) # (16000, 1)
audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000)
audio_orig = self.waveform.clone()
audio_scaled = self.scale(audio_orig) # (1, 16000)
sample_rate = 16000
n_mfcc = 40
n_mels = 128
mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate,
mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc,
norm='ortho')
# check defaults
torch_mfcc = mfcc_transform(audio_scaled)
torch_mfcc = mfcc_transform(audio_scaled) # (1, 40, 321)
self.assertTrue(torch_mfcc.dim() == 3)
self.assertTrue(torch_mfcc.shape[2] == n_mfcc)
self.assertTrue(torch_mfcc.shape[1] == 321)
self.assertTrue(torch_mfcc.shape[1] == n_mfcc)
self.assertTrue(torch_mfcc.shape[2] == 321)
# check melkwargs are passed through
melkwargs = {'ws': 200}
mfcc_transform2 = torchaudio.transforms.MFCC(sr=sample_rate,
melkwargs = {'win_length': 200}
mfcc_transform2 = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc,
norm='ortho',
melkwargs=melkwargs)
torch_mfcc2 = mfcc_transform2(audio_scaled)
self.assertTrue(torch_mfcc2.shape[1] == 641)
torch_mfcc2 = mfcc_transform2(audio_scaled) # (1, 40, 641)
self.assertTrue(torch_mfcc2.shape[2] == 641)
# check norms work correctly
mfcc_transform_norm_none = torchaudio.transforms.MFCC(sr=sample_rate,
mfcc_transform_norm_none = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc,
norm=None)
torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled)
torch_mfcc_norm_none = mfcc_transform_norm_none(audio_scaled) # (1, 40, 321)
norm_check = torch_mfcc.clone()
norm_check[:, :, 0] *= math.sqrt(n_mels) * 2
norm_check[:, :, 1:] *= math.sqrt(n_mels / 2) * 2
norm_check[:, 0, :] *= math.sqrt(n_mels) * 2
norm_check[:, 1:, :] *= math.sqrt(n_mels / 2) * 2
self.assertTrue(torch_mfcc_norm_none.allclose(norm_check))
......@@ -212,45 +144,45 @@ class Tester(unittest.TestCase):
def _test_librosa_consistency_helper(n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
sound, sample_rate = torchaudio.load(input_path)
sound_librosa = sound.cpu().numpy().squeeze().T # squeeze batch and channel first
sound_librosa = sound.cpu().numpy().squeeze() # (64000)
# test core spectrogram
spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop=hop_length, power=2)
spect_transform = torchaudio.transforms.Spectrogram(n_fft=n_fft, hop_length=hop_length, power=2)
out_librosa, _ = librosa.core.spectrum._spectrogram(y=sound_librosa,
n_fft=n_fft,
hop_length=hop_length,
power=2)
out_torch = spect_transform(sound).squeeze().cpu().t()
out_torch = spect_transform(sound).squeeze().cpu()
self.assertTrue(torch.allclose(out_torch, torch.from_numpy(out_librosa), atol=1e-5))
# test mel spectrogram
melspect_transform = torchaudio.transforms.MelSpectrogram(sr=sample_rate, window=torch.hann_window,
hop=hop_length, n_mels=n_mels, n_fft=n_fft)
melspect_transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, window_fn=torch.hann_window,
hop_length=hop_length, n_mels=n_mels, n_fft=n_fft)
librosa_mel = librosa.feature.melspectrogram(y=sound_librosa, sr=sample_rate,
n_fft=n_fft, hop_length=hop_length, n_mels=n_mels,
htk=True, norm=None)
librosa_mel_tensor = torch.from_numpy(librosa_mel)
torch_mel = melspect_transform(sound).squeeze().cpu().t()
torch_mel = melspect_transform(sound).squeeze().cpu()
self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3))
# test s2db
db_transform = torchaudio.transforms.SpectrogramToDB("power", 80.)
db_torch = db_transform(spect_transform(sound)).squeeze().cpu().t()
db_transform = torchaudio.transforms.SpectrogramToDB('power', 80.)
db_torch = db_transform(spect_transform(sound)).squeeze().cpu()
db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3))
db_torch = db_transform(melspect_transform(sound)).squeeze().cpu().t()
db_torch = db_transform(melspect_transform(sound)).squeeze().cpu()
db_librosa = librosa.core.spectrum.power_to_db(librosa_mel)
db_librosa_tensor = torch.from_numpy(db_librosa)
self.assertTrue(torch.allclose(db_torch.type(db_librosa_tensor.dtype), db_librosa_tensor, atol=5e-3))
# test MFCC
melkwargs = {'hop': hop_length, 'n_fft': n_fft}
mfcc_transform = torchaudio.transforms.MFCC(sr=sample_rate,
melkwargs = {'hop_length': hop_length, 'n_fft': n_fft}
mfcc_transform = torchaudio.transforms.MFCC(sample_rate=sample_rate,
n_mfcc=n_mfcc,
norm='ortho',
melkwargs=melkwargs)
......@@ -271,7 +203,7 @@ class Tester(unittest.TestCase):
librosa_mfcc = scipy.fftpack.dct(db_librosa, axis=0, type=2, norm='ortho')[:n_mfcc]
librosa_mfcc_tensor = torch.from_numpy(librosa_mfcc)
torch_mfcc = mfcc_transform(sound).squeeze().cpu().t()
torch_mfcc = mfcc_transform(sound).squeeze().cpu()
self.assertTrue(torch.allclose(torch_mfcc.type(librosa_mfcc_tensor.dtype), librosa_mfcc_tensor, atol=5e-3))
......@@ -308,27 +240,27 @@ class Tester(unittest.TestCase):
def test_resample_size(self):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
sound, sample_rate = torchaudio.load(input_path)
waveform, sample_rate = torchaudio.load(input_path)
upsample_rate = sample_rate * 2
downsample_rate = sample_rate // 2
invalid_resample = torchaudio.transforms.Resample(sample_rate, upsample_rate, resampling_method='foo')
self.assertRaises(ValueError, invalid_resample, sound)
self.assertRaises(ValueError, invalid_resample, waveform)
upsample_resample = torchaudio.transforms.Resample(
sample_rate, upsample_rate, resampling_method='sinc_interpolation')
up_sampled = upsample_resample(sound)
up_sampled = upsample_resample(waveform)
# we expect the upsampled signal to have twice as many samples
self.assertTrue(up_sampled.size(-1) == sound.size(-1) * 2)
self.assertTrue(up_sampled.size(-1) == waveform.size(-1) * 2)
downsample_resample = torchaudio.transforms.Resample(
sample_rate, downsample_rate, resampling_method='sinc_interpolation')
down_sampled = downsample_resample(sound)
down_sampled = downsample_resample(waveform)
# we expect the downsampled signal to have half as many samples
self.assertTrue(down_sampled.size(-1) == sound.size(-1) // 2)
self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)
if __name__ == '__main__':
unittest.main()
......@@ -3,109 +3,48 @@ import torch
__all__ = [
'scale',
'pad_trim',
'downmix_mono',
'LC2CL',
'istft',
'spectrogram',
'create_fb_matrix',
'spectrogram_to_DB',
'create_dct',
'BLC2CBL',
'mu_law_encoding',
'mu_law_expanding'
'mu_law_expanding',
'complex_norm',
'angle',
'magphase',
'phase_vocoder',
]
@torch.jit.script
def scale(tensor, factor):
# type: (Tensor, int) -> Tensor
r"""Scale audio tensor from a 16-bit integer (represented as a
:class:`torch.FloatTensor`) to a floating point number between -1.0 and 1.0.
Note the 16-bit number is called the "bit depth" or "precision", not to be
confused with "bit rate".
Args:
tensor (torch.Tensor): Tensor of audio of size (n, c) or (c, n)
factor (int): Maximum value of input tensor
Returns:
torch.Tensor: Scaled by the scale factor
"""
if not tensor.is_floating_point():
tensor = tensor.to(torch.float32)
return tensor / factor
@torch.jit.script
def pad_trim(tensor, ch_dim, max_len, len_dim, fill_value):
# type: (Tensor, int, int, int, float) -> Tensor
r"""Pad/trim a 2D tensor (signal or labels).
def pad_trim(waveform, max_len, fill_value):
# type: (Tensor, int, float) -> Tensor
r"""Pad/trim a 2D tensor
Args:
tensor (torch.Tensor): Tensor of audio of size (n, c) or (c, n)
ch_dim (int): Dimension of channel (not size)
max_len (int): Length to which the tensor will be padded
len_dim (int): Dimension of length (not size)
waveform (torch.Tensor): Tensor of audio of size (c, n)
max_len (int): Length to which the waveform will be padded
fill_value (float): Value to fill in
Returns:
torch.Tensor: Padded/trimmed tensor
"""
if max_len > tensor.size(len_dim):
# array of [padding_left, padding_right, padding_top, padding_bottom]
# so pad similar to append (aka only right/bottom) and do not pad
# the length dimension. assumes equal sizes of padding.
padding = [max_len - tensor.size(len_dim)
if (i % 2 == 1) and (i // 2 != len_dim)
else 0
for i in [0, 1, 2, 3]]
n = waveform.size(1)
if max_len > n:
# TODO add "with torch.no_grad():" back when JIT supports it
tensor = torch.nn.functional.pad(tensor, padding, "constant", fill_value)
elif max_len < tensor.size(len_dim):
tensor = tensor.narrow(len_dim, 0, max_len)
return tensor
@torch.jit.script
def downmix_mono(tensor, ch_dim):
# type: (Tensor, int) -> Tensor
r"""Downmix any stereo signals to mono.
Args:
tensor (torch.Tensor): Tensor of audio of size (c, n) or (n, c)
ch_dim (int): Dimension of channel (not size)
Returns:
torch.Tensor: Mono signal
"""
if not tensor.is_floating_point():
tensor = tensor.to(torch.float32)
tensor = torch.mean(tensor, ch_dim, True)
return tensor
@torch.jit.script
def LC2CL(tensor):
# type: (Tensor) -> Tensor
r"""Permute a 2D tensor from samples (n, c) to (c, n).
Args:
tensor (torch.Tensor): Tensor of audio signal with shape (n, c)
waveform = torch.nn.functional.pad(waveform, (0, max_len - n), 'constant', fill_value)
else:
waveform = waveform[:, :max_len]
return waveform
Returns:
torch.Tensor: Tensor of audio signal with shape (c, n)
"""
return tensor.transpose(0, 1).contiguous()
# TODO: remove this once https://github.com/pytorch/pytorch/issues/21478 gets solved
@torch.jit.ignore
def _stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided):
def _stft(waveform, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided):
# type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensor
return torch.stft(input, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided)
return torch.stft(waveform, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided)
def istft(stft_matrix, # type: Tensor
......@@ -149,8 +88,8 @@ def istft(stft_matrix, # type: Tensor
IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984.
Args:
stft_matrix (torch.Tensor): Output of stft where each row of a batch is a frequency and each
column is a window. it has a shape of either (batch, fft_size, n_frames, 2) or (
stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each
column is a window. it has a shape of either (channel, fft_size, n_frames, 2) or (
fft_size, n_frames, 2)
n_fft (int): Size of Fourier transform
hop_length (Optional[int]): The distance between neighboring sliding window frames.
......@@ -168,20 +107,20 @@ def istft(stft_matrix, # type: Tensor
Returns:
torch.Tensor: Least squares estimation of the original signal of size
(batch, signal_length) or (signal_length)
(channel, signal_length) or (signal_length)
"""
stft_matrix_dim = stft_matrix.dim()
assert 3 <= stft_matrix_dim <= 4, ('Incorrect stft dimension: %d' % (stft_matrix_dim))
if stft_matrix_dim == 3:
# add a batch dimension
# add a channel dimension
stft_matrix = stft_matrix.unsqueeze(0)
device = stft_matrix.device
fft_size = stft_matrix.size(1)
assert (onesided and n_fft // 2 + 1 == fft_size) or (not onesided and n_fft == fft_size), (
'one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. '
+ 'Given values were onesided: %s, n_fft: %d, fft_size: %d' % ('True' if onesided else False, n_fft, fft_size))
'one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. ' +
'Given values were onesided: %s, n_fft: %d, fft_size: %d' % ('True' if onesided else False, n_fft, fft_size))
# use stft defaults for Optionals
if win_length is None:
......@@ -206,16 +145,16 @@ def istft(stft_matrix, # type: Tensor
assert window.size(0) == n_fft
# win_length and n_fft are synonymous from here on
stft_matrix = stft_matrix.transpose(1, 2) # size (batch, n_frames, fft_size, 2)
stft_matrix = stft_matrix.transpose(1, 2) # size (channel, n_frames, fft_size, 2)
stft_matrix = torch.irfft(stft_matrix, 1, normalized,
onesided, signal_sizes=(n_fft,)) # size (batch, n_frames, n_fft)
onesided, signal_sizes=(n_fft,)) # size (channel, n_frames, n_fft)
assert stft_matrix.size(2) == n_fft
n_frames = stft_matrix.size(1)
ytmp = stft_matrix * window.view(1, 1, n_fft) # size (batch, n_frames, n_fft)
# each column of a batch is a frame which needs to be overlap added at the right place
ytmp = ytmp.transpose(1, 2) # size (batch, n_fft, n_frames)
ytmp = stft_matrix * window.view(1, 1, n_fft) # size (channel, n_frames, n_fft)
# each column of a channel is a frame which needs to be overlap added at the right place
ytmp = ytmp.transpose(1, 2) # size (channel, n_fft, n_frames)
eye = torch.eye(n_fft, requires_grad=False,
device=device).unsqueeze(1) # size (n_fft, 1, n_fft)
......@@ -223,7 +162,7 @@ def istft(stft_matrix, # type: Tensor
# this does overlap add where the frames of ytmp are added such that the i'th frame of
# ytmp is added starting at i*hop_length in the output
y = torch.nn.functional.conv_transpose1d(
ytmp, eye, stride=hop_length, padding=0) # size (batch, 1, expected_signal_len)
ytmp, eye, stride=hop_length, padding=0) # size (channel, 1, expected_signal_len)
# do the same for the window function
window_sq = window.pow(2).view(n_fft, 1).repeat((1, n_frames)).unsqueeze(0) # size (1, n_fft, n_frames)
......@@ -246,67 +185,70 @@ def istft(stft_matrix, # type: Tensor
window_envelop_lowest = window_envelop.abs().min()
assert window_envelop_lowest > 1e-11, ('window overlap add min: %f' % (window_envelop_lowest))
y = (y / window_envelop).squeeze(1) # size (batch, expected_signal_len)
y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len)
if stft_matrix_dim == 3: # remove the batch dimension
if stft_matrix_dim == 3: # remove the channel dimension
y = y.squeeze(0)
return y
@torch.jit.script
def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize):
def spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, normalized):
# type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor
r"""Create a spectrogram from a raw audio signal.
Args:
sig (torch.Tensor): Tensor of audio of size (c, n)
waveform (torch.Tensor): Tensor of audio of size (c, n)
pad (int): Two sided padding of signal
window (torch.Tensor): Window_tensor
window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window
n_fft (int): Size of fft
hop (int): Length of hop between STFT windows
ws (int): Window size
power (int) : Exponent for the magnitude spectrogram,
hop_length (int): Length of hop between STFT windows
win_length (int): Window size
power (int): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc.
normalize (bool) : Whether to normalize by magnitude after stft
normalized (bool): Whether to normalize by magnitude after stft
Returns:
torch.Tensor: Channels x hops x n_fft (c, l, f), where channels
is unchanged, hops is the number of hops, and n_fft is the
number of fourier bins, which should be the window size divided
by 2 plus 1.
torch.Tensor: Channels x frequency x time (c, f, t), where channels
is unchanged, frequency is `n_fft // 2 + 1` where `n_fft` is the number of
fourier bins, and time is the number of window hops (n_frames).
"""
assert sig.dim() == 2
assert waveform.dim() == 2
if pad > 0:
# TODO add "with torch.no_grad():" back when JIT supports it
sig = torch.nn.functional.pad(sig, (pad, pad), "constant")
waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant")
# default values are consistent with librosa.core.spectrum._spectrogram
spec_f = _stft(sig, n_fft, hop, ws, window,
True, 'reflect', False, True).transpose(1, 2)
spec_f = _stft(waveform, n_fft, hop_length, win_length, window,
True, 'reflect', False, True)
if normalize:
if normalized:
spec_f /= window.pow(2).sum().sqrt()
spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor (c, l, n_fft)
spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor
return spec_f
@torch.jit.script
def create_fb_matrix(n_stft, f_min, f_max, n_mels):
def create_fb_matrix(n_freqs, f_min, f_max, n_mels):
# type: (int, float, float, int) -> Tensor
r""" Create a frequency bin conversion matrix.
Args:
n_stft (int): Number of filter banks from spectrogram
n_freqs (int): Number of frequencies to highlight/apply
f_min (float): Minimum frequency
f_max (float): Maximum frequency
n_mels (int): Number of mel bins
n_mels (int): Number of mel filterbanks
Returns:
torch.Tensor: Triangular filter banks (fb matrix)
torch.Tensor: Triangular filter banks (fb matrix) of size (`n_freqs`, `n_mels`)
meaning number of frequencies to highlight/apply to x the number of filterbanks.
Each column is a filterbank so that assuming there is a matrix A of
size (..., `n_freqs`), the applied result would be
`A * create_fb_matrix(A.size(-1), ...)`.
"""
# get stft freq bins
stft_freqs = torch.linspace(f_min, f_max, n_stft)
# freq bins
freqs = torch.linspace(f_min, f_max, n_freqs)
# calculate mel freq bins
# hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
m_min = 0. if f_min == 0 else 2595. * math.log10(1. + (f_min / 700.))
......@@ -316,17 +258,17 @@ def create_fb_matrix(n_stft, f_min, f_max, n_mels):
f_pts = 700. * (10**(m_pts / 2595.) - 1.)
# calculate the difference between each mel point and each stft freq point in hertz
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
slopes = f_pts.unsqueeze(0) - stft_freqs.unsqueeze(1) # (n_stft, n_mels + 2)
slopes = f_pts.unsqueeze(0) - freqs.unsqueeze(1) # (n_freqs, n_mels + 2)
# create overlapping triangles
z = torch.zeros(1)
down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1] # (n_stft, n_mels)
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_stft, n_mels)
fb = torch.max(z, torch.min(down_slopes, up_slopes))
zero = torch.zeros(1)
down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels)
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels)
fb = torch.max(zero, torch.min(down_slopes, up_slopes))
return fb
@torch.jit.script
def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None):
def spectrogram_to_DB(specgram, multiplier, amin, db_multiplier, top_db=None):
# type: (Tensor, float, float, float, Optional[float]) -> Tensor
r"""Turns a spectrogram from the power/amplitude scale to the decibel scale.
......@@ -335,72 +277,57 @@ def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db=None):
a full clip.
Args:
spec (torch.Tensor): Normal STFT
specgram (torch.Tensor): Normal STFT of size (c, f, t)
multiplier (float): Use 10. for power and 20. for amplitude
amin (float): Number to clamp spec
amin (float): Number to clamp specgram
db_multiplier (float): Log10(max(reference value and amin))
top_db (Optional[float]): Minimum negative cut-off in decibels. A reasonable number
top_db (Optional[float]): Minimum negative cut-off in decibels. A reasonable number
is 80.
Returns:
torch.Tensor: Spectrogram in DB
torch.Tensor: Spectrogram in DB of size (c, f, t)
"""
spec_db = multiplier * torch.log10(torch.clamp(spec, min=amin))
spec_db -= multiplier * db_multiplier
specgram_db = multiplier * torch.log10(torch.clamp(specgram, min=amin))
specgram_db -= multiplier * db_multiplier
if top_db is not None:
new_spec_db_max = torch.tensor(float(spec_db.max()) - top_db, dtype=spec_db.dtype, device=spec_db.device)
spec_db = torch.max(spec_db, new_spec_db_max)
new_spec_db_max = torch.tensor(float(specgram_db.max()) - top_db,
dtype=specgram_db.dtype, device=specgram_db.device)
specgram_db = torch.max(specgram_db, new_spec_db_max)
return spec_db
return specgram_db
@torch.jit.script
def create_dct(n_mfcc, n_mels, norm):
# type: (int, int, Optional[str]) -> Tensor
r"""Creates a DCT transformation matrix with shape (num_mels, num_mfcc),
r"""Creates a DCT transformation matrix with shape (`n_mels`, `n_mfcc`),
normalized depending on norm.
Args:
n_mfcc (int) : Number of mfc coefficients to retain
n_mels (int): Number of MEL bins
norm (Optional[str]) : Norm to use (either 'ortho' or None)
n_mfcc (int): Number of mfc coefficients to retain
n_mels (int): Number of mel filterbanks
norm (Optional[str]): Norm to use (either 'ortho' or None)
Returns:
torch.Tensor: The transformation matrix, to be right-multiplied to row-wise data.
torch.Tensor: The transformation matrix, to be right-multiplied to
row-wise data of size (`n_mels`, `n_mfcc`).
"""
outdim = n_mfcc
dim = n_mels
# http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
n = torch.arange(dim)
k = torch.arange(outdim)[:, None]
dct = torch.cos(math.pi / float(dim) * (n + 0.5) * k)
n = torch.arange(float(n_mels))
k = torch.arange(float(n_mfcc)).unsqueeze(1)
dct = torch.cos(math.pi / float(n_mels) * (n + 0.5) * k) # size (n_mfcc, n_mels)
if norm is None:
dct *= 2.0
else:
assert norm == 'ortho'
dct[0] *= 1.0 / math.sqrt(2.0)
dct *= math.sqrt(2.0 / float(dim))
dct *= math.sqrt(2.0 / float(n_mels))
return dct.t()
@torch.jit.script
def BLC2CBL(tensor):
# type: (Tensor) -> Tensor
r"""Permute a 3D tensor from Bands x Sample length x Channels to Channels x
Bands x Samples length.
Args:
tensor (torch.Tensor): Tensor of spectrogram with shape (b, l, c)
Returns:
torch.Tensor: Tensor of spectrogram with shape (c, b, l)
"""
return tensor.permute(2, 0, 1).contiguous()
@torch.jit.script
def mu_law_encoding(x, qc):
def mu_law_encoding(x, quantization_channels):
# type: (Tensor, int) -> Tensor
r"""Encode signal based on mu-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
......@@ -410,13 +337,12 @@ def mu_law_encoding(x, qc):
Args:
x (torch.Tensor): Input tensor
qc (int): Number of channels (i.e. quantization channels)
quantization_channels (int): Number of channels
Returns:
torch.Tensor: Input after mu-law companding
"""
assert isinstance(x, torch.Tensor), 'mu_law_encoding expects a Tensor'
mu = qc - 1.
mu = quantization_channels - 1.
if not x.is_floating_point():
x = x.to(torch.float)
mu = torch.tensor(mu, dtype=x.dtype)
......@@ -427,7 +353,7 @@ def mu_law_encoding(x, qc):
@torch.jit.script
def mu_law_expanding(x_mu, qc):
def mu_law_expanding(x_mu, quantization_channels):
# type: (Tensor, int) -> Tensor
r"""Decode mu-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
......@@ -437,13 +363,12 @@ def mu_law_expanding(x_mu, qc):
Args:
x_mu (torch.Tensor): Input tensor
qc (int): Number of channels (i.e. quantization channels)
quantization_channels (int): Number of channels
Returns:
torch.Tensor: Input after decoding
"""
assert isinstance(x_mu, torch.Tensor), 'mu_law_expanding expects a Tensor'
mu = qc - 1.
mu = quantization_channels - 1.
if not x_mu.is_floating_point():
x_mu = x_mu.to(torch.float)
mu = torch.tensor(mu, dtype=x_mu.dtype)
......@@ -452,71 +377,15 @@ def mu_law_expanding(x_mu, qc):
return x
def stft(waveforms, fft_length, hop_length=None, win_length=None, window=None,
center=True, pad_mode='reflect', normalized=False, onesided=True):
"""Compute a short time Fourier transform of the input waveform(s).
It wraps `torch.stft` after reshaping the input audio to allow for `waveforms` that `.dim()` >= 3.
It follows most of the `torch.stft` default values, but for `window`, which defaults to hann window.
Args:
waveforms (torch.Tensor): Audio signal of size `(*, channel, time)`
fft_length (int): FFT size [sample].
hop_length (int): Hop size [sample] between STFT frames.
(Defaults to `fft_length // 4`, 75%-overlapping windows by `torch.stft`).
win_length (int): Size of STFT window. (Defaults to `fft_length` by `torch.stft`).
window (torch.Tensor): window function. (Defaults to Hann Window of size `win_length` *unlike* `torch.stft`).
center (bool): Whether to pad `waveforms` on both sides so that the `t`-th frame is centered
at time `t * hop_length`. (Defaults to `True` by `torch.stft`)
pad_mode (str): padding method (see `torch.nn.functional.pad`). (Defaults to `'reflect'` by `torch.stft`).
normalized (bool): Whether the results are normalized. (Defaults to `False` by `torch.stft`).
onesided (bool): Whether the half + 1 frequency bins are returned to removethe symmetric part of STFT
of real-valued signal. (Defaults to `True` by `torch.stft`).
Returns:
torch.Tensor: `(*, channel, num_freqs, time, complex=2)`
Example:
>>> waveforms = torch.randn(16, 2, 10000) # (batch, channel, time)
>>> x = stft(waveforms, 2048, 512)
>>> x.shape
torch.Size([16, 2, 1025, 20])
"""
leading_dims = waveforms.shape[:-1]
waveforms = waveforms.reshape(-1, waveforms.size(-1))
if window is None:
if win_length is None:
window = torch.hann_window(fft_length)
else:
window = torch.hann_window(win_length)
complex_specgrams = torch.stft(waveforms,
n_fft=fft_length,
hop_length=hop_length,
win_length=win_length,
window=window,
center=center,
pad_mode=pad_mode,
normalized=normalized,
onesided=onesided)
complex_specgrams = complex_specgrams.reshape(
leading_dims +
complex_specgrams.shape[1:])
return complex_specgrams
def complex_norm(complex_tensor, power=1.0):
"""Compute the norm of complex tensor input
r"""Compute the norm of complex tensor input.
Args:
complex_tensor (Tensor): Tensor shape of `(*, complex=2)`
power (float): Power of the norm. Defaults to `1.0`.
complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)`
power (float): Power of the norm. (Default: `1.0`).
Returns:
Tensor: power of the normed input tensor, shape of `(*, )`
torch.Tensor: Power of the normed input tensor. Shape of `(*, )`
"""
if power == 1.0:
return torch.norm(complex_tensor, 2, -1)
......@@ -524,16 +393,26 @@ def complex_norm(complex_tensor, power=1.0):
def angle(complex_tensor):
"""
Return angle of a complex tensor with shape (*, 2).
r"""Compute the angle of complex tensor input.
Args:
complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)`
Return:
torch.Tensor: Angle of a complex tensor. Shape of `(*, )`
"""
return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0])
def magphase(complex_tensor, power=1.):
"""
Separate a complex-valued spectrogram with shape (*,2)
into its magnitude and phase.
r"""Separate a complex-valued spectrogram with shape (*,2) into its magnitude and phase.
Args:
complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)`
power (float): Power of the norm. (Default: `1.0`)
Returns:
Tuple[torch.Tensor, torch.Tensor]: The magnitude and phase of the complex_tensor
"""
mag = complex_norm(complex_tensor, power)
phase = angle(complex_tensor)
......@@ -541,20 +420,16 @@ def magphase(complex_tensor, power=1.):
def phase_vocoder(complex_specgrams, rate, phase_advance):
"""
Phase vocoder. Given a STFT tensor, speed up in time
without modifying pitch by a factor of `rate`.
r"""Given a STFT tensor, speed up in time without modifying pitch by a
factor of `rate`.
Args:
complex_specgrams (Tensor):
(*, channel, num_freqs, time, complex=2)
rate (float): Speed-up factor.
phase_advance (Tensor): Expected phase advance in
each bin. (num_freqs, 1).
complex_specgrams (torch.Tensor): Size of (*, c, f, t, complex=2)
rate (float): Speed-up factor
phase_advance (torch.Tensor): Expected phase advance in each bin. Size of (f, 1)
Returns:
complex_specgrams_stretch (Tensor):
(*, channel, num_freqs, ceil(time/rate), complex=2).
complex_specgrams_stretch (torch.Tensor): Size of (*, c, f, ceil(t/rate), complex=2)
Example:
>>> num_freqs, hop_length = 1025, 512
......
......@@ -7,314 +7,205 @@ from . import functional as F
from .compliance import kaldi
# TODO remove this class
class Compose(object):
"""Composes several transforms together.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.Scale(),
>>> transforms.PadTrim(max_len=16000),
>>> ])
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, audio):
for t in self.transforms:
audio = t(audio)
return audio
def __repr__(self):
format_string = self.__class__.__name__ + '('
for t in self.transforms:
format_string += '\n'
format_string += ' {0}'.format(t)
format_string += '\n)'
return format_string
class Scale(torch.jit.ScriptModule):
"""Scale audio tensor from a 16-bit integer (represented as a FloatTensor)
to a floating point number between -1.0 and 1.0. Note the 16-bit number is
called the "bit depth" or "precision", not to be confused with "bit rate".
Args:
factor (int): maximum value of input tensor. default: 16-bit depth
"""
__constants__ = ['factor']
def __init__(self, factor=2**31):
super(Scale, self).__init__()
self.factor = factor
@torch.jit.script_method
def forward(self, tensor):
"""
Args:
tensor (Tensor): Tensor of audio of size (Samples x Channels)
Returns:
Tensor: Scaled by the scale factor. (default between -1.0 and 1.0)
"""
return F.scale(tensor, self.factor)
def __repr__(self):
return self.__class__.__name__ + '()'
class PadTrim(torch.jit.ScriptModule):
"""Pad/Trim a 2d-Tensor (Signal or Labels)
r"""Pad/Trim a 2D tensor
Args:
tensor (Tensor): Tensor of audio of size (n x c) or (c x n)
max_len (int): Length to which the tensor will be padded
channels_first (bool): Pad for channels first tensors. Default: `True`
max_len (int): Length to which the waveform will be padded
fill_value (float): Value to fill in
"""
__constants__ = ['max_len', 'fill_value', 'len_dim', 'ch_dim']
__constants__ = ['max_len', 'fill_value']
def __init__(self, max_len, fill_value=0., channels_first=True):
def __init__(self, max_len, fill_value=0.):
super(PadTrim, self).__init__()
self.max_len = max_len
self.fill_value = fill_value
self.len_dim, self.ch_dim = int(channels_first), int(not channels_first)
@torch.jit.script_method
def forward(self, tensor):
"""
Returns:
Tensor: (c x n) or (n x c)
"""
return F.pad_trim(tensor, self.ch_dim, self.max_len, self.len_dim, self.fill_value)
def __repr__(self):
return self.__class__.__name__ + '(max_len={0})'.format(self.max_len)
class DownmixMono(torch.jit.ScriptModule):
"""Downmix any stereo signals to mono. Consider using a `SoxEffectsChain` with
the `channels` effect instead of this transformation.
Inputs:
tensor (Tensor): Tensor of audio of size (c x n) or (n x c)
channels_first (bool): Downmix across channels dimension. Default: `True`
Returns:
tensor (Tensor) (Samples x 1):
"""
__constants__ = ['ch_dim']
def __init__(self, channels_first=None):
super(DownmixMono, self).__init__()
self.ch_dim = int(not channels_first)
@torch.jit.script_method
def forward(self, tensor):
return F.downmix_mono(tensor, self.ch_dim)
def __repr__(self):
return self.__class__.__name__ + '()'
class LC2CL(torch.jit.ScriptModule):
"""Permute a 2d tensor from samples (n x c) to (c x n)
"""
def __init__(self):
super(LC2CL, self).__init__()
@torch.jit.script_method
def forward(self, tensor):
"""
def forward(self, waveform):
r"""
Args:
tensor (Tensor): Tensor of audio signal with shape (LxC)
waveform (torch.Tensor): Tensor of audio of size (c, n)
Returns:
tensor (Tensor): Tensor of audio signal with shape (CxL)
Tensor: Tensor of size (c, `max_len`)
"""
return F.LC2CL(tensor)
def __repr__(self):
return self.__class__.__name__ + '()'
def SPECTROGRAM(*args, **kwargs):
warn("SPECTROGRAM has been renamed to Spectrogram")
return Spectrogram(*args, **kwargs)
return F.pad_trim(waveform, self.max_len, self.fill_value)
class Spectrogram(torch.jit.ScriptModule):
"""Create a spectrogram from a raw audio signal
r"""Create a spectrogram from a audio signal
Args:
n_fft (int, optional): size of fft, creates n_fft // 2 + 1 bins
ws (int): window size. default: n_fft
hop (int, optional): length of hop between STFT windows. default: ws // 2
pad (int): two sided padding of signal
window (torch windowing function): default: torch.hann_window
power (int > 0 ) : Exponent for the magnitude spectrogram,
e.g., 1 for energy, 2 for power, etc.
normalize (bool) : whether to normalize by magnitude after stft
wkwargs (dict, optional): arguments for window function
n_fft (int, optional): Size of fft, creates `n_fft // 2 + 1` bins
win_length (int): Window size. (Default: `n_fft`)
hop_length (int, optional): Length of hop between STFT windows. (
Default: `win_length // 2`)
pad (int): Two sided padding of signal. (Default: 0)
window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: `torch.hann_window`)
power (int): Exponent for the magnitude spectrogram,
(must be > 0) e.g., 1 for energy, 2 for power, etc.
normalized (bool): Whether to normalize by magnitude after stft. (Default: `False`)
wkwargs (Dict[..., ...]): Arguments for window function. (Default: `None`)
"""
__constants__ = ['n_fft', 'ws', 'hop', 'pad', 'power', 'normalize']
__constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
def __init__(self, n_fft=400, ws=None, hop=None,
pad=0, window=torch.hann_window,
power=2, normalize=False, wkwargs=None):
def __init__(self, n_fft=400, win_length=None, hop_length=None,
pad=0, window_fn=torch.hann_window,
power=2, normalized=False, wkwargs=None):
super(Spectrogram, self).__init__()
self.n_fft = n_fft
# number of fft bins. the returned STFT result will have n_fft // 2 + 1
# number of frequecies due to onesided=True in torch.stft
self.ws = ws if ws is not None else n_fft
self.hop = hop if hop is not None else self.ws // 2
window = window(self.ws) if wkwargs is None else window(self.ws, **wkwargs)
self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 2
window = window_fn(self.win_length) if wkwargs is None else window_fn(self.win_length, **wkwargs)
self.window = torch.jit.Attribute(window, torch.Tensor)
self.pad = pad
self.power = power
self.normalize = normalize
self.normalized = normalized
@torch.jit.script_method
def forward(self, sig):
"""
def forward(self, waveform):
r"""
Args:
sig (Tensor): Tensor of audio of size (c, n)
waveform (torch.Tensor): Tensor of audio of size (c, n)
Returns:
spec_f (Tensor): channels x hops x n_fft (c, l, f), where channels
is unchanged, hops is the number of hops, and n_fft is the
number of fourier bins, which should be the window size divided
by 2 plus 1.
torch.Tensor: Channels x frequency x time (c, f, t), where channels
is unchanged, frequency is `n_fft // 2 + 1` where `n_fft` is the number of
fourier bins, and time is the number of window hops (n_frames).
"""
return F.spectrogram(sig, self.pad, self.window, self.n_fft, self.hop,
self.ws, self.power, self.normalize)
def F2M(*args, **kwargs):
warn("F2M has been renamed to MelScale")
return MelScale(*args, **kwargs)
return F.spectrogram(waveform, self.pad, self.window, self.n_fft, self.hop_length,
self.win_length, self.power, self.normalized)
class MelScale(torch.jit.ScriptModule):
"""This turns a normal STFT into a mel frequency STFT, using a conversion
r"""This turns a normal STFT into a mel frequency STFT, using a conversion
matrix. This uses triangular filter banks.
User can control which device the filter bank (`fb`) is (e.g. fb.to(spec_f.device)).
Args:
n_mels (int): number of mel bins
sr (int): sample rate of audio signal
f_max (float, optional): maximum frequency. default: `sr` // 2
f_min (float): minimum frequency. default: 0
n_stft (int, optional): number of filter banks from stft. Calculated from first input
n_mels (int): Number of mel filterbanks. (Default: 128)
sample_rate (int): Sample rate of audio signal. (Default: 16000)
f_min (float): Minimum frequency. (Default: 0.)
f_max (float, optional): Maximum frequency. (Default: `sample_rate // 2`)
n_stft (int, optional): Number of bins in STFT. Calculated from first input
if `None` is given. See `n_fft` in `Spectrogram`.
"""
__constants__ = ['n_mels', 'sr', 'f_min', 'f_max']
__constants__ = ['n_mels', 'sample_rate', 'f_min', 'f_max']
def __init__(self, n_mels=128, sr=16000, f_max=None, f_min=0., n_stft=None):
def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=None):
super(MelScale, self).__init__()
self.n_mels = n_mels
self.sr = sr
self.f_max = f_max if f_max is not None else float(sr // 2)
self.sample_rate = sample_rate
self.f_max = f_max if f_max is not None else float(sample_rate // 2)
assert f_min <= self.f_max, 'Require f_min: %f < f_max: %f' % (f_min, self.f_max)
self.f_min = f_min
fb = torch.empty(0) if n_stft is None else F.create_fb_matrix(
n_stft, self.f_min, self.f_max, self.n_mels)
self.fb = torch.jit.Attribute(fb, torch.Tensor)
@torch.jit.script_method
def forward(self, spec_f):
def forward(self, specgram):
r"""
Args:
specgram (torch.Tensor): a spectrogram STFT of size (c, f, t)
Returns:
torch.Tensor: mel frequency spectrogram of size (c, `n_mels`, t)
"""
if self.fb.numel() == 0:
tmp_fb = F.create_fb_matrix(spec_f.size(2), self.f_min, self.f_max, self.n_mels)
tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels)
# Attributes cannot be reassigned outside __init__ so workaround
self.fb.resize_(tmp_fb.size())
self.fb.copy_(tmp_fb)
spec_m = torch.matmul(spec_f, self.fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return spec_m
# (c, f, t).transpose(...) dot (f, n_mels) -> (c, t, n_mels).transpose(...)
mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
return mel_specgram
class SpectrogramToDB(torch.jit.ScriptModule):
"""Turns a spectrogram from the power/amplitude scale to the decibel scale.
r"""Turns a spectrogram from the power/amplitude scale to the decibel scale.
This output depends on the maximum value in the input spectrogram, and so
may return different values for an audio clip split into snippets vs. a
a full clip.
Args:
stype (str): scale of input spectrogram ("power" or "magnitude"). The
power being the elementwise square of the magnitude. default: "power"
stype (str): scale of input spectrogram ('power' or 'magnitude'). The
power being the elementwise square of the magnitude. (Default: 'power')
top_db (float, optional): minimum negative cut-off in decibels. A reasonable number
is 80.
"""
__constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier']
def __init__(self, stype="power", top_db=None):
def __init__(self, stype='power', top_db=None):
super(SpectrogramToDB, self).__init__()
self.stype = torch.jit.Attribute(stype, str)
if top_db is not None and top_db < 0:
raise ValueError('top_db must be positive value')
self.top_db = torch.jit.Attribute(top_db, Optional[float])
self.multiplier = 10. if stype == "power" else 20.
self.multiplier = 10.0 if stype == 'power' else 20.0
self.amin = 1e-10
self.ref_value = 1.
self.ref_value = 1.0
self.db_multiplier = math.log10(max(self.amin, self.ref_value))
@torch.jit.script_method
def forward(self, spec):
# numerically stable implementation from librosa
# https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html
return F.spectrogram_to_DB(spec, self.multiplier, self.amin, self.db_multiplier, self.top_db)
def forward(self, specgram):
r"""Numerically stable implementation from Librosa
https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html
Args:
specgram (torch.Tensor): STFT of size (c, f, t)
Returns:
torch.Tensor: STFT after changing scale of size (c, f, t)
"""
return F.spectrogram_to_DB(specgram, self.multiplier, self.amin, self.db_multiplier, self.top_db)
class MFCC(torch.jit.ScriptModule):
"""Create the Mel-frequency cepstrum coefficients from an audio signal
r"""Create the Mel-frequency cepstrum coefficients from an audio signal
By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
This is not the textbook implementation, but is implemented here to
give consistency with librosa.
By default, this calculates the MFCC on the DB-scaled Mel spectrogram.
This is not the textbook implementation, but is implemented here to
give consistency with librosa.
This output depends on the maximum value in the input spectrogram, and so
may return different values for an audio clip split into snippets vs. a
a full clip.
This output depends on the maximum value in the input spectrogram, and so
may return different values for an audio clip split into snippets vs. a
a full clip.
Args:
sr (int) : sample rate of audio signal
n_mfcc (int) : number of mfc coefficients to retain
dct_type (int) : type of DCT (discrete cosine transform) to use
norm (string, optional) : norm to use
log_mels (bool) : whether to use log-mel spectrograms instead of db-scaled
Args:
sample_rate (int): Sample rate of audio signal. (Default: 16000)
n_mfcc (int): Number of mfc coefficients to retain
dct_type (int): type of DCT (discrete cosine transform) to use
norm (string, optional): norm to use
log_mels (bool): whether to use log-mel spectrograms instead of db-scaled
melkwargs (dict, optional): arguments for MelSpectrogram
"""
__constants__ = ['sr', 'n_mfcc', 'dct_type', 'top_db', 'log_mels']
__constants__ = ['sample_rate', 'n_mfcc', 'dct_type', 'top_db', 'log_mels']
def __init__(self, sr=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False,
def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_mels=False,
melkwargs=None):
super(MFCC, self).__init__()
supported_dct_types = [2]
if dct_type not in supported_dct_types:
raise ValueError('DCT type not supported'.format(dct_type))
self.sr = sr
self.sample_rate = sample_rate
self.n_mfcc = n_mfcc
self.dct_type = dct_type
self.norm = torch.jit.Attribute(norm, Optional[str])
self.top_db = 80.
self.s2db = SpectrogramToDB("power", self.top_db)
self.top_db = 80.0
self.spectrogram_to_DB = SpectrogramToDB('power', self.top_db)
if melkwargs is not None:
self.MelSpectrogram = MelSpectrogram(sr=self.sr, **melkwargs)
self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs)
else:
self.MelSpectrogram = MelSpectrogram(sr=self.sr)
self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate)
if self.n_mfcc > self.MelSpectrogram.n_mels:
raise ValueError('Cannot select more MFCC coefficients than # mel bins')
......@@ -323,29 +214,28 @@ class MFCC(torch.jit.ScriptModule):
self.log_mels = log_mels
@torch.jit.script_method
def forward(self, sig):
"""
def forward(self, waveform):
r"""
Args:
sig (Tensor): Tensor of audio of size (channels [c], samples [n])
waveform (torch.Tensor): Tensor of audio of size (c, n)
Returns:
spec_mel_db (Tensor): channels x hops x n_mels (c, l, m), where channels
is unchanged, hops is the number of hops, and n_mels is the
number of mel bins.
torch.Tensor: specgram_mel_db of size (c, `n_mfcc`, t)
"""
mel_spect = self.MelSpectrogram(sig)
mel_specgram = self.MelSpectrogram(waveform)
if self.log_mels:
log_offset = 1e-6
mel_spect = torch.log(mel_spect + log_offset)
mel_specgram = torch.log(mel_specgram + log_offset)
else:
mel_spect = self.s2db(mel_spect)
mfcc = torch.matmul(mel_spect, self.dct_mat)
mel_specgram = self.spectrogram_to_DB(mel_specgram)
# (c, `n_mels`, t).tranpose(...) dot (`n_mels`, `n_mfcc`) -> (c, t, `n_mfcc`).tranpose(...)
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
return mfcc
class MelSpectrogram(torch.jit.ScriptModule):
"""Create MEL Spectrograms from a raw audio signal using the stft
function in PyTorch.
r"""Create MelSpectrogram for a raw audio signal. This is a composition of Spectrogram
and MelScale.
Sources:
* https://gist.github.com/kastnerkyle/179d6e9a88202ab0a2fe
......@@ -353,87 +243,58 @@ class MelSpectrogram(torch.jit.ScriptModule):
* http://haythamfayek.com/2016/04/21/speech-processing-for-machine-learning.html
Args:
sr (int): sample rate of audio signal
ws (int): window size
hop (int, optional): length of hop between STFT windows. default: `ws` // 2
n_fft (int, optional): number of fft bins. default: `ws` // 2 + 1
f_max (float, optional): maximum frequency. default: `sr` // 2
f_min (float): minimum frequency. default: 0
pad (int): two sided padding of signal
n_mels (int): number of MEL bins
window (torch windowing function): default: `torch.hann_window`
wkwargs (dict, optional): arguments for window function
sample_rate (int): Sample rate of audio signal. (Default: 16000)
win_length (int): Window size. (Default: `n_fft`)
hop_length (int, optional): Length of hop between STFT windows. (
Default: `win_length // 2`)
n_fft (int, optional): Size of fft, creates `n_fft // 2 + 1` bins
f_min (float): Minimum frequency. (Default: 0.)
f_max (float, optional): Maximum frequency. (Default: `None`)
pad (int): Two sided padding of signal. (Default: 0)
n_mels (int): Number of mel filterbanks. (Default: 128)
window_fn (Callable[[...], torch.Tensor]): A function to create a window tensor
that is applied/multiplied to each frame/window. (Default: `torch.hann_window`)
wkwargs (Dict[..., ...]): Arguments for window function. (Default: `None`)
Example:
>>> sig, sr = torchaudio.load("test.wav", normalization=True)
>>> spec_mel = transforms.MelSpectrogram(sr)(sig) # (c, l, m)
>>> waveform, sample_rate = torchaudio.load('test.wav', normalization=True)
>>> mel_specgram = transforms.MelSpectrogram(sample_rate)(waveform) # (c, n_mels, t)
"""
__constants__ = ['sr', 'n_fft', 'ws', 'hop', 'pad', 'n_mels', 'f_min']
__constants__ = ['sample_rate', 'n_fft', 'win_length', 'hop_length', 'pad', 'n_mels', 'f_min']
def __init__(self, sr=16000, n_fft=400, ws=None, hop=None, f_min=0., f_max=None,
pad=0, n_mels=128, window=torch.hann_window, wkwargs=None):
def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=None, f_min=0., f_max=None,
pad=0, n_mels=128, window_fn=torch.hann_window, wkwargs=None):
super(MelSpectrogram, self).__init__()
self.sr = sr
self.sample_rate = sample_rate
self.n_fft = n_fft
self.ws = ws if ws is not None else n_fft
self.hop = hop if hop is not None else self.ws // 2
self.win_length = win_length if win_length is not None else n_fft
self.hop_length = hop_length if hop_length is not None else self.win_length // 2
self.pad = pad
self.n_mels = n_mels # number of mel frequency bins
self.f_max = torch.jit.Attribute(f_max, Optional[float])
self.f_min = f_min
self.spec = Spectrogram(n_fft=self.n_fft, ws=self.ws, hop=self.hop,
pad=self.pad, window=window, power=2,
normalize=False, wkwargs=wkwargs)
self.fm = MelScale(self.n_mels, self.sr, self.f_max, self.f_min)
self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
hop_length=self.hop_length,
pad=self.pad, window_fn=window_fn, power=2,
normalized=False, wkwargs=wkwargs)
self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max)
@torch.jit.script_method
def forward(self, sig):
"""
def forward(self, waveform):
r"""
Args:
sig (Tensor): Tensor of audio of size (channels [c], samples [n])
waveform (torch.Tensor): Tensor of audio of size (c, n)
Returns:
spec_mel (Tensor): channels x hops x n_mels (c, l, m), where channels
is unchanged, hops is the number of hops, and n_mels is the
number of mel bins.
torch.Tensor: mel frequency spectrogram of size (c, `n_mels`, t)
"""
spec = self.spec(sig)
spec_mel = self.fm(spec)
return spec_mel
def MEL(*args, **kwargs):
raise DeprecationWarning("MEL has been removed from the library please use MelSpectrogram or librosa")
class BLC2CBL(torch.jit.ScriptModule):
"""Permute a 3d tensor from Bands x Sample length x Channels to Channels x
Bands x Samples length
"""
def __init__(self):
super(BLC2CBL, self).__init__()
@torch.jit.script_method
def forward(self, tensor):
"""
Args:
tensor (Tensor): Tensor of spectrogram with shape (BxLxC)
Returns:
tensor (Tensor): Tensor of spectrogram with shape (CxBxL)
"""
return F.BLC2CBL(tensor)
def __repr__(self):
return self.__class__.__name__ + '()'
specgram = self.spectrogram(waveform)
mel_specgram = self.mel_scale(specgram)
return mel_specgram
class MuLawEncoding(torch.jit.ScriptModule):
"""Encode signal based on mu-law companding. For more info see the
r"""Encode signal based on mu-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
This algorithm assumes the signal has been scaled to between -1 and 1 and
......@@ -441,33 +302,27 @@ class MuLawEncoding(torch.jit.ScriptModule):
Args:
quantization_channels (int): Number of channels. default: 256
"""
__constants__ = ['qc']
__constants__ = ['quantization_channels']
def __init__(self, quantization_channels=256):
super(MuLawEncoding, self).__init__()
self.qc = quantization_channels
self.quantization_channels = quantization_channels
@torch.jit.script_method
def forward(self, x):
"""
r"""
Args:
x (FloatTensor/LongTensor)
x (torch.Tensor): A signal to be encoded
Returns:
x_mu (LongTensor)
x_mu (torch.Tensor): An encoded signal
"""
return F.mu_law_encoding(x, self.qc)
def __repr__(self):
return self.__class__.__name__ + '()'
return F.mu_law_encoding(x, self.quantization_channels)
class MuLawExpanding(torch.jit.ScriptModule):
"""Decode mu-law encoded signal. For more info see the
r"""Decode mu-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
This expects an input with values between 0 and quantization_channels - 1
......@@ -475,33 +330,27 @@ class MuLawExpanding(torch.jit.ScriptModule):
Args:
quantization_channels (int): Number of channels. default: 256
"""
__constants__ = ['qc']
__constants__ = ['quantization_channels']
def __init__(self, quantization_channels=256):
super(MuLawExpanding, self).__init__()
self.qc = quantization_channels
self.quantization_channels = quantization_channels
@torch.jit.script_method
def forward(self, x_mu):
"""
r"""
Args:
x_mu (Tensor)
x_mu (torch.Tensor): A mu-law encoded signal which needs to be decoded
Returns:
x (Tensor)
torch.Tensor: The signal decoded
"""
return F.mu_law_expanding(x_mu, self.qc)
def __repr__(self):
return self.__class__.__name__ + '()'
return F.mu_law_expanding(x_mu, self.quantization_channels)
class Resample(torch.nn.Module):
"""Resamples a signal from one frequency to another. A resampling method can
r"""Resamples a signal from one frequency to another. A resampling method can
be given.
Args:
......@@ -516,15 +365,15 @@ class Resample(torch.nn.Module):
self.new_freq = new_freq
self.resampling_method = resampling_method
def forward(self, sig):
"""
def forward(self, waveform):
r"""
Args:
sig (Tensor): the input signal of size (c, n)
waveform (torch.Tensor): The input signal of size (c, n)
Returns:
Tensor: output signal of size (c, m)
torch.Tensor: Output signal of size (c, m)
"""
if self.resampling_method == 'sinc_interpolation':
return kaldi.resample_waveform(sig, self.orig_freq, self.new_freq)
return kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)
raise ValueError('Invalid resampling method: %s' % (self.resampling_method))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment