Unverified Commit d11ad6bb authored by engineerchuan's avatar engineerchuan Committed by GitHub
Browse files

Stop using whitenoise.wav, mp3 and torchaudio.load in sox effect test

Part of #764

 - Replace `whitenoise.wav` with on-the-fly data generation
 - Replace `torchaudio.load` with `common_utils.load_wav`
 - Replace `steam-train-whistle-daniel_simon.mp3` with `.wav`
parent 4b3e9052
...@@ -21,9 +21,10 @@ def get_whitenoise( ...@@ -21,9 +21,10 @@ def get_whitenoise(
seed: int = 0, seed: int = 0,
dtype: Union[str, torch.dtype] = "float32", dtype: Union[str, torch.dtype] = "float32",
device: Union[str, torch.device] = "cpu", device: Union[str, torch.device] = "cpu",
channels_first=True,
scale_factor: float = 1,
): ):
"""Generate pseudo audio data with whitenoise """Generate pseudo audio data with whitenoise
Args: Args:
sample_rate: Sampling rate sample_rate: Sampling rate
duration: Length of the resulting Tensor in seconds. duration: Length of the resulting Tensor in seconds.
...@@ -32,19 +33,34 @@ def get_whitenoise( ...@@ -32,19 +33,34 @@ def get_whitenoise(
Note that this function does not modify global random generator state. Note that this function does not modify global random generator state.
dtype: Torch dtype dtype: Torch dtype
device: device device: device
channels_first: whether first dimension is n_channels
scale_factor: scale the Tensor before clamping and quantization
Returns: Returns:
Tensor: shape of (n_channels, sample_rate * duration) Tensor: shape of (n_channels, sample_rate * duration)
""" """
if isinstance(dtype, str): if isinstance(dtype, str):
dtype = getattr(torch, dtype) dtype = getattr(torch, dtype)
shape = [n_channels, sample_rate * duration] if dtype not in [torch.float32, torch.int32, torch.int16, torch.uint8]:
raise NotImplementedError(f'dtype {dtype} is not supported.')
# According to the doc, folking rng on all CUDA devices is slow when there are many CUDA devices, # According to the doc, folking rng on all CUDA devices is slow when there are many CUDA devices,
# so we only folk on CPU, generate values and move the data to the given device # so we only folk on CPU, generate values and move the data to the given device
with torch.random.fork_rng([]): with torch.random.fork_rng([]):
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
tensor = torch.randn(shape, dtype=dtype, device='cpu') tensor = torch.randn([sample_rate * duration], dtype=torch.float32, device='cpu')
tensor /= 2.0 tensor /= 2.0
tensor *= scale_factor
tensor.clamp_(-1.0, 1.0) tensor.clamp_(-1.0, 1.0)
if dtype == torch.int32:
tensor *= (tensor > 0) * 2147483647 + (tensor < 0) * 2147483648
if dtype == torch.int16:
tensor *= (tensor > 0) * 32767 + (tensor < 0) * 32768
if dtype == torch.uint8:
tensor *= (tensor > 0) * 127 + (tensor < 0) * 128
tensor += 128
tensor = tensor.to(dtype)
tensor = tensor.repeat([n_channels, 1])
if not channels_first:
tensor = tensor.t()
return tensor.to(device=device) return tensor.to(device=device)
......
This diff is collapsed.
...@@ -11,7 +11,7 @@ from . import common_utils ...@@ -11,7 +11,7 @@ from . import common_utils
class Test_SoxEffectsChain(common_utils.TorchaudioTestCase): class Test_SoxEffectsChain(common_utils.TorchaudioTestCase):
backend = 'sox' backend = 'sox'
test_filepath = common_utils.get_asset_path("steam-train-whistle-daniel_simon.mp3") test_filepath = common_utils.get_asset_path("steam-train-whistle-daniel_simon.wav")
def test_single_channel(self): def test_single_channel(self):
fn_sine = common_utils.get_asset_path("sinewave.wav") fn_sine = common_utils.get_asset_path("sinewave.wav")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment