Unverified Commit 077a5f4a authored by yangarbiter's avatar yangarbiter Committed by GitHub
Browse files

Add CMUDict dataset (#1627)

parent 83dc5ec7
...@@ -33,6 +33,14 @@ CMUARCTIC ...@@ -33,6 +33,14 @@ CMUARCTIC
:special-members: __getitem__ :special-members: __getitem__
CMUDict
~~~~~~~~~
.. autoclass:: CMUDict
:members:
:special-members: __getitem__
COMMONVOICE COMMONVOICE
~~~~~~~~~~~ ~~~~~~~~~~~
......
import os
from pathlib import Path
from torchaudio.datasets import CMUDict
from torchaudio_unittest.common_utils import (
TempDirMixin,
TorchaudioTestCase,
)
def get_mock_dataset(root_dir, return_punc=False):
"""
root_dir: directory to the mocked dataset
"""
header = [
";;; # CMUdict -- Major Version: 0.07",
";;; ",
";;; # $HeadURL$",
]
puncs = [
"!EXCLAMATION-POINT EH2 K S K L AH0 M EY1 SH AH0 N P OY2 N T",
"\"CLOSE-QUOTE K L OW1 Z K W OW1 T",
"#HASH-MARK HH AE1 M AA2 R K",
"%PERCENT P ER0 S EH1 N T",
"&AMPERSAND AE1 M P ER0 S AE2 N D",
"'END-INNER-QUOTE EH1 N D IH1 N ER0 K W OW1 T",
"(BEGIN-PARENS B IH0 G IH1 N P ER0 EH1 N Z",
")CLOSE-PAREN K L OW1 Z P ER0 EH1 N",
"+PLUS P L UH1 S",
",COMMA K AA1 M AH0",
"--DASH D AE1 SH",
"!EXCLAMATION-POINT EH2 K S K L AH0 M EY1 SH AH0 N P OY2 N T",
"/SLASH S L AE1 SH",
":COLON K OW1 L AH0 N",
";SEMI-COLON S EH1 M IY0 K OW1 L AH0 N",
"?QUESTION-MARK K W EH1 S CH AH0 N M AA1 R K",
"{BRACE B R EY1 S",
"}CLOSE-BRACE K L OW1 Z B R EY1 S",
"...ELLIPSIS IH2 L IH1 P S IH0 S",
]
punc_outputs = [
"!",
"\"",
"#",
"%",
"&",
"'",
"(",
")",
"+",
",",
"--",
"!",
"/",
":",
";",
"?",
"{",
"}",
"...",
]
words = [
"3-D TH R IY1 D IY2",
"'BOUT B AW1 T",
"'CAUSE K AH0 Z",
"'TWAS T W AH1 Z",
"A AH0",
"B B IY1",
"C S IY1",
"D D IY1",
"E IY1",
"F EH1 F",
"G JH IY1",
"H EY1 CH",
"I AY1",
"J JH EY1",
"K K EY1",
"L EH1 L",
"M EH1 M",
"N EH1 N",
"O OW1",
"P P IY1",
"Q K Y UW1",
"R AA1 R",
"S EH1 S",
"T T IY1",
"U Y UW1",
"V V IY1",
"X EH1 K S",
"Y W AY1",
"Z Z IY1",
]
mocked_symbols = [
"AA1",
"AA2",
"AE1",
"AE2",
"AH0",
"AH1",
"AY1",
"B",
"CH",
"D",
"EH1",
"EH2",
"ER0",
"EY1",
"F",
"G",
"HH",
"IH0",
"IH1",
"IY0",
"IY1",
"IY2",
"JH",
"K",
"L",
"M",
"N",
"OW1",
"OY2",
"P",
"R",
"S",
"SH",
"T",
"TH",
"UH1",
"UW0",
"UW1",
"V",
"W",
"Y",
"Z",
]
dict_file = os.path.join(root_dir, "cmudict-0.7b")
symbol_file = os.path.join(root_dir, "cmudict-0.7b.symbols")
with open(dict_file, "w") as fileobj:
for section in [header, puncs, words]:
for line in section:
fileobj.write(line)
fileobj.write("\n")
with open(symbol_file, "w") as txt:
txt.write("\n".join(mocked_symbols))
mocked_data = []
if return_punc:
for i, ent in enumerate(puncs):
_, phones = ent.split(" ")
mocked_data.append((punc_outputs[i], phones.split(" ")))
for ent in words:
word, phones = ent.split(" ")
mocked_data.append((word, phones.split(" ")))
return mocked_data
class TestCMUDict(TempDirMixin, TorchaudioTestCase):
root_dir = None
root_punc_dir = None
samples = []
punc_samples = []
@classmethod
def setUpClass(cls):
cls.root_dir = os.path.join(cls.get_base_temp_dir(), "normal")
os.mkdir(cls.root_dir)
cls.samples = get_mock_dataset(cls.root_dir)
cls.root_punc_dir = os.path.join(cls.get_base_temp_dir(), "punc")
os.mkdir(cls.root_punc_dir)
cls.punc_samples = get_mock_dataset(cls.root_punc_dir, return_punc=True)
def _test_cmudict(self, dataset):
"""Test if the dataset is reading the mocked data correctly."""
n_item = 0
for i, (word, phones) in enumerate(dataset):
expected_word, expected_phones = self.samples[i]
assert word == expected_word
assert phones == expected_phones
n_item += 1
assert n_item == len(self.samples)
def _test_punc_cmudict(self, dataset):
"""Test if the dataset is reading the mocked data with punctuations correctly."""
n_item = 0
for i, (word, phones) in enumerate(dataset):
expected_word, expected_phones = self.punc_samples[i]
assert word == expected_word
assert phones == expected_phones
n_item += 1
assert n_item == len(self.punc_samples)
def test_cmuarctic_path_with_punctuation(self):
dataset = CMUDict(Path(self.root_punc_dir), exclude_punctuations=False)
self._test_punc_cmudict(dataset)
def test_cmuarctic_str_with_punctuation(self):
dataset = CMUDict(self.root_punc_dir, exclude_punctuations=False)
self._test_punc_cmudict(dataset)
def test_cmuarctic_path(self):
dataset = CMUDict(Path(self.root_punc_dir), exclude_punctuations=True)
self._test_cmudict(dataset)
def test_cmuarctic_str(self):
dataset = CMUDict(self.root_punc_dir, exclude_punctuations=True)
self._test_cmudict(dataset)
...@@ -7,9 +7,11 @@ from .gtzan import GTZAN ...@@ -7,9 +7,11 @@ from .gtzan import GTZAN
from .yesno import YESNO from .yesno import YESNO
from .ljspeech import LJSPEECH from .ljspeech import LJSPEECH
from .cmuarctic import CMUARCTIC from .cmuarctic import CMUARCTIC
from .cmudict import CMUDict
from .libritts import LIBRITTS from .libritts import LIBRITTS
from .tedlium import TEDLIUM from .tedlium import TEDLIUM
__all__ = [ __all__ = [
"COMMONVOICE", "COMMONVOICE",
"LIBRISPEECH", "LIBRISPEECH",
...@@ -20,6 +22,7 @@ __all__ = [ ...@@ -20,6 +22,7 @@ __all__ = [
"LJSPEECH", "LJSPEECH",
"GTZAN", "GTZAN",
"CMUARCTIC", "CMUARCTIC",
"CMUDict",
"LIBRITTS", "LIBRITTS",
"diskcache_iterator", "diskcache_iterator",
"bg_iterator", "bg_iterator",
......
import os
import re
from pathlib import Path
from typing import Iterable, Tuple, Union, List
from torch.utils.data import Dataset
from torchaudio.datasets.utils import download_url
_CHECKSUMS = {
"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b":
"825f4ebd9183f2417df9f067a9cabe86",
"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols":
"385e490aabc71b48e772118e3d02923e",
}
_PUNCTUATIONS = set([
"!EXCLAMATION-POINT",
"\"CLOSE-QUOTE",
"\"DOUBLE-QUOTE",
"\"END-OF-QUOTE",
"\"END-QUOTE",
"\"IN-QUOTES",
"\"QUOTE",
"\"UNQUOTE",
"#HASH-MARK",
"#POUND-SIGN",
"#SHARP-SIGN",
"%PERCENT",
"&AMPERSAND",
"'END-INNER-QUOTE",
"'END-QUOTE",
"'INNER-QUOTE",
"'QUOTE",
"'SINGLE-QUOTE",
"(BEGIN-PARENS",
"(IN-PARENTHESES",
"(LEFT-PAREN",
"(OPEN-PARENTHESES",
"(PAREN",
"(PARENS",
"(PARENTHESES",
")CLOSE-PAREN",
")CLOSE-PARENTHESES",
")END-PAREN",
")END-PARENS",
")END-PARENTHESES",
")END-THE-PAREN",
")PAREN",
")PARENS",
")RIGHT-PAREN",
")UN-PARENTHESES",
"+PLUS",
",COMMA",
"--DASH",
"-DASH",
"-HYPHEN",
"...ELLIPSIS",
".DECIMAL",
".DOT",
".FULL-STOP",
".PERIOD",
".POINT",
"/SLASH",
":COLON",
";SEMI-COLON",
";SEMI-COLON(1)",
"?QUESTION-MARK",
"{BRACE",
"{LEFT-BRACE",
"{OPEN-BRACE",
"}CLOSE-BRACE",
"}RIGHT-BRACE",
])
def _parse_dictionary(lines: Iterable[str], exclude_punctuations: bool) -> List[str]:
_alt_re = re.compile(r'\([0-9]+\)')
cmudict: List[Tuple[str, List[str]]] = list()
for line in lines:
if not line or line.startswith(';;;'): # ignore comments
continue
word, phones = line.strip().split(' ')
if word in _PUNCTUATIONS:
if exclude_punctuations:
continue
# !EXCLAMATION-POINT -> !
# --DASH -> --
# ...ELLIPSIS -> ...
if word.startswith("..."):
word = "..."
elif word.startswith("--"):
word = "--"
else:
word = word[0]
# if a word have multiple pronunciations, there will be (number) appended to it
# for example, DATAPOINTS and DATAPOINTS(1),
# the regular expression `_alt_re` removes the '(1)' and change the word DATAPOINTS(1) to DATAPOINTS
word = re.sub(_alt_re, '', word)
phones = phones.split(" ")
cmudict.append((word, phones))
return cmudict
class CMUDict(Dataset):
"""Create a Dataset for CMU Pronouncing Dictionary (CMUDict).
Args:
root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional):
The URL to download the dictionary from.
(default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b"``)
url_symbols (str, optional):
The URL to download the list of symbols from.
(default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols"``)
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
"""
def __init__(self,
root: Union[str, Path],
url: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b",
url_symbols: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols",
download: bool = False,
exclude_punctuations: bool = True) -> None:
self.exclude_punctuations = exclude_punctuations
root = Path(root)
if not os.path.isdir(root):
os.mkdir(root)
if download:
if os.path.isdir(root):
checksum = _CHECKSUMS.get(url, None)
download_url(url, root, hash_value=checksum, hash_type="md5")
checksum = _CHECKSUMS.get(url_symbols, None)
download_url(url_symbols, root, hash_value=checksum, hash_type="md5")
else:
RuntimeError("The argument `root` must be a path to directory, "
f"but '{root}' is passed in instead.")
self._root_path = root
basename = os.path.basename(url)
basename_symbols = os.path.basename(url_symbols)
with open(os.path.join(self._root_path, basename_symbols), "r") as text:
self._symbols = [line.strip() for line in text.readlines()]
with open(os.path.join(self._root_path, basename), "r") as text:
self._dictionary = _parse_dictionary(text.readlines(),
exclude_punctuations=self.exclude_punctuations)
def __getitem__(self, n: int) -> Tuple[str, List[str]]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded.
Returns:
tuple: The corresponding word and phonemes ``(word, [phonemes])``.
"""
return self._dictionary[n]
def __len__(self) -> int:
return len(self._dictionary)
@property
def symbols(self) -> List[str]:
return self._symbols.copy()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment