Commit 3433b9b6 authored by David Pollack's avatar David Pollack
Browse files

basic transforms

vctk dataset

add label processing

added tests

chunk dataset

simple caching scheme

caching fixed

vctk downsample

yesno dataset
parent ecb538df
...@@ -5,6 +5,7 @@ __pycache__/ ...@@ -5,6 +5,7 @@ __pycache__/
# C extensions # C extensions
*.so *.so
_ext/
# Distribution / packaging # Distribution / packaging
.Python .Python
...@@ -68,7 +69,7 @@ docs/_build/ ...@@ -68,7 +69,7 @@ docs/_build/
target/ target/
# Jupyter Notebook # Jupyter Notebook
.ipynb_checkpoints .ipynb_checkpoints/
# pyenv # pyenv
.python-version .python-version
...@@ -100,5 +101,5 @@ ENV/ ...@@ -100,5 +101,5 @@ ENV/
# mypy # mypy
.mypy_cache/ .mypy_cache/
# Jupyter Notebooks # Generated Files
.ipynb_checkpoints/ test/assets/sinewave.wav
...@@ -2,17 +2,32 @@ ...@@ -2,17 +2,32 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchaudio import torchaudio
import math
x, sample_rate = torchaudio.load("steam-train-whistle-daniel_simon.mp3") steam_train = "assets/steam-train-whistle-daniel_simon.mp3"
x, sample_rate = torchaudio.load(steam_train)
print(sample_rate) print(sample_rate)
print(x.size()) print(x.size())
print(x[10000]) print(x[10000])
print(x.min(), x.max()) print(x.min(), x.max())
print(x.mean(), x.std()) print(x.mean(), x.std())
x, sample_rate = torchaudio.load("steam-train-whistle-daniel_simon.mp3", x, sample_rate = torchaudio.load(steam_train,
out=torch.LongTensor()) out=torch.LongTensor())
print(sample_rate) print(sample_rate)
print(x.size()) print(x.size())
print(x[10000]) print(x[10000])
print(x.min(), x.max()) print(x.min(), x.max())
sine_wave = "assets/sinewave.wav"
sr = 16000
freq = 440
volume = 0.3
y = (torch.cos(2*math.pi*torch.arange(0, 4*sr) * freq/sr)).float()
y.unsqueeze_(1)
# y is between -1 and 1, so must scale
y = (y*volume*2**31).long()
torchaudio.save(sine_wave, y, sr)
print(y.min(), y.max())
import torch
import torchaudio
import torchaudio.transforms as transforms
import numpy as np
import unittest
STEAM_TRAIN = "assets/steam-train-whistle-daniel_simon.mp3"
class Tester(unittest.TestCase):
sr = 16000
freq = 440
volume = 0.3
sig = (torch.cos(2*np.pi*torch.arange(0, 4*sr) * freq/sr)).float()
sig.unsqueeze_(1)
sig = (sig*volume*2**31).long()
def test_scale(self):
audio_orig = self.sig.clone()
result = transforms.Scale()(audio_orig)
self.assertTrue(result.min() >= -1. and result.max() <= 1.,
"min: {}, max: {}".format(result.min(), result.max()))
maxminmax = np.abs([audio_orig.min(), audio_orig.max()]).max().astype(np.float)
result = transforms.Scale(factor=maxminmax)(audio_orig)
self.assertTrue((result.min() == -1. or result.max() == 1.) and
result.min() >= -1. and result.max() <= 1.,
"min: {}, max: {}".format(result.min(), result.max()))
def test_pad_trim(self):
audio_orig = self.sig.clone()
length_orig = audio_orig.size(0)
length_new = int(length_orig * 1.2)
result = transforms.PadTrim(max_len=length_new)(audio_orig)
self.assertTrue(result.size(0) == length_new,
"old size: {}, new size: {}".format(audio_orig.size(0), result.size(0)))
audio_orig = self.sig.clone()
length_orig = audio_orig.size(0)
length_new = int(length_orig * 0.8)
result = transforms.PadTrim(max_len=length_new)(audio_orig)
self.assertTrue(result.size(0) == length_new,
"old size: {}, new size: {}".format(audio_orig.size(0), result.size(0)))
def test_downmix_mono(self):
audio_L = self.sig.clone()
audio_R = self.sig.clone()
R_idx = int(audio_R.size(0) * 0.1)
audio_R = torch.cat((audio_R[R_idx:], audio_R[:R_idx]))
audio_Stereo = torch.cat((audio_L, audio_R), dim=1)
self.assertTrue(audio_Stereo.size(1) == 2)
result = transforms.DownmixMono()(audio_Stereo)
self.assertTrue(result.size(1) == 1)
def test_compose(self):
audio_orig = self.sig.clone()
length_orig = audio_orig.size(0)
length_new = int(length_orig * 1.2)
maxminmax = np.abs([audio_orig.min(), audio_orig.max()]).max().astype(np.float)
tset = (transforms.Scale(factor=maxminmax),
transforms.PadTrim(max_len=length_new))
result = transforms.Compose(tset)(audio_orig)
self.assertTrue(np.abs([result.min(), result.max()]).max() == 1.)
self.assertTrue(result.size(0) == length_new)
if __name__ == '__main__':
unittest.main()
...@@ -7,6 +7,8 @@ from cffi import FFI ...@@ -7,6 +7,8 @@ from cffi import FFI
ffi = FFI() ffi = FFI()
from ._ext import th_sox from ._ext import th_sox
from torchaudio import transforms
from torchaudio import datasets
def check_input(src): def check_input(src):
if not torch.is_tensor(src): if not torch.is_tensor(src):
...@@ -23,7 +25,7 @@ def load(filename, out=None): ...@@ -23,7 +25,7 @@ def load(filename, out=None):
typename = type(out).__name__.replace('Tensor', '') typename = type(out).__name__.replace('Tensor', '')
func = getattr(th_sox, 'libthsox_{}_read_audio_file'.format(typename)) func = getattr(th_sox, 'libthsox_{}_read_audio_file'.format(typename))
sample_rate_p = ffi.new('int*') sample_rate_p = ffi.new('int*')
func(str(filename).encode("ascii"), out, sample_rate_p) func(str(filename).encode("utf-8"), out, sample_rate_p)
sample_rate = sample_rate_p[0] sample_rate = sample_rate_p[0]
return out, sample_rate return out, sample_rate
...@@ -37,4 +39,4 @@ def save(filepath, src, sample_rate): ...@@ -37,4 +39,4 @@ def save(filepath, src, sample_rate):
typename = type(src).__name__.replace('Tensor', '') typename = type(src).__name__.replace('Tensor', '')
func = getattr(th_sox, 'libthsox_{}_write_audio_file'.format(typename)) func = getattr(th_sox, 'libthsox_{}_write_audio_file'.format(typename))
func(bytes(filepath, "ascii"), src, extension[1:], sample_rate) func(bytes(filepath, "utf-8"), src, bytes(extension[1:], "utf-8"), sample_rate)
from .yesno import YESNO
from .vctk import VCTK
__all__ = ('YESNO', 'VCTK')
from __future__ import print_function
import torch.utils.data as data
import os
import os.path
import shutil
import errno
import torch
import torchaudio
AUDIO_EXTENSIONS = [
'.wav', '.mp3', '.flac', '.sph', '.ogg', '.opus',
'.WAV', '.MP3', '.FLAC', '.SPH', '.OGG', '.OPUS',
]
def is_audio_file(filename):
return any(filename.endswith(extension) for extension in AUDIO_EXTENSIONS)
def make_manifest(dir):
audios = []
dir = os.path.expanduser(dir)
for target in sorted(os.listdir(dir)):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in fnames:
if is_audio_file(fname):
path = os.path.join(root, fname)
item = path
audios.append(item)
return audios
def read_audio(fp, downsample=True):
sig, sr = torchaudio.load(fp)
if downsample:
# 48khz -> 16 khz
if sig.size(0) % 3 == 0:
sig = sig.view(3, -1, sig.size(1)).mean(0)
else:
sig = sig[:-(sig.size(0) % 3)].view(3, -1, sig.size(1)).mean(0)
return sig, sr
def load_txts(dir):
"""Create a dictionary with all the text of the audio transcriptions."""
utterences = dict()
txts = []
dir = os.path.expanduser(dir)
for target in sorted(os.listdir(dir)):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in fnames:
if fname.endswith(".txt"):
with open(os.path.join(root, fname), "r") as f:
fname_no_ext = os.path.basename(fname).rsplit(".", 1)[0]
utterences[fname_no_ext] = f.readline()
#utterences += f.readlines()
#utterences = dict([tuple(u.strip().split(" ", 1)) for u in utterences])
return utterences
class VCTK(data.Dataset):
"""`VCTK <http://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html>`_ Dataset.
`alternate url <http://datashare.is.ed.ac.uk/handle/10283/2651>`
Args:
root (string): Root directory of dataset where ``processed/training.pt``
and ``processed/test.pt`` exist.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.Scale``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
dev_mode(bool, optional): if true, clean up is not performed on downloaded
files. Useful to keep raw audio and transcriptions.
"""
raw_folder = 'vctk/raw'
processed_folder = 'vctk/processed'
url = 'http://homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz'
dset_path = 'VCTK-Corpus'
def __init__(self, root, downsample=True, transform=None, target_transform=None, download=False, dev_mode=False):
self.root = os.path.expanduser(root)
self.downsample = downsample
self.transform = transform
self.target_transform = target_transform
self.dev_mode = dev_mode
self.data = []
self.labels = []
self.chunk_size = 1000
self.num_samples = 0
self.max_len = 0
self.cached_pt = 0
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
self._read_info()
self.data, self.labels = torch.load(os.path.join(self.root, self.processed_folder, "vctk_{:04d}.pt".format(self.cached_pt)))
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
if self.cached_pt != index // self.chunk_size:
self.cached_pt = int(index // self.chunk_size)
self.data, self.labels = torch.load(os.path.join(self.root, self.processed_folder, "vctk_{:04d}.pt".format(self.cached_pt)))
index = index % self.chunk_size
audio, target = self.data[index], self.labels[index]
if self.transform is not None:
audio = self.transform(audio)
if self.target_transform is not None:
target = self.target_transform(target)
return audio, target
def __len__(self):
return self.num_samples
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, "vctk_info.txt"))
def _write_info(self, num_items):
info_path = os.path.join(self.root, self.processed_folder, "vctk_info.txt")
with open(info_path, "w") as f:
f.write("num_samples,{}\n".format(num_items))
f.write("max_len,{}\n".format(self.max_len))
def _read_info(self):
info_path = os.path.join(self.root, self.processed_folder, "vctk_info.txt")
with open(info_path, "r") as f:
self.num_samples = int(f.readline().split(",")[1])
self.max_len = int(f.readline().split(",")[1])
def download(self):
"""Download the VCTK data if it doesn't exist in processed_folder already."""
from six.moves import urllib
import tarfile
if self._check_exists():
return
raw_abs_dir = os.path.join(self.root, self.raw_folder)
processed_abs_dir = os.path.join(self.root, self.processed_folder)
dset_abs_path = os.path.join(self.root, self.raw_folder, self.dset_path)
# download files
try:
os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
url = self.url
print('Downloading ' + url)
filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, self.raw_folder, filename)
if not os.path.isfile(file_path):
data = urllib.request.urlopen(url)
with open(file_path, 'wb') as f:
f.write(data.read())
if not os.path.exists(dset_abs_path):
with tarfile.open(file_path) as zip_f:
zip_f.extractall(raw_abs_dir)
else:
print("Using existing raw folder")
if not self.dev_mode:
os.unlink(file_path)
# process and save as torch files
print('Processing...')
shutil.copyfile(
os.path.join(dset_abs_path, "COPYING"),
os.path.join(processed_abs_dir, "VCTK_COPYING")
)
audios = make_manifest(dset_abs_path)
utterences = load_txts(dset_abs_path)
self.max_len = 0
print("Found {} audio files and {} utterences".format(len(audios), len(utterences)))
for n in range(len(audios) // self.chunk_size + 1):
tensors = []
labels = []
lengths = []
st_idx = n * self.chunk_size
end_idx = st_idx + self.chunk_size
for i, f in enumerate(audios[st_idx:end_idx]):
txt_dir = os.path.dirname(f).replace("wav48", "txt")
if os.path.exists(txt_dir):
f_rel_no_ext = os.path.basename(f).rsplit(".", 1)[0]
sig = read_audio(f, downsample=self.downsample)[0]
tensors.append(sig)
lengths.append(sig.size(0))
labels.append(utterences[f_rel_no_ext])
self.max_len = sig.size(0) if sig.size(0) > self.max_len else self.max_len
# sort sigs/labels: longest -> shortest
tensors, labels = zip(*[(b, c) for (a,b,c) in sorted(zip(lengths, tensors, labels), key=lambda x: x[0], reverse=True)])
data = (tensors, labels)
torch.save(
data,
os.path.join(
self.root,
self.processed_folder,
"vctk_{:04d}.pt".format(n)
)
)
self._write_info((n*self.chunk_size)+i+1)
if not self.dev_mode:
shutil.rmtree(raw_abs_dir, ignore_errors=True)
print('Done!')
from __future__ import print_function
import torch.utils.data as data
import os
import os.path
import shutil
import errno
import torch
import torchaudio
class YESNO(data.Dataset):
"""`YesNo Hebrew <http://www.openslr.org/1/>`_ Dataset.
Args:
root (string): Root directory of dataset where ``processed/training.pt``
and ``processed/test.pt`` exist.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.Scale``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
dev_mode(bool, optional): if true, clean up is not performed on downloaded
files. Useful to keep raw audio and transcriptions.
"""
raw_folder = 'yesno/raw'
processed_folder = 'yesno/processed'
url = 'http://www.openslr.org/resources/1/waves_yesno.tar.gz'
dset_path = 'waves_yesno'
processed_file = 'yesno.pt'
def __init__(self, root, transform=None, target_transform=None, download=False, dev_mode=False):
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
self.dev_mode = dev_mode
self.data = []
self.labels = []
self.num_samples = 0
self.max_len = 0
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
self.data, self.labels = torch.load(os.path.join(self.root, self.processed_folder, self.processed_file))
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
audio, target = self.data[index], self.labels[index]
if self.transform is not None:
audio = self.transform(audio)
if self.target_transform is not None:
target = self.target_transform(target)
return audio, target
def __len__(self):
return len(self.data)
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, self.processed_file))
def download(self):
"""Download the yesno data if it doesn't exist in processed_folder already."""
from six.moves import urllib
import tarfile
if self._check_exists():
return
raw_abs_dir = os.path.join(self.root, self.raw_folder)
processed_abs_dir = os.path.join(self.root, self.processed_folder)
dset_abs_path = os.path.join(self.root, self.raw_folder, self.dset_path)
# download files
try:
os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
url = self.url
print('Downloading ' + url)
filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, self.raw_folder, filename)
if not os.path.isfile(file_path):
data = urllib.request.urlopen(url)
with open(file_path, 'wb') as f:
f.write(data.read())
else:
print("Tar file already downloaded")
if not os.path.exists(dset_abs_path):
with tarfile.open(file_path) as zip_f:
zip_f.extractall(raw_abs_dir)
else:
print("Tar file already extracted")
if not self.dev_mode:
os.unlink(file_path)
# process and save as torch files
print('Processing...')
shutil.copyfile(
os.path.join(dset_abs_path, "README"),
os.path.join(processed_abs_dir, "YESNO_README")
)
audios = [x for x in os.listdir(dset_abs_path) if ".wav" in x]
print("Found {} audio files".format(len(audios)))
tensors = []
labels = []
lengths = []
for i, f in enumerate(audios):
full_path = os.path.join(dset_abs_path, f)
sig, sr = torchaudio.load(full_path)
tensors.append(sig)
lengths.append(sig.size(0))
labels.append(os.path.basename(f).split(".", 1)[0].split("_"))
# sort sigs/labels: longest -> shortest
tensors, labels = zip(*[(b, c) for (a,b,c) in sorted(zip(lengths, tensors, labels), key=lambda x: x[0], reverse=True)])
self.max_len = tensors[0].size(0)
torch.save(
(tensors, labels),
os.path.join(
self.root,
self.processed_folder,
self.processed_file
)
)
if not self.dev_mode:
shutil.rmtree(raw_abs_dir, ignore_errors=True)
print('Done!')
from __future__ import division
import torch
import numpy as np
class Compose(object):
"""Composes several transforms together.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.Scale(),
>>> transforms.PadTrim(max_len=16000),
>>> ])
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, audio):
for t in self.transforms:
audio = t(audio)
return audio
class Scale(object):
"""Scale audio tensor from a 16-bit integer (represented as a FloatTensor)
to a floating point number between -1.0 and 1.0. Note the 16-bit number is
called the "bit depth" or "precision", not to be confused with "bit rate".
Args:
factor (float): maximum value of input tensor. default: 16-bit depth
"""
def __init__(self, factor=2**31):
self.factor = factor
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor of audio of size (Samples x Channels)
Returns:
Tensor: Scaled by the scale factor. (default between -1.0 and 1.0)
"""
if isinstance(tensor, (torch.LongTensor, torch.IntTensor)):
tensor = tensor.float()
return tensor / self.factor
class PadTrim(object):
"""Pad/Trim a 1d-Tensor (Signal or Labels)
"""
def __init__(self, max_len, fill_value=0):
self.max_len = max_len
self.fill_value = fill_value
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor of audio of size (Samples x Channels)
max_len (int): Length to which the tensor will be padded
Returns:
Tensor: (max_len x Channels)
"""
if self.max_len > tensor.size(0):
pad = torch.ones((self.max_len-tensor.size(0),
tensor.size(1))) * self.fill_value
pad = pad.type_as(tensor)
tensor = torch.cat((tensor, pad), dim=0)
elif self.max_len < tensor.size(0):
tensor = tensor[:self.max_len, :]
return tensor
class DownmixMono(object):
"""Downmix any stereo signals to mono
"""
def __init__(self):
pass
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor of audio of size (Samples x Channels)
Returns:
Tensor: (Samples x 1)
"""
if isinstance(tensor, (torch.LongTensor, torch.IntTensor)):
tensor = tensor.float()
if tensor.size(1) > 1:
tensor = torch.mean(tensor.float(), 1, True)
return tensor
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment