Commit 2c7fdcc8 authored by Cami Williams's avatar Cami Williams Committed by Vincent QB
Browse files

Add functionals gain, dither, scale_to_interval (#319)

* Initial commit for SoX logic in VCTK

* change to train whistle file for tests

* apply probability
parent bdf92553
...@@ -5,6 +5,7 @@ import os ...@@ -5,6 +5,7 @@ import os
import torch import torch
import torchaudio import torchaudio
import torchaudio.functional as F import torchaudio.functional as F
import torchaudio.transforms as T
import pytest import pytest
import unittest import unittest
import common_utils import common_utils
...@@ -31,8 +32,10 @@ class TestFunctional(unittest.TestCase): ...@@ -31,8 +32,10 @@ class TestFunctional(unittest.TestCase):
specgram = torch.tensor([1., 2., 3., 4.]) specgram = torch.tensor([1., 2., 3., 4.])
test_dirpath, test_dir = common_utils.create_temp_assets_dir() test_dirpath, test_dir = common_utils.create_temp_assets_dir()
test_filepath = os.path.join(test_dirpath, 'assets', test_filepath = os.path.join(test_dirpath, 'assets',
'steam-train-whistle-daniel_simon.mp3') 'steam-train-whistle-daniel_simon.mp3')
waveform_train, sr_train = torchaudio.load(test_filepath)
def test_torchscript_spectrogram(self): def test_torchscript_spectrogram(self):
...@@ -365,8 +368,63 @@ class TestFunctional(unittest.TestCase): ...@@ -365,8 +368,63 @@ class TestFunctional(unittest.TestCase):
self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0) self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0)
self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0) self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0)
def test_pitch(self): def test_gain(self):
waveform_gain = F.gain(self.waveform_train, 3)
self.assertTrue(waveform_gain.abs().max().item(), 1.)
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(self.test_filepath)
E.append_effect_to_chain("gain", [3])
sox_gain_waveform = E.sox_build_flow_effects()[0]
self.assertTrue(torch.allclose(waveform_gain, sox_gain_waveform, atol=1e-04))
def test_scale_to_interval(self):
scaled = 5.5 # [-5.5, 5.5]
waveform_scaled = F._scale_to_interval(self.waveform_train, scaled)
self.assertTrue(torch.max(waveform_scaled) <= scaled)
self.assertTrue(torch.min(waveform_scaled) >= -scaled)
def test_dither(self):
waveform_dithered = F.dither(self.waveform_train)
waveform_dithered_noiseshaped = F.dither(self.waveform_train, noise_shaping=True)
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(self.test_filepath)
E.append_effect_to_chain("dither", [])
sox_dither_waveform = E.sox_build_flow_effects()[0]
self.assertTrue(torch.allclose(waveform_dithered, sox_dither_waveform, atol=1e-04))
E.clear_chain()
E.append_effect_to_chain("dither", ["-s"])
sox_dither_waveform_ns = E.sox_build_flow_effects()[0]
self.assertTrue(torch.allclose(waveform_dithered_noiseshaped, sox_dither_waveform_ns, atol=1e-02))
def test_vctk_transform_pipeline(self):
test_filepath_vctk = os.path.join(self.test_dirpath, "assets/VCTK-Corpus/wav48/p224/", "p224_002.wav")
wf_vctk, sr_vctk = torchaudio.load(test_filepath_vctk)
# rate
sample = T.Resample(sr_vctk, 16000, resampling_method='sinc_interpolation')
wf_vctk = sample(wf_vctk)
# dither
wf_vctk = F.dither(wf_vctk, noise_shaping=True)
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(test_filepath_vctk)
E.append_effect_to_chain("gain", ["-h"])
E.append_effect_to_chain("channels", [1])
E.append_effect_to_chain("rate", [16000])
E.append_effect_to_chain("gain", ["-rh"])
E.append_effect_to_chain("dither", ["-s"])
wf_vctk_sox = E.sox_build_flow_effects()[0]
self.assertTrue(torch.allclose(wf_vctk, wf_vctk_sox, rtol=1e-03, atol=1e-03))
def test_pitch(self):
test_dirpath, test_dir = common_utils.create_temp_assets_dir() test_dirpath, test_dir = common_utils.create_temp_assets_dir()
test_filepath_100 = os.path.join(test_dirpath, 'assets', "100Hz_44100Hz_16bit_05sec.wav") test_filepath_100 = os.path.join(test_dirpath, 'assets', "100Hz_44100Hz_16bit_05sec.wav")
test_filepath_440 = os.path.join(test_dirpath, 'assets', "440Hz_44100Hz_16bit_05sec.wav") test_filepath_440 = os.path.join(test_dirpath, 'assets', "440Hz_44100Hz_16bit_05sec.wav")
...@@ -518,6 +576,25 @@ def test_phase_vocoder(complex_specgrams, rate, hop_length): ...@@ -518,6 +576,25 @@ def test_phase_vocoder(complex_specgrams, rate, hop_length):
_test_torchscript_functional(F.mask_along_axis_iid, specgrams, mask_param, mask_value, axis) _test_torchscript_functional(F.mask_along_axis_iid, specgrams, mask_param, mask_value, axis)
def test_torchscript_gain(self):
tensor = torch.rand((1, 1000))
gainDB = 2.0
_test_torchscript_functional(F.gain, tensor, gainDB)
def test_torchscript_scale_to_interval(self):
tensor = torch.rand((1, 1000))
scaled = 3.5
_test_torchscript_functional(F._scale_to_interval, tensor, scaled)
def test_torchscript_dither(self):
tensor = torch.rand((1, 1000))
_test_torchscript_functional(F.dither, tensor)
_test_torchscript_functional(F.dither, tensor, "RPDF")
_test_torchscript_functional(F.dither, tensor, "GPDF")
@pytest.mark.parametrize('complex_tensor', [ @pytest.mark.parametrize('complex_tensor', [
torch.randn(1, 2, 1025, 400, 2), torch.randn(1, 2, 1025, 400, 2),
......
...@@ -21,18 +21,16 @@ def load_vctk_item( ...@@ -21,18 +21,16 @@ def load_vctk_item(
# Read wav # Read wav
file_audio = os.path.join(path, folder_audio, speaker_id, fileid + ext_audio) file_audio = os.path.join(path, folder_audio, speaker_id, fileid + ext_audio)
waveform, sample_rate = torchaudio.load(file_audio)
if downsample: if downsample:
# Legacy # TODO Remove this parameter after deprecation
E = torchaudio.sox_effects.SoxEffectsChain() F = torchaudio.functional
E.set_input_file(file_audio) T = torchaudio.transforms
E.append_effect_to_chain("gain", ["-h"]) # rate
E.append_effect_to_chain("channels", [1]) sample = T.Resample(sample_rate, 16000, resampling_method='sinc_interpolation')
E.append_effect_to_chain("rate", [16000]) waveform = sample(waveform)
E.append_effect_to_chain("gain", ["-rh"]) # dither
E.append_effect_to_chain("dither", ["-s"]) waveform = F.dither(waveform, noise_shaping=True)
waveform, sample_rate = E.sox_build_flow_effects()
else:
waveform, sample_rate = torchaudio.load(file_audio)
return waveform, sample_rate, utterance, speaker_id, utterance_id return waveform, sample_rate, utterance, speaker_id, utterance_id
......
...@@ -858,6 +858,162 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): ...@@ -858,6 +858,162 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
return output return output
def gain(waveform, gain_db=1.0):
# type: (Tensor, float) -> Tensor
r"""Apply amplification or attenuation to the whole waveform.
Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time).
gain_db (float) Gain adjustment in decibels (dB) (Default: `1.0`).
Returns:
torch.Tensor: the whole waveform amplified by gain_db.
"""
if (gain_db == 0):
return waveform
ratio = 10 ** (gain_db / 20)
return waveform * ratio
def _scale_to_interval(waveform, interval_max=1.0):
# type: (Tensor, float) -> Tensor
r"""Scale the waveform to the interval [-interval_max, interval_max] across all dimensions.
Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time).
interval_max (float): The bounds of the interval, where the float indicates
the upper bound and the negative of the float indicates the lower
bound (Default: `1.0`).
Example: interval=1.0 -> [-1.0, 1.0]
Returns:
torch.Tensor: the whole waveform scaled to interval.
"""
abs_max = torch.max(torch.abs(waveform))
ratio = abs_max / interval_max
waveform /= ratio
return waveform
def _add_noise_shaping(dithered_waveform, waveform):
r"""Noise shaping is calculated by error:
error[n] = dithered[n] - original[n]
noise_shaped_waveform[n] = dithered[n] + error[n-1]
"""
wf_shape = waveform.size()
waveform = waveform.reshape(-1, wf_shape[-1])
dithered_shape = dithered_waveform.size()
dithered_waveform = dithered_waveform.reshape(-1, dithered_shape[-1])
error = dithered_waveform - waveform
# add error[n-1] to dithered_waveform[n], so offset the error by 1 index
for index in range(error.size()[0]):
err = error[index]
error_offset = torch.cat((torch.zeros(1), err))
error[index] = error_offset[:waveform.size()[1]]
noise_shaped = dithered_waveform + error
return noise_shaped.reshape(dithered_shape[:-1] + noise_shaped.shape[-1:])
def _apply_probability_distribution(waveform, density_function="TPDF"):
# type: (Tensor, str) -> Tensor
r"""Apply a probability distribution function on a waveform.
Triangular probability density function (TPDF) dither noise has a
triangular distribution; values in the center of the range have a higher
probability of occurring.
Rectangular probability density function (RPDF) dither noise has a
uniform distribution; any value in the specified range has the same
probability of occurring.
Gaussian probability density function (GPDF) has a normal distribution.
The relationship of probabilities of results follows a bell-shaped,
or Gaussian curve, typical of dither generated by analog sources.
Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
probability_density_function (string): The density function of a
continuous random variable (Default: `TPDF`)
Options: Triangular Probability Density Function - `TPDF`
Rectangular Probability Density Function - `RPDF`
Gaussian Probability Density Function - `GPDF`
Returns:
torch.Tensor: waveform dithered with TPDF
"""
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])
channel_size = waveform.size()[0] - 1
time_size = waveform.size()[-1] - 1
random_channel = int(torch.randint(channel_size, [1, ]).item()) if channel_size > 0 else 0
random_time = int(torch.randint(time_size, [1, ]).item()) if time_size > 0 else 0
number_of_bits = 16
up_scaling = 2 ** (number_of_bits - 1) - 2
signal_scaled = waveform * up_scaling
down_scaling = 2 ** (number_of_bits - 1)
signal_scaled_dis = waveform
if (density_function == "RPDF"):
RPDF = waveform[random_channel][random_time] - 0.5
signal_scaled_dis = signal_scaled + RPDF
elif (density_function == "GPDF"):
# TODO Replace by distribution code once
# https://github.com/pytorch/pytorch/issues/29843 is resolved
# gaussian = torch.distributions.normal.Normal(torch.mean(waveform, -1), 1).sample()
num_rand_variables = 6
gaussian = waveform[random_channel][random_time]
for ws in num_rand_variables * [time_size]:
rand_chan = int(torch.randint(channel_size, [1, ]).item())
gaussian += waveform[rand_chan][int(torch.randint(ws, [1, ]).item())]
signal_scaled_dis = signal_scaled + gaussian
else:
TPDF = torch.bartlett_window(time_size + 1)
TPDF = TPDF.repeat((channel_size + 1), 1)
signal_scaled_dis = signal_scaled + TPDF
quantised_signal_scaled = torch.round(signal_scaled_dis)
quantised_signal = quantised_signal_scaled / down_scaling
return quantised_signal.reshape(shape[:-1] + quantised_signal.shape[-1:])
def dither(waveform, density_function="TPDF", noise_shaping=False):
# type: (Tensor, str, bool) -> Tensor
r"""Dither increases the perceived dynamic range of audio stored at a
particular bit-depth by eliminating nonlinear truncation distortion
(i.e. adding minimally perceived noise to mask distortion caused by quantization).
Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
density_function (string): The density function of a
continuous random variable (Default: `TPDF`)
Options: Triangular Probability Density Function - `TPDF`
Rectangular Probability Density Function - `RPDF`
Gaussian Probability Density Function - `GPDF`
noise_shaping (boolean): a filtering process that shapes the spectral
energy of quantisation error (Default: `False`)
Returns:
torch.Tensor: waveform dithered
"""
dithered = _apply_probability_distribution(waveform, density_function=density_function)
if noise_shaping:
return _add_noise_shaping(dithered, waveform)
else:
return dithered
def _compute_nccf(waveform, sample_rate, frame_time, freq_low): def _compute_nccf(waveform, sample_rate, frame_time, freq_low):
# type: (Tensor, int, float, int) -> Tensor # type: (Tensor, int, float, int) -> Tensor
r""" r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment