Commit 3c09fd44 authored by Soumith Chintala's avatar Soumith Chintala Committed by GitHub
Browse files

Merge pull request #11 from dhpollack/VCTK

basic transforms and two datasets
parents ecb538df ff47955c
......@@ -5,6 +5,7 @@ __pycache__/
# C extensions
*.so
_ext/
# Distribution / packaging
.Python
......@@ -68,7 +69,7 @@ docs/_build/
target/
# Jupyter Notebook
.ipynb_checkpoints
.ipynb_checkpoints/
# pyenv
.python-version
......@@ -100,5 +101,5 @@ ENV/
# mypy
.mypy_cache/
# Jupyter Notebooks
.ipynb_checkpoints/
# Generated Files
test/assets/sinewave.wav
......@@ -2,17 +2,32 @@
import torch
import torch.nn as nn
import torchaudio
import math
x, sample_rate = torchaudio.load("steam-train-whistle-daniel_simon.mp3")
steam_train = "assets/steam-train-whistle-daniel_simon.mp3"
x, sample_rate = torchaudio.load(steam_train)
print(sample_rate)
print(x.size())
print(x[10000])
print(x.min(), x.max())
print(x.mean(), x.std())
x, sample_rate = torchaudio.load("steam-train-whistle-daniel_simon.mp3",
x, sample_rate = torchaudio.load(steam_train,
out=torch.LongTensor())
print(sample_rate)
print(x.size())
print(x[10000])
print(x.min(), x.max())
sine_wave = "assets/sinewave.wav"
sr = 16000
freq = 440
volume = 0.3
y = (torch.cos(2*math.pi*torch.arange(0, 4*sr) * freq/sr)).float()
y.unsqueeze_(1)
# y is between -1 and 1, so must scale
y = (y*volume*2**31).long()
torchaudio.save(sine_wave, y, sr)
print(y.min(), y.max())
import torch
import torchaudio
import torchaudio.transforms as transforms
import numpy as np
import unittest
STEAM_TRAIN = "assets/steam-train-whistle-daniel_simon.mp3"
class Tester(unittest.TestCase):
sr = 16000
freq = 440
volume = 0.3
sig = (torch.cos(2*np.pi*torch.arange(0, 4*sr) * freq/sr)).float()
sig.unsqueeze_(1)
sig = (sig*volume*2**31).long()
def test_scale(self):
audio_orig = self.sig.clone()
result = transforms.Scale()(audio_orig)
self.assertTrue(result.min() >= -1. and result.max() <= 1.,
"min: {}, max: {}".format(result.min(), result.max()))
maxminmax = np.abs([audio_orig.min(), audio_orig.max()]).max().astype(np.float)
result = transforms.Scale(factor=maxminmax)(audio_orig)
self.assertTrue((result.min() == -1. or result.max() == 1.) and
result.min() >= -1. and result.max() <= 1.,
"min: {}, max: {}".format(result.min(), result.max()))
def test_pad_trim(self):
audio_orig = self.sig.clone()
length_orig = audio_orig.size(0)
length_new = int(length_orig * 1.2)
result = transforms.PadTrim(max_len=length_new)(audio_orig)
self.assertTrue(result.size(0) == length_new,
"old size: {}, new size: {}".format(audio_orig.size(0), result.size(0)))
audio_orig = self.sig.clone()
length_orig = audio_orig.size(0)
length_new = int(length_orig * 0.8)
result = transforms.PadTrim(max_len=length_new)(audio_orig)
self.assertTrue(result.size(0) == length_new,
"old size: {}, new size: {}".format(audio_orig.size(0), result.size(0)))
def test_downmix_mono(self):
audio_L = self.sig.clone()
audio_R = self.sig.clone()
R_idx = int(audio_R.size(0) * 0.1)
audio_R = torch.cat((audio_R[R_idx:], audio_R[:R_idx]))
audio_Stereo = torch.cat((audio_L, audio_R), dim=1)
self.assertTrue(audio_Stereo.size(1) == 2)
result = transforms.DownmixMono()(audio_Stereo)
self.assertTrue(result.size(1) == 1)
def test_compose(self):
audio_orig = self.sig.clone()
length_orig = audio_orig.size(0)
length_new = int(length_orig * 1.2)
maxminmax = np.abs([audio_orig.min(), audio_orig.max()]).max().astype(np.float)
tset = (transforms.Scale(factor=maxminmax),
transforms.PadTrim(max_len=length_new))
result = transforms.Compose(tset)(audio_orig)
self.assertTrue(np.abs([result.min(), result.max()]).max() == 1.)
self.assertTrue(result.size(0) == length_new)
if __name__ == '__main__':
unittest.main()
......@@ -7,6 +7,8 @@ from cffi import FFI
ffi = FFI()
from ._ext import th_sox
from torchaudio import transforms
from torchaudio import datasets
def check_input(src):
if not torch.is_tensor(src):
......@@ -23,7 +25,7 @@ def load(filename, out=None):
typename = type(out).__name__.replace('Tensor', '')
func = getattr(th_sox, 'libthsox_{}_read_audio_file'.format(typename))
sample_rate_p = ffi.new('int*')
func(str(filename).encode("ascii"), out, sample_rate_p)
func(str(filename).encode("utf-8"), out, sample_rate_p)
sample_rate = sample_rate_p[0]
return out, sample_rate
......@@ -37,4 +39,4 @@ def save(filepath, src, sample_rate):
typename = type(src).__name__.replace('Tensor', '')
func = getattr(th_sox, 'libthsox_{}_write_audio_file'.format(typename))
func(bytes(filepath, "ascii"), src, extension[1:], sample_rate)
func(bytes(filepath, "utf-8"), src, bytes(extension[1:], "utf-8"), sample_rate)
from .yesno import YESNO
from .vctk import VCTK
__all__ = ('YESNO', 'VCTK')
from __future__ import print_function
import torch.utils.data as data
import os
import os.path
import shutil
import errno
import torch
import torchaudio
AUDIO_EXTENSIONS = [
'.wav', '.mp3', '.flac', '.sph', '.ogg', '.opus',
'.WAV', '.MP3', '.FLAC', '.SPH', '.OGG', '.OPUS',
]
def is_audio_file(filename):
return any(filename.endswith(extension) for extension in AUDIO_EXTENSIONS)
def make_manifest(dir):
audios = []
dir = os.path.expanduser(dir)
for target in sorted(os.listdir(dir)):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in fnames:
if is_audio_file(fname):
path = os.path.join(root, fname)
item = path
audios.append(item)
return audios
def read_audio(fp, downsample=True):
sig, sr = torchaudio.load(fp)
if downsample:
# 48khz -> 16 khz
if sig.size(0) % 3 == 0:
sig = sig[::3].contiguous()
else:
sig = sig[:-(sig.size(0) % 3):3].contiguous()
return sig, sr
def load_txts(dir):
"""Create a dictionary with all the text of the audio transcriptions."""
utterences = dict()
txts = []
dir = os.path.expanduser(dir)
for target in sorted(os.listdir(dir)):
d = os.path.join(dir, target)
if not os.path.isdir(d):
continue
for root, _, fnames in sorted(os.walk(d)):
for fname in fnames:
if fname.endswith(".txt"):
with open(os.path.join(root, fname), "r") as f:
fname_no_ext = os.path.basename(fname).rsplit(".", 1)[0]
utterences[fname_no_ext] = f.readline()
#utterences += f.readlines()
#utterences = dict([tuple(u.strip().split(" ", 1)) for u in utterences])
return utterences
class VCTK(data.Dataset):
"""`VCTK <http://homepages.inf.ed.ac.uk/jyamagis/page3/page58/page58.html>`_ Dataset.
`alternate url <http://datashare.is.ed.ac.uk/handle/10283/2651>`
Args:
root (string): Root directory of dataset where ``processed/training.pt``
and ``processed/test.pt`` exist.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.Scale``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
dev_mode(bool, optional): if true, clean up is not performed on downloaded
files. Useful to keep raw audio and transcriptions.
"""
raw_folder = 'vctk/raw'
processed_folder = 'vctk/processed'
url = 'http://homepages.inf.ed.ac.uk/jyamagis/release/VCTK-Corpus.tar.gz'
dset_path = 'VCTK-Corpus'
def __init__(self, root, downsample=True, transform=None, target_transform=None, download=False, dev_mode=False):
self.root = os.path.expanduser(root)
self.downsample = downsample
self.transform = transform
self.target_transform = target_transform
self.dev_mode = dev_mode
self.data = []
self.labels = []
self.chunk_size = 1000
self.num_samples = 0
self.max_len = 0
self.cached_pt = 0
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
self._read_info()
self.data, self.labels = torch.load(os.path.join(self.root, self.processed_folder, "vctk_{:04d}.pt".format(self.cached_pt)))
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
if self.cached_pt != index // self.chunk_size:
self.cached_pt = int(index // self.chunk_size)
self.data, self.labels = torch.load(os.path.join(self.root, self.processed_folder, "vctk_{:04d}.pt".format(self.cached_pt)))
index = index % self.chunk_size
audio, target = self.data[index], self.labels[index]
if self.transform is not None:
audio = self.transform(audio)
if self.target_transform is not None:
target = self.target_transform(target)
return audio, target
def __len__(self):
return self.num_samples
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, "vctk_info.txt"))
def _write_info(self, num_items):
info_path = os.path.join(self.root, self.processed_folder, "vctk_info.txt")
with open(info_path, "w") as f:
f.write("num_samples,{}\n".format(num_items))
f.write("max_len,{}\n".format(self.max_len))
def _read_info(self):
info_path = os.path.join(self.root, self.processed_folder, "vctk_info.txt")
with open(info_path, "r") as f:
self.num_samples = int(f.readline().split(",")[1])
self.max_len = int(f.readline().split(",")[1])
def download(self):
"""Download the VCTK data if it doesn't exist in processed_folder already."""
from six.moves import urllib
import tarfile
if self._check_exists():
return
raw_abs_dir = os.path.join(self.root, self.raw_folder)
processed_abs_dir = os.path.join(self.root, self.processed_folder)
dset_abs_path = os.path.join(self.root, self.raw_folder, self.dset_path)
# download files
try:
os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
url = self.url
print('Downloading ' + url)
filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, self.raw_folder, filename)
if not os.path.isfile(file_path):
data = urllib.request.urlopen(url)
with open(file_path, 'wb') as f:
f.write(data.read())
if not os.path.exists(dset_abs_path):
with tarfile.open(file_path) as zip_f:
zip_f.extractall(raw_abs_dir)
else:
print("Using existing raw folder")
if not self.dev_mode:
os.unlink(file_path)
# process and save as torch files
print('Processing...')
shutil.copyfile(
os.path.join(dset_abs_path, "COPYING"),
os.path.join(processed_abs_dir, "VCTK_COPYING")
)
audios = make_manifest(dset_abs_path)
utterences = load_txts(dset_abs_path)
self.max_len = 0
print("Found {} audio files and {} utterences".format(len(audios), len(utterences)))
for n in range(len(audios) // self.chunk_size + 1):
tensors = []
labels = []
lengths = []
st_idx = n * self.chunk_size
end_idx = st_idx + self.chunk_size
for i, f in enumerate(audios[st_idx:end_idx]):
txt_dir = os.path.dirname(f).replace("wav48", "txt")
if os.path.exists(txt_dir):
f_rel_no_ext = os.path.basename(f).rsplit(".", 1)[0]
sig = read_audio(f, downsample=self.downsample)[0]
tensors.append(sig)
lengths.append(sig.size(0))
labels.append(utterences[f_rel_no_ext])
self.max_len = sig.size(0) if sig.size(0) > self.max_len else self.max_len
# sort sigs/labels: longest -> shortest
tensors, labels = zip(*[(b, c) for (a,b,c) in sorted(zip(lengths, tensors, labels), key=lambda x: x[0], reverse=True)])
data = (tensors, labels)
torch.save(
data,
os.path.join(
self.root,
self.processed_folder,
"vctk_{:04d}.pt".format(n)
)
)
self._write_info((n*self.chunk_size)+i+1)
if not self.dev_mode:
shutil.rmtree(raw_abs_dir, ignore_errors=True)
print('Done!')
from __future__ import print_function
import torch.utils.data as data
import os
import os.path
import shutil
import errno
import torch
import torchaudio
class YESNO(data.Dataset):
"""`YesNo Hebrew <http://www.openslr.org/1/>`_ Dataset.
Args:
root (string): Root directory of dataset where ``processed/training.pt``
and ``processed/test.pt`` exist.
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.Scale``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
dev_mode(bool, optional): if true, clean up is not performed on downloaded
files. Useful to keep raw audio and transcriptions.
"""
raw_folder = 'yesno/raw'
processed_folder = 'yesno/processed'
url = 'http://www.openslr.org/resources/1/waves_yesno.tar.gz'
dset_path = 'waves_yesno'
processed_file = 'yesno.pt'
def __init__(self, root, transform=None, target_transform=None, download=False, dev_mode=False):
self.root = os.path.expanduser(root)
self.transform = transform
self.target_transform = target_transform
self.dev_mode = dev_mode
self.data = []
self.labels = []
self.num_samples = 0
self.max_len = 0
if download:
self.download()
if not self._check_exists():
raise RuntimeError('Dataset not found.' +
' You can use download=True to download it')
self.data, self.labels = torch.load(os.path.join(self.root, self.processed_folder, self.processed_file))
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
audio, target = self.data[index], self.labels[index]
if self.transform is not None:
audio = self.transform(audio)
if self.target_transform is not None:
target = self.target_transform(target)
return audio, target
def __len__(self):
return len(self.data)
def _check_exists(self):
return os.path.exists(os.path.join(self.root, self.processed_folder, self.processed_file))
def download(self):
"""Download the yesno data if it doesn't exist in processed_folder already."""
from six.moves import urllib
import tarfile
if self._check_exists():
return
raw_abs_dir = os.path.join(self.root, self.raw_folder)
processed_abs_dir = os.path.join(self.root, self.processed_folder)
dset_abs_path = os.path.join(self.root, self.raw_folder, self.dset_path)
# download files
try:
os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
url = self.url
print('Downloading ' + url)
filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, self.raw_folder, filename)
if not os.path.isfile(file_path):
data = urllib.request.urlopen(url)
with open(file_path, 'wb') as f:
f.write(data.read())
else:
print("Tar file already downloaded")
if not os.path.exists(dset_abs_path):
with tarfile.open(file_path) as zip_f:
zip_f.extractall(raw_abs_dir)
else:
print("Tar file already extracted")
if not self.dev_mode:
os.unlink(file_path)
# process and save as torch files
print('Processing...')
shutil.copyfile(
os.path.join(dset_abs_path, "README"),
os.path.join(processed_abs_dir, "YESNO_README")
)
audios = [x for x in os.listdir(dset_abs_path) if ".wav" in x]
print("Found {} audio files".format(len(audios)))
tensors = []
labels = []
lengths = []
for i, f in enumerate(audios):
full_path = os.path.join(dset_abs_path, f)
sig, sr = torchaudio.load(full_path)
tensors.append(sig)
lengths.append(sig.size(0))
labels.append(os.path.basename(f).split(".", 1)[0].split("_"))
# sort sigs/labels: longest -> shortest
tensors, labels = zip(*[(b, c) for (a,b,c) in sorted(zip(lengths, tensors, labels), key=lambda x: x[0], reverse=True)])
self.max_len = tensors[0].size(0)
torch.save(
(tensors, labels),
os.path.join(
self.root,
self.processed_folder,
self.processed_file
)
)
if not self.dev_mode:
shutil.rmtree(raw_abs_dir, ignore_errors=True)
print('Done!')
from __future__ import division
import torch
import numpy as np
class Compose(object):
"""Composes several transforms together.
Args:
transforms (list of ``Transform`` objects): list of transforms to compose.
Example:
>>> transforms.Compose([
>>> transforms.Scale(),
>>> transforms.PadTrim(max_len=16000),
>>> ])
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, audio):
for t in self.transforms:
audio = t(audio)
return audio
class Scale(object):
"""Scale audio tensor from a 16-bit integer (represented as a FloatTensor)
to a floating point number between -1.0 and 1.0. Note the 16-bit number is
called the "bit depth" or "precision", not to be confused with "bit rate".
Args:
factor (float): maximum value of input tensor. default: 16-bit depth
"""
def __init__(self, factor=2**31):
self.factor = factor
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor of audio of size (Samples x Channels)
Returns:
Tensor: Scaled by the scale factor. (default between -1.0 and 1.0)
"""
if isinstance(tensor, (torch.LongTensor, torch.IntTensor)):
tensor = tensor.float()
return tensor / self.factor
class PadTrim(object):
"""Pad/Trim a 1d-Tensor (Signal or Labels)
"""
def __init__(self, max_len, fill_value=0):
self.max_len = max_len
self.fill_value = fill_value
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor of audio of size (Samples x Channels)
max_len (int): Length to which the tensor will be padded
Returns:
Tensor: (max_len x Channels)
"""
if self.max_len > tensor.size(0):
pad = torch.ones((self.max_len-tensor.size(0),
tensor.size(1))) * self.fill_value
pad = pad.type_as(tensor)
tensor = torch.cat((tensor, pad), dim=0)
elif self.max_len < tensor.size(0):
tensor = tensor[:self.max_len, :]
return tensor
class DownmixMono(object):
"""Downmix any stereo signals to mono
"""
def __init__(self):
pass
def __call__(self, tensor):
"""
Args:
tensor (Tensor): Tensor of audio of size (Samples x Channels)
Returns:
Tensor: (Samples x 1)
"""
if isinstance(tensor, (torch.LongTensor, torch.IntTensor)):
tensor = tensor.float()
if tensor.size(1) > 1:
tensor = torch.mean(tensor.float(), 1, True)
return tensor
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment