Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
ab97afa0
Commit
ab97afa0
authored
Oct 11, 2021
by
moto
Browse files
Clean up constructor of CMUDict (#1852)
parent
f18d01a0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
26 deletions
+34
-26
torchaudio/datasets/cmudict.py
torchaudio/datasets/cmudict.py
+34
-26
No files found.
torchaudio/datasets/cmudict.py
View file @
ab97afa0
...
@@ -108,49 +108,57 @@ class CMUDict(Dataset):
...
@@ -108,49 +108,57 @@ class CMUDict(Dataset):
Args:
Args:
root (str or Path): Path to the directory where the dataset is found or downloaded.
root (str or Path): Path to the directory where the dataset is found or downloaded.
exclude_punctuations (bool, optional):
When enabled, exclude the pronounciation of punctuations, such as
`!EXCLAMATION-POINT` and `#HASH-MARK`.
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
url (str, optional):
url (str, optional):
The URL to download the dictionary from.
The URL to download the dictionary from.
(default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b"``)
(default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b"``)
url_symbols (str, optional):
url_symbols (str, optional):
The URL to download the list of symbols from.
The URL to download the list of symbols from.
(default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols"``)
(default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols"``)
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
root
:
Union
[
str
,
Path
],
root
:
Union
[
str
,
Path
],
exclude_punctuations
:
bool
=
True
,
*
,
download
:
bool
=
False
,
url
:
str
=
"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b"
,
url
:
str
=
"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b"
,
url_symbols
:
str
=
"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols"
,
url_symbols
:
str
=
"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols"
,
download
:
bool
=
False
,
)
->
None
:
exclude_punctuations
:
bool
=
True
)
->
None
:
self
.
exclude_punctuations
=
exclude_punctuations
self
.
exclude_punctuations
=
exclude_punctuations
root
=
Path
(
root
)
self
.
_root_path
=
Path
(
root
)
if
not
os
.
path
.
isdir
(
root
):
if
not
os
.
path
.
isdir
(
self
.
_root_path
):
os
.
mkdir
(
root
)
raise
RuntimeError
(
f
'The root directory does not exist;
{
root
}
'
)
if
download
:
dict_file
=
self
.
_root_path
/
os
.
path
.
basename
(
url
)
if
os
.
path
.
isdir
(
root
):
symbol_file
=
self
.
_root_path
/
os
.
path
.
basename
(
url_symbols
)
checksum
=
_CHECKSUMS
.
get
(
url
,
None
)
if
not
os
.
path
.
exists
(
dict_file
):
download_url
(
url
,
root
,
hash_value
=
checksum
,
hash_type
=
"md5"
)
if
not
download
:
checksum
=
_CHECKSUMS
.
get
(
url_symbols
,
None
)
raise
RuntimeError
(
download_url
(
url_symbols
,
root
,
hash_value
=
checksum
,
hash_type
=
"md5"
)
'The dictionary file is not found in the following location. '
else
:
f
'Set `download=True` to download it.
{
dict_file
}
'
)
RuntimeError
(
"The argument `root` must be a path to directory, "
checksum
=
_CHECKSUMS
.
get
(
url
,
None
)
f
"but '
{
root
}
' is passed in instead."
)
download_url
(
url
,
root
,
hash_value
=
checksum
,
hash_type
=
"md5"
)
if
not
os
.
path
.
exists
(
symbol_file
):
self
.
_root_path
=
root
if
not
download
:
basename
=
os
.
path
.
basename
(
url
)
raise
RuntimeError
(
basename_symbols
=
os
.
path
.
basename
(
url_symbols
)
'The symbol file is not found in the following location. '
f
'Set `download=True` to download it.
{
symbol_file
}
'
)
with
open
(
os
.
path
.
join
(
self
.
_root_path
,
basename_symbols
),
"r"
)
as
text
:
checksum
=
_CHECKSUMS
.
get
(
url_symbols
,
None
)
download_url
(
url_symbols
,
root
,
hash_value
=
checksum
,
hash_type
=
"md5"
)
with
open
(
symbol_file
,
"r"
)
as
text
:
self
.
_symbols
=
[
line
.
strip
()
for
line
in
text
.
readlines
()]
self
.
_symbols
=
[
line
.
strip
()
for
line
in
text
.
readlines
()]
with
open
(
os
.
path
.
join
(
self
.
_root_path
,
basename
)
,
"r"
,
encoding
=
'latin-1'
)
as
text
:
with
open
(
dict_file
,
"r"
,
encoding
=
'latin-1'
)
as
text
:
self
.
_dictionary
=
_parse_dictionary
(
text
.
readlines
(),
self
.
_dictionary
=
_parse_dictionary
(
exclude_punctuations
=
self
.
exclude_punctuations
)
text
.
readlines
(),
exclude_punctuations
=
self
.
exclude_punctuations
)
def
__getitem__
(
self
,
n
:
int
)
->
Tuple
[
str
,
List
[
str
]]:
def
__getitem__
(
self
,
n
:
int
)
->
Tuple
[
str
,
List
[
str
]]:
"""Load the n-th sample from the dataset.
"""Load the n-th sample from the dataset.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment