Unverified Commit 077a5f4a authored by yangarbiter's avatar yangarbiter Committed by GitHub
Browse files

Add CMUDict dataset (#1627)

parent 83dc5ec7
......@@ -33,6 +33,14 @@ CMUARCTIC
:special-members: __getitem__
CMUDict
~~~~~~~~~
.. autoclass:: CMUDict
:members:
:special-members: __getitem__
COMMONVOICE
~~~~~~~~~~~
......
import os
from pathlib import Path
from torchaudio.datasets import CMUDict
from torchaudio_unittest.common_utils import (
TempDirMixin,
TorchaudioTestCase,
)
def get_mock_dataset(root_dir, return_punc=False):
"""
root_dir: directory to the mocked dataset
"""
header = [
";;; # CMUdict -- Major Version: 0.07",
";;; ",
";;; # $HeadURL$",
]
puncs = [
"!EXCLAMATION-POINT EH2 K S K L AH0 M EY1 SH AH0 N P OY2 N T",
"\"CLOSE-QUOTE K L OW1 Z K W OW1 T",
"#HASH-MARK HH AE1 M AA2 R K",
"%PERCENT P ER0 S EH1 N T",
"&AMPERSAND AE1 M P ER0 S AE2 N D",
"'END-INNER-QUOTE EH1 N D IH1 N ER0 K W OW1 T",
"(BEGIN-PARENS B IH0 G IH1 N P ER0 EH1 N Z",
")CLOSE-PAREN K L OW1 Z P ER0 EH1 N",
"+PLUS P L UH1 S",
",COMMA K AA1 M AH0",
"--DASH D AE1 SH",
"!EXCLAMATION-POINT EH2 K S K L AH0 M EY1 SH AH0 N P OY2 N T",
"/SLASH S L AE1 SH",
":COLON K OW1 L AH0 N",
";SEMI-COLON S EH1 M IY0 K OW1 L AH0 N",
"?QUESTION-MARK K W EH1 S CH AH0 N M AA1 R K",
"{BRACE B R EY1 S",
"}CLOSE-BRACE K L OW1 Z B R EY1 S",
"...ELLIPSIS IH2 L IH1 P S IH0 S",
]
punc_outputs = [
"!",
"\"",
"#",
"%",
"&",
"'",
"(",
")",
"+",
",",
"--",
"!",
"/",
":",
";",
"?",
"{",
"}",
"...",
]
words = [
"3-D TH R IY1 D IY2",
"'BOUT B AW1 T",
"'CAUSE K AH0 Z",
"'TWAS T W AH1 Z",
"A AH0",
"B B IY1",
"C S IY1",
"D D IY1",
"E IY1",
"F EH1 F",
"G JH IY1",
"H EY1 CH",
"I AY1",
"J JH EY1",
"K K EY1",
"L EH1 L",
"M EH1 M",
"N EH1 N",
"O OW1",
"P P IY1",
"Q K Y UW1",
"R AA1 R",
"S EH1 S",
"T T IY1",
"U Y UW1",
"V V IY1",
"X EH1 K S",
"Y W AY1",
"Z Z IY1",
]
mocked_symbols = [
"AA1",
"AA2",
"AE1",
"AE2",
"AH0",
"AH1",
"AY1",
"B",
"CH",
"D",
"EH1",
"EH2",
"ER0",
"EY1",
"F",
"G",
"HH",
"IH0",
"IH1",
"IY0",
"IY1",
"IY2",
"JH",
"K",
"L",
"M",
"N",
"OW1",
"OY2",
"P",
"R",
"S",
"SH",
"T",
"TH",
"UH1",
"UW0",
"UW1",
"V",
"W",
"Y",
"Z",
]
dict_file = os.path.join(root_dir, "cmudict-0.7b")
symbol_file = os.path.join(root_dir, "cmudict-0.7b.symbols")
with open(dict_file, "w") as fileobj:
for section in [header, puncs, words]:
for line in section:
fileobj.write(line)
fileobj.write("\n")
with open(symbol_file, "w") as txt:
txt.write("\n".join(mocked_symbols))
mocked_data = []
if return_punc:
for i, ent in enumerate(puncs):
_, phones = ent.split(" ")
mocked_data.append((punc_outputs[i], phones.split(" ")))
for ent in words:
word, phones = ent.split(" ")
mocked_data.append((word, phones.split(" ")))
return mocked_data
class TestCMUDict(TempDirMixin, TorchaudioTestCase):
root_dir = None
root_punc_dir = None
samples = []
punc_samples = []
@classmethod
def setUpClass(cls):
cls.root_dir = os.path.join(cls.get_base_temp_dir(), "normal")
os.mkdir(cls.root_dir)
cls.samples = get_mock_dataset(cls.root_dir)
cls.root_punc_dir = os.path.join(cls.get_base_temp_dir(), "punc")
os.mkdir(cls.root_punc_dir)
cls.punc_samples = get_mock_dataset(cls.root_punc_dir, return_punc=True)
def _test_cmudict(self, dataset):
"""Test if the dataset is reading the mocked data correctly."""
n_item = 0
for i, (word, phones) in enumerate(dataset):
expected_word, expected_phones = self.samples[i]
assert word == expected_word
assert phones == expected_phones
n_item += 1
assert n_item == len(self.samples)
def _test_punc_cmudict(self, dataset):
"""Test if the dataset is reading the mocked data with punctuations correctly."""
n_item = 0
for i, (word, phones) in enumerate(dataset):
expected_word, expected_phones = self.punc_samples[i]
assert word == expected_word
assert phones == expected_phones
n_item += 1
assert n_item == len(self.punc_samples)
def test_cmuarctic_path_with_punctuation(self):
dataset = CMUDict(Path(self.root_punc_dir), exclude_punctuations=False)
self._test_punc_cmudict(dataset)
def test_cmuarctic_str_with_punctuation(self):
dataset = CMUDict(self.root_punc_dir, exclude_punctuations=False)
self._test_punc_cmudict(dataset)
def test_cmuarctic_path(self):
dataset = CMUDict(Path(self.root_punc_dir), exclude_punctuations=True)
self._test_cmudict(dataset)
def test_cmuarctic_str(self):
dataset = CMUDict(self.root_punc_dir, exclude_punctuations=True)
self._test_cmudict(dataset)
......@@ -7,9 +7,11 @@ from .gtzan import GTZAN
from .yesno import YESNO
from .ljspeech import LJSPEECH
from .cmuarctic import CMUARCTIC
from .cmudict import CMUDict
from .libritts import LIBRITTS
from .tedlium import TEDLIUM
__all__ = [
"COMMONVOICE",
"LIBRISPEECH",
......@@ -20,6 +22,7 @@ __all__ = [
"LJSPEECH",
"GTZAN",
"CMUARCTIC",
"CMUDict",
"LIBRITTS",
"diskcache_iterator",
"bg_iterator",
......
import os
import re
from pathlib import Path
from typing import Iterable, Tuple, Union, List
from torch.utils.data import Dataset
from torchaudio.datasets.utils import download_url
_CHECKSUMS = {
"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b":
"825f4ebd9183f2417df9f067a9cabe86",
"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols":
"385e490aabc71b48e772118e3d02923e",
}
_PUNCTUATIONS = set([
"!EXCLAMATION-POINT",
"\"CLOSE-QUOTE",
"\"DOUBLE-QUOTE",
"\"END-OF-QUOTE",
"\"END-QUOTE",
"\"IN-QUOTES",
"\"QUOTE",
"\"UNQUOTE",
"#HASH-MARK",
"#POUND-SIGN",
"#SHARP-SIGN",
"%PERCENT",
"&AMPERSAND",
"'END-INNER-QUOTE",
"'END-QUOTE",
"'INNER-QUOTE",
"'QUOTE",
"'SINGLE-QUOTE",
"(BEGIN-PARENS",
"(IN-PARENTHESES",
"(LEFT-PAREN",
"(OPEN-PARENTHESES",
"(PAREN",
"(PARENS",
"(PARENTHESES",
")CLOSE-PAREN",
")CLOSE-PARENTHESES",
")END-PAREN",
")END-PARENS",
")END-PARENTHESES",
")END-THE-PAREN",
")PAREN",
")PARENS",
")RIGHT-PAREN",
")UN-PARENTHESES",
"+PLUS",
",COMMA",
"--DASH",
"-DASH",
"-HYPHEN",
"...ELLIPSIS",
".DECIMAL",
".DOT",
".FULL-STOP",
".PERIOD",
".POINT",
"/SLASH",
":COLON",
";SEMI-COLON",
";SEMI-COLON(1)",
"?QUESTION-MARK",
"{BRACE",
"{LEFT-BRACE",
"{OPEN-BRACE",
"}CLOSE-BRACE",
"}RIGHT-BRACE",
])
def _parse_dictionary(lines: Iterable[str], exclude_punctuations: bool) -> List[str]:
_alt_re = re.compile(r'\([0-9]+\)')
cmudict: List[Tuple[str, List[str]]] = list()
for line in lines:
if not line or line.startswith(';;;'): # ignore comments
continue
word, phones = line.strip().split(' ')
if word in _PUNCTUATIONS:
if exclude_punctuations:
continue
# !EXCLAMATION-POINT -> !
# --DASH -> --
# ...ELLIPSIS -> ...
if word.startswith("..."):
word = "..."
elif word.startswith("--"):
word = "--"
else:
word = word[0]
# if a word have multiple pronunciations, there will be (number) appended to it
# for example, DATAPOINTS and DATAPOINTS(1),
# the regular expression `_alt_re` removes the '(1)' and change the word DATAPOINTS(1) to DATAPOINTS
word = re.sub(_alt_re, '', word)
phones = phones.split(" ")
cmudict.append((word, phones))
return cmudict
class CMUDict(Dataset):
"""Create a Dataset for CMU Pronouncing Dictionary (CMUDict).
Args:
root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional):
The URL to download the dictionary from.
(default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b"``)
url_symbols (str, optional):
The URL to download the list of symbols from.
(default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols"``)
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
"""
def __init__(self,
root: Union[str, Path],
url: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b",
url_symbols: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols",
download: bool = False,
exclude_punctuations: bool = True) -> None:
self.exclude_punctuations = exclude_punctuations
root = Path(root)
if not os.path.isdir(root):
os.mkdir(root)
if download:
if os.path.isdir(root):
checksum = _CHECKSUMS.get(url, None)
download_url(url, root, hash_value=checksum, hash_type="md5")
checksum = _CHECKSUMS.get(url_symbols, None)
download_url(url_symbols, root, hash_value=checksum, hash_type="md5")
else:
RuntimeError("The argument `root` must be a path to directory, "
f"but '{root}' is passed in instead.")
self._root_path = root
basename = os.path.basename(url)
basename_symbols = os.path.basename(url_symbols)
with open(os.path.join(self._root_path, basename_symbols), "r") as text:
self._symbols = [line.strip() for line in text.readlines()]
with open(os.path.join(self._root_path, basename), "r") as text:
self._dictionary = _parse_dictionary(text.readlines(),
exclude_punctuations=self.exclude_punctuations)
def __getitem__(self, n: int) -> Tuple[str, List[str]]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded.
Returns:
tuple: The corresponding word and phonemes ``(word, [phonemes])``.
"""
return self._dictionary[n]
def __len__(self) -> int:
return len(self._dictionary)
@property
def symbols(self) -> List[str]:
return self._symbols.copy()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment