Unverified Commit 3488f314 authored by SJ's avatar SJ Committed by GitHub
Browse files

Add HTK format support to sox_io's save & info (#1276)

parent a70931f1
...@@ -205,7 +205,7 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -205,7 +205,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.encoding == "ULAW" assert info.encoding == "ULAW"
def test_alaw(self): def test_alaw(self):
"""`sox_io_backend.info` can check ulaw file correctly""" """`sox_io_backend.info` can check alaw file correctly"""
duration = 1 duration = 1
num_channels = 1 num_channels = 1
sample_rate = 8000 sample_rate = 8000
...@@ -221,6 +221,22 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -221,6 +221,22 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.bits_per_sample == 8 assert info.bits_per_sample == 8
assert info.encoding == "ALAW" assert info.encoding == "ALAW"
def test_htk(self):
"""`sox_io_backend.info` can check HTK file correctly"""
duration = 1
num_channels = 1
sample_rate = 8000
path = self.get_temp_path('data.htk')
sox_utils.gen_audio_file(
path, sample_rate=sample_rate, num_channels=num_channels,
bit_depth=16, duration=duration)
info = sox_io_backend.info(path)
assert info.sample_rate == sample_rate
assert info.num_frames == sample_rate * duration
assert info.num_channels == num_channels
assert info.bits_per_sample == 16
assert info.encoding == "PCM_S"
@skipIfNoExtension @skipIfNoExtension
class TestInfoOpus(PytorchTestCase): class TestInfoOpus(PytorchTestCase):
......
...@@ -237,6 +237,12 @@ class SaveTest(SaveTestBase): ...@@ -237,6 +237,12 @@ class SaveTest(SaveTestBase):
"flac", compression=compression_level, "flac", compression=compression_level,
bits_per_sample=bits_per_sample, test_mode=test_mode) bits_per_sample=bits_per_sample, test_mode=test_mode)
@nested_params(
["path", "fileobj", "bytesio"],
)
def test_save_htk(self, test_mode):
self.assert_save_consistency("htk", test_mode=test_mode, num_channels=1)
@nested_params( @nested_params(
["path", "fileobj", "bytesio"], ["path", "fileobj", "bytesio"],
[ [
......
...@@ -195,7 +195,8 @@ def save( ...@@ -195,7 +195,8 @@ def save(
When ``filepath`` argument is file-like object, this argument is required. When ``filepath`` argument is file-like object, this argument is required.
Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``, Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``,
``"amb"``, ``"flac"``, ``"sph"`` and ``"gsm"``. ``"amb"``, ``"flac"``, ``"sph"``, ``"gsm"``, and ``"htk"``.
encoding (str, optional): Changes the encoding for the supported formats. encoding (str, optional): Changes the encoding for the supported formats.
This argument is effective only for supported formats, cush as ``"wav"``, ``""amb"`` This argument is effective only for supported formats, cush as ``"wav"``, ``""amb"``
and ``"sph"``. Valid values are; and ``"sph"``. Valid values are;
...@@ -294,6 +295,9 @@ def save( ...@@ -294,6 +295,9 @@ def save(
``"gsm"`` ``"gsm"``
Lossy Speech Compression, CPU intensive. Lossy Speech Compression, CPU intensive.
``"htk"``
Uses a default single-channel 16-bit PCM format.
Note: Note:
To save into formats that ``libsox`` does not handle natively, (such as ``"mp3"``, To save into formats that ``libsox`` does not handle natively, (such as ``"mp3"``,
``"flac"``, ``"ogg"`` and ``"vorbis"``), your installation of ``torchaudio`` has ``"flac"``, ``"ogg"`` and ``"vorbis"``), your installation of ``torchaudio`` has
......
...@@ -20,6 +20,8 @@ Format get_format_from_string(const std::string& format) { ...@@ -20,6 +20,8 @@ Format get_format_from_string(const std::string& format) {
return Format::AMB; return Format::AMB;
if (format == "sph") if (format == "sph")
return Format::SPHERE; return Format::SPHERE;
if (format == "htk")
return Format::HTK;
if (format == "gsm") if (format == "gsm")
return Format::GSM; return Format::GSM;
std::ostringstream stream; std::ostringstream stream;
......
...@@ -16,6 +16,7 @@ enum class Format { ...@@ -16,6 +16,7 @@ enum class Format {
AMB, AMB,
SPHERE, SPHERE,
GSM, GSM,
HTK,
}; };
Format get_format_from_string(const std::string& format); Format get_format_from_string(const std::string& format);
......
...@@ -314,6 +314,13 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding( ...@@ -314,6 +314,13 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding(
throw std::runtime_error( throw std::runtime_error(
"mp3 does not support `bits_per_sample` option."); "mp3 does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_MP3, 16); return std::make_tuple<>(SOX_ENCODING_MP3, 16);
case Format::HTK:
if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("htk does not support `encoding` option.");
if (bps != BitDepth::NOT_PROVIDED)
throw std::runtime_error(
"htk does not support `bits_per_sample` option.");
return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
case Format::VORBIS: case Format::VORBIS:
if (enc != Encoding::NOT_PROVIDED) if (enc != Encoding::NOT_PROVIDED)
throw std::runtime_error("vorbis does not support `encoding` option."); throw std::runtime_error("vorbis does not support `encoding` option.");
...@@ -417,8 +424,12 @@ unsigned get_precision(const std::string filetype, caffe2::TypeMeta dtype) { ...@@ -417,8 +424,12 @@ unsigned get_precision(const std::string filetype, caffe2::TypeMeta dtype) {
if (filetype == "amr-nb") { if (filetype == "amr-nb") {
return 16; return 16;
} }
if (filetype == "gsm") if (filetype == "gsm") {
return 16; return 16;
}
if (filetype == "htk") {
return 16;
}
throw std::runtime_error("Unsupported file type: " + filetype); throw std::runtime_error("Unsupported file type: " + filetype);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment