"vscode:/vscode.git/clone" did not exist on "c480d4e4e213a850cced7758f7b62c20caad8820"
Commit 9dcc7a15 authored by flyingdown's avatar flyingdown
Browse files

init v0.10.0

parent db2b0b79
Pipeline #254 failed with stages
in 0 seconds
#include <torch/script.h>
namespace torchaudio {
namespace {
bool is_sox_available() {
#ifdef INCLUDE_SOX
return true;
#else
return false;
#endif
}
bool is_kaldi_available() {
#ifdef INCLUDE_KALDI
return true;
#else
return false;
#endif
}
} // namespace
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::is_sox_available", &is_sox_available);
m.def("torchaudio::is_kaldi_available", &is_kaldi_available);
}
} // namespace torchaudio
from .commonvoice import COMMONVOICE
from .librispeech import LIBRISPEECH
from .speechcommands import SPEECHCOMMANDS
from .utils import bg_iterator, diskcache_iterator
from .vctk import VCTK, VCTK_092
from .gtzan import GTZAN
from .yesno import YESNO
from .ljspeech import LJSPEECH
from .cmuarctic import CMUARCTIC
from .cmudict import CMUDict
from .librimix import LibriMix
from .libritts import LIBRITTS
from .tedlium import TEDLIUM
__all__ = [
"COMMONVOICE",
"LIBRISPEECH",
"SPEECHCOMMANDS",
"VCTK",
"VCTK_092",
"YESNO",
"LJSPEECH",
"GTZAN",
"CMUARCTIC",
"CMUDict",
"LibriMix",
"LIBRITTS",
"diskcache_iterator",
"bg_iterator",
"TEDLIUM",
]
import os
import csv
from pathlib import Path
from typing import Tuple, Union
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from torchaudio.datasets.utils import (
download_url,
extract_archive,
)
URL = "aew"
FOLDER_IN_ARCHIVE = "ARCTIC"
_CHECKSUMS = {
"http://festvox.org/cmu_arctic/packed/cmu_us_aew_arctic.tar.bz2":
"4382b116efcc8339c37e01253cb56295",
"http://festvox.org/cmu_arctic/packed/cmu_us_ahw_arctic.tar.bz2":
"b072d6e961e3f36a2473042d097d6da9",
"http://festvox.org/cmu_arctic/packed/cmu_us_aup_arctic.tar.bz2":
"5301c7aee8919d2abd632e2667adfa7f",
"http://festvox.org/cmu_arctic/packed/cmu_us_awb_arctic.tar.bz2":
"280fdff1e9857119d9a2c57b50e12db7",
"http://festvox.org/cmu_arctic/packed/cmu_us_axb_arctic.tar.bz2":
"5e21cb26c6529c533df1d02ccde5a186",
"http://festvox.org/cmu_arctic/packed/cmu_us_bdl_arctic.tar.bz2":
"b2c3e558f656af2e0a65da0ac0c3377a",
"http://festvox.org/cmu_arctic/packed/cmu_us_clb_arctic.tar.bz2":
"3957c503748e3ce17a3b73c1b9861fb0",
"http://festvox.org/cmu_arctic/packed/cmu_us_eey_arctic.tar.bz2":
"59708e932d27664f9eda3e8e6859969b",
"http://festvox.org/cmu_arctic/packed/cmu_us_fem_arctic.tar.bz2":
"dba4f992ff023347c07c304bf72f4c73",
"http://festvox.org/cmu_arctic/packed/cmu_us_gka_arctic.tar.bz2":
"24a876ea7335c1b0ff21460e1241340f",
"http://festvox.org/cmu_arctic/packed/cmu_us_jmk_arctic.tar.bz2":
"afb69d95f02350537e8a28df5ab6004b",
"http://festvox.org/cmu_arctic/packed/cmu_us_ksp_arctic.tar.bz2":
"4ce5b3b91a0a54b6b685b1b05aa0b3be",
"http://festvox.org/cmu_arctic/packed/cmu_us_ljm_arctic.tar.bz2":
"6f45a3b2c86a4ed0465b353be291f77d",
"http://festvox.org/cmu_arctic/packed/cmu_us_lnh_arctic.tar.bz2":
"c6a15abad5c14d27f4ee856502f0232f",
"http://festvox.org/cmu_arctic/packed/cmu_us_rms_arctic.tar.bz2":
"71072c983df1e590d9e9519e2a621f6e",
"http://festvox.org/cmu_arctic/packed/cmu_us_rxr_arctic.tar.bz2":
"3771ff03a2f5b5c3b53aa0a68b9ad0d5",
"http://festvox.org/cmu_arctic/packed/cmu_us_slp_arctic.tar.bz2":
"9cbf984a832ea01b5058ba9a96862850",
"http://festvox.org/cmu_arctic/packed/cmu_us_slt_arctic.tar.bz2":
"959eecb2cbbc4ac304c6b92269380c81",
}
def load_cmuarctic_item(line: str,
path: str,
folder_audio: str,
ext_audio: str) -> Tuple[Tensor, int, str, str]:
utterance_id, transcript = line[0].strip().split(" ", 2)[1:]
# Remove space, double quote, and single parenthesis from transcript
transcript = transcript[1:-3]
file_audio = os.path.join(path, folder_audio, utterance_id + ext_audio)
# Load audio
waveform, sample_rate = torchaudio.load(file_audio)
return (
waveform,
sample_rate,
transcript,
utterance_id.split("_")[1]
)
class CMUARCTIC(Dataset):
"""Create a Dataset for CMU_ARCTIC.
Args:
root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional):
The URL to download the dataset from or the type of the dataset to dowload.
(default: ``"aew"``)
Allowed type values are ``"aew"``, ``"ahw"``, ``"aup"``, ``"awb"``, ``"axb"``, ``"bdl"``,
``"clb"``, ``"eey"``, ``"fem"``, ``"gka"``, ``"jmk"``, ``"ksp"``, ``"ljm"``, ``"lnh"``,
``"rms"``, ``"rxr"``, ``"slp"`` or ``"slt"``.
folder_in_archive (str, optional):
The top-level directory of the dataset. (default: ``"ARCTIC"``)
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
"""
_file_text = "txt.done.data"
_folder_text = "etc"
_ext_audio = ".wav"
_folder_audio = "wav"
def __init__(self,
root: Union[str, Path],
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False) -> None:
if url in [
"aew",
"ahw",
"aup",
"awb",
"axb",
"bdl",
"clb",
"eey",
"fem",
"gka",
"jmk",
"ksp",
"ljm",
"lnh",
"rms",
"rxr",
"slp",
"slt"
]:
url = "cmu_us_" + url + "_arctic"
ext_archive = ".tar.bz2"
base_url = "http://www.festvox.org/cmu_arctic/packed/"
url = os.path.join(base_url, url + ext_archive)
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
basename = os.path.basename(url)
root = os.path.join(root, folder_in_archive)
if not os.path.isdir(root):
os.mkdir(root)
archive = os.path.join(root, basename)
basename = basename.split(".")[0]
self._path = os.path.join(root, basename)
if download:
if not os.path.isdir(self._path):
if not os.path.isfile(archive):
checksum = _CHECKSUMS.get(url, None)
download_url(url, root, hash_value=checksum, hash_type="md5")
extract_archive(archive)
self._text = os.path.join(self._path, self._folder_text, self._file_text)
with open(self._text, "r") as text:
walker = csv.reader(text, delimiter="\n")
self._walker = list(walker)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
(Tensor, int, str, str): ``(waveform, sample_rate, transcript, utterance_id)``
"""
line = self._walker[n]
return load_cmuarctic_item(line, self._path, self._folder_audio, self._ext_audio)
def __len__(self) -> int:
return len(self._walker)
import os
import re
from pathlib import Path
from typing import Iterable, Tuple, Union, List
from torch.utils.data import Dataset
from torchaudio.datasets.utils import download_url
_CHECKSUMS = {
"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b":
"825f4ebd9183f2417df9f067a9cabe86",
"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols":
"385e490aabc71b48e772118e3d02923e",
}
_PUNCTUATIONS = set([
"!EXCLAMATION-POINT",
"\"CLOSE-QUOTE",
"\"DOUBLE-QUOTE",
"\"END-OF-QUOTE",
"\"END-QUOTE",
"\"IN-QUOTES",
"\"QUOTE",
"\"UNQUOTE",
"#HASH-MARK",
"#POUND-SIGN",
"#SHARP-SIGN",
"%PERCENT",
"&AMPERSAND",
"'END-INNER-QUOTE",
"'END-QUOTE",
"'INNER-QUOTE",
"'QUOTE",
"'SINGLE-QUOTE",
"(BEGIN-PARENS",
"(IN-PARENTHESES",
"(LEFT-PAREN",
"(OPEN-PARENTHESES",
"(PAREN",
"(PARENS",
"(PARENTHESES",
")CLOSE-PAREN",
")CLOSE-PARENTHESES",
")END-PAREN",
")END-PARENS",
")END-PARENTHESES",
")END-THE-PAREN",
")PAREN",
")PARENS",
")RIGHT-PAREN",
")UN-PARENTHESES",
"+PLUS",
",COMMA",
"--DASH",
"-DASH",
"-HYPHEN",
"...ELLIPSIS",
".DECIMAL",
".DOT",
".FULL-STOP",
".PERIOD",
".POINT",
"/SLASH",
":COLON",
";SEMI-COLON",
";SEMI-COLON(1)",
"?QUESTION-MARK",
"{BRACE",
"{LEFT-BRACE",
"{OPEN-BRACE",
"}CLOSE-BRACE",
"}RIGHT-BRACE",
])
def _parse_dictionary(lines: Iterable[str], exclude_punctuations: bool) -> List[str]:
_alt_re = re.compile(r'\([0-9]+\)')
cmudict: List[Tuple[str, List[str]]] = list()
for line in lines:
if not line or line.startswith(';;;'): # ignore comments
continue
word, phones = line.strip().split(' ')
if word in _PUNCTUATIONS:
if exclude_punctuations:
continue
# !EXCLAMATION-POINT -> !
# --DASH -> --
# ...ELLIPSIS -> ...
if word.startswith("..."):
word = "..."
elif word.startswith("--"):
word = "--"
else:
word = word[0]
# if a word have multiple pronunciations, there will be (number) appended to it
# for example, DATAPOINTS and DATAPOINTS(1),
# the regular expression `_alt_re` removes the '(1)' and change the word DATAPOINTS(1) to DATAPOINTS
word = re.sub(_alt_re, '', word)
phones = phones.split(" ")
cmudict.append((word, phones))
return cmudict
class CMUDict(Dataset):
"""Create a Dataset for CMU Pronouncing Dictionary (CMUDict).
Args:
root (str or Path): Path to the directory where the dataset is found or downloaded.
exclude_punctuations (bool, optional):
When enabled, exclude the pronounciation of punctuations, such as
`!EXCLAMATION-POINT` and `#HASH-MARK`.
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
url (str, optional):
The URL to download the dictionary from.
(default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b"``)
url_symbols (str, optional):
The URL to download the list of symbols from.
(default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols"``)
"""
def __init__(self,
root: Union[str, Path],
exclude_punctuations: bool = True,
*,
download: bool = False,
url: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b",
url_symbols: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols",
) -> None:
self.exclude_punctuations = exclude_punctuations
self._root_path = Path(root)
if not os.path.isdir(self._root_path):
raise RuntimeError(f'The root directory does not exist; {root}')
dict_file = self._root_path / os.path.basename(url)
symbol_file = self._root_path / os.path.basename(url_symbols)
if not os.path.exists(dict_file):
if not download:
raise RuntimeError(
'The dictionary file is not found in the following location. '
f'Set `download=True` to download it. {dict_file}')
checksum = _CHECKSUMS.get(url, None)
download_url(url, root, hash_value=checksum, hash_type="md5")
if not os.path.exists(symbol_file):
if not download:
raise RuntimeError(
'The symbol file is not found in the following location. '
f'Set `download=True` to download it. {symbol_file}')
checksum = _CHECKSUMS.get(url_symbols, None)
download_url(url_symbols, root, hash_value=checksum, hash_type="md5")
with open(symbol_file, "r") as text:
self._symbols = [line.strip() for line in text.readlines()]
with open(dict_file, "r", encoding='latin-1') as text:
self._dictionary = _parse_dictionary(
text.readlines(), exclude_punctuations=self.exclude_punctuations)
def __getitem__(self, n: int) -> Tuple[str, List[str]]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded.
Returns:
(str, List[str]): The corresponding word and phonemes ``(word, [phonemes])``.
"""
return self._dictionary[n]
def __len__(self) -> int:
return len(self._dictionary)
@property
def symbols(self) -> List[str]:
"""list[str]: A list of phonemes symbols, such as `AA`, `AE`, `AH`.
"""
return self._symbols.copy()
import csv
import os
from pathlib import Path
from typing import List, Dict, Tuple, Union
from torch import Tensor
from torch.utils.data import Dataset
import torchaudio
def load_commonvoice_item(line: List[str],
header: List[str],
path: str,
folder_audio: str,
ext_audio: str) -> Tuple[Tensor, int, Dict[str, str]]:
# Each line as the following data:
# client_id, path, sentence, up_votes, down_votes, age, gender, accent
assert header[1] == "path"
fileid = line[1]
filename = os.path.join(path, folder_audio, fileid)
if not filename.endswith(ext_audio):
filename += ext_audio
waveform, sample_rate = torchaudio.load(filename)
dic = dict(zip(header, line))
return waveform, sample_rate, dic
class COMMONVOICE(Dataset):
"""Create a Dataset for CommonVoice.
Args:
root (str or Path): Path to the directory where the dataset is located.
(Where the ``tsv`` file is present.)
tsv (str, optional):
The name of the tsv file used to construct the metadata, such as
``"train.tsv"``, ``"test.tsv"``, ``"dev.tsv"``, ``"invalidated.tsv"``,
``"validated.tsv"`` and ``"other.tsv"``. (default: ``"train.tsv"``)
"""
_ext_txt = ".txt"
_ext_audio = ".mp3"
_folder_audio = "clips"
def __init__(self,
root: Union[str, Path],
tsv: str = "train.tsv") -> None:
# Get string representation of 'root' in case Path object is passed
self._path = os.fspath(root)
self._tsv = os.path.join(self._path, tsv)
with open(self._tsv, "r") as tsv_:
walker = csv.reader(tsv_, delimiter="\t")
self._header = next(walker)
self._walker = list(walker)
def __getitem__(self, n: int) -> Tuple[Tensor, int, Dict[str, str]]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
(Tensor, int, Dict[str, str]): ``(waveform, sample_rate, dictionary)``, where dictionary
is built from the TSV file with the following keys: ``client_id``, ``path``, ``sentence``,
``up_votes``, ``down_votes``, ``age``, ``gender`` and ``accent``.
"""
line = self._walker[n]
return load_commonvoice_item(line, self._header, self._path, self._folder_audio, self._ext_audio)
def __len__(self) -> int:
return len(self._walker)
import os
from pathlib import Path
from typing import Tuple, Optional, Union
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from torchaudio.datasets.utils import (
download_url,
extract_archive,
)
# The following lists prefixed with `filtered_` provide a filtered split
# that:
#
# a. Mitigate a known issue with GTZAN (duplication)
#
# b. Provide a standard split for testing it against other
# methods (e.g. the one in jordipons/sklearn-audio-transfer-learning).
#
# Those are used when GTZAN is initialised with the `filtered` keyword.
# The split was taken from (github) jordipons/sklearn-audio-transfer-learning.
gtzan_genres = [
"blues",
"classical",
"country",
"disco",
"hiphop",
"jazz",
"metal",
"pop",
"reggae",
"rock",
]
filtered_test = [
"blues.00012",
"blues.00013",
"blues.00014",
"blues.00015",
"blues.00016",
"blues.00017",
"blues.00018",
"blues.00019",
"blues.00020",
"blues.00021",
"blues.00022",
"blues.00023",
"blues.00024",
"blues.00025",
"blues.00026",
"blues.00027",
"blues.00028",
"blues.00061",
"blues.00062",
"blues.00063",
"blues.00064",
"blues.00065",
"blues.00066",
"blues.00067",
"blues.00068",
"blues.00069",
"blues.00070",
"blues.00071",
"blues.00072",
"blues.00098",
"blues.00099",
"classical.00011",
"classical.00012",
"classical.00013",
"classical.00014",
"classical.00015",
"classical.00016",
"classical.00017",
"classical.00018",
"classical.00019",
"classical.00020",
"classical.00021",
"classical.00022",
"classical.00023",
"classical.00024",
"classical.00025",
"classical.00026",
"classical.00027",
"classical.00028",
"classical.00029",
"classical.00034",
"classical.00035",
"classical.00036",
"classical.00037",
"classical.00038",
"classical.00039",
"classical.00040",
"classical.00041",
"classical.00049",
"classical.00077",
"classical.00078",
"classical.00079",
"country.00030",
"country.00031",
"country.00032",
"country.00033",
"country.00034",
"country.00035",
"country.00036",
"country.00037",
"country.00038",
"country.00039",
"country.00040",
"country.00043",
"country.00044",
"country.00046",
"country.00047",
"country.00048",
"country.00050",
"country.00051",
"country.00053",
"country.00054",
"country.00055",
"country.00056",
"country.00057",
"country.00058",
"country.00059",
"country.00060",
"country.00061",
"country.00062",
"country.00063",
"country.00064",
"disco.00001",
"disco.00021",
"disco.00058",
"disco.00062",
"disco.00063",
"disco.00064",
"disco.00065",
"disco.00066",
"disco.00069",
"disco.00076",
"disco.00077",
"disco.00078",
"disco.00079",
"disco.00080",
"disco.00081",
"disco.00082",
"disco.00083",
"disco.00084",
"disco.00085",
"disco.00086",
"disco.00087",
"disco.00088",
"disco.00091",
"disco.00092",
"disco.00093",
"disco.00094",
"disco.00096",
"disco.00097",
"disco.00099",
"hiphop.00000",
"hiphop.00026",
"hiphop.00027",
"hiphop.00030",
"hiphop.00040",
"hiphop.00043",
"hiphop.00044",
"hiphop.00045",
"hiphop.00051",
"hiphop.00052",
"hiphop.00053",
"hiphop.00054",
"hiphop.00062",
"hiphop.00063",
"hiphop.00064",
"hiphop.00065",
"hiphop.00066",
"hiphop.00067",
"hiphop.00068",
"hiphop.00069",
"hiphop.00070",
"hiphop.00071",
"hiphop.00072",
"hiphop.00073",
"hiphop.00074",
"hiphop.00075",
"hiphop.00099",
"jazz.00073",
"jazz.00074",
"jazz.00075",
"jazz.00076",
"jazz.00077",
"jazz.00078",
"jazz.00079",
"jazz.00080",
"jazz.00081",
"jazz.00082",
"jazz.00083",
"jazz.00084",
"jazz.00085",
"jazz.00086",
"jazz.00087",
"jazz.00088",
"jazz.00089",
"jazz.00090",
"jazz.00091",
"jazz.00092",
"jazz.00093",
"jazz.00094",
"jazz.00095",
"jazz.00096",
"jazz.00097",
"jazz.00098",
"jazz.00099",
"metal.00012",
"metal.00013",
"metal.00014",
"metal.00015",
"metal.00022",
"metal.00023",
"metal.00025",
"metal.00026",
"metal.00027",
"metal.00028",
"metal.00029",
"metal.00030",
"metal.00031",
"metal.00032",
"metal.00033",
"metal.00038",
"metal.00039",
"metal.00067",
"metal.00070",
"metal.00073",
"metal.00074",
"metal.00075",
"metal.00078",
"metal.00083",
"metal.00085",
"metal.00087",
"metal.00088",
"pop.00000",
"pop.00001",
"pop.00013",
"pop.00014",
"pop.00043",
"pop.00063",
"pop.00064",
"pop.00065",
"pop.00066",
"pop.00069",
"pop.00070",
"pop.00071",
"pop.00072",
"pop.00073",
"pop.00074",
"pop.00075",
"pop.00076",
"pop.00077",
"pop.00078",
"pop.00079",
"pop.00082",
"pop.00088",
"pop.00089",
"pop.00090",
"pop.00091",
"pop.00092",
"pop.00093",
"pop.00094",
"pop.00095",
"pop.00096",
"reggae.00034",
"reggae.00035",
"reggae.00036",
"reggae.00037",
"reggae.00038",
"reggae.00039",
"reggae.00040",
"reggae.00046",
"reggae.00047",
"reggae.00048",
"reggae.00052",
"reggae.00053",
"reggae.00064",
"reggae.00065",
"reggae.00066",
"reggae.00067",
"reggae.00068",
"reggae.00071",
"reggae.00079",
"reggae.00082",
"reggae.00083",
"reggae.00084",
"reggae.00087",
"reggae.00088",
"reggae.00089",
"reggae.00090",
"rock.00010",
"rock.00011",
"rock.00012",
"rock.00013",
"rock.00014",
"rock.00015",
"rock.00027",
"rock.00028",
"rock.00029",
"rock.00030",
"rock.00031",
"rock.00032",
"rock.00033",
"rock.00034",
"rock.00035",
"rock.00036",
"rock.00037",
"rock.00039",
"rock.00040",
"rock.00041",
"rock.00042",
"rock.00043",
"rock.00044",
"rock.00045",
"rock.00046",
"rock.00047",
"rock.00048",
"rock.00086",
"rock.00087",
"rock.00088",
"rock.00089",
"rock.00090",
]
filtered_train = [
"blues.00029",
"blues.00030",
"blues.00031",
"blues.00032",
"blues.00033",
"blues.00034",
"blues.00035",
"blues.00036",
"blues.00037",
"blues.00038",
"blues.00039",
"blues.00040",
"blues.00041",
"blues.00042",
"blues.00043",
"blues.00044",
"blues.00045",
"blues.00046",
"blues.00047",
"blues.00048",
"blues.00049",
"blues.00073",
"blues.00074",
"blues.00075",
"blues.00076",
"blues.00077",
"blues.00078",
"blues.00079",
"blues.00080",
"blues.00081",
"blues.00082",
"blues.00083",
"blues.00084",
"blues.00085",
"blues.00086",
"blues.00087",
"blues.00088",
"blues.00089",
"blues.00090",
"blues.00091",
"blues.00092",
"blues.00093",
"blues.00094",
"blues.00095",
"blues.00096",
"blues.00097",
"classical.00030",
"classical.00031",
"classical.00032",
"classical.00033",
"classical.00043",
"classical.00044",
"classical.00045",
"classical.00046",
"classical.00047",
"classical.00048",
"classical.00050",
"classical.00051",
"classical.00052",
"classical.00053",
"classical.00054",
"classical.00055",
"classical.00056",
"classical.00057",
"classical.00058",
"classical.00059",
"classical.00060",
"classical.00061",
"classical.00062",
"classical.00063",
"classical.00064",
"classical.00065",
"classical.00066",
"classical.00067",
"classical.00080",
"classical.00081",
"classical.00082",
"classical.00083",
"classical.00084",
"classical.00085",
"classical.00086",
"classical.00087",
"classical.00088",
"classical.00089",
"classical.00090",
"classical.00091",
"classical.00092",
"classical.00093",
"classical.00094",
"classical.00095",
"classical.00096",
"classical.00097",
"classical.00098",
"classical.00099",
"country.00019",
"country.00020",
"country.00021",
"country.00022",
"country.00023",
"country.00024",
"country.00025",
"country.00026",
"country.00028",
"country.00029",
"country.00065",
"country.00066",
"country.00067",
"country.00068",
"country.00069",
"country.00070",
"country.00071",
"country.00072",
"country.00073",
"country.00074",
"country.00075",
"country.00076",
"country.00077",
"country.00078",
"country.00079",
"country.00080",
"country.00081",
"country.00082",
"country.00083",
"country.00084",
"country.00085",
"country.00086",
"country.00087",
"country.00088",
"country.00089",
"country.00090",
"country.00091",
"country.00092",
"country.00093",
"country.00094",
"country.00095",
"country.00096",
"country.00097",
"country.00098",
"country.00099",
"disco.00005",
"disco.00015",
"disco.00016",
"disco.00017",
"disco.00018",
"disco.00019",
"disco.00020",
"disco.00022",
"disco.00023",
"disco.00024",
"disco.00025",
"disco.00026",
"disco.00027",
"disco.00028",
"disco.00029",
"disco.00030",
"disco.00031",
"disco.00032",
"disco.00033",
"disco.00034",
"disco.00035",
"disco.00036",
"disco.00037",
"disco.00039",
"disco.00040",
"disco.00041",
"disco.00042",
"disco.00043",
"disco.00044",
"disco.00045",
"disco.00047",
"disco.00049",
"disco.00053",
"disco.00054",
"disco.00056",
"disco.00057",
"disco.00059",
"disco.00061",
"disco.00070",
"disco.00073",
"disco.00074",
"disco.00089",
"hiphop.00002",
"hiphop.00003",
"hiphop.00004",
"hiphop.00005",
"hiphop.00006",
"hiphop.00007",
"hiphop.00008",
"hiphop.00009",
"hiphop.00010",
"hiphop.00011",
"hiphop.00012",
"hiphop.00013",
"hiphop.00014",
"hiphop.00015",
"hiphop.00016",
"hiphop.00017",
"hiphop.00018",
"hiphop.00019",
"hiphop.00020",
"hiphop.00021",
"hiphop.00022",
"hiphop.00023",
"hiphop.00024",
"hiphop.00025",
"hiphop.00028",
"hiphop.00029",
"hiphop.00031",
"hiphop.00032",
"hiphop.00033",
"hiphop.00034",
"hiphop.00035",
"hiphop.00036",
"hiphop.00037",
"hiphop.00038",
"hiphop.00041",
"hiphop.00042",
"hiphop.00055",
"hiphop.00056",
"hiphop.00057",
"hiphop.00058",
"hiphop.00059",
"hiphop.00060",
"hiphop.00061",
"hiphop.00077",
"hiphop.00078",
"hiphop.00079",
"hiphop.00080",
"jazz.00000",
"jazz.00001",
"jazz.00011",
"jazz.00012",
"jazz.00013",
"jazz.00014",
"jazz.00015",
"jazz.00016",
"jazz.00017",
"jazz.00018",
"jazz.00019",
"jazz.00020",
"jazz.00021",
"jazz.00022",
"jazz.00023",
"jazz.00024",
"jazz.00041",
"jazz.00047",
"jazz.00048",
"jazz.00049",
"jazz.00050",
"jazz.00051",
"jazz.00052",
"jazz.00053",
"jazz.00054",
"jazz.00055",
"jazz.00056",
"jazz.00057",
"jazz.00058",
"jazz.00059",
"jazz.00060",
"jazz.00061",
"jazz.00062",
"jazz.00063",
"jazz.00064",
"jazz.00065",
"jazz.00066",
"jazz.00067",
"jazz.00068",
"jazz.00069",
"jazz.00070",
"jazz.00071",
"jazz.00072",
"metal.00002",
"metal.00003",
"metal.00005",
"metal.00021",
"metal.00024",
"metal.00035",
"metal.00046",
"metal.00047",
"metal.00048",
"metal.00049",
"metal.00050",
"metal.00051",
"metal.00052",
"metal.00053",
"metal.00054",
"metal.00055",
"metal.00056",
"metal.00057",
"metal.00059",
"metal.00060",
"metal.00061",
"metal.00062",
"metal.00063",
"metal.00064",
"metal.00065",
"metal.00066",
"metal.00069",
"metal.00071",
"metal.00072",
"metal.00079",
"metal.00080",
"metal.00084",
"metal.00086",
"metal.00089",
"metal.00090",
"metal.00091",
"metal.00092",
"metal.00093",
"metal.00094",
"metal.00095",
"metal.00096",
"metal.00097",
"metal.00098",
"metal.00099",
"pop.00002",
"pop.00003",
"pop.00004",
"pop.00005",
"pop.00006",
"pop.00007",
"pop.00008",
"pop.00009",
"pop.00011",
"pop.00012",
"pop.00016",
"pop.00017",
"pop.00018",
"pop.00019",
"pop.00020",
"pop.00023",
"pop.00024",
"pop.00025",
"pop.00026",
"pop.00027",
"pop.00028",
"pop.00029",
"pop.00031",
"pop.00032",
"pop.00033",
"pop.00034",
"pop.00035",
"pop.00036",
"pop.00038",
"pop.00039",
"pop.00040",
"pop.00041",
"pop.00042",
"pop.00044",
"pop.00046",
"pop.00049",
"pop.00050",
"pop.00080",
"pop.00097",
"pop.00098",
"pop.00099",
"reggae.00000",
"reggae.00001",
"reggae.00002",
"reggae.00004",
"reggae.00006",
"reggae.00009",
"reggae.00011",
"reggae.00012",
"reggae.00014",
"reggae.00015",
"reggae.00016",
"reggae.00017",
"reggae.00018",
"reggae.00019",
"reggae.00020",
"reggae.00021",
"reggae.00022",
"reggae.00023",
"reggae.00024",
"reggae.00025",
"reggae.00026",
"reggae.00027",
"reggae.00028",
"reggae.00029",
"reggae.00030",
"reggae.00031",
"reggae.00032",
"reggae.00042",
"reggae.00043",
"reggae.00044",
"reggae.00045",
"reggae.00049",
"reggae.00050",
"reggae.00051",
"reggae.00054",
"reggae.00055",
"reggae.00056",
"reggae.00057",
"reggae.00058",
"reggae.00059",
"reggae.00060",
"reggae.00063",
"reggae.00069",
"rock.00000",
"rock.00001",
"rock.00002",
"rock.00003",
"rock.00004",
"rock.00005",
"rock.00006",
"rock.00007",
"rock.00008",
"rock.00009",
"rock.00016",
"rock.00017",
"rock.00018",
"rock.00019",
"rock.00020",
"rock.00021",
"rock.00022",
"rock.00023",
"rock.00024",
"rock.00025",
"rock.00026",
"rock.00057",
"rock.00058",
"rock.00059",
"rock.00060",
"rock.00061",
"rock.00062",
"rock.00063",
"rock.00064",
"rock.00065",
"rock.00066",
"rock.00067",
"rock.00068",
"rock.00069",
"rock.00070",
"rock.00091",
"rock.00092",
"rock.00093",
"rock.00094",
"rock.00095",
"rock.00096",
"rock.00097",
"rock.00098",
"rock.00099",
]
filtered_valid = [
"blues.00000",
"blues.00001",
"blues.00002",
"blues.00003",
"blues.00004",
"blues.00005",
"blues.00006",
"blues.00007",
"blues.00008",
"blues.00009",
"blues.00010",
"blues.00011",
"blues.00050",
"blues.00051",
"blues.00052",
"blues.00053",
"blues.00054",
"blues.00055",
"blues.00056",
"blues.00057",
"blues.00058",
"blues.00059",
"blues.00060",
"classical.00000",
"classical.00001",
"classical.00002",
"classical.00003",
"classical.00004",
"classical.00005",
"classical.00006",
"classical.00007",
"classical.00008",
"classical.00009",
"classical.00010",
"classical.00068",
"classical.00069",
"classical.00070",
"classical.00071",
"classical.00072",
"classical.00073",
"classical.00074",
"classical.00075",
"classical.00076",
"country.00000",
"country.00001",
"country.00002",
"country.00003",
"country.00004",
"country.00005",
"country.00006",
"country.00007",
"country.00009",
"country.00010",
"country.00011",
"country.00012",
"country.00013",
"country.00014",
"country.00015",
"country.00016",
"country.00017",
"country.00018",
"country.00027",
"country.00041",
"country.00042",
"country.00045",
"country.00049",
"disco.00000",
"disco.00002",
"disco.00003",
"disco.00004",
"disco.00006",
"disco.00007",
"disco.00008",
"disco.00009",
"disco.00010",
"disco.00011",
"disco.00012",
"disco.00013",
"disco.00014",
"disco.00046",
"disco.00048",
"disco.00052",
"disco.00067",
"disco.00068",
"disco.00072",
"disco.00075",
"disco.00090",
"disco.00095",
"hiphop.00081",
"hiphop.00082",
"hiphop.00083",
"hiphop.00084",
"hiphop.00085",
"hiphop.00086",
"hiphop.00087",
"hiphop.00088",
"hiphop.00089",
"hiphop.00090",
"hiphop.00091",
"hiphop.00092",
"hiphop.00093",
"hiphop.00094",
"hiphop.00095",
"hiphop.00096",
"hiphop.00097",
"hiphop.00098",
"jazz.00002",
"jazz.00003",
"jazz.00004",
"jazz.00005",
"jazz.00006",
"jazz.00007",
"jazz.00008",
"jazz.00009",
"jazz.00010",
"jazz.00025",
"jazz.00026",
"jazz.00027",
"jazz.00028",
"jazz.00029",
"jazz.00030",
"jazz.00031",
"jazz.00032",
"metal.00000",
"metal.00001",
"metal.00006",
"metal.00007",
"metal.00008",
"metal.00009",
"metal.00010",
"metal.00011",
"metal.00016",
"metal.00017",
"metal.00018",
"metal.00019",
"metal.00020",
"metal.00036",
"metal.00037",
"metal.00068",
"metal.00076",
"metal.00077",
"metal.00081",
"metal.00082",
"pop.00010",
"pop.00053",
"pop.00055",
"pop.00058",
"pop.00059",
"pop.00060",
"pop.00061",
"pop.00062",
"pop.00081",
"pop.00083",
"pop.00084",
"pop.00085",
"pop.00086",
"reggae.00061",
"reggae.00062",
"reggae.00070",
"reggae.00072",
"reggae.00074",
"reggae.00076",
"reggae.00077",
"reggae.00078",
"reggae.00085",
"reggae.00092",
"reggae.00093",
"reggae.00094",
"reggae.00095",
"reggae.00096",
"reggae.00097",
"reggae.00098",
"reggae.00099",
"rock.00038",
"rock.00049",
"rock.00050",
"rock.00051",
"rock.00052",
"rock.00053",
"rock.00054",
"rock.00055",
"rock.00056",
"rock.00071",
"rock.00072",
"rock.00073",
"rock.00074",
"rock.00075",
"rock.00076",
"rock.00077",
"rock.00078",
"rock.00079",
"rock.00080",
"rock.00081",
"rock.00082",
"rock.00083",
"rock.00084",
"rock.00085",
]
URL = "http://opihi.cs.uvic.ca/sound/genres.tar.gz"
FOLDER_IN_ARCHIVE = "genres"
_CHECKSUMS = {
"http://opihi.cs.uvic.ca/sound/genres.tar.gz": "5b3d6dddb579ab49814ab86dba69e7c7"
}
def load_gtzan_item(fileid: str, path: str, ext_audio: str) -> Tuple[Tensor, str]:
"""
Loads a file from the dataset and returns the raw waveform
as a Torch Tensor, its sample rate as an integer, and its
genre as a string.
"""
# Filenames are of the form label.id, e.g. blues.00078
label, _ = fileid.split(".")
# Read wav
file_audio = os.path.join(path, label, fileid + ext_audio)
waveform, sample_rate = torchaudio.load(file_audio)
return waveform, sample_rate, label
class GTZAN(Dataset):
"""Create a Dataset for GTZAN.
Note:
Please see http://marsyas.info/downloads/datasets.html if you are planning to use
this dataset to publish results.
Args:
root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional): The URL to download the dataset from.
(default: ``"http://opihi.cs.uvic.ca/sound/genres.tar.gz"``)
folder_in_archive (str, optional): The top-level directory of the dataset.
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
subset (str or None, optional): Which subset of the dataset to use.
One of ``"training"``, ``"validation"``, ``"testing"`` or ``None``.
If ``None``, the entire dataset is used. (default: ``None``).
"""
_ext_audio = ".wav"
def __init__(
self,
root: Union[str, Path],
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False,
subset: Optional[str] = None,
) -> None:
# super(GTZAN, self).__init__()
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
self.root = root
self.url = url
self.folder_in_archive = folder_in_archive
self.download = download
self.subset = subset
assert subset is None or subset in ["training", "validation", "testing"], (
"When `subset` not None, it must take a value from "
+ "{'training', 'validation', 'testing'}."
)
archive = os.path.basename(url)
archive = os.path.join(root, archive)
self._path = os.path.join(root, folder_in_archive)
if download:
if not os.path.isdir(self._path):
if not os.path.isfile(archive):
checksum = _CHECKSUMS.get(url, None)
download_url(url, root, hash_value=checksum, hash_type="md5")
extract_archive(archive)
if not os.path.isdir(self._path):
raise RuntimeError(
"Dataset not found. Please use `download=True` to download it."
)
if self.subset is None:
# Check every subdirectory under dataset root
# which has the same name as the genres in
# GTZAN (e.g. `root_dir'/blues/, `root_dir'/rock, etc.)
# This lets users remove or move around song files,
# useful when e.g. they want to use only some of the files
# in a genre or want to label other files with a different
# genre.
self._walker = []
root = os.path.expanduser(self._path)
for directory in gtzan_genres:
fulldir = os.path.join(root, directory)
if not os.path.exists(fulldir):
continue
songs_in_genre = os.listdir(fulldir)
songs_in_genre.sort()
for fname in songs_in_genre:
name, ext = os.path.splitext(fname)
if ext.lower() == ".wav" and "." in name:
# Check whether the file is of the form
# `gtzan_genre`.`5 digit number`.wav
genre, num = name.split(".")
if genre in gtzan_genres and len(num) == 5 and num.isdigit():
self._walker.append(name)
else:
if self.subset == "training":
self._walker = filtered_train
elif self.subset == "validation":
self._walker = filtered_valid
elif self.subset == "testing":
self._walker = filtered_test
def __getitem__(self, n: int) -> Tuple[Tensor, int, str]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
(Tensor, int, str): ``(waveform, sample_rate, label)``
"""
fileid = self._walker[n]
item = load_gtzan_item(fileid, self._path, self._ext_audio)
waveform, sample_rate, label = item
return waveform, sample_rate, label
def __len__(self) -> int:
return len(self._walker)
from pathlib import Path
from typing import Union, Tuple, List
import torch
from torch.utils.data import Dataset
import torchaudio
SampleType = Tuple[int, torch.Tensor, List[torch.Tensor]]
class LibriMix(Dataset):
r"""Create the LibriMix dataset.
Args:
root (str or Path): The path to the directory where the directory ``Libri2Mix`` or
``Libri3Mix`` is stored.
subset (str, optional): The subset to use. Options: [``train-360`, ``train-100``,
``dev``, and ``test``] (Default: ``train-360``).
num_speakers (int, optional): The number of speakers, which determines the directories
to traverse. The Dataset will traverse ``s1`` to ``sN`` directories to collect
N source audios. (Default: 2)
sample_rate (int, optional): sample rate of audio files. The ``sample_rate`` determines
which subdirectory the audio are fetched. If any of the audio has a different sample
rate, raises ``ValueError``. Options: [8000, 16000] (Default: 8000)
task (str, optional): the task of LibriMix.
Options: [``enh_single``, ``enh_both``, ``sep_clean``, ``sep_noisy``]
(Default: ``sep_clean``)
Note:
The LibriMix dataset needs to be manually generated. Please check https://github.com/JorisCos/LibriMix
"""
def __init__(
self,
root: Union[str, Path],
subset: str = "train-360",
num_speakers: int = 2,
sample_rate: int = 8000,
task: str = "sep_clean",
):
self.root = Path(root) / f"Libri{num_speakers}Mix"
if sample_rate == 8000:
self.root = self.root / "wav8k/min" / subset
elif sample_rate == 16000:
self.root = self.root / "wav16k/min" / subset
else:
raise ValueError(
f"Unsupported sample rate. Found {sample_rate}."
)
self.sample_rate = sample_rate
self.task = task
self.mix_dir = (self.root / f"mix_{task.split('_')[1]}").resolve()
self.src_dirs = [(self.root / f"s{i+1}").resolve() for i in range(num_speakers)]
self.files = [p.name for p in self.mix_dir.glob("*wav")]
self.files.sort()
def _load_audio(self, path) -> torch.Tensor:
waveform, sample_rate = torchaudio.load(path)
if sample_rate != self.sample_rate:
raise ValueError(
f"The dataset contains audio file of sample rate {sample_rate}, "
f"but the requested sample rate is {self.sample_rate}."
)
return waveform
def _load_sample(self, filename) -> SampleType:
mixed = self._load_audio(str(self.mix_dir / filename))
srcs = []
for i, dir_ in enumerate(self.src_dirs):
src = self._load_audio(str(dir_ / filename))
if mixed.shape != src.shape:
raise ValueError(
f"Different waveform shapes. mixed: {mixed.shape}, src[{i}]: {src.shape}"
)
srcs.append(src)
return self.sample_rate, mixed, srcs
def __len__(self) -> int:
return len(self.files)
def __getitem__(self, key: int) -> SampleType:
"""Load the n-th sample from the dataset.
Args:
key (int): The index of the sample to be loaded
Returns:
(int, Tensor, List[Tensor]): ``(sample_rate, mix_waveform, list_of_source_waveforms)``
"""
return self._load_sample(self.files[key])
import os
from typing import Tuple, Union
from pathlib import Path
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from torchaudio.datasets.utils import (
download_url,
extract_archive,
)
URL = "train-clean-100"
FOLDER_IN_ARCHIVE = "LibriSpeech"
_CHECKSUMS = {
"http://www.openslr.org/resources/12/dev-clean.tar.gz":
"76f87d090650617fca0cac8f88b9416e0ebf80350acb97b343a85fa903728ab3",
"http://www.openslr.org/resources/12/dev-other.tar.gz":
"12661c48e8c3fe1de2c1caa4c3e135193bfb1811584f11f569dd12645aa84365",
"http://www.openslr.org/resources/12/test-clean.tar.gz":
"39fde525e59672dc6d1551919b1478f724438a95aa55f874b576be21967e6c23",
"http://www.openslr.org/resources/12/test-other.tar.gz":
"d09c181bba5cf717b3dee7d4d592af11a3ee3a09e08ae025c5506f6ebe961c29",
"http://www.openslr.org/resources/12/train-clean-100.tar.gz":
"d4ddd1d5a6ab303066f14971d768ee43278a5f2a0aa43dc716b0e64ecbbbf6e2",
"http://www.openslr.org/resources/12/train-clean-360.tar.gz":
"146a56496217e96c14334a160df97fffedd6e0a04e66b9c5af0d40be3c792ecf",
"http://www.openslr.org/resources/12/train-other-500.tar.gz":
"ddb22f27f96ec163645d53215559df6aa36515f26e01dd70798188350adcb6d2"
}
def load_librispeech_item(fileid: str,
path: str,
ext_audio: str,
ext_txt: str) -> Tuple[Tensor, int, str, int, int, int]:
speaker_id, chapter_id, utterance_id = fileid.split("-")
file_text = speaker_id + "-" + chapter_id + ext_txt
file_text = os.path.join(path, speaker_id, chapter_id, file_text)
fileid_audio = speaker_id + "-" + chapter_id + "-" + utterance_id
file_audio = fileid_audio + ext_audio
file_audio = os.path.join(path, speaker_id, chapter_id, file_audio)
# Load audio
waveform, sample_rate = torchaudio.load(file_audio)
# Load text
with open(file_text) as ft:
for line in ft:
fileid_text, transcript = line.strip().split(" ", 1)
if fileid_audio == fileid_text:
break
else:
# Translation not found
raise FileNotFoundError("Translation not found for " + fileid_audio)
return (
waveform,
sample_rate,
transcript,
int(speaker_id),
int(chapter_id),
int(utterance_id),
)
class LIBRISPEECH(Dataset):
"""Create a Dataset for LibriSpeech.
Args:
root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional): The URL to download the dataset from,
or the type of the dataset to dowload.
Allowed type values are ``"dev-clean"``, ``"dev-other"``, ``"test-clean"``,
``"test-other"``, ``"train-clean-100"``, ``"train-clean-360"`` and
``"train-other-500"``. (default: ``"train-clean-100"``)
folder_in_archive (str, optional):
The top-level directory of the dataset. (default: ``"LibriSpeech"``)
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
"""
_ext_txt = ".trans.txt"
_ext_audio = ".flac"
def __init__(self,
root: Union[str, Path],
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False) -> None:
if url in [
"dev-clean",
"dev-other",
"test-clean",
"test-other",
"train-clean-100",
"train-clean-360",
"train-other-500",
]:
ext_archive = ".tar.gz"
base_url = "http://www.openslr.org/resources/12/"
url = os.path.join(base_url, url + ext_archive)
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
basename = os.path.basename(url)
archive = os.path.join(root, basename)
basename = basename.split(".")[0]
folder_in_archive = os.path.join(folder_in_archive, basename)
self._path = os.path.join(root, folder_in_archive)
if download:
if not os.path.isdir(self._path):
if not os.path.isfile(archive):
checksum = _CHECKSUMS.get(url, None)
download_url(url, root, hash_value=checksum)
extract_archive(archive)
self._walker = sorted(str(p.stem) for p in Path(self._path).glob('*/*/*' + self._ext_audio))
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
(Tensor, int, str, int, int, int):
``(waveform, sample_rate, transcript, speaker_id, chapter_id, utterance_id)``
"""
fileid = self._walker[n]
return load_librispeech_item(fileid, self._path, self._ext_audio, self._ext_txt)
def __len__(self) -> int:
return len(self._walker)
import os
from typing import Tuple, Union
from pathlib import Path
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from torchaudio.datasets.utils import (
download_url,
extract_archive,
)
URL = "train-clean-100"
FOLDER_IN_ARCHIVE = "LibriTTS"
_CHECKSUMS = {
"http://www.openslr.org/60/dev-clean.tar.gz": "0c3076c1e5245bb3f0af7d82087ee207",
"http://www.openslr.org/60/dev-other.tar.gz": "815555d8d75995782ac3ccd7f047213d",
"http://www.openslr.org/60/test-clean.tar.gz": "7bed3bdb047c4c197f1ad3bc412db59f",
"http://www.openslr.org/60/test-other.tar.gz": "ae3258249472a13b5abef2a816f733e4",
"http://www.openslr.org/60/train-clean-100.tar.gz": "4a8c202b78fe1bc0c47916a98f3a2ea8",
"http://www.openslr.org/60/train-clean-360.tar.gz": "a84ef10ddade5fd25df69596a2767b2d",
"http://www.openslr.org/60/train-other-500.tar.gz": "7b181dd5ace343a5f38427999684aa6f",
}
def load_libritts_item(
fileid: str,
path: str,
ext_audio: str,
ext_original_txt: str,
ext_normalized_txt: str,
) -> Tuple[Tensor, int, str, str, int, int, str]:
speaker_id, chapter_id, segment_id, utterance_id = fileid.split("_")
utterance_id = fileid
normalized_text = utterance_id + ext_normalized_txt
normalized_text = os.path.join(path, speaker_id, chapter_id, normalized_text)
original_text = utterance_id + ext_original_txt
original_text = os.path.join(path, speaker_id, chapter_id, original_text)
file_audio = utterance_id + ext_audio
file_audio = os.path.join(path, speaker_id, chapter_id, file_audio)
# Load audio
waveform, sample_rate = torchaudio.load(file_audio)
# Load original text
with open(original_text) as ft:
original_text = ft.readline()
# Load normalized text
with open(normalized_text, "r") as ft:
normalized_text = ft.readline()
return (
waveform,
sample_rate,
original_text,
normalized_text,
int(speaker_id),
int(chapter_id),
utterance_id,
)
class LIBRITTS(Dataset):
"""Create a Dataset for LibriTTS.
Args:
root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional): The URL to download the dataset from,
or the type of the dataset to dowload.
Allowed type values are ``"dev-clean"``, ``"dev-other"``, ``"test-clean"``,
``"test-other"``, ``"train-clean-100"``, ``"train-clean-360"`` and
``"train-other-500"``. (default: ``"train-clean-100"``)
folder_in_archive (str, optional):
The top-level directory of the dataset. (default: ``"LibriTTS"``)
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
"""
_ext_original_txt = ".original.txt"
_ext_normalized_txt = ".normalized.txt"
_ext_audio = ".wav"
def __init__(
self,
root: Union[str, Path],
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False,
) -> None:
if url in [
"dev-clean",
"dev-other",
"test-clean",
"test-other",
"train-clean-100",
"train-clean-360",
"train-other-500",
]:
ext_archive = ".tar.gz"
base_url = "http://www.openslr.org/resources/60/"
url = os.path.join(base_url, url + ext_archive)
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
basename = os.path.basename(url)
archive = os.path.join(root, basename)
basename = basename.split(".")[0]
folder_in_archive = os.path.join(folder_in_archive, basename)
self._path = os.path.join(root, folder_in_archive)
if download:
if not os.path.isdir(self._path):
if not os.path.isfile(archive):
checksum = _CHECKSUMS.get(url, None)
download_url(url, root, hash_value=checksum)
extract_archive(archive)
self._walker = sorted(str(p.stem) for p in Path(self._path).glob('*/*/*' + self._ext_audio))
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int, int, str]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
(Tensor, int, str, str, str, int, int, str):
``(waveform, sample_rate, original_text, normalized_text, speaker_id, chapter_id, utterance_id)``
"""
fileid = self._walker[n]
return load_libritts_item(
fileid,
self._path,
self._ext_audio,
self._ext_original_txt,
self._ext_normalized_txt,
)
def __len__(self) -> int:
return len(self._walker)
import os
import csv
from typing import Tuple, Union
from pathlib import Path
import torchaudio
from torchaudio.datasets.utils import download_url, extract_archive
from torch import Tensor
from torch.utils.data import Dataset
_RELEASE_CONFIGS = {
"release1": {
"folder_in_archive": "wavs",
"url": "https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2",
"checksum": "be1a30453f28eb8dd26af4101ae40cbf2c50413b1bb21936cbcdc6fae3de8aa5",
}
}
class LJSPEECH(Dataset):
"""Create a Dataset for LJSpeech-1.1.
Args:
root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional): The URL to download the dataset from.
(default: ``"https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"``)
folder_in_archive (str, optional):
The top-level directory of the dataset. (default: ``"wavs"``)
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
"""
def __init__(self,
root: Union[str, Path],
url: str = _RELEASE_CONFIGS["release1"]["url"],
folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"],
download: bool = False) -> None:
self._parse_filesystem(root, url, folder_in_archive, download)
def _parse_filesystem(self, root: str, url: str, folder_in_archive: str, download: bool) -> None:
root = Path(root)
basename = os.path.basename(url)
archive = root / basename
basename = Path(basename.split(".tar.bz2")[0])
folder_in_archive = basename / folder_in_archive
self._path = root / folder_in_archive
self._metadata_path = root / basename / 'metadata.csv'
if download:
if not os.path.isdir(self._path):
if not os.path.isfile(archive):
checksum = _RELEASE_CONFIGS["release1"]["checksum"]
download_url(url, root, hash_value=checksum)
extract_archive(archive)
with open(self._metadata_path, "r", newline='') as metadata:
flist = csv.reader(metadata, delimiter="|", quoting=csv.QUOTE_NONE)
self._flist = list(flist)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
(Tensor, int, str, str):
``(waveform, sample_rate, transcript, normalized_transcript)``
"""
line = self._flist[n]
fileid, transcript, normalized_transcript = line
fileid_audio = self._path / (fileid + ".wav")
# Load audio
waveform, sample_rate = torchaudio.load(fileid_audio)
return (
waveform,
sample_rate,
transcript,
normalized_transcript,
)
def __len__(self) -> int:
return len(self._flist)
import os
from typing import Tuple, Optional, Union
from pathlib import Path
import torchaudio
from torch.utils.data import Dataset
from torch import Tensor
from torchaudio.datasets.utils import (
download_url,
extract_archive,
)
FOLDER_IN_ARCHIVE = "SpeechCommands"
URL = "speech_commands_v0.02"
HASH_DIVIDER = "_nohash_"
EXCEPT_FOLDER = "_background_noise_"
_CHECKSUMS = {
"https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.01.tar.gz":
"3cd23799cb2bbdec517f1cc028f8d43c",
"https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz":
"6b74f3901214cb2c2934e98196829835",
}
def _load_list(root, *filenames):
output = []
for filename in filenames:
filepath = os.path.join(root, filename)
with open(filepath) as fileobj:
output += [os.path.normpath(os.path.join(root, line.strip())) for line in fileobj]
return output
def load_speechcommands_item(filepath: str, path: str) -> Tuple[Tensor, int, str, str, int]:
relpath = os.path.relpath(filepath, path)
label, filename = os.path.split(relpath)
# Besides the officially supported split method for datasets defined by "validation_list.txt"
# and "testing_list.txt" over "speech_commands_v0.0x.tar.gz" archives, an alternative split
# method referred to in paragraph 2-3 of Section 7.1, references 13 and 14 of the original
# paper, and the checksums file from the tensorflow_datasets package [1] is also supported.
# Some filenames in those "speech_commands_test_set_v0.0x.tar.gz" archives have the form
# "xxx.wav.wav", so file extensions twice needs to be stripped twice.
# [1] https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/url_checksums/speech_commands.txt
speaker, _ = os.path.splitext(filename)
speaker, _ = os.path.splitext(speaker)
speaker_id, utterance_number = speaker.split(HASH_DIVIDER)
utterance_number = int(utterance_number)
# Load audio
waveform, sample_rate = torchaudio.load(filepath)
return waveform, sample_rate, label, speaker_id, utterance_number
class SPEECHCOMMANDS(Dataset):
"""Create a Dataset for Speech Commands.
Args:
root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional): The URL to download the dataset from,
or the type of the dataset to dowload.
Allowed type values are ``"speech_commands_v0.01"`` and ``"speech_commands_v0.02"``
(default: ``"speech_commands_v0.02"``)
folder_in_archive (str, optional):
The top-level directory of the dataset. (default: ``"SpeechCommands"``)
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
subset (str or None, optional):
Select a subset of the dataset [None, "training", "validation", "testing"]. None means
the whole dataset. "validation" and "testing" are defined in "validation_list.txt" and
"testing_list.txt", respectively, and "training" is the rest. Details for the files
"validation_list.txt" and "testing_list.txt" are explained in the README of the dataset
and in the introduction of Section 7 of the original paper and its reference 12. The
original paper can be found `here <https://arxiv.org/pdf/1804.03209.pdf>`_. (Default: ``None``)
"""
def __init__(self,
root: Union[str, Path],
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False,
subset: Optional[str] = None,
) -> None:
assert subset is None or subset in ["training", "validation", "testing"], (
"When `subset` not None, it must take a value from "
+ "{'training', 'validation', 'testing'}."
)
if url in [
"speech_commands_v0.01",
"speech_commands_v0.02",
]:
base_url = "https://storage.googleapis.com/download.tensorflow.org/data/"
ext_archive = ".tar.gz"
url = os.path.join(base_url, url + ext_archive)
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
basename = os.path.basename(url)
archive = os.path.join(root, basename)
basename = basename.rsplit(".", 2)[0]
folder_in_archive = os.path.join(folder_in_archive, basename)
self._path = os.path.join(root, folder_in_archive)
if download:
if not os.path.isdir(self._path):
if not os.path.isfile(archive):
checksum = _CHECKSUMS.get(url, None)
download_url(url, root, hash_value=checksum, hash_type="md5")
extract_archive(archive, self._path)
if subset == "validation":
self._walker = _load_list(self._path, "validation_list.txt")
elif subset == "testing":
self._walker = _load_list(self._path, "testing_list.txt")
elif subset == "training":
excludes = set(_load_list(self._path, "validation_list.txt", "testing_list.txt"))
walker = sorted(str(p) for p in Path(self._path).glob('*/*.wav'))
self._walker = [
w for w in walker
if HASH_DIVIDER in w
and EXCEPT_FOLDER not in w
and os.path.normpath(w) not in excludes
]
else:
walker = sorted(str(p) for p in Path(self._path).glob('*/*.wav'))
self._walker = [w for w in walker if HASH_DIVIDER in w and EXCEPT_FOLDER not in w]
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, int]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
(Tensor, int, str, str, int):
``(waveform, sample_rate, label, speaker_id, utterance_number)``
"""
fileid = self._walker[n]
return load_speechcommands_item(fileid, self._path)
def __len__(self) -> int:
return len(self._walker)
import os
from typing import Tuple, Union
from pathlib import Path
import torchaudio
from torch import Tensor
from torch.utils.data import Dataset
from torchaudio.datasets.utils import (
download_url,
extract_archive,
)
_RELEASE_CONFIGS = {
"release1": {
"folder_in_archive": "TEDLIUM_release1",
"url": "http://www.openslr.org/resources/7/TEDLIUM_release1.tar.gz",
"checksum": "30301975fd8c5cac4040c261c0852f57cfa8adbbad2ce78e77e4986957445f27",
"data_path": "",
"subset": "train",
"supported_subsets": ["train", "test", "dev"],
"dict": "TEDLIUM.150K.dic",
},
"release2": {
"folder_in_archive": "TEDLIUM_release2",
"url": "http://www.openslr.org/resources/19/TEDLIUM_release2.tar.gz",
"checksum": "93281b5fcaaae5c88671c9d000b443cb3c7ea3499ad12010b3934ca41a7b9c58",
"data_path": "",
"subset": "train",
"supported_subsets": ["train", "test", "dev"],
"dict": "TEDLIUM.152k.dic",
},
"release3": {
"folder_in_archive": "TEDLIUM_release-3",
"url": "http://www.openslr.org/resources/51/TEDLIUM_release-3.tgz",
"checksum": "ad1e454d14d1ad550bc2564c462d87c7a7ec83d4dc2b9210f22ab4973b9eccdb",
"data_path": "data/",
"subset": None,
"supported_subsets": [None],
"dict": "TEDLIUM.152k.dic",
},
}
class TEDLIUM(Dataset):
"""
Create a Dataset for Tedlium. It supports releases 1,2 and 3.
Args:
root (str or Path): Path to the directory where the dataset is found or downloaded.
release (str, optional): Release version.
Allowed values are ``"release1"``, ``"release2"`` or ``"release3"``.
(default: ``"release1"``).
subset (str, optional): The subset of dataset to use. Valid options are ``"train"``, ``"dev"``,
and ``"test"`` for releases 1&2, ``None`` for release3. Defaults to ``"train"`` or ``None``.
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
audio_ext (str, optional): extension for audio file (default: ``"audio_ext"``)
"""
def __init__(
self,
root: Union[str, Path],
release: str = "release1",
subset: str = None,
download: bool = False,
audio_ext: str = ".sph"
) -> None:
self._ext_audio = audio_ext
if release in _RELEASE_CONFIGS.keys():
folder_in_archive = _RELEASE_CONFIGS[release]["folder_in_archive"]
url = _RELEASE_CONFIGS[release]["url"]
subset = subset if subset else _RELEASE_CONFIGS[release]["subset"]
else:
# Raise warning
raise RuntimeError(
"The release {} does not match any of the supported tedlium releases{} ".format(
release, _RELEASE_CONFIGS.keys(),
)
)
if subset not in _RELEASE_CONFIGS[release]["supported_subsets"]:
# Raise warning
raise RuntimeError(
"The subset {} does not match any of the supported tedlium subsets{} ".format(
subset, _RELEASE_CONFIGS[release]["supported_subsets"],
)
)
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
basename = os.path.basename(url)
archive = os.path.join(root, basename)
basename = basename.split(".")[0]
self._path = os.path.join(root, folder_in_archive, _RELEASE_CONFIGS[release]["data_path"])
if subset in ["train", "dev", "test"]:
self._path = os.path.join(self._path, subset)
if download:
if not os.path.isdir(self._path):
if not os.path.isfile(archive):
checksum = _RELEASE_CONFIGS[release]["checksum"]
download_url(url, root, hash_value=checksum)
extract_archive(archive)
# Create list for all samples
self._filelist = []
stm_path = os.path.join(self._path, "stm")
for file in sorted(os.listdir(stm_path)):
if file.endswith(".stm"):
stm_path = os.path.join(self._path, "stm", file)
with open(stm_path) as f:
l = len(f.readlines())
file = file.replace(".stm", "")
self._filelist.extend((file, line) for line in range(l))
# Create dict path for later read
self._dict_path = os.path.join(root, folder_in_archive, _RELEASE_CONFIGS[release]["dict"])
self._phoneme_dict = None
def _load_tedlium_item(self, fileid: str, line: int, path: str) -> Tuple[Tensor, int, str, int, int, int]:
"""Loads a TEDLIUM dataset sample given a file name and corresponding sentence name.
Args:
fileid (str): File id to identify both text and audio files corresponding to the sample
line (int): Line identifier for the sample inside the text file
path (str): Dataset root path
Returns:
(Tensor, int, str, int, int, int):
``(waveform, sample_rate, transcript, talk_id, speaker_id, identifier)``
"""
transcript_path = os.path.join(path, "stm", fileid)
with open(transcript_path + ".stm") as f:
transcript = f.readlines()[line]
talk_id, _, speaker_id, start_time, end_time, identifier, transcript = transcript.split(" ", 6)
wave_path = os.path.join(path, "sph", fileid)
waveform, sample_rate = self._load_audio(wave_path + self._ext_audio, start_time=start_time, end_time=end_time)
return (waveform, sample_rate, transcript, talk_id, speaker_id, identifier)
def _load_audio(self, path: str, start_time: float, end_time: float, sample_rate: int = 16000) -> [Tensor, int]:
"""Default load function used in TEDLIUM dataset, you can overwrite this function to customize functionality
and load individual sentences from a full ted audio talk file.
Args:
path (str): Path to audio file
start_time (int): Time in seconds where the sample sentence stars
end_time (int): Time in seconds where the sample sentence finishes
sample_rate (float, optional): Sampling rate
Returns:
[Tensor, int]: Audio tensor representation and sample rate
"""
start_time = int(float(start_time) * sample_rate)
end_time = int(float(end_time) * sample_rate)
kwargs = {"frame_offset": start_time, "num_frames": end_time - start_time}
return torchaudio.load(path, **kwargs)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, int, int, int]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
tuple: ``(waveform, sample_rate, transcript, talk_id, speaker_id, identifier)``
"""
fileid, line = self._filelist[n]
return self._load_tedlium_item(fileid, line, self._path)
def __len__(self) -> int:
"""TEDLIUM dataset custom function overwritting len default behaviour.
Returns:
int: TEDLIUM dataset length
"""
return len(self._filelist)
@property
def phoneme_dict(self):
"""dict[str, tuple[str]]: Phonemes. Mapping from word to tuple of phonemes.
Note that some words have empty phonemes.
"""
# Read phoneme dictionary
if not self._phoneme_dict:
self._phoneme_dict = {}
with open(self._dict_path, "r", encoding="utf-8") as f:
for line in f.readlines():
content = line.strip().split()
self._phoneme_dict[content[0]] = tuple(content[1:]) # content[1:] can be empty list
return self._phoneme_dict.copy()
import hashlib
import logging
import os
import tarfile
import threading
import urllib
import urllib.request
import zipfile
from queue import Queue
from typing import Any, Iterable, List, Optional
import torch
from torch.utils.data import Dataset
from torch.utils.model_zoo import tqdm
from torchaudio._internal.module_utils import deprecated
def stream_url(url: str,
start_byte: Optional[int] = None,
block_size: int = 32 * 1024,
progress_bar: bool = True) -> Iterable:
"""Stream url by chunk
Args:
url (str): Url.
start_byte (int or None, optional): Start streaming at that point (Default: ``None``).
block_size (int, optional): Size of chunks to stream (Default: ``32 * 1024``).
progress_bar (bool, optional): Display a progress bar (Default: ``True``).
"""
# If we already have the whole file, there is no need to download it again
req = urllib.request.Request(url, method="HEAD")
with urllib.request.urlopen(req) as response:
url_size = int(response.info().get("Content-Length", -1))
if url_size == start_byte:
return
req = urllib.request.Request(url)
if start_byte:
req.headers["Range"] = "bytes={}-".format(start_byte)
with urllib.request.urlopen(req) as upointer, tqdm(
unit="B",
unit_scale=True,
unit_divisor=1024,
total=url_size,
disable=not progress_bar,
) as pbar:
num_bytes = 0
while True:
chunk = upointer.read(block_size)
if not chunk:
break
yield chunk
num_bytes += len(chunk)
pbar.update(len(chunk))
def download_url(url: str,
download_folder: str,
filename: Optional[str] = None,
hash_value: Optional[str] = None,
hash_type: str = "sha256",
progress_bar: bool = True,
resume: bool = False) -> None:
"""Download file to disk.
Args:
url (str): Url.
download_folder (str): Folder to download file.
filename (str or None, optional): Name of downloaded file. If None, it is inferred from the url
(Default: ``None``).
hash_value (str or None, optional): Hash for url (Default: ``None``).
hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
progress_bar (bool, optional): Display a progress bar (Default: ``True``).
resume (bool, optional): Enable resuming download (Default: ``False``).
"""
req = urllib.request.Request(url, method="HEAD")
req_info = urllib.request.urlopen(req).info()
# Detect filename
filename = filename or req_info.get_filename() or os.path.basename(url)
filepath = os.path.join(download_folder, filename)
if resume and os.path.exists(filepath):
mode = "ab"
local_size: Optional[int] = os.path.getsize(filepath)
elif not resume and os.path.exists(filepath):
raise RuntimeError(
"{} already exists. Delete the file manually and retry.".format(filepath)
)
else:
mode = "wb"
local_size = None
if hash_value and local_size == int(req_info.get("Content-Length", -1)):
with open(filepath, "rb") as file_obj:
if validate_file(file_obj, hash_value, hash_type):
return
raise RuntimeError(
"The hash of {} does not match. Delete the file manually and retry.".format(
filepath
)
)
with open(filepath, mode) as fpointer:
for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar):
fpointer.write(chunk)
with open(filepath, "rb") as file_obj:
if hash_value and not validate_file(file_obj, hash_value, hash_type):
raise RuntimeError(
"The hash of {} does not match. Delete the file manually and retry.".format(
filepath
)
)
def validate_file(file_obj: Any, hash_value: str, hash_type: str = "sha256") -> bool:
"""Validate a given file object with its hash.
Args:
file_obj: File object to read from.
hash_value (str): Hash for url.
hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``).
Returns:
bool: return True if its a valid file, else False.
"""
if hash_type == "sha256":
hash_func = hashlib.sha256()
elif hash_type == "md5":
hash_func = hashlib.md5()
else:
raise ValueError
while True:
# Read by chunk to avoid filling memory
chunk = file_obj.read(1024 ** 2)
if not chunk:
break
hash_func.update(chunk)
return hash_func.hexdigest() == hash_value
def extract_archive(from_path: str, to_path: Optional[str] = None, overwrite: bool = False) -> List[str]:
"""Extract archive.
Args:
from_path (str): the path of the archive.
to_path (str or None, optional): the root path of the extraced files (directory of from_path)
(Default: ``None``)
overwrite (bool, optional): overwrite existing files (Default: ``False``)
Returns:
List[str]: List of paths to extracted files even if not overwritten.
Examples:
>>> url = 'http://www.quest.dcs.shef.ac.uk/wmt16_files_mmt/validation.tar.gz'
>>> from_path = './validation.tar.gz'
>>> to_path = './'
>>> torchaudio.datasets.utils.download_from_url(url, from_path)
>>> torchaudio.datasets.utils.extract_archive(from_path, to_path)
"""
if to_path is None:
to_path = os.path.dirname(from_path)
try:
with tarfile.open(from_path, "r") as tar:
logging.info("Opened tar file {}.".format(from_path))
files = []
for file_ in tar: # type: Any
file_path = os.path.join(to_path, file_.name)
if file_.isfile():
files.append(file_path)
if os.path.exists(file_path):
logging.info("{} already extracted.".format(file_path))
if not overwrite:
continue
tar.extract(file_, to_path)
return files
except tarfile.ReadError:
pass
try:
with zipfile.ZipFile(from_path, "r") as zfile:
logging.info("Opened zip file {}.".format(from_path))
files = zfile.namelist()
for file_ in files:
file_path = os.path.join(to_path, file_)
if os.path.exists(file_path):
logging.info("{} already extracted.".format(file_path))
if not overwrite:
continue
zfile.extract(file_, to_path)
return files
except zipfile.BadZipFile:
pass
raise NotImplementedError("We currently only support tar.gz, tgz, and zip achives.")
class _DiskCache(Dataset):
"""
Wrap a dataset so that, whenever a new item is returned, it is saved to disk.
"""
def __init__(self, dataset: Dataset, location: str = ".cached") -> None:
self.dataset = dataset
self.location = location
self._id = id(self)
self._cache: List = [None] * len(dataset)
def __getitem__(self, n: int) -> Any:
if self._cache[n]:
f = self._cache[n]
return torch.load(f)
f = str(self._id) + "-" + str(n)
f = os.path.join(self.location, f)
item = self.dataset[n]
self._cache[n] = f
os.makedirs(self.location, exist_ok=True)
torch.save(item, f)
return item
def __len__(self) -> int:
return len(self.dataset)
@deprecated('', version='0.11')
def diskcache_iterator(dataset: Dataset, location: str = ".cached") -> Dataset:
return _DiskCache(dataset, location)
class _ThreadedIterator(threading.Thread):
"""
Prefetch the next queue_length items from iterator in a background thread.
Example:
>> for i in bg_iterator(range(10)):
>> print(i)
"""
class _End:
pass
def __init__(self, generator: Iterable, maxsize: int) -> None:
threading.Thread.__init__(self)
self.queue: Queue = Queue(maxsize)
self.generator = generator
self.daemon = True
self.start()
def run(self) -> None:
for item in self.generator:
self.queue.put(item)
self.queue.put(self._End)
def __iter__(self) -> Any:
return self
def __next__(self) -> Any:
next_item = self.queue.get()
if next_item == self._End:
raise StopIteration
return next_item
# Required for Python 2.7 compatibility
def next(self) -> Any:
return self.__next__()
@deprecated('', version='0.11')
def bg_iterator(iterable: Iterable, maxsize: int) -> Any:
return _ThreadedIterator(iterable, maxsize=maxsize)
import os
import warnings
from pathlib import Path
from typing import Tuple, Union
from torch import Tensor
from torch.utils.data import Dataset
import torchaudio
from torchaudio.datasets.utils import (
download_url,
extract_archive,
)
URL = "https://datashare.is.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip"
FOLDER_IN_ARCHIVE = "VCTK-Corpus"
_CHECKSUMS = {
"https://datashare.is.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip": "8a6ba2946b36fcbef0212cad601f4bfa"
}
def load_vctk_item(fileid: str,
path: str,
ext_audio: str,
ext_txt: str,
folder_audio: str,
folder_txt: str,
downsample: bool = False) -> Tuple[Tensor, int, str, str, str]:
speaker_id, utterance_id = fileid.split("_")
# Read text
file_txt = os.path.join(path, folder_txt, speaker_id, fileid + ext_txt)
with open(file_txt) as file_text:
utterance = file_text.readlines()[0]
# Read wav
file_audio = os.path.join(path, folder_audio, speaker_id, fileid + ext_audio)
waveform, sample_rate = torchaudio.load(file_audio)
if downsample:
# TODO Remove this parameter after deprecation
F = torchaudio.functional
T = torchaudio.transforms
# rate
sample = T.Resample(sample_rate, 16000, resampling_method='sinc_interpolation')
waveform = sample(waveform)
# dither
waveform = F.dither(waveform, noise_shaping=True)
return waveform, sample_rate, utterance, speaker_id, utterance_id
class VCTK(Dataset):
"""Create a Dataset for VCTK.
Note:
* **This dataset is no longer publicly available.** Please use :py:class:`VCTK_092`
* Directory ``p315`` is ignored because there is no corresponding text files.
For more information about the dataset visit: https://datashare.is.ed.ac.uk/handle/10283/3443
Args:
root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional): Not used as the dataset is no longer publicly available.
folder_in_archive (str, optional):
The top-level directory of the dataset. (default: ``"VCTK-Corpus"``)
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
Giving ``download=True`` will result in error as the dataset is no longer
publicly available.
downsample (bool, optional): Not used.
"""
_folder_txt = "txt"
_folder_audio = "wav48"
_ext_txt = ".txt"
_ext_audio = ".wav"
_except_folder = "p315"
def __init__(self,
root: Union[str, Path],
url: str = URL,
folder_in_archive: str = FOLDER_IN_ARCHIVE,
download: bool = False,
downsample: bool = False) -> None:
warnings.warn(
'VCTK class has been deprecated and will be removed in 0.11 release. '
'Please use VCTK_092.'
)
if downsample:
warnings.warn(
"In the next version, transforms will not be part of the dataset. "
"Please use `downsample=False` to enable this behavior now, "
"and suppress this warning."
)
self.downsample = downsample
# Get string representation of 'root' in case Path object is passed
root = os.fspath(root)
archive = os.path.basename(url)
archive = os.path.join(root, archive)
self._path = os.path.join(root, folder_in_archive)
if download:
raise RuntimeError(
"This Dataset is no longer available. "
"Please use `VCTK_092` class to download the latest version."
)
if not os.path.isdir(self._path):
raise RuntimeError(
"Dataset not found. Please use `VCTK_092` class "
"with `download=True` to donwload the latest version."
)
walker = sorted(str(p.stem) for p in Path(self._path).glob('**/*' + self._ext_audio))
walker = filter(lambda w: self._except_folder not in w, walker)
self._walker = list(walker)
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
tuple: ``(waveform, sample_rate, utterance, speaker_id, utterance_id)``
"""
fileid = self._walker[n]
item = load_vctk_item(
fileid,
self._path,
self._ext_audio,
self._ext_txt,
self._folder_audio,
self._folder_txt,
)
# TODO Upon deprecation, uncomment line below and remove following code
# return item
waveform, sample_rate, utterance, speaker_id, utterance_id = item
return waveform, sample_rate, utterance, speaker_id, utterance_id
def __len__(self) -> int:
return len(self._walker)
SampleType = Tuple[Tensor, int, str, str, str]
class VCTK_092(Dataset):
"""Create VCTK 0.92 Dataset
Args:
root (str): Root directory where the dataset's top level directory is found.
mic_id (str, optional): Microphone ID. Either ``"mic1"`` or ``"mic2"``. (default: ``"mic2"``)
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
url (str, optional): The URL to download the dataset from.
(default: ``"https://datashare.is.ed.ac.uk/bitstream/handle/10283/3443/VCTK-Corpus-0.92.zip"``)
audio_ext (str, optional): Custom audio extension if dataset is converted to non-default audio format.
Note:
* All the speeches from speaker ``p315`` will be skipped due to the lack of the corresponding text files.
* All the speeches from ``p280`` will be skipped for ``mic_id="mic2"`` due to the lack of the audio files.
* Some of the speeches from speaker ``p362`` will be skipped due to the lack of the audio files.
* See Also: https://datashare.is.ed.ac.uk/handle/10283/3443
"""
def __init__(
self,
root: str,
mic_id: str = "mic2",
download: bool = False,
url: str = URL,
audio_ext=".flac",
):
if mic_id not in ["mic1", "mic2"]:
raise RuntimeError(
f'`mic_id` has to be either "mic1" or "mic2". Found: {mic_id}'
)
archive = os.path.join(root, "VCTK-Corpus-0.92.zip")
self._path = os.path.join(root, "VCTK-Corpus-0.92")
self._txt_dir = os.path.join(self._path, "txt")
self._audio_dir = os.path.join(self._path, "wav48_silence_trimmed")
self._mic_id = mic_id
self._audio_ext = audio_ext
if download:
if not os.path.isdir(self._path):
if not os.path.isfile(archive):
checksum = _CHECKSUMS.get(url, None)
download_url(url, root, hash_value=checksum, hash_type="md5")
extract_archive(archive, self._path)
if not os.path.isdir(self._path):
raise RuntimeError(
"Dataset not found. Please use `download=True` to download it."
)
# Extracting speaker IDs from the folder structure
self._speaker_ids = sorted(os.listdir(self._txt_dir))
self._sample_ids = []
"""
Due to some insufficient data complexity in the 0.92 version of this dataset,
we start traversing the audio folder structure in accordance with the text folder.
As some of the audio files are missing of either ``mic_1`` or ``mic_2`` but the
text is present for the same, we first check for the existence of the audio file
before adding it to the ``sample_ids`` list.
Once the ``audio_ids`` are loaded into memory we can quickly access the list for
different parameters required by the user.
"""
for speaker_id in self._speaker_ids:
if speaker_id == "p280" and mic_id == "mic2":
continue
utterance_dir = os.path.join(self._txt_dir, speaker_id)
for utterance_file in sorted(
f for f in os.listdir(utterance_dir) if f.endswith(".txt")
):
utterance_id = os.path.splitext(utterance_file)[0]
audio_path_mic = os.path.join(
self._audio_dir,
speaker_id,
f"{utterance_id}_{mic_id}{self._audio_ext}",
)
if speaker_id == "p362" and not os.path.isfile(audio_path_mic):
continue
self._sample_ids.append(utterance_id.split("_"))
def _load_text(self, file_path) -> str:
with open(file_path) as file_path:
return file_path.readlines()[0]
def _load_audio(self, file_path) -> Tuple[Tensor, int]:
return torchaudio.load(file_path)
def _load_sample(self, speaker_id: str, utterance_id: str, mic_id: str) -> SampleType:
transcript_path = os.path.join(
self._txt_dir, speaker_id, f"{speaker_id}_{utterance_id}.txt"
)
audio_path = os.path.join(
self._audio_dir,
speaker_id,
f"{speaker_id}_{utterance_id}_{mic_id}{self._audio_ext}",
)
# Reading text
transcript = self._load_text(transcript_path)
# Reading FLAC
waveform, sample_rate = self._load_audio(audio_path)
return (waveform, sample_rate, transcript, speaker_id, utterance_id)
def __getitem__(self, n: int) -> SampleType:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
(Tensor, int, str, str, str):
``(waveform, sample_rate, transcript, speaker_id, utterance_id)``
"""
speaker_id, utterance_id = self._sample_ids[n]
return self._load_sample(speaker_id, utterance_id, self._mic_id)
def __len__(self) -> int:
return len(self._sample_ids)
import os
from pathlib import Path
from typing import List, Tuple, Union
from torch import Tensor
from torch.utils.data import Dataset
import torchaudio
from torchaudio.datasets.utils import (
download_url,
extract_archive,
)
_RELEASE_CONFIGS = {
"release1": {
"folder_in_archive": "waves_yesno",
"url": "http://www.openslr.org/resources/1/waves_yesno.tar.gz",
"checksum": "c3f49e0cca421f96b75b41640749167b52118f232498667ca7a5f9416aef8e73",
}
}
class YESNO(Dataset):
"""Create a Dataset for YesNo.
Args:
root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional): The URL to download the dataset from.
(default: ``"http://www.openslr.org/resources/1/waves_yesno.tar.gz"``)
folder_in_archive (str, optional):
The top-level directory of the dataset. (default: ``"waves_yesno"``)
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
"""
def __init__(
self,
root: Union[str, Path],
url: str = _RELEASE_CONFIGS["release1"]["url"],
folder_in_archive: str = _RELEASE_CONFIGS["release1"]["folder_in_archive"],
download: bool = False
) -> None:
self._parse_filesystem(root, url, folder_in_archive, download)
def _parse_filesystem(self, root: str, url: str, folder_in_archive: str, download: bool) -> None:
root = Path(root)
archive = os.path.basename(url)
archive = root / archive
self._path = root / folder_in_archive
if download:
if not os.path.isdir(self._path):
if not os.path.isfile(archive):
checksum = _RELEASE_CONFIGS["release1"]["checksum"]
download_url(url, root, hash_value=checksum)
extract_archive(archive)
if not os.path.isdir(self._path):
raise RuntimeError(
"Dataset not found. Please use `download=True` to download it."
)
self._walker = sorted(str(p.stem) for p in Path(self._path).glob("*.wav"))
def _load_item(self, fileid: str, path: str):
labels = [int(c) for c in fileid.split("_")]
file_audio = os.path.join(path, fileid + ".wav")
waveform, sample_rate = torchaudio.load(file_audio)
return waveform, sample_rate, labels
def __getitem__(self, n: int) -> Tuple[Tensor, int, List[int]]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded
Returns:
(Tensor, int, List[int]): ``(waveform, sample_rate, labels)``
"""
fileid = self._walker[n]
item = self._load_item(fileid, self._path)
return item
def __len__(self) -> int:
return len(self._walker)
from .functional import (
amplitude_to_DB,
angle,
complex_norm,
compute_deltas,
compute_kaldi_pitch,
create_dct,
create_fb_matrix,
melscale_fbanks,
linear_fbanks,
DB_to_amplitude,
detect_pitch_frequency,
inverse_spectrogram,
griffinlim,
magphase,
mask_along_axis,
mask_along_axis_iid,
mu_law_encoding,
mu_law_decoding,
phase_vocoder,
sliding_window_cmn,
spectrogram,
spectral_centroid,
apply_codec,
resample,
edit_distance,
pitch_shift,
rnnt_loss,
)
from .filtering import (
allpass_biquad,
band_biquad,
bandpass_biquad,
bandreject_biquad,
bass_biquad,
biquad,
contrast,
dither,
dcshift,
deemph_biquad,
equalizer_biquad,
filtfilt,
flanger,
gain,
highpass_biquad,
lfilter,
lowpass_biquad,
overdrive,
phaser,
riaa_biquad,
treble_biquad,
vad,
)
__all__ = [
'amplitude_to_DB',
'angle',
'complex_norm',
'compute_deltas',
'compute_kaldi_pitch',
'create_dct',
'create_fb_matrix',
'melscale_fbanks',
'linear_fbanks',
'DB_to_amplitude',
'detect_pitch_frequency',
'griffinlim',
'magphase',
'mask_along_axis',
'mask_along_axis_iid',
'mu_law_encoding',
'mu_law_decoding',
'phase_vocoder',
'sliding_window_cmn',
'spectrogram',
'inverse_spectrogram',
'spectral_centroid',
'allpass_biquad',
'band_biquad',
'bandpass_biquad',
'bandreject_biquad',
'bass_biquad',
'biquad',
'contrast',
'dither',
'dcshift',
'deemph_biquad',
'equalizer_biquad',
'filtfilt',
'flanger',
'gain',
'highpass_biquad',
'lfilter',
'lowpass_biquad',
'overdrive',
'phaser',
'riaa_biquad',
'treble_biquad',
'vad',
'apply_codec',
'resample',
'edit_distance',
'pitch_shift',
'rnnt_loss',
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment