Unverified Commit 313f4f5c authored by moto's avatar moto Committed by GitHub
Browse files

Add util to generate whitenoise (#654)

* Add util to generate whitenoise

* Use sinusoid for pitch and revert vad
parent 2f2319d6
import os
import tempfile
import unittest
from typing import Type, Iterable
from typing import Type, Iterable, Union
from contextlib import contextmanager
from shutil import copytree
......@@ -117,3 +117,68 @@ def define_test_suites(
for dtype in dtypes:
t = define_test_suite(suite, dtype, device)
scope[t.__name__] = t
def get_whitenoise(
*,
sample_rate: int = 16000,
duration: float = 1, # seconds
n_channels: int = 1,
seed: int = 0,
dtype: Union[str, torch.dtype] = "float32",
device: Union[str, torch.device] = "cpu",
):
"""Generate pseudo audio data with whitenoise
Args:
sample_rate: Sampling rate
duration: Length of the resulting Tensor in seconds.
n_channels: Number of channels
seed: Seed value used for random number generation.
Note that this function does not modify global random generator state.
dtype: Torch dtype
device: device
Returns:
Tensor: shape of (n_channels, sample_rate * duration)
"""
if isinstance(dtype, str):
dtype = getattr(torch, dtype)
shape = [n_channels, sample_rate * duration]
# According to the doc, folking rng on all CUDA devices is slow when there are many CUDA devices,
# so we only folk on CPU, generate values and move the data to the given device
with torch.random.fork_rng([]):
torch.random.manual_seed(seed)
tensor = torch.randn(shape, dtype=dtype, device='cpu')
tensor /= 2.0
tensor.clamp_(-1.0, 1.0)
return tensor.to(device=device)
def get_sinusoid(
*,
frequency: float = 300,
sample_rate: int = 16000,
duration: float = 1, # seconds
n_channels: int = 1,
dtype: Union[str, torch.dtype] = "float32",
device: Union[str, torch.device] = "cpu",
):
"""Generate pseudo audio data with sine wave.
Args:
frequency: Frequency of sine wave
sample_rate: Sampling rate
duration: Length of the resulting Tensor in seconds.
n_channels: Number of channels
dtype: Torch dtype
device: device
Returns:
Tensor: shape of (n_channels, sample_rate * duration)
"""
if isinstance(dtype, str):
dtype = getattr(torch, dtype)
pie2 = 2 * 3.141592653589793
end = pie2 * frequency * duration
theta = torch.linspace(0, end, sample_rate * duration, dtype=dtype, device=device)
return torch.sin(theta, out=None).repeat([n_channels, 1])
......@@ -33,7 +33,7 @@ class Functional(common_utils.TestBaseMixin):
normalize = False
return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize)
tensor = torch.rand((1, 1000))
tensor = common_utils.get_whitenoise()
self._assert_consistency(func, tensor)
def test_griffinlim(self):
......@@ -65,8 +65,7 @@ class Functional(common_utils.TestBaseMixin):
self._assert_consistency(func, tensor)
def test_detect_pitch_frequency(self):
filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(filepath)
waveform = common_utils.get_sinusoid(sample_rate=44100)
def func(tensor):
sample_rate = 44100
......@@ -128,8 +127,8 @@ class Functional(common_utils.TestBaseMixin):
qc = 256
return F.mu_law_encoding(tensor, qc)
tensor = torch.rand((1, 10))
self._assert_consistency(func, tensor)
waveform = common_utils.get_whitenoise()
self._assert_consistency(func, waveform)
def test_mu_law_decoding(self):
def func(tensor):
......@@ -179,29 +178,28 @@ class Functional(common_utils.TestBaseMixin):
def func(tensor):
return F.dither(tensor, 'TPDF')
tensor = torch.rand((2, 1000))
tensor = common_utils.get_whitenoise(n_channels=2)
self._assert_consistency(func, tensor, shape_only=True)
def test_dither_RPDF(self):
def func(tensor):
return F.dither(tensor, 'RPDF')
tensor = torch.rand((2, 1000))
tensor = common_utils.get_whitenoise(n_channels=2)
self._assert_consistency(func, tensor, shape_only=True)
def test_dither_GPDF(self):
def func(tensor):
return F.dither(tensor, 'GPDF')
tensor = torch.rand((2, 1000))
tensor = common_utils.get_whitenoise(n_channels=2)
self._assert_consistency(func, tensor, shape_only=True)
def test_lfilter(self):
if self.dtype == torch.float64:
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True)
waveform = common_utils.get_whitenoise()
def func(tensor):
# Design an IIR lowpass filter using scipy.signal filter design
......@@ -244,8 +242,7 @@ class Functional(common_utils.TestBaseMixin):
if self.dtype == torch.float64:
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True)
waveform = common_utils.get_whitenoise(sample_rate=44100)
def func(tensor):
sample_rate = 44100
......@@ -258,8 +255,7 @@ class Functional(common_utils.TestBaseMixin):
if self.dtype == torch.float64:
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True)
waveform = common_utils.get_whitenoise(sample_rate=44100)
def func(tensor):
sample_rate = 44100
......@@ -272,8 +268,7 @@ class Functional(common_utils.TestBaseMixin):
if self.dtype == torch.float64:
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True)
waveform = common_utils.get_whitenoise(sample_rate=44100)
def func(tensor):
sample_rate = 44100
......@@ -287,8 +282,7 @@ class Functional(common_utils.TestBaseMixin):
if self.dtype == torch.float64:
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
waveform = common_utils.get_whitenoise(sample_rate=44100)
def func(tensor):
sample_rate = 44100
......@@ -303,8 +297,7 @@ class Functional(common_utils.TestBaseMixin):
if self.dtype == torch.float64:
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
waveform = common_utils.get_whitenoise(sample_rate=44100)
def func(tensor):
sample_rate = 44100
......@@ -319,8 +312,7 @@ class Functional(common_utils.TestBaseMixin):
if self.dtype == torch.float64:
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
waveform = common_utils.get_whitenoise(sample_rate=44100)
def func(tensor):
sample_rate = 44100
......@@ -334,8 +326,7 @@ class Functional(common_utils.TestBaseMixin):
if self.dtype == torch.float64:
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
waveform = common_utils.get_whitenoise(sample_rate=44100)
def func(tensor):
sample_rate = 44100
......@@ -350,8 +341,7 @@ class Functional(common_utils.TestBaseMixin):
if self.dtype == torch.float64:
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
waveform = common_utils.get_whitenoise(sample_rate=44100)
def func(tensor):
sample_rate = 44100
......@@ -366,8 +356,7 @@ class Functional(common_utils.TestBaseMixin):
if self.dtype == torch.float64:
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
waveform = common_utils.get_whitenoise(sample_rate=44100)
def func(tensor):
sample_rate = 44100
......@@ -382,8 +371,7 @@ class Functional(common_utils.TestBaseMixin):
if self.dtype == torch.float64:
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
waveform = common_utils.get_whitenoise(sample_rate=44100)
def func(tensor):
sample_rate = 44100
......@@ -395,8 +383,7 @@ class Functional(common_utils.TestBaseMixin):
if self.dtype == torch.float64:
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
waveform = common_utils.get_whitenoise(sample_rate=44100)
def func(tensor):
sample_rate = 44100
......@@ -408,8 +395,7 @@ class Functional(common_utils.TestBaseMixin):
if self.dtype == torch.float64:
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
waveform = common_utils.get_whitenoise(sample_rate=44100)
def func(tensor):
sample_rate = 44100
......@@ -424,8 +410,7 @@ class Functional(common_utils.TestBaseMixin):
if self.dtype == torch.float64:
raise unittest.SkipTest("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
waveform = common_utils.get_whitenoise()
def func(tensor):
a = torch.tensor([0.7, 0.2, 0.6], device=tensor.device, dtype=tensor.dtype)
......@@ -470,8 +455,7 @@ class Functional(common_utils.TestBaseMixin):
self._assert_consistency(func, b)
def test_contrast(self):
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
waveform = common_utils.get_whitenoise()
def func(tensor):
enhancement_amount = 80.
......@@ -480,8 +464,7 @@ class Functional(common_utils.TestBaseMixin):
self._assert_consistency(func, waveform)
def test_dcshift(self):
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
waveform = common_utils.get_whitenoise()
def func(tensor):
shift = 0.5
......@@ -491,8 +474,7 @@ class Functional(common_utils.TestBaseMixin):
self._assert_consistency(func, waveform)
def test_overdrive(self):
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
waveform = common_utils.get_whitenoise()
def func(tensor):
gain = 30.
......@@ -502,8 +484,7 @@ class Functional(common_utils.TestBaseMixin):
self._assert_consistency(func, waveform)
def test_phaser(self):
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
waveform = common_utils.get_whitenoise(sample_rate=44100)
def func(tensor):
gain_in = 0.5
......@@ -553,17 +534,16 @@ class Transforms(common_utils.TestBaseMixin):
self._assert_consistency(T.MFCC(), tensor)
def test_Resample(self):
tensor = torch.rand((2, 1000))
sample_rate = 100.
sample_rate_2 = 50.
self._assert_consistency(T.Resample(sample_rate, sample_rate_2), tensor)
sr1, sr2 = 16000, 8000
tensor = common_utils.get_whitenoise(sample_rate=sr1)
self._assert_consistency(T.Resample(float(sr1), float(sr2)), tensor)
def test_ComplexNorm(self):
tensor = torch.rand((1, 2, 201, 2))
self._assert_consistency(T.ComplexNorm(), tensor)
def test_MuLawEncoding(self):
tensor = torch.rand((1, 10))
tensor = common_utils.get_whitenoise()
self._assert_consistency(T.MuLawEncoding(), tensor)
def test_MuLawDecoding(self):
......@@ -581,8 +561,7 @@ class Transforms(common_utils.TestBaseMixin):
)
def test_Fade(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath)
waveform = common_utils.get_whitenoise()
fade_in_len = 3000
fade_out_len = 3000
self._assert_consistency(T.Fade(fade_in_len, fade_out_len), waveform)
......@@ -596,8 +575,7 @@ class Transforms(common_utils.TestBaseMixin):
self._assert_consistency(T.TimeMasking(time_mask_param=30, iid_masks=False), tensor)
def test_Vol(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath)
waveform = common_utils.get_whitenoise()
self._assert_consistency(T.Vol(1.1), waveform)
def test_SlidingWindowCmn(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment