Unverified Commit dc3c5c51 authored by Bhargav Kathivarapu's avatar Bhargav Kathivarapu Committed by GitHub
Browse files

Add Pathlib support for 'apply_effects_file' (#1048)

parent 34696b14
import itertools import itertools
from pathlib import Path
from torchaudio import sox_effects from torchaudio import sox_effects
from parameterized import parameterized from parameterized import parameterized
...@@ -104,7 +105,7 @@ class TestSoxEffectsFile(TempDirMixin, PytorchTestCase): ...@@ -104,7 +105,7 @@ class TestSoxEffectsFile(TempDirMixin, PytorchTestCase):
load_params("sox_effect_test_args.json"), load_params("sox_effect_test_args.json"),
name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}', name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}',
) )
def test_apply_effects(self, args): def test_apply_effects_str(self, args):
"""`apply_effects_file` should return identical data as sox command""" """`apply_effects_file` should return identical data as sox command"""
dtype = 'int32' dtype = 'int32'
channels_first = True channels_first = True
...@@ -127,6 +128,29 @@ class TestSoxEffectsFile(TempDirMixin, PytorchTestCase): ...@@ -127,6 +128,29 @@ class TestSoxEffectsFile(TempDirMixin, PytorchTestCase):
assert sr == expected_sr assert sr == expected_sr
self.assertEqual(found, expected) self.assertEqual(found, expected)
def test_apply_effects_path(self):
"""`apply_effects_file` should return identical data as sox command when file path is given as a Path Object"""
dtype = 'int32'
channels_first = True
effects = [["hilbert"]]
num_channels = 2
input_sr = 8000
output_sr = 8000
input_path = self.get_temp_path('input.wav')
reference_path = self.get_temp_path('reference.wav')
data = get_wav_data(dtype, num_channels, channels_first=channels_first)
save_wav(input_path, data, input_sr, channels_first=channels_first)
sox_utils.run_sox_effect(
input_path, reference_path, effects, output_sample_rate=output_sr)
expected, expected_sr = load_wav(reference_path)
found, sr = sox_effects.apply_effects_file(
Path(input_path), effects, normalize=False, channels_first=channels_first)
assert sr == expected_sr
self.assertEqual(found, expected)
@skipIfNoExtension @skipIfNoExtension
class TestFileFormats(TempDirMixin, PytorchTestCase): class TestFileFormats(TempDirMixin, PytorchTestCase):
......
from typing import List, Tuple from typing import List, Tuple, Union
from pathlib import Path
import torch import torch
...@@ -169,7 +170,8 @@ def apply_effects_file( ...@@ -169,7 +170,8 @@ def apply_effects_file(
rate and leave samples untouched. rate and leave samples untouched.
Args: Args:
path (str): Path to the audio file. path (str or pathlib.Path): Path to the audio file. This function also handles ``pathlib.Path`` objects, but is
annotated as ``str`` for TorchScript compiler compatibility.
effects (List[List[str]]): List of effects. effects (List[List[str]]): List of effects.
normalize (bool): normalize (bool):
When ``True``, this function always return ``float32``, and sample values are When ``True``, this function always return ``float32``, and sample values are
...@@ -247,5 +249,7 @@ def apply_effects_file( ...@@ -247,5 +249,7 @@ def apply_effects_file(
>>> for batch in loader: >>> for batch in loader:
>>> pass >>> pass
""" """
# Get string representation of 'path' in case Path object is passed
path = str(path)
signal = torch.ops.torchaudio.sox_effects_apply_effects_file(path, effects, normalize, channels_first) signal = torch.ops.torchaudio.sox_effects_apply_effects_file(path, effects, normalize, channels_first)
return signal.get_tensor(), signal.get_sample_rate() return signal.get_tensor(), signal.get_sample_rate()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment