Commit ab97afa0 authored by moto's avatar moto
Browse files

Clean up constructor of CMUDict (#1852)

parent f18d01a0
...@@ -108,49 +108,57 @@ class CMUDict(Dataset): ...@@ -108,49 +108,57 @@ class CMUDict(Dataset):
Args: Args:
root (str or Path): Path to the directory where the dataset is found or downloaded. root (str or Path): Path to the directory where the dataset is found or downloaded.
exclude_punctuations (bool, optional):
When enabled, exclude the pronounciation of punctuations, such as
`!EXCLAMATION-POINT` and `#HASH-MARK`.
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
url (str, optional): url (str, optional):
The URL to download the dictionary from. The URL to download the dictionary from.
(default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b"``) (default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b"``)
url_symbols (str, optional): url_symbols (str, optional):
The URL to download the list of symbols from. The URL to download the list of symbols from.
(default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols"``) (default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols"``)
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
""" """
def __init__(self, def __init__(self,
root: Union[str, Path], root: Union[str, Path],
exclude_punctuations: bool = True,
*,
download: bool = False,
url: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b", url: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b",
url_symbols: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols", url_symbols: str = "http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols",
download: bool = False, ) -> None:
exclude_punctuations: bool = True) -> None:
self.exclude_punctuations = exclude_punctuations self.exclude_punctuations = exclude_punctuations
root = Path(root) self._root_path = Path(root)
if not os.path.isdir(root): if not os.path.isdir(self._root_path):
os.mkdir(root) raise RuntimeError(f'The root directory does not exist; {root}')
if download: dict_file = self._root_path / os.path.basename(url)
if os.path.isdir(root): symbol_file = self._root_path / os.path.basename(url_symbols)
checksum = _CHECKSUMS.get(url, None) if not os.path.exists(dict_file):
download_url(url, root, hash_value=checksum, hash_type="md5") if not download:
checksum = _CHECKSUMS.get(url_symbols, None) raise RuntimeError(
download_url(url_symbols, root, hash_value=checksum, hash_type="md5") 'The dictionary file is not found in the following location. '
else: f'Set `download=True` to download it. {dict_file}')
RuntimeError("The argument `root` must be a path to directory, " checksum = _CHECKSUMS.get(url, None)
f"but '{root}' is passed in instead.") download_url(url, root, hash_value=checksum, hash_type="md5")
if not os.path.exists(symbol_file):
self._root_path = root if not download:
basename = os.path.basename(url) raise RuntimeError(
basename_symbols = os.path.basename(url_symbols) 'The symbol file is not found in the following location. '
f'Set `download=True` to download it. {symbol_file}')
with open(os.path.join(self._root_path, basename_symbols), "r") as text: checksum = _CHECKSUMS.get(url_symbols, None)
download_url(url_symbols, root, hash_value=checksum, hash_type="md5")
with open(symbol_file, "r") as text:
self._symbols = [line.strip() for line in text.readlines()] self._symbols = [line.strip() for line in text.readlines()]
with open(os.path.join(self._root_path, basename), "r", encoding='latin-1') as text: with open(dict_file, "r", encoding='latin-1') as text:
self._dictionary = _parse_dictionary(text.readlines(), self._dictionary = _parse_dictionary(
exclude_punctuations=self.exclude_punctuations) text.readlines(), exclude_punctuations=self.exclude_punctuations)
def __getitem__(self, n: int) -> Tuple[str, List[str]]: def __getitem__(self, n: int) -> Tuple[str, List[str]]:
"""Load the n-th sample from the dataset. """Load the n-th sample from the dataset.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment