Unverified Commit 88fccd14 authored by moto's avatar moto Committed by GitHub
Browse files

Add TorchScript-able "info" func to sox_io backend (#728)

This is a part of PRs to add new "sox_io" backend #726, and depends on #718.

This PR adds `info` function to "sox_io" backend, which allows users to fetch some metadata of an audio file. 
At this moment, the information retrieved are;

 - Number of samples in the audio file
 - Sampling rate
 - Number of channels
parent f8eac89b
...@@ -20,4 +20,4 @@ printf "Installing PyTorch with %s\n" "${cudatoolkit}" ...@@ -20,4 +20,4 @@ printf "Installing PyTorch with %s\n" "${cudatoolkit}"
conda install -y -c pytorch-nightly pytorch "${cudatoolkit}" conda install -y -c pytorch-nightly pytorch "${cudatoolkit}"
printf "* Installing torchaudio\n" printf "* Installing torchaudio\n"
BUILD_SOX=1 python setup.py develop python setup.py develop
...@@ -34,4 +34,7 @@ printf "* Installing dependencies (except PyTorch)\n" ...@@ -34,4 +34,7 @@ printf "* Installing dependencies (except PyTorch)\n"
conda env update --file "${this_dir}/environment.yml" --prune conda env update --file "${this_dir}/environment.yml" --prune
# 4. Build codecs # 4. Build codecs
build_tools/setup_helpers/build_third_party.sh # build_tools/setup_helpers/build_third_party.sh
# 4. Install codecs
apt update -q
apt install -y -q sox libsox-dev libsox-fmt-all
import os import os
import shutil
import tempfile import tempfile
import unittest import unittest
from typing import Union from typing import Union
...@@ -7,6 +8,7 @@ from shutil import copytree ...@@ -7,6 +8,7 @@ from shutil import copytree
import torch import torch
from torch.testing._internal.common_utils import TestCase as PytorchTestCase from torch.testing._internal.common_utils import TestCase as PytorchTestCase
import torchaudio import torchaudio
from torchaudio._internal.module_utils import is_module_available
_TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__)) _TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
BACKENDS = torchaudio.list_audio_backends() BACKENDS = torchaudio.list_audio_backends()
...@@ -87,6 +89,33 @@ def set_audio_backend(backend): ...@@ -87,6 +89,33 @@ def set_audio_backend(backend):
torchaudio.set_audio_backend(be) torchaudio.set_audio_backend(be)
class TempDirMixin:
"""Mixin to provide easy access to temp dir"""
temp_dir_ = None
temp_dir = None
def setUp(self):
super().setUp()
self._init_temp_dir()
def tearDown(self):
super().tearDownClass()
self._clean_up_temp_dir()
def _init_temp_dir(self):
self.temp_dir_ = tempfile.TemporaryDirectory()
self.temp_dir = self.temp_dir_.name
def _clean_up_temp_dir(self):
if self.temp_dir_ is not None:
self.temp_dir_.cleanup()
self.temp_dir_ = None
self.temp_dir = None
def get_temp_path(self, *paths):
return os.path.join(self.temp_dir, *paths)
class TestBaseMixin: class TestBaseMixin:
"""Mixin to provide consistent way to define device/dtype/backend aware TestCase""" """Mixin to provide consistent way to define device/dtype/backend aware TestCase"""
dtype = None dtype = None
...@@ -102,8 +131,18 @@ class TorchaudioTestCase(TestBaseMixin, PytorchTestCase): ...@@ -102,8 +131,18 @@ class TorchaudioTestCase(TestBaseMixin, PytorchTestCase):
pass pass
def skipIfNoExec(cmd):
return unittest.skipIf(shutil.which(cmd) is None, f'`{cmd}` is not available')
def skipIfNoModule(module, display_name=None):
display_name = display_name or module
return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available')
skipIfNoSoxBackend = unittest.skipIf('sox' not in BACKENDS, 'Sox backend not available') skipIfNoSoxBackend = unittest.skipIf('sox' not in BACKENDS, 'Sox backend not available')
skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available') skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')
skipIfNoExtension = skipIfNoModule('torchaudio._torchaudio', 'torchaudio C++ extension')
def get_whitenoise( def get_whitenoise(
......
"""Test suites for checking numerical compatibility against Kaldi""" """Test suites for checking numerical compatibility against Kaldi"""
import json import json
import shutil
import unittest
import subprocess import subprocess
import kaldi_io import kaldi_io
...@@ -13,10 +11,6 @@ from . import common_utils ...@@ -13,10 +11,6 @@ from . import common_utils
from parameterized import parameterized, param from parameterized import parameterized, param
def _not_available(cmd):
return shutil.which(cmd) is None
def _convert_args(**kwargs): def _convert_args(**kwargs):
args = [] args = []
for key, value in kwargs.items(): for key, value in kwargs.items():
...@@ -61,7 +55,7 @@ class Kaldi(common_utils.TestBaseMixin): ...@@ -61,7 +55,7 @@ class Kaldi(common_utils.TestBaseMixin):
expected = expected.to(dtype=self.dtype, device=self.device) expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol) self.assertEqual(output, expected, rtol=rtol, atol=atol)
@unittest.skipIf(_not_available('apply-cmvn-sliding'), '`apply-cmvn-sliding` not available') @common_utils.skipIfNoExec('apply-cmvn-sliding')
def test_sliding_window_cmn(self): def test_sliding_window_cmn(self):
"""sliding_window_cmn should be numerically compatible with apply-cmvn-sliding""" """sliding_window_cmn should be numerically compatible with apply-cmvn-sliding"""
kwargs = { kwargs = {
...@@ -78,7 +72,7 @@ class Kaldi(common_utils.TestBaseMixin): ...@@ -78,7 +72,7 @@ class Kaldi(common_utils.TestBaseMixin):
self.assert_equal(result, expected=kaldi_result) self.assert_equal(result, expected=kaldi_result)
@parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_fbank_args.json'))) @parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_fbank_args.json')))
@unittest.skipIf(_not_available('compute-fbank-feats'), '`compute-fbank-feats` not available') @common_utils.skipIfNoExec('compute-fbank-feats')
def test_fbank(self, kwargs): def test_fbank(self, kwargs):
"""fbank should be numerically compatible with compute-fbank-feats""" """fbank should be numerically compatible with compute-fbank-feats"""
wave_file = common_utils.get_asset_path('kaldi_file.wav') wave_file = common_utils.get_asset_path('kaldi_file.wav')
...@@ -89,7 +83,7 @@ class Kaldi(common_utils.TestBaseMixin): ...@@ -89,7 +83,7 @@ class Kaldi(common_utils.TestBaseMixin):
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
@parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_spectrogram_args.json'))) @parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_spectrogram_args.json')))
@unittest.skipIf(_not_available('compute-spectrogram-feats'), '`compute-spectrogram-feats` not available') @common_utils.skipIfNoExec('compute-spectrogram-feats')
def test_spectrogram(self, kwargs): def test_spectrogram(self, kwargs):
"""spectrogram should be numerically compatible with compute-spectrogram-feats""" """spectrogram should be numerically compatible with compute-spectrogram-feats"""
wave_file = common_utils.get_asset_path('kaldi_file.wav') wave_file = common_utils.get_asset_path('kaldi_file.wav')
...@@ -100,7 +94,7 @@ class Kaldi(common_utils.TestBaseMixin): ...@@ -100,7 +94,7 @@ class Kaldi(common_utils.TestBaseMixin):
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
@parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_mfcc_args.json'))) @parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_mfcc_args.json')))
@unittest.skipIf(_not_available('compute-mfcc-feats'), '`compute-mfcc-feats` not available') @common_utils.skipIfNoExec('compute-mfcc-feats')
def test_mfcc(self, kwargs): def test_mfcc(self, kwargs):
"""mfcc should be numerically compatible with compute-mfcc-feats""" """mfcc should be numerically compatible with compute-mfcc-feats"""
wave_file = common_utils.get_asset_path('kaldi_file.wav') wave_file = common_utils.get_asset_path('kaldi_file.wav')
......
def get_test_name(func, _, params):
return f'{func.__name__}_{"_".join(str(p) for p in params.args)}'
import subprocess
def get_encoding(dtype):
encodings = {
'float32': 'floating-point',
'int32': 'signed-integer',
'int16': 'signed-integer',
'uint8': 'unsigned-integer',
}
return encodings[dtype]
def get_bit_depth(dtype):
bit_depths = {
'float32': 32,
'int32': 32,
'int16': 16,
'uint8': 8,
}
return bit_depths[dtype]
def gen_audio_file(
path, sample_rate, num_channels,
*, encoding=None, bit_depth=None, compression=None, attenuation=None, duration=1,
):
"""Generate synthetic audio file with `sox` command."""
command = [
'sox',
'-V', # verbose
'--rate', str(sample_rate),
'--null', # no input
'--channels', str(num_channels),
]
if compression is not None:
command += ['--compression', str(compression)]
if bit_depth is not None:
command += ['--bits', str(bit_depth)]
if encoding is not None:
command += ['--encoding', str(encoding)]
command += [
str(path),
'synth', str(duration), # synthesizes for the given duration [sec]
'sawtooth', '1',
# saw tooth covers the both ends of value range, which is a good property for test.
# similar to linspace(-1., 1.)
# this introduces bigger boundary effect than sine when converted to mp3
]
if attenuation is not None:
command += ['vol', f'-{attenuation}dB']
print(' '.join(command))
subprocess.run(command, check=True)
subprocess.run(['soxi', path], check=True)
import itertools
from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from ..common_utils import (
TempDirMixin,
PytorchTestCase,
skipIfNoExec,
skipIfNoExtension,
)
from .common import (
get_test_name
)
from . import sox_utils
@skipIfNoExec('sox')
@skipIfNoExtension
class TestInfo(TempDirMixin, PytorchTestCase):
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=get_test_name)
def test_wav(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file correctly"""
duration = 1
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=sox_utils.get_bit_depth(dtype),
encoding=sox_utils.get_encoding(dtype),
duration=duration,
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration
assert info.get_num_channels() == num_channels
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[4, 8, 16, 32],
)), name_func=get_test_name)
def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file with channels more than 2 correctly"""
duration = 1
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
bit_depth=sox_utils.get_bit_depth(dtype),
encoding=sox_utils.get_encoding(dtype),
duration=duration,
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration
assert info.get_num_channels() == num_channels
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[96, 128, 160, 192, 224, 256, 320],
)), name_func=get_test_name)
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.info` can check mp3 file correctly"""
duration = 1
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}k.mp3')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=bit_rate, duration=duration,
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
# mp3 does not preserve the number of samples
# assert info.get_num_samples() == sample_rate * duration
assert info.get_num_channels() == num_channels
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=get_test_name)
def test_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.info` can check flac file correctly"""
duration = 1
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}.flac')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=compression_level, duration=duration,
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration
assert info.get_num_channels() == num_channels
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)), name_func=get_test_name)
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.info` can check vorbis file correctly"""
duration = 1
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}.vorbis')
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=quality_level, duration=duration,
)
info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate
assert info.get_num_samples() == sample_rate * duration
assert info.get_num_channels() == num_channels
import itertools
import torch
from torchaudio.backend import sox_io_backend
from parameterized import parameterized
from ..common_utils import (
TempDirMixin,
TorchaudioTestCase,
skipIfNoExec,
skipIfNoExtension,
)
from .common import (
get_test_name,
)
from . import sox_utils
def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo:
return sox_io_backend.info(filepath)
@skipIfNoExec('sox')
@skipIfNoExtension
class SoxIO(TempDirMixin, TorchaudioTestCase):
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
)), name_func=get_test_name)
def test_info_wav(self, dtype, sample_rate, num_channels):
audio_path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
sox_utils.gen_audio_file(
audio_path, sample_rate, num_channels,
bit_depth=sox_utils.get_bit_depth(dtype),
encoding=sox_utils.get_encoding(dtype),
)
script_path = self.get_temp_path('info_func')
torch.jit.script(py_info_func).save(script_path)
ts_info_func = torch.jit.load(script_path)
py_info = py_info_func(audio_path)
ts_info = ts_info_func(audio_path)
assert py_info.get_sample_rate() == ts_info.get_sample_rate()
assert py_info.get_num_samples() == ts_info.get_num_samples()
assert py_info.get_num_channels() == ts_info.get_num_channels()
import unittest import unittest
import torchaudio import torchaudio
from torchaudio._internal.module_utils import is_module_available
from . import common_utils from . import common_utils
...@@ -28,15 +27,13 @@ class TestBackendSwitch_NoBackend(BackendSwitchMixin, common_utils.TorchaudioTes ...@@ -28,15 +27,13 @@ class TestBackendSwitch_NoBackend(BackendSwitchMixin, common_utils.TorchaudioTes
backend_module = torchaudio.backend.no_backend backend_module = torchaudio.backend.no_backend
@unittest.skipIf( @common_utils.skipIfNoExtension
not is_module_available('torchaudio._torchaudio'),
'torchaudio C++ extension not available')
class TestBackendSwitch_SoX(BackendSwitchMixin, common_utils.TorchaudioTestCase): class TestBackendSwitch_SoX(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'sox' backend = 'sox'
backend_module = torchaudio.backend.sox_backend backend_module = torchaudio.backend.sox_backend
@unittest.skipIf(not is_module_available('soundfile'), '"soundfile" not available') @common_utils.skipIfNoModule('soundfile')
class TestBackendSwitch_soundfile(BackendSwitchMixin, common_utils.TorchaudioTestCase): class TestBackendSwitch_soundfile(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'soundfile' backend = 'soundfile'
backend_module = torchaudio.backend.soundfile_backend backend_module = torchaudio.backend.soundfile_backend
import torch
from torchaudio._internal import (
module_utils as _mod_utils,
)
@_mod_utils.requires_module('torchaudio._torchaudio')
def info(filepath: str) -> torch.classes.torchaudio.SignalInfo:
"""Get signal information of an audio file."""
return torch.ops.torchaudio.sox_io_get_info(filepath)
#ifndef TORCHAUDIO_REGISTER_H #ifndef TORCHAUDIO_REGISTER_H
#define TORCHAUDIO_REGISTER_H #define TORCHAUDIO_REGISTER_H
#include <torchaudio/csrc/sox_io.h>
#include <torchaudio/csrc/typedefs.h> #include <torchaudio/csrc/typedefs.h>
namespace torchaudio { namespace torchaudio {
...@@ -13,6 +14,12 @@ static auto registerSignalInfo = ...@@ -13,6 +14,12 @@ static auto registerSignalInfo =
.def("get_num_channels", &SignalInfo::getNumChannels) .def("get_num_channels", &SignalInfo::getNumChannels)
.def("get_num_samples", &SignalInfo::getNumSamples); .def("get_num_samples", &SignalInfo::getNumSamples);
static auto registerGetInfo = torch::RegisterOperators().op(
torch::RegisterOperators::options()
.schema(
"torchaudio::sox_io_get_info(str path) -> __torch__.torch.classes.torchaudio.SignalInfo info")
.catchAllKernel<decltype(sox_io::get_info), &sox_io::get_info>());
} // namespace } // namespace
} // namespace torchaudio } // namespace torchaudio
#endif #endif
#include <sox.h>
#include <torchaudio/csrc/sox_io.h>
using namespace torch::indexing;
namespace torchaudio {
namespace sox_io {
namespace {
/// Helper struct to safely close the sox_format_t descriptor.
struct SoxDescriptor {
explicit SoxDescriptor(sox_format_t* fd) noexcept : fd_(fd) {}
SoxDescriptor(const SoxDescriptor& other) = delete;
SoxDescriptor(SoxDescriptor&& other) = delete;
SoxDescriptor& operator=(const SoxDescriptor& other) = delete;
SoxDescriptor& operator=(SoxDescriptor&& other) = delete;
~SoxDescriptor() {
if (fd_ != nullptr) {
sox_close(fd_);
}
}
sox_format_t* operator->() noexcept {
return fd_;
}
sox_format_t* get() noexcept {
return fd_;
}
private:
sox_format_t* fd_;
};
} // namespace
c10::intrusive_ptr<::torchaudio::SignalInfo> get_info(
const std::string& file_name) {
SoxDescriptor sd(sox_open_read(
file_name.c_str(),
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/nullptr));
if (sd.get() == nullptr) {
throw std::runtime_error("Error opening audio file");
}
return c10::make_intrusive<::torchaudio::SignalInfo>(
static_cast<int64_t>(sd->signal.rate),
static_cast<int64_t>(sd->signal.channels),
static_cast<int64_t>(sd->signal.length / sd->signal.channels));
}
} // namespace sox_io
} // namespace torchaudio
#include <torch/script.h>
#include <torchaudio/csrc/typedefs.h>
namespace torchaudio {
namespace sox_io {
c10::intrusive_ptr<::torchaudio::SignalInfo> get_info(
const std::string& file_name);
} // namespace sox_io
} // namespace torchaudio
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment