You need to sign in or sign up before continuing.
Commit a422f3fe authored by jamarshon's avatar jamarshon Committed by cpuhrsch
Browse files

Add Kaldi IO as a dependency + put a wrapper to convert to Tensor + add test...

Add Kaldi IO as a dependency + put  a wrapper to convert to Tensor + add test to check correct type (#111)
parent 33dac6ac
...@@ -6,6 +6,7 @@ torchaudio: an audio library for PyTorch ...@@ -6,6 +6,7 @@ torchaudio: an audio library for PyTorch
- mp3, wav, aac, ogg, flac, avr, cdda, cvs/vms, - mp3, wav, aac, ogg, flac, avr, cdda, cvs/vms,
- aiff, au, amr, mp2, mp4, ac3, avi, wmv, - aiff, au, amr, mp2, mp4, ac3, avi, wmv,
- mpeg, ircam and any other format supported by libsox. - mpeg, ircam and any other format supported by libsox.
- [Kaldi (ark/scp)](http://pytorch.org/audio/kaldi_io.html)
- [Dataloaders for common audio datasets (VCTK, YesNo)](http://pytorch.org/audio/datasets.html) - [Dataloaders for common audio datasets (VCTK, YesNo)](http://pytorch.org/audio/datasets.html)
- Common audio transforms - Common audio transforms
- [Scale, PadTrim, DownmixMono, LC2CL, BLC2CBL, MuLawEncoding, MuLawExpanding](http://pytorch.org/audio/transforms.html) - [Scale, PadTrim, DownmixMono, LC2CL, BLC2CBL, MuLawEncoding, MuLawExpanding](http://pytorch.org/audio/transforms.html)
...@@ -13,6 +14,7 @@ torchaudio: an audio library for PyTorch ...@@ -13,6 +14,7 @@ torchaudio: an audio library for PyTorch
Dependencies Dependencies
------------ ------------
* libsox v14.3.2 or above * libsox v14.3.2 or above
* [optional] vesis84/kaldi-io-for-python commit cb46cb1f44318a5d04d4941cf39084c5b021241e or above
Quick install on Quick install on
OSX (Homebrew): OSX (Homebrew):
...@@ -31,7 +33,7 @@ Installation ...@@ -31,7 +33,7 @@ Installation
# Linux # Linux
python setup.py install python setup.py install
# OSX # OSX
MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py install MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py install
``` ```
......
...@@ -9,6 +9,7 @@ The :mod:`torchaudio` package consists of I/O, popular datasets and common audio ...@@ -9,6 +9,7 @@ The :mod:`torchaudio` package consists of I/O, popular datasets and common audio
sox_effects sox_effects
datasets datasets
kaldi_io
transforms transforms
legacy legacy
......
torchaudio.kaldi_io
======================
.. currentmodule:: torchaudio.kaldi_io
To use this module, the dependency kaldi_io_ needs to be installed.
This is a light wrapper around ``kaldi_io`` that returns :class:`torch.Tensors`.
.. _kaldi_io: https://github.com/vesis84/kaldi-io-for-python
Vectors
~~~~~
.. autodata:: read_vec_int_ark
.. autodata:: read_vec_flt_scp
.. autodata:: read_vec_flt_ark
Matrices
~~~~~
.. autodata:: read_mat_scp
.. autodata:: read_mat_ark
import os
import torch
import torchaudio.kaldi_io as kio
import unittest
class KaldiIOTest(unittest.TestCase):
data1 = [[1, 2, 3], [11, 12, 13], [21, 22, 23]]
data2 = [[31, 32, 33], [41, 42, 43], [51, 52, 53]]
test_dirpath = os.path.dirname(os.path.realpath(__file__))
def _test_helper(self, file_name, expected_data, fn, expected_dtype):
""" Takes a file_name to the input data and a function fn to extract the
data. It compares the extracted data to the expected_data. The expected_dtype
will be used to check that the extracted data is of the right type.
"""
test_filepath = os.path.join(self.test_dirpath, "assets", file_name)
expected_output = {'key' + str(idx + 1): torch.tensor(val, dtype=expected_dtype)
for idx, val in enumerate(expected_data)}
for key, vec in fn(test_filepath):
self.assertTrue(key in expected_output)
self.assertTrue(isinstance(vec, torch.Tensor))
self.assertEqual(vec.dtype, expected_dtype)
self.assertTrue(torch.all(torch.eq(vec, expected_output[key])))
def test_read_vec_int_ark(self):
self._test_helper("vec_int.ark", self.data1, kio.read_vec_int_ark, torch.int32)
def test_read_vec_flt_ark(self):
self._test_helper("vec_flt.ark", self.data1, kio.read_vec_flt_ark, torch.float32)
def test_read_mat_ark(self):
self._test_helper("mat.ark", [self.data1, self.data2], kio.read_mat_ark, torch.float32)
if __name__ == '__main__':
unittest.main()
# To use this file, the dependency (https://github.com/vesis84/kaldi-io-for-python)
# needs to be installed. This is a light wrapper around kaldi_io that returns
# torch.Tensors.
import numpy as np
import torch
__all__ = [
'read_vec_int_ark',
'read_vec_flt_scp',
'read_vec_flt_ark',
'read_mat_scp',
'read_mat_ark',
]
def _default_not_imported_method():
raise ImportError('Could not import kaldi_io. Did you install it?')
def _wrap_method(fn, convert_contiguous=False):
# type: (Function, bool) -> Function
""" Takes a method with the signature (file name/descriptor) -> generator(string, ndarray)
and converts it to (file name/descriptor) -> generator(string, Tensor).
convert_contiguous determines whether the array should be converted into a
contiguous layout.
"""
def _wrapped_fn(file_or_fd):
for key, np_arr in fn(file_or_fd):
if convert_contiguous:
np_arr = np.ascontiguousarray(np_arr)
yield key, torch.from_numpy(np_arr)
return _wrapped_fn
#: Create generator of (key,vector<int>) tuples, which reads from the ark file/stream.
#:
#: file_or_fd : ark, gzipped ark, pipe or opened file descriptor.
#:
#: Example, read ark to a 'dictionary':
#:
#: >>> # generator(key,vec) = torchaudio.kaldi_io.read_vec_int_ark(file_or_fd)
#: >>> d = { u:d for u,d in torchaudio.kaldi_io.read_vec_int_ark(file) }
read_vec_int_ark = _default_not_imported_method
#: Create generator of (key,vector<float32/float64>) tuples, read according to kaldi scp.
#:
#: file_or_fd : scp, gzipped scp, pipe or opened file descriptor.
#:
#: Example, read scp to a 'dictionary':
#:
#: >>> # generator(key,vec) = torchaudio.kaldi_io.read_vec_flt_scp(file_or_fd)
#: >>> d = { u:d for u,d in torchaudio.kaldi_io.read_vec_flt_scp(file) }
read_vec_flt_scp = _default_not_imported_method
#: Create generator of (key,vector<float32/float64>) tuples, which reads from the ark file/stream.
#:
#: file_or_fd : ark, gzipped ark, pipe or opened file descriptor.
#:
#: Example, read ark to a 'dictionary':
#:
#: >>> # generator(key,vec) = torchaudio.kaldi_io.read_vec_flt_ark(file_or_fd)
#: >>> d = { u:d for u,d in torchaudio.kaldi_io.read_vec_flt_ark(file) }
read_vec_flt_ark = _default_not_imported_method
#: Create generator of (key,matrix<float32/float64>) tuples, read according to kaldi scp.
#:
#: file_or_fd : scp, gzipped scp, pipe or opened file descriptor.
#:
#: Example, read scp to a 'dictionary':
#:
#: >>> # generator(key,mat) = torchaudio.kaldi_io.read_mat_scp(file_or_fd)
#: >>> d = { u:d for u,d in torchaudio.kaldi_io.read_mat_scp(file) }
read_mat_scp = _default_not_imported_method
#: Create generator of (key,matrix<float32/float64>) tuples, which reads from the ark file/stream.
#:
#: file_or_fd : ark, gzipped ark, pipe or opened file descriptor.
#:
#: Example, read ark to a 'dictionary':
#:
#: >>> # generator(key,mat) = torchaudio.kaldi_io.read_mat_ark(file_or_fd)
#: >>> d = { u:d for u,d in torchaudio.kaldi_io.read_mat_ark(file) }
read_mat_ark = _default_not_imported_method
try:
import kaldi_io
# Overwrite methods
# Elements from int32 vector are sored in tuples: (sizeof(int32), value)
# so strides are (5,) instead of (4,) which will throw an error in from_numpy
# as it expects strides to be a multiple of 4 (int32).
read_vec_int_ark = _wrap_method(kaldi_io.read_vec_int_ark, convert_contiguous=True)
read_vec_flt_scp = _wrap_method(kaldi_io.read_vec_flt_scp)
read_vec_flt_ark = _wrap_method(kaldi_io.read_vec_flt_ark)
read_mat_scp = _wrap_method(kaldi_io.read_mat_scp)
read_mat_ark = _wrap_method(kaldi_io.read_mat_ark)
except ImportError:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment