Commit e7eb0be2 authored by Soumith Chintala's avatar Soumith Chintala
Browse files

fix lint

parent 6e8045ea
...@@ -4,14 +4,17 @@ import torchaudio ...@@ -4,14 +4,17 @@ import torchaudio
import math import math
import os import os
class Test_LoadSave(unittest.TestCase): class Test_LoadSave(unittest.TestCase):
test_dirpath = os.path.dirname(os.path.realpath(__file__)) test_dirpath = os.path.dirname(os.path.realpath(__file__))
test_filepath = os.path.join(test_dirpath, "assets", "steam-train-whistle-daniel_simon.mp3") test_filepath = os.path.join(
test_dirpath, "assets", "steam-train-whistle-daniel_simon.mp3")
def test_load(self): def test_load(self):
# check normal loading # check normal loading
x, sr = torchaudio.load(self.test_filepath) x, sr = torchaudio.load(self.test_filepath)
self.assertEqual(sr, 44100) self.assertEqual(sr, 44100)
self.assertEqual(x.size(), (278756,2)) self.assertEqual(x.size(), (278756, 2))
# check normalizing # check normalizing
x, sr = torchaudio.load(self.test_filepath, normalization=True) x, sr = torchaudio.load(self.test_filepath, normalization=True)
...@@ -23,7 +26,8 @@ class Test_LoadSave(unittest.TestCase): ...@@ -23,7 +26,8 @@ class Test_LoadSave(unittest.TestCase):
torchaudio.load("file-does-not-exist.mp3") torchaudio.load("file-does-not-exist.mp3")
with self.assertRaises(OSError): with self.assertRaises(OSError):
tdir = os.path.join(os.path.dirname(self.test_dirpath), "torchaudio") tdir = os.path.join(os.path.dirname(
self.test_dirpath), "torchaudio")
torchaudio.load(tdir) torchaudio.load(tdir)
def test_save(self): def test_save(self):
...@@ -43,21 +47,21 @@ class Test_LoadSave(unittest.TestCase): ...@@ -43,21 +47,21 @@ class Test_LoadSave(unittest.TestCase):
os.unlink(new_filepath) os.unlink(new_filepath)
# test save 1d tensor # test save 1d tensor
x = x[:, 0] # get mono signal x = x[:, 0] # get mono signal
x.squeeze_() # remove channel dim x.squeeze_() # remove channel dim
torchaudio.save(new_filepath, x, sr) torchaudio.save(new_filepath, x, sr)
self.assertTrue(os.path.isfile(new_filepath)) self.assertTrue(os.path.isfile(new_filepath))
os.unlink(new_filepath) os.unlink(new_filepath)
# don't allow invalid sizes as inputs # don't allow invalid sizes as inputs
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
x.unsqueeze_(0) # N x L not L x N x.unsqueeze_(0) # N x L not L x N
torchaudio.save(new_filepath, x, sr) torchaudio.save(new_filepath, x, sr)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
x.squeeze_() x.squeeze_()
x.unsqueeze_(1) x.unsqueeze_(1)
x.unsqueeze_(0) # 1 x L x 1 x.unsqueeze_(0) # 1 x L x 1
torchaudio.save(new_filepath, x, sr) torchaudio.save(new_filepath, x, sr)
# automatically convert sr from floating point to int # automatically convert sr from floating point to int
...@@ -74,21 +78,24 @@ class Test_LoadSave(unittest.TestCase): ...@@ -74,21 +78,24 @@ class Test_LoadSave(unittest.TestCase):
# don't save to folders that don't exist # don't save to folders that don't exist
with self.assertRaises(OSError): with self.assertRaises(OSError):
new_filepath = os.path.join(self.test_dirpath, "no-path", "test.wav") new_filepath = os.path.join(
self.test_dirpath, "no-path", "test.wav")
torchaudio.save(new_filepath, x, sr) torchaudio.save(new_filepath, x, sr)
# save created file # save created file
sinewave_filepath = os.path.join(self.test_dirpath, "assets", "sinewave.wav") sinewave_filepath = os.path.join(
self.test_dirpath, "assets", "sinewave.wav")
sr = 16000 sr = 16000
freq = 440 freq = 440
volume = 0.3 volume = 0.3
y = (torch.cos(2*math.pi*torch.arange(0, 4*sr) * freq/sr)).float() y = (torch.cos(2 * math.pi * torch.arange(0, 4 * sr) * freq / sr)).float()
y.unsqueeze_(1) y.unsqueeze_(1)
# y is between -1 and 1, so must scale # y is between -1 and 1, so must scale
y = (y*volume*2**31).long() y = (y * volume * 2**31).long()
torchaudio.save(sinewave_filepath, y, sr) torchaudio.save(sinewave_filepath, y, sr)
self.assertTrue(os.path.isfile(sinewave_filepath)) self.assertTrue(os.path.isfile(sinewave_filepath))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -5,14 +5,15 @@ import torchaudio.transforms as transforms ...@@ -5,14 +5,15 @@ import torchaudio.transforms as transforms
import numpy as np import numpy as np
import unittest import unittest
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
sr = 16000 sr = 16000
freq = 440 freq = 440
volume = 0.3 volume = 0.3
sig = (torch.cos(2*np.pi*torch.arange(0, 4*sr) * freq/sr)).float() sig = (torch.cos(2 * np.pi * torch.arange(0, 4 * sr) * freq / sr)).float()
sig.unsqueeze_(1) sig.unsqueeze_(1)
sig = (sig*volume*2**31).long() sig = (sig * volume * 2**31).long()
def test_scale(self): def test_scale(self):
...@@ -21,7 +22,8 @@ class Tester(unittest.TestCase): ...@@ -21,7 +22,8 @@ class Tester(unittest.TestCase):
self.assertTrue(result.min() >= -1. and result.max() <= 1., self.assertTrue(result.min() >= -1. and result.max() <= 1.,
print("min: {}, max: {}".format(result.min(), result.max()))) print("min: {}, max: {}".format(result.min(), result.max())))
maxminmax = np.abs([audio_orig.min(), audio_orig.max()]).max().astype(np.float) maxminmax = np.abs(
[audio_orig.min(), audio_orig.max()]).max().astype(np.float)
result = transforms.Scale(factor=maxminmax)(audio_orig) result = transforms.Scale(factor=maxminmax)(audio_orig)
self.assertTrue((result.min() == -1. or result.max() == 1.) and self.assertTrue((result.min() == -1. or result.max() == 1.) and
result.min() >= -1. and result.max() <= 1., result.min() >= -1. and result.max() <= 1.,
...@@ -47,7 +49,6 @@ class Tester(unittest.TestCase): ...@@ -47,7 +49,6 @@ class Tester(unittest.TestCase):
self.assertTrue(result.size(0) == length_new, self.assertTrue(result.size(0) == length_new,
print("old size: {}, new size: {}".format(audio_orig.size(0), result.size(0)))) print("old size: {}, new size: {}".format(audio_orig.size(0), result.size(0))))
def test_downmix_mono(self): def test_downmix_mono(self):
audio_L = self.sig.clone() audio_L = self.sig.clone()
...@@ -84,7 +85,8 @@ class Tester(unittest.TestCase): ...@@ -84,7 +85,8 @@ class Tester(unittest.TestCase):
audio_orig = self.sig.clone() audio_orig = self.sig.clone()
length_orig = audio_orig.size(0) length_orig = audio_orig.size(0)
length_new = int(length_orig * 1.2) length_new = int(length_orig * 1.2)
maxminmax = np.abs([audio_orig.min(), audio_orig.max()]).max().astype(np.float) maxminmax = np.abs(
[audio_orig.min(), audio_orig.max()]).max().astype(np.float)
tset = (transforms.Scale(factor=maxminmax), tset = (transforms.Scale(factor=maxminmax),
transforms.PadTrim(max_len=length_new)) transforms.PadTrim(max_len=length_new))
...@@ -109,10 +111,6 @@ class Tester(unittest.TestCase): ...@@ -109,10 +111,6 @@ class Tester(unittest.TestCase):
sig_exp = transforms.MuLawExpanding(quantization_channels)(sig_mu) sig_exp = transforms.MuLawExpanding(quantization_channels)(sig_mu)
self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.) self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.)
#diff = sig - sig_exp
#mse = np.linalg.norm(diff) / diff.shape[0]
#self.assertTrue(mse, np.isclose(mse, 0., atol=1e-4)) # not always true
sig = self.sig.clone() sig = self.sig.clone()
sig = sig / torch.abs(sig).max() sig = sig / torch.abs(sig).max()
self.assertTrue(sig.min() >= -1. and sig.max() <= 1.) self.assertTrue(sig.min() >= -1. and sig.max() <= 1.)
...@@ -123,5 +121,6 @@ class Tester(unittest.TestCase): ...@@ -123,5 +121,6 @@ class Tester(unittest.TestCase):
sig_exp = transforms.MuLawExpanding(quantization_channels)(sig_mu) sig_exp = transforms.MuLawExpanding(quantization_channels)(sig_mu)
self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.) self.assertTrue(sig_exp.min() >= -1. and sig_exp.max() <= 1.)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -14,7 +14,9 @@ from torchaudio import datasets ...@@ -14,7 +14,9 @@ from torchaudio import datasets
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
_bytes = bytes _bytes = bytes
else: else:
_bytes = lambda s, e: s.encode(e) def _bytes(s, e):
return s.encode(e)
def check_input(src): def check_input(src):
if not torch.is_tensor(src): if not torch.is_tensor(src):
...@@ -22,13 +24,16 @@ def check_input(src): ...@@ -22,13 +24,16 @@ def check_input(src):
if not src.__module__ == 'torch': if not src.__module__ == 'torch':
raise TypeError('Expected a CPU based tensor, got %s' % type(src)) raise TypeError('Expected a CPU based tensor, got %s' % type(src))
def load(filepath, out=None, normalization=None): def load(filepath, out=None, normalization=None):
"""Loads an audio file from disk into a Tensor """Loads an audio file from disk into a Tensor
Args: Args:
filepath (string): path to audio file filepath (string): path to audio file
out (Tensor, optional): an output Tensor to use instead of creating one out (Tensor, optional): an output Tensor to use instead of creating one
normalization (bool or number, optional): If boolean `True`, then output is divided by `1 << 31` (assumes 16-bit depth audio, and normalizes to `[0, 1]`. If `number`, then output is divided by that number normalization (bool or number, optional): If boolean `True`, then output is divided by `1 << 31`
(assumes 16-bit depth audio, and normalizes to `[0, 1]`.
If `number`, then output is divided by that number
Returns: tuple(Tensor, int) Returns: tuple(Tensor, int)
- Tensor: output Tensor of size `[L x C]` where L is the number of audio frames, C is the number of channels - Tensor: output Tensor of size `[L x C]` where L is the number of audio frames, C is the number of channels
...@@ -41,7 +46,7 @@ def load(filepath, out=None, normalization=None): ...@@ -41,7 +46,7 @@ def load(filepath, out=None, normalization=None):
torch.Size([278756, 2]) torch.Size([278756, 2])
>>> print(sample_rate) >>> print(sample_rate)
44100 44100
""" """
# check if valid file # check if valid file
if not os.path.isfile(filepath): if not os.path.isfile(filepath):
...@@ -59,24 +64,26 @@ def load(filepath, out=None, normalization=None): ...@@ -59,24 +64,26 @@ def load(filepath, out=None, normalization=None):
sample_rate = sample_rate_p[0] sample_rate = sample_rate_p[0]
# normalize if needed # normalize if needed
if isinstance(normalization, bool) and normalization: if isinstance(normalization, bool) and normalization:
out /= 1 << 31 # assuming 16-bit depth out /= 1 << 31 # assuming 16-bit depth
elif isinstance(normalization, (float, int)): elif isinstance(normalization, (float, int)):
out /= normalization # normalize with custom value out /= normalization # normalize with custom value
return out, sample_rate return out, sample_rate
def save(filepath, src, sample_rate): def save(filepath, src, sample_rate):
"""Saves a Tensor with audio signal to disk as a standard format like mp3, wav, etc. """Saves a Tensor with audio signal to disk as a standard format like mp3, wav, etc.
Args: Args:
filepath (string): path to audio file filepath (string): path to audio file
src (Tensor): an input 2D Tensor of shape `[L x C]` where L is the number of audio frames, C is the number of channels src (Tensor): an input 2D Tensor of shape `[L x C]` where L is
the number of audio frames, C is the number of channels
sample_rate (int): the sample-rate of the audio to be saved sample_rate (int): the sample-rate of the audio to be saved
Example:: Example::
>>> data, sample_rate = torchaudio.load('foo.mp3') >>> data, sample_rate = torchaudio.load('foo.mp3')
>>> torchaudio.save('foo.wav', data, sample_rate) >>> torchaudio.save('foo.wav', data, sample_rate)
""" """
# check if save directory exists # check if save directory exists
abs_dirpath = os.path.dirname(os.path.abspath(filepath)) abs_dirpath = os.path.dirname(os.path.abspath(filepath))
...@@ -87,7 +94,8 @@ def save(filepath, src, sample_rate): ...@@ -87,7 +94,8 @@ def save(filepath, src, sample_rate):
# 1d tensors as assumed to be mono signals # 1d tensors as assumed to be mono signals
src.unsqueeze_(1) src.unsqueeze_(1)
elif len(src.size()) > 2 or src.size(1) > 2: elif len(src.size()) > 2 or src.size(1) > 2:
raise ValueError("Expected format (L x N), N = 1 or 2, but found {}".format(src.size())) raise ValueError(
"Expected format (L x N), N = 1 or 2, but found {}".format(src.size()))
# check if sample_rate is an integer # check if sample_rate is an integer
if not isinstance(sample_rate, int): if not isinstance(sample_rate, int):
if int(sample_rate) == sample_rate: if int(sample_rate) == sample_rate:
...@@ -96,11 +104,12 @@ def save(filepath, src, sample_rate): ...@@ -96,11 +104,12 @@ def save(filepath, src, sample_rate):
raise TypeError('Sample rate should be a integer') raise TypeError('Sample rate should be a integer')
# programs such as librosa normalize the signal, unnormalize if detected # programs such as librosa normalize the signal, unnormalize if detected
if src.min() >= -1.0 and src.max() <= 1.0: if src.min() >= -1.0 and src.max() <= 1.0:
src = src * (1 << 31) # assuming 16-bit depth src = src * (1 << 31) # assuming 16-bit depth
src = src.long() src = src.long()
# save data to file # save data to file
filename, extension = os.path.splitext(filepath) filename, extension = os.path.splitext(filepath)
check_input(src) check_input(src)
typename = type(src).__name__.replace('Tensor', '') typename = type(src).__name__.replace('Tensor', '')
func = getattr(th_sox, 'libthsox_{}_write_audio_file'.format(typename)) func = getattr(th_sox, 'libthsox_{}_write_audio_file'.format(typename))
func(_bytes(filepath, "utf-8"), src, _bytes(extension[1:], "utf-8"), sample_rate) func(_bytes(filepath, "utf-8"), src,
_bytes(extension[1:], "utf-8"), sample_rate)
...@@ -12,9 +12,11 @@ AUDIO_EXTENSIONS = [ ...@@ -12,9 +12,11 @@ AUDIO_EXTENSIONS = [
'.WAV', '.MP3', '.FLAC', '.SPH', '.OGG', '.OPUS', '.WAV', '.MP3', '.FLAC', '.SPH', '.OGG', '.OPUS',
] ]
def is_audio_file(filename): def is_audio_file(filename):
return any(filename.endswith(extension) for extension in AUDIO_EXTENSIONS) return any(filename.endswith(extension) for extension in AUDIO_EXTENSIONS)
def make_manifest(dir): def make_manifest(dir):
audios = [] audios = []
dir = os.path.expanduser(dir) dir = os.path.expanduser(dir)
...@@ -31,6 +33,7 @@ def make_manifest(dir): ...@@ -31,6 +33,7 @@ def make_manifest(dir):
audios.append(item) audios.append(item)
return audios return audios
def read_audio(fp, downsample=True): def read_audio(fp, downsample=True):
sig, sr = torchaudio.load(fp) sig, sr = torchaudio.load(fp)
if downsample: if downsample:
...@@ -41,6 +44,7 @@ def read_audio(fp, downsample=True): ...@@ -41,6 +44,7 @@ def read_audio(fp, downsample=True):
sig = sig[:-(sig.size(0) % 3):3].contiguous() sig = sig[:-(sig.size(0) % 3):3].contiguous()
return sig, sr return sig, sr
def load_txts(dir): def load_txts(dir):
"""Create a dictionary with all the text of the audio transcriptions.""" """Create a dictionary with all the text of the audio transcriptions."""
utterences = dict() utterences = dict()
...@@ -55,12 +59,12 @@ def load_txts(dir): ...@@ -55,12 +59,12 @@ def load_txts(dir):
for fname in fnames: for fname in fnames:
if fname.endswith(".txt"): if fname.endswith(".txt"):
with open(os.path.join(root, fname), "r") as f: with open(os.path.join(root, fname), "r") as f:
fname_no_ext = os.path.basename(fname).rsplit(".", 1)[0] fname_no_ext = os.path.basename(
fname).rsplit(".", 1)[0]
utterences[fname_no_ext] = f.readline() utterences[fname_no_ext] = f.readline()
#utterences += f.readlines()
#utterences = dict([tuple(u.strip().split(" ", 1)) for u in utterences])
return utterences return utterences
class VCTK(data.Dataset): class VCTK(data.Dataset):
"""`VCTK <http://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html>`_ Dataset. """`VCTK <http://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html>`_ Dataset.
`alternate url <http://datashare.is.ed.ac.uk/handle/10283/2651>` `alternate url <http://datashare.is.ed.ac.uk/handle/10283/2651>`
...@@ -103,7 +107,8 @@ class VCTK(data.Dataset): ...@@ -103,7 +107,8 @@ class VCTK(data.Dataset):
raise RuntimeError('Dataset not found.' + raise RuntimeError('Dataset not found.' +
' You can use download=True to download it') ' You can use download=True to download it')
self._read_info() self._read_info()
self.data, self.labels = torch.load(os.path.join(self.root, self.processed_folder, "vctk_{:04d}.pt".format(self.cached_pt))) self.data, self.labels = torch.load(os.path.join(
self.root, self.processed_folder, "vctk_{:04d}.pt".format(self.cached_pt)))
def __getitem__(self, index): def __getitem__(self, index):
""" """
...@@ -115,7 +120,8 @@ class VCTK(data.Dataset): ...@@ -115,7 +120,8 @@ class VCTK(data.Dataset):
""" """
if self.cached_pt != index // self.chunk_size: if self.cached_pt != index // self.chunk_size:
self.cached_pt = int(index // self.chunk_size) self.cached_pt = int(index // self.chunk_size)
self.data, self.labels = torch.load(os.path.join(self.root, self.processed_folder, "vctk_{:04d}.pt".format(self.cached_pt))) self.data, self.labels = torch.load(os.path.join(
self.root, self.processed_folder, "vctk_{:04d}.pt".format(self.cached_pt)))
index = index % self.chunk_size index = index % self.chunk_size
audio, target = self.data[index], self.labels[index] audio, target = self.data[index], self.labels[index]
...@@ -134,13 +140,15 @@ class VCTK(data.Dataset): ...@@ -134,13 +140,15 @@ class VCTK(data.Dataset):
return os.path.exists(os.path.join(self.root, self.processed_folder, "vctk_info.txt")) return os.path.exists(os.path.join(self.root, self.processed_folder, "vctk_info.txt"))
def _write_info(self, num_items): def _write_info(self, num_items):
info_path = os.path.join(self.root, self.processed_folder, "vctk_info.txt") info_path = os.path.join(
self.root, self.processed_folder, "vctk_info.txt")
with open(info_path, "w") as f: with open(info_path, "w") as f:
f.write("num_samples,{}\n".format(num_items)) f.write("num_samples,{}\n".format(num_items))
f.write("max_len,{}\n".format(self.max_len)) f.write("max_len,{}\n".format(self.max_len))
def _read_info(self): def _read_info(self):
info_path = os.path.join(self.root, self.processed_folder, "vctk_info.txt") info_path = os.path.join(
self.root, self.processed_folder, "vctk_info.txt")
with open(info_path, "r") as f: with open(info_path, "r") as f:
self.num_samples = int(f.readline().split(",")[1]) self.num_samples = int(f.readline().split(",")[1])
self.max_len = int(f.readline().split(",")[1]) self.max_len = int(f.readline().split(",")[1])
...@@ -155,7 +163,8 @@ class VCTK(data.Dataset): ...@@ -155,7 +163,8 @@ class VCTK(data.Dataset):
raw_abs_dir = os.path.join(self.root, self.raw_folder) raw_abs_dir = os.path.join(self.root, self.raw_folder)
processed_abs_dir = os.path.join(self.root, self.processed_folder) processed_abs_dir = os.path.join(self.root, self.processed_folder)
dset_abs_path = os.path.join(self.root, self.raw_folder, self.dset_path) dset_abs_path = os.path.join(
self.root, self.raw_folder, self.dset_path)
# download files # download files
try: try:
...@@ -192,7 +201,8 @@ class VCTK(data.Dataset): ...@@ -192,7 +201,8 @@ class VCTK(data.Dataset):
audios = make_manifest(dset_abs_path) audios = make_manifest(dset_abs_path)
utterences = load_txts(dset_abs_path) utterences = load_txts(dset_abs_path)
self.max_len = 0 self.max_len = 0
print("Found {} audio files and {} utterences".format(len(audios), len(utterences))) print("Found {} audio files and {} utterences".format(
len(audios), len(utterences)))
for n in range(len(audios) // self.chunk_size + 1): for n in range(len(audios) // self.chunk_size + 1):
tensors = [] tensors = []
labels = [] labels = []
...@@ -207,9 +217,11 @@ class VCTK(data.Dataset): ...@@ -207,9 +217,11 @@ class VCTK(data.Dataset):
tensors.append(sig) tensors.append(sig)
lengths.append(sig.size(0)) lengths.append(sig.size(0))
labels.append(utterences[f_rel_no_ext]) labels.append(utterences[f_rel_no_ext])
self.max_len = sig.size(0) if sig.size(0) > self.max_len else self.max_len self.max_len = sig.size(0) if sig.size(
0) > self.max_len else self.max_len
# sort sigs/labels: longest -> shortest # sort sigs/labels: longest -> shortest
tensors, labels = zip(*[(b, c) for (a,b,c) in sorted(zip(lengths, tensors, labels), key=lambda x: x[0], reverse=True)]) tensors, labels = zip(*[(b, c) for (a, b, c) in sorted(
zip(lengths, tensors, labels), key=lambda x: x[0], reverse=True)])
data = (tensors, labels) data = (tensors, labels)
torch.save( torch.save(
data, data,
...@@ -219,7 +231,7 @@ class VCTK(data.Dataset): ...@@ -219,7 +231,7 @@ class VCTK(data.Dataset):
"vctk_{:04d}.pt".format(n) "vctk_{:04d}.pt".format(n)
) )
) )
self._write_info((n*self.chunk_size)+i+1) self._write_info((n * self.chunk_size) + i + 1)
if not self.dev_mode: if not self.dev_mode:
shutil.rmtree(raw_abs_dir, ignore_errors=True) shutil.rmtree(raw_abs_dir, ignore_errors=True)
......
...@@ -7,6 +7,7 @@ import errno ...@@ -7,6 +7,7 @@ import errno
import torch import torch
import torchaudio import torchaudio
class YESNO(data.Dataset): class YESNO(data.Dataset):
"""`YesNo Hebrew <http://www.openslr.org/1/>`_ Dataset. """`YesNo Hebrew <http://www.openslr.org/1/>`_ Dataset.
...@@ -45,7 +46,8 @@ class YESNO(data.Dataset): ...@@ -45,7 +46,8 @@ class YESNO(data.Dataset):
if not self._check_exists(): if not self._check_exists():
raise RuntimeError('Dataset not found.' + raise RuntimeError('Dataset not found.' +
' You can use download=True to download it') ' You can use download=True to download it')
self.data, self.labels = torch.load(os.path.join(self.root, self.processed_folder, self.processed_file)) self.data, self.labels = torch.load(os.path.join(
self.root, self.processed_folder, self.processed_file))
def __getitem__(self, index): def __getitem__(self, index):
""" """
...@@ -81,7 +83,8 @@ class YESNO(data.Dataset): ...@@ -81,7 +83,8 @@ class YESNO(data.Dataset):
raw_abs_dir = os.path.join(self.root, self.raw_folder) raw_abs_dir = os.path.join(self.root, self.raw_folder)
processed_abs_dir = os.path.join(self.root, self.processed_folder) processed_abs_dir = os.path.join(self.root, self.processed_folder)
dset_abs_path = os.path.join(self.root, self.raw_folder, self.dset_path) dset_abs_path = os.path.join(
self.root, self.raw_folder, self.dset_path)
# download files # download files
try: try:
...@@ -130,7 +133,8 @@ class YESNO(data.Dataset): ...@@ -130,7 +133,8 @@ class YESNO(data.Dataset):
lengths.append(sig.size(0)) lengths.append(sig.size(0))
labels.append(os.path.basename(f).split(".", 1)[0].split("_")) labels.append(os.path.basename(f).split(".", 1)[0].split("_"))
# sort sigs/labels: longest -> shortest # sort sigs/labels: longest -> shortest
tensors, labels = zip(*[(b, c) for (a,b,c) in sorted(zip(lengths, tensors, labels), key=lambda x: x[0], reverse=True)]) tensors, labels = zip(*[(b, c) for (a, b, c) in sorted(
zip(lengths, tensors, labels), key=lambda x: x[0], reverse=True)])
self.max_len = tensors[0].size(0) self.max_len = tensors[0].size(0)
torch.save( torch.save(
(tensors, labels), (tensors, labels),
......
...@@ -6,6 +6,7 @@ try: ...@@ -6,6 +6,7 @@ try:
except ImportError: except ImportError:
librosa = None librosa = None
class Compose(object): class Compose(object):
"""Composes several transforms together. """Composes several transforms together.
...@@ -27,6 +28,7 @@ class Compose(object): ...@@ -27,6 +28,7 @@ class Compose(object):
audio = t(audio) audio = t(audio)
return audio return audio
class Scale(object): class Scale(object):
"""Scale audio tensor from a 16-bit integer (represented as a FloatTensor) """Scale audio tensor from a 16-bit integer (represented as a FloatTensor)
to a floating point number between -1.0 and 1.0. Note the 16-bit number is to a floating point number between -1.0 and 1.0. Note the 16-bit number is
...@@ -55,6 +57,7 @@ class Scale(object): ...@@ -55,6 +57,7 @@ class Scale(object):
return tensor / self.factor return tensor / self.factor
class PadTrim(object): class PadTrim(object):
"""Pad/Trim a 1d-Tensor (Signal or Labels) """Pad/Trim a 1d-Tensor (Signal or Labels)
...@@ -76,7 +79,7 @@ class PadTrim(object): ...@@ -76,7 +79,7 @@ class PadTrim(object):
""" """
if self.max_len > tensor.size(0): if self.max_len > tensor.size(0):
pad = torch.ones((self.max_len-tensor.size(0), pad = torch.ones((self.max_len - tensor.size(0),
tensor.size(1))) * self.fill_value tensor.size(1))) * self.fill_value
pad = pad.type_as(tensor) pad = pad.type_as(tensor)
tensor = torch.cat((tensor, pad), dim=0) tensor = torch.cat((tensor, pad), dim=0)
...@@ -107,6 +110,7 @@ class DownmixMono(object): ...@@ -107,6 +110,7 @@ class DownmixMono(object):
tensor = torch.mean(tensor.float(), 1, True) tensor = torch.mean(tensor.float(), 1, True)
return tensor return tensor
class LC2CL(object): class LC2CL(object):
"""Permute a 2d tensor from samples (Length) x Channels to Channels x """Permute a 2d tensor from samples (Length) x Channels to Channels x
samples (Length) samples (Length)
...@@ -153,14 +157,16 @@ class MEL(object): ...@@ -153,14 +157,16 @@ class MEL(object):
return tensor return tensor
L = [] L = []
for i in range(tensor.size(1)): for i in range(tensor.size(1)):
nparr = tensor[:, i].numpy() # (samples, ) nparr = tensor[:, i].numpy() # (samples, )
sgram = librosa.feature.melspectrogram(nparr, **self.kwargs) # (n_mels, hops) sgram = librosa.feature.melspectrogram(
nparr, **self.kwargs) # (n_mels, hops)
L.append(sgram) L.append(sgram)
L = np.stack(L, 2) # (n_mels, hops, channels) L = np.stack(L, 2) # (n_mels, hops, channels)
tensor = torch.from_numpy(L).type_as(tensor) tensor = torch.from_numpy(L).type_as(tensor)
return tensor return tensor
class BLC2CBL(object): class BLC2CBL(object):
"""Permute a 3d tensor from Bands x samples (Length) x Channels to Channels x """Permute a 3d tensor from Bands x samples (Length) x Channels to Channels x
Bands x samples (Length) Bands x samples (Length)
...@@ -179,6 +185,7 @@ class BLC2CBL(object): ...@@ -179,6 +185,7 @@ class BLC2CBL(object):
return tensor.permute(2, 0, 1).contiguous() return tensor.permute(2, 0, 1).contiguous()
class MuLawEncoding(object): class MuLawEncoding(object):
"""Encode signal based on mu-law companding. For more info see the """Encode signal based on mu-law companding. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_ `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
...@@ -212,10 +219,12 @@ class MuLawEncoding(object): ...@@ -212,10 +219,12 @@ class MuLawEncoding(object):
if isinstance(x, torch.LongTensor): if isinstance(x, torch.LongTensor):
x = x.float() x = x.float()
mu = torch.FloatTensor([mu]) mu = torch.FloatTensor([mu])
x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu) x_mu = torch.sign(x) * torch.log1p(mu *
torch.abs(x)) / torch.log1p(mu)
x_mu = ((x_mu + 1) / 2 * mu + 0.5).long() x_mu = ((x_mu + 1) / 2 * mu + 0.5).long()
return x_mu return x_mu
class MuLawExpanding(object): class MuLawExpanding(object):
"""Decode mu-law encoded signal. For more info see the """Decode mu-law encoded signal. For more info see the
`Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_ `Wikipedia Entry <https://en.wikipedia.org/wiki/%CE%9C-law_algorithm>`_
......
[flake8]
max-line-length = 120
ignore = E305,E402,E721,E741,F401,F403,F405,F821,F841,F999
exclude = docs/src,venv,torch/lib/gloo,torch/lib/pybind11,torch/lib/nanopb
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment