Commit 740c5a86 authored by David Pollack's avatar David Pollack
Browse files

update save/load

.gitignore _ext/
parent ecb538df
...@@ -3,8 +3,9 @@ __pycache__/ ...@@ -3,8 +3,9 @@ __pycache__/
*.py[cod] *.py[cod]
*$py.class *$py.class
# C extensions # C extensions / folders
*.so *.so
_ext/
# Distribution / packaging # Distribution / packaging
.Python .Python
......
import unittest
import torch
import torch.nn as nn
import torchaudio import torchaudio
import os
class Test_LoadSave(unittest.TestCase):
test_dirpath = os.path.dirname(os.path.realpath(__file__))
test_filepath = os.path.join(test_dirpath, "steam-train-whistle-daniel_simon.mp3")
def test_load(self):
# check normal loading
x, sr = torchaudio.load(self.test_filepath)
self.assertEqual(sr, 44100)
self.assertEqual(x.size(), (278756,2))
# check normalizing
x, sr = torchaudio.load(self.test_filepath, normalization=True)
self.assertTrue(x.min() >= -1.0)
self.assertTrue(x.max() <= 1.0)
# check raising errors
with self.assertRaises(OSError):
torchaudio.load("file-does-not-exist.mp3")
with self.assertRaises(OSError):
tdir = os.path.join(os.path.dirname(self.test_dirpath), "torchaudio")
torchaudio.load(tdir)
def test_save(self):
# load signal
x, sr = torchaudio.load(self.test_filepath)
# check save
new_filepath = os.path.join(self.test_dirpath, "test.wav")
torchaudio.save(new_filepath, x, sr)
self.assertTrue(os.path.isfile(new_filepath))
os.unlink(new_filepath)
# check automatic normalization
x /= 1 << 31
torchaudio.save(new_filepath, x, sr)
self.assertTrue(os.path.isfile(new_filepath))
os.unlink(new_filepath)
# test save 1d tensor
x = x[:, 0] # get mono signal
x.squeeze_() # remove channel dim
torchaudio.save(new_filepath, x, sr)
self.assertTrue(os.path.isfile(new_filepath))
os.unlink(new_filepath)
# don't allow invalid sizes as inputs
with self.assertRaises(ValueError):
x.unsqueeze_(0) # N x L not L x N
torchaudio.save(new_filepath, x, sr)
with self.assertRaises(ValueError):
x.squeeze_()
x.unsqueeze_(1)
x.unsqueeze_(0) # 1 x L x 1
torchaudio.save(new_filepath, x, sr)
# automatically convert sr from floating point to int
x.squeeze_(0)
torchaudio.save(new_filepath, x, float(sr))
self.assertTrue(os.path.isfile(new_filepath))
os.unlink(new_filepath)
# don't allow uneven integers
with self.assertRaises(TypeError):
torchaudio.save(new_filepath, x, float(sr) + 0.5)
self.assertTrue(os.path.isfile(new_filepath))
os.unlink(new_filepath)
# don't save to folders that don't exist
with self.assertRaises(OSError):
new_filepath = os.path.join(self.test_dirpath, "no-path", "test.wav")
torchaudio.save(new_filepath, x, sr)
x, sample_rate = torchaudio.load("steam-train-whistle-daniel_simon.mp3") if __name__ == '__main__':
print(sample_rate) unittest.main()
print(x.size())
print(x[10000])
print(x.min(), x.max())
print(x.mean(), x.std())
x, sample_rate = torchaudio.load("steam-train-whistle-daniel_simon.mp3",
out=torch.LongTensor())
print(sample_rate)
print(x.size())
print(x[10000])
print(x.min(), x.max())
...@@ -14,27 +14,52 @@ def check_input(src): ...@@ -14,27 +14,52 @@ def check_input(src):
if not src.__module__ == 'torch': if not src.__module__ == 'torch':
raise TypeError('Expected a CPU based tensor, got %s' % type(src)) raise TypeError('Expected a CPU based tensor, got %s' % type(src))
def load(filepath, out=None, normalization=None):
def load(filename, out=None): # check if valid file
if not os.path.isfile(filepath):
raise OSError("{} not found or is a directory".format(filepath))
# initialize output tensor
if out is not None: if out is not None:
check_input(out) check_input(out)
else: else:
out = torch.FloatTensor() out = torch.FloatTensor()
# load audio signal
typename = type(out).__name__.replace('Tensor', '') typename = type(out).__name__.replace('Tensor', '')
func = getattr(th_sox, 'libthsox_{}_read_audio_file'.format(typename)) func = getattr(th_sox, 'libthsox_{}_read_audio_file'.format(typename))
sample_rate_p = ffi.new('int*') sample_rate_p = ffi.new('int*')
func(str(filename).encode("ascii"), out, sample_rate_p) func(str(filepath).encode("utf-8"), out, sample_rate_p)
sample_rate = sample_rate_p[0] sample_rate = sample_rate_p[0]
# normalize if needed
if isinstance(normalization, bool) and normalization:
out /= 1 << 31 # assuming 16-bit depth
elif isinstance(normalization, (float, int)):
out /= normalization # normalize with custom value
return out, sample_rate return out, sample_rate
def save(filepath, src, sample_rate): def save(filepath, src, sample_rate):
# check if save directory exists
abs_dirpath = os.path.dirname(os.path.abspath(filepath))
if not os.path.isdir(abs_dirpath):
raise OSError("Directory does not exist: {}".format(abs_dirpath))
# Check/Fix shape of source data
if len(src.size()) == 1:
# 1d tensors as assumed to be mono signals
src.unsqueeze_(1)
elif len(src.size()) > 2 or src.size(1) > 2:
raise ValueError("Expected format (L x N), N = 1 or 2, but found {}".format(src.size()))
# check if sample_rate is an integer
if not isinstance(sample_rate, int):
if int(sample_rate) == sample_rate:
sample_rate = int(sample_rate)
else:
raise TypeError('Sample rate should be a integer')
# programs such as librosa normalize the signal, unnormalize if detected
if src.min() >= -1.0 and src.max() <= 1.0:
src = src * (1 << 31) # assuming 16-bit depth
src = src.long()
# save data to file
filename, extension = os.path.splitext(filepath) filename, extension = os.path.splitext(filepath)
if type(sample_rate) != int:
raise TypeError('Sample rate should be a integer')
check_input(src) check_input(src)
typename = type(src).__name__.replace('Tensor', '') typename = type(src).__name__.replace('Tensor', '')
func = getattr(th_sox, 'libthsox_{}_write_audio_file'.format(typename)) func = getattr(th_sox, 'libthsox_{}_write_audio_file'.format(typename))
func(bytes(filepath, "utf-8"), src, bytes(extension[1:], "utf-8"), sample_rate)
func(bytes(filepath, "ascii"), src, extension[1:], sample_rate)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment