Unverified Commit ec1b3e36 authored by moto's avatar moto Committed by GitHub
Browse files

Support file-like object in apply_effects_file (#1166)

parent 5085aeb9
......@@ -43,7 +43,7 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase):
load_params("sox_effect_test_args.json"),
name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}',
)
def test_apply_effects(self, args):
def test_apply_effects_file(self, args):
"""`apply_effects_file` should return identical data as sox command"""
dtype = 'int32'
channels_first = True
......@@ -57,3 +57,23 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase):
_found, _sr = sox_effects.apply_effects_file(
input_path, effects, normalize=False, channels_first=channels_first)
@parameterized.expand(
load_params("sox_effect_test_args.json"),
name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}',
)
def test_apply_effects_fileobj(self, args):
"""`apply_effects_file` should return identical data as sox command"""
dtype = 'int32'
channels_first = True
effects = args['effects']
num_channels = args.get("num_channels", 2)
input_sr = args.get("input_sample_rate", 8000)
input_path = self.get_temp_path('input.wav')
data = get_wav_data(dtype, num_channels, channels_first=channels_first)
save_wav(input_path, data, input_sr, channels_first=channels_first)
with open(input_path, 'rb') as fileobj:
_found, _sr = sox_effects.apply_effects_file(
fileobj, effects, normalize=False, channels_first=channels_first)
import io
import itertools
from pathlib import Path
import tarfile
from torchaudio import sox_effects
from parameterized import parameterized
from torchaudio import sox_effects
from torchaudio._internal import module_utils as _mod_utils
from torchaudio_unittest.common_utils import (
TempDirMixin,
HttpServerMixin,
PytorchTestCase,
skipIfNoExtension,
skipIfNoModule,
skipIfNoExec,
get_asset_path,
get_sinusoid,
get_wav_data,
......@@ -21,6 +27,10 @@ from .common import (
)
if _mod_utils.is_module_available("requests"):
import requests
@skipIfNoExtension
class TestSoxEffects(PytorchTestCase):
def test_init(self):
......@@ -262,3 +272,152 @@ class TestApplyEffectFileWithoutExtension(PytorchTestCase):
path = get_asset_path("mp3_without_ext")
_, sr = sox_effects.apply_effects_file(path, effects, format="mp3")
assert sr == 16000
@skipIfNoExec('sox')
@skipIfNoExtension
class TestFileObject(TempDirMixin, PytorchTestCase):
@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_fileobj(self, ext, compression):
"""Applying effects via file object works"""
sample_rate = 16000
channels_first = True
effects = [['band', '300', '10']]
format_ = ext if ext in ['mp3'] else None
input_path = self.get_temp_path(f'input.{ext}')
reference_path = self.get_temp_path('reference.wav')
sox_utils.gen_audio_file(
input_path, sample_rate, num_channels=2, compression=compression)
sox_utils.run_sox_effect(
input_path, reference_path, effects, output_bitdepth=32)
expected, expected_sr = load_wav(reference_path)
with open(input_path, 'rb') as fileobj:
found, sr = sox_effects.apply_effects_file(
fileobj, effects, channels_first=channels_first, format=format_)
save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first)
assert sr == expected_sr
self.assertEqual(found, expected)
@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_bytesio(self, ext, compression):
"""Applying effects via BytesIO object works"""
sample_rate = 16000
channels_first = True
effects = [['band', '300', '10']]
format_ = ext if ext in ['mp3'] else None
input_path = self.get_temp_path(f'input.{ext}')
reference_path = self.get_temp_path('reference.wav')
sox_utils.gen_audio_file(
input_path, sample_rate, num_channels=2, compression=compression)
sox_utils.run_sox_effect(
input_path, reference_path, effects, output_bitdepth=32)
expected, expected_sr = load_wav(reference_path)
with open(input_path, 'rb') as file_:
fileobj = io.BytesIO(file_.read())
found, sr = sox_effects.apply_effects_file(
fileobj, effects, channels_first=channels_first, format=format_)
save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first)
assert sr == expected_sr
self.assertEqual(found, expected)
@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_tarfile(self, ext, compression):
"""Applying effects to compressed audio via file-like file works"""
sample_rate = 16000
channels_first = True
effects = [['band', '300', '10']]
format_ = ext if ext in ['mp3'] else None
audio_file = f'input.{ext}'
input_path = self.get_temp_path(audio_file)
reference_path = self.get_temp_path('reference.wav')
archive_path = self.get_temp_path('archive.tar.gz')
sox_utils.gen_audio_file(
input_path, sample_rate, num_channels=2, compression=compression)
sox_utils.run_sox_effect(
input_path, reference_path, effects, output_bitdepth=32)
expected, expected_sr = load_wav(reference_path)
with tarfile.TarFile(archive_path, 'w') as tarobj:
tarobj.add(input_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj:
fileobj = tarobj.extractfile(audio_file)
found, sr = sox_effects.apply_effects_file(
fileobj, effects, channels_first=channels_first, format=format_)
save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first)
assert sr == expected_sr
self.assertEqual(found, expected)
@skipIfNoExtension
@skipIfNoExec('sox')
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_requests(self, ext, compression):
sample_rate = 16000
channels_first = True
effects = [['band', '300', '10']]
format_ = ext if ext in ['mp3'] else None
audio_file = f'input.{ext}'
input_path = self.get_temp_path(audio_file)
reference_path = self.get_temp_path('reference.wav')
sox_utils.gen_audio_file(
input_path, sample_rate, num_channels=2, compression=compression)
sox_utils.run_sox_effect(
input_path, reference_path, effects, output_bitdepth=32)
expected, expected_sr = load_wav(reference_path)
url = self.get_url(audio_file)
with requests.get(url, stream=True) as resp:
found, sr = sox_effects.apply_effects_file(
resp.raw, effects, channels_first=channels_first, format=format_)
save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first)
assert sr == expected_sr
self.assertEqual(found, expected)
#include <torch/extension.h>
#include <torchaudio/csrc/sox/effects.h>
#include <torchaudio/csrc/sox/io.h>
#include <torchaudio/csrc/sox/legacy.h>
......@@ -112,4 +113,8 @@ PYBIND11_MODULE(_torchaudio, m) {
"save_audio_fileobj",
&torchaudio::sox_io::save_audio_fileobj,
"Save audio to file obj.");
m.def(
"apply_effects_fileobj",
&torchaudio::sox_effects::apply_effects_fileobj,
"Decode audio data from file-like obj and apply effects.");
}
import os
from pathlib import Path
from typing import List, Tuple, Optional
import torch
import torchaudio
from torchaudio._internal import module_utils as _mod_utils
from torchaudio.utils.sox_utils import list_effects
......@@ -170,8 +173,16 @@ def apply_effects_file(
rate and leave samples untouched.
Args:
path (str or pathlib.Path): Path to the audio file. This function also handles ``pathlib.Path`` objects, but is
annotated as ``str`` for TorchScript compiler compatibility.
path (path-like object or file-like object):
Source of audio data. When the function is not compiled by TorchScript,
(e.g. ``torch.jit.script``), the following types are accepted;
* ``path-like``: file path
* ``file-like``: Object with ``read(size: int) -> bytes`` method,
which returns byte string of at most ``size`` length.
When the function is compiled by TorchScript, only ``str`` type is allowed.
Note:
* This argument is intentionally annotated as ``str`` only for
TorchScript compiler compatibility.
effects (List[List[str]]): List of effects.
normalize (bool):
When ``True``, this function always return ``float32``, and sample values are
......@@ -252,8 +263,11 @@ def apply_effects_file(
>>> for batch in loader:
>>> pass
"""
# Get string representation of 'path' in case Path object is passed
path = str(path)
if not torch.jit.is_scripting():
if hasattr(path, 'read'):
return torchaudio._torchaudio.apply_effects_fileobj(
path, effects, normalize, channels_first, format)
path = os.fspath(path)
signal = torch.ops.torchaudio.sox_effects_apply_effects_file(
path, effects, normalize, channels_first, format)
return signal.get_tensor(), signal.get_sample_rate()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment