Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
077a5f4a
Unverified
Commit
077a5f4a
authored
Aug 02, 2021
by
yangarbiter
Committed by
GitHub
Aug 02, 2021
Browse files
Add CMUDict dataset (#1627)
parent
83dc5ec7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
401 additions
and
0 deletions
+401
-0
docs/source/datasets.rst
docs/source/datasets.rst
+8
-0
test/torchaudio_unittest/datasets/cmudict_test.py
test/torchaudio_unittest/datasets/cmudict_test.py
+218
-0
torchaudio/datasets/__init__.py
torchaudio/datasets/__init__.py
+3
-0
torchaudio/datasets/cmudict.py
torchaudio/datasets/cmudict.py
+172
-0
No files found.
docs/source/datasets.rst
View file @
077a5f4a
...
@@ -33,6 +33,14 @@ CMUARCTIC
...
@@ -33,6 +33,14 @@ CMUARCTIC
:special-members: __getitem__
:special-members: __getitem__
CMUDict
~~~~~~~~~
.. autoclass:: CMUDict
:members:
:special-members: __getitem__
COMMONVOICE
COMMONVOICE
~~~~~~~~~~~
~~~~~~~~~~~
...
...
test/torchaudio_unittest/datasets/cmudict_test.py
0 → 100644
View file @
077a5f4a
import
os
from
pathlib
import
Path
from
torchaudio.datasets
import
CMUDict
from
torchaudio_unittest.common_utils
import
(
TempDirMixin
,
TorchaudioTestCase
,
)
def
get_mock_dataset
(
root_dir
,
return_punc
=
False
):
"""
root_dir: directory to the mocked dataset
"""
header
=
[
";;; # CMUdict -- Major Version: 0.07"
,
";;; "
,
";;; # $HeadURL$"
,
]
puncs
=
[
"!EXCLAMATION-POINT EH2 K S K L AH0 M EY1 SH AH0 N P OY2 N T"
,
"
\"
CLOSE-QUOTE K L OW1 Z K W OW1 T"
,
"#HASH-MARK HH AE1 M AA2 R K"
,
"%PERCENT P ER0 S EH1 N T"
,
"&ERSAND AE1 M P ER0 S AE2 N D"
,
"'END-INNER-QUOTE EH1 N D IH1 N ER0 K W OW1 T"
,
"(BEGIN-PARENS B IH0 G IH1 N P ER0 EH1 N Z"
,
")CLOSE-PAREN K L OW1 Z P ER0 EH1 N"
,
"+PLUS P L UH1 S"
,
",COMMA K AA1 M AH0"
,
"--DASH D AE1 SH"
,
"!EXCLAMATION-POINT EH2 K S K L AH0 M EY1 SH AH0 N P OY2 N T"
,
"/SLASH S L AE1 SH"
,
":COLON K OW1 L AH0 N"
,
";SEMI-COLON S EH1 M IY0 K OW1 L AH0 N"
,
"?QUESTION-MARK K W EH1 S CH AH0 N M AA1 R K"
,
"{BRACE B R EY1 S"
,
"}CLOSE-BRACE K L OW1 Z B R EY1 S"
,
"...ELLIPSIS IH2 L IH1 P S IH0 S"
,
]
punc_outputs
=
[
"!"
,
"
\"
"
,
"#"
,
"%"
,
"&"
,
"'"
,
"("
,
")"
,
"+"
,
","
,
"--"
,
"!"
,
"/"
,
":"
,
";"
,
"?"
,
"{"
,
"}"
,
"..."
,
]
words
=
[
"3-D TH R IY1 D IY2"
,
"'BOUT B AW1 T"
,
"'CAUSE K AH0 Z"
,
"'TWAS T W AH1 Z"
,
"A AH0"
,
"B B IY1"
,
"C S IY1"
,
"D D IY1"
,
"E IY1"
,
"F EH1 F"
,
"G JH IY1"
,
"H EY1 CH"
,
"I AY1"
,
"J JH EY1"
,
"K K EY1"
,
"L EH1 L"
,
"M EH1 M"
,
"N EH1 N"
,
"O OW1"
,
"P P IY1"
,
"Q K Y UW1"
,
"R AA1 R"
,
"S EH1 S"
,
"T T IY1"
,
"U Y UW1"
,
"V V IY1"
,
"X EH1 K S"
,
"Y W AY1"
,
"Z Z IY1"
,
]
mocked_symbols
=
[
"AA1"
,
"AA2"
,
"AE1"
,
"AE2"
,
"AH0"
,
"AH1"
,
"AY1"
,
"B"
,
"CH"
,
"D"
,
"EH1"
,
"EH2"
,
"ER0"
,
"EY1"
,
"F"
,
"G"
,
"HH"
,
"IH0"
,
"IH1"
,
"IY0"
,
"IY1"
,
"IY2"
,
"JH"
,
"K"
,
"L"
,
"M"
,
"N"
,
"OW1"
,
"OY2"
,
"P"
,
"R"
,
"S"
,
"SH"
,
"T"
,
"TH"
,
"UH1"
,
"UW0"
,
"UW1"
,
"V"
,
"W"
,
"Y"
,
"Z"
,
]
dict_file
=
os
.
path
.
join
(
root_dir
,
"cmudict-0.7b"
)
symbol_file
=
os
.
path
.
join
(
root_dir
,
"cmudict-0.7b.symbols"
)
with
open
(
dict_file
,
"w"
)
as
fileobj
:
for
section
in
[
header
,
puncs
,
words
]:
for
line
in
section
:
fileobj
.
write
(
line
)
fileobj
.
write
(
"
\n
"
)
with
open
(
symbol_file
,
"w"
)
as
txt
:
txt
.
write
(
"
\n
"
.
join
(
mocked_symbols
))
mocked_data
=
[]
if
return_punc
:
for
i
,
ent
in
enumerate
(
puncs
):
_
,
phones
=
ent
.
split
(
" "
)
mocked_data
.
append
((
punc_outputs
[
i
],
phones
.
split
(
" "
)))
for
ent
in
words
:
word
,
phones
=
ent
.
split
(
" "
)
mocked_data
.
append
((
word
,
phones
.
split
(
" "
)))
return
mocked_data
class
TestCMUDict
(
TempDirMixin
,
TorchaudioTestCase
):
root_dir
=
None
root_punc_dir
=
None
samples
=
[]
punc_samples
=
[]
@
classmethod
def
setUpClass
(
cls
):
cls
.
root_dir
=
os
.
path
.
join
(
cls
.
get_base_temp_dir
(),
"normal"
)
os
.
mkdir
(
cls
.
root_dir
)
cls
.
samples
=
get_mock_dataset
(
cls
.
root_dir
)
cls
.
root_punc_dir
=
os
.
path
.
join
(
cls
.
get_base_temp_dir
(),
"punc"
)
os
.
mkdir
(
cls
.
root_punc_dir
)
cls
.
punc_samples
=
get_mock_dataset
(
cls
.
root_punc_dir
,
return_punc
=
True
)
def
_test_cmudict
(
self
,
dataset
):
"""Test if the dataset is reading the mocked data correctly."""
n_item
=
0
for
i
,
(
word
,
phones
)
in
enumerate
(
dataset
):
expected_word
,
expected_phones
=
self
.
samples
[
i
]
assert
word
==
expected_word
assert
phones
==
expected_phones
n_item
+=
1
assert
n_item
==
len
(
self
.
samples
)
def
_test_punc_cmudict
(
self
,
dataset
):
"""Test if the dataset is reading the mocked data with punctuations correctly."""
n_item
=
0
for
i
,
(
word
,
phones
)
in
enumerate
(
dataset
):
expected_word
,
expected_phones
=
self
.
punc_samples
[
i
]
assert
word
==
expected_word
assert
phones
==
expected_phones
n_item
+=
1
assert
n_item
==
len
(
self
.
punc_samples
)
def
test_cmuarctic_path_with_punctuation
(
self
):
dataset
=
CMUDict
(
Path
(
self
.
root_punc_dir
),
exclude_punctuations
=
False
)
self
.
_test_punc_cmudict
(
dataset
)
def
test_cmuarctic_str_with_punctuation
(
self
):
dataset
=
CMUDict
(
self
.
root_punc_dir
,
exclude_punctuations
=
False
)
self
.
_test_punc_cmudict
(
dataset
)
def
test_cmuarctic_path
(
self
):
dataset
=
CMUDict
(
Path
(
self
.
root_punc_dir
),
exclude_punctuations
=
True
)
self
.
_test_cmudict
(
dataset
)
def
test_cmuarctic_str
(
self
):
dataset
=
CMUDict
(
self
.
root_punc_dir
,
exclude_punctuations
=
True
)
self
.
_test_cmudict
(
dataset
)
torchaudio/datasets/__init__.py
View file @
077a5f4a
...
@@ -7,9 +7,11 @@ from .gtzan import GTZAN
...
@@ -7,9 +7,11 @@ from .gtzan import GTZAN
from
.yesno
import
YESNO
from
.yesno
import
YESNO
from
.ljspeech
import
LJSPEECH
from
.ljspeech
import
LJSPEECH
from
.cmuarctic
import
CMUARCTIC
from
.cmuarctic
import
CMUARCTIC
from
.cmudict
import
CMUDict
from
.libritts
import
LIBRITTS
from
.libritts
import
LIBRITTS
from
.tedlium
import
TEDLIUM
from
.tedlium
import
TEDLIUM
__all__
=
[
__all__
=
[
"COMMONVOICE"
,
"COMMONVOICE"
,
"LIBRISPEECH"
,
"LIBRISPEECH"
,
...
@@ -20,6 +22,7 @@ __all__ = [
...
@@ -20,6 +22,7 @@ __all__ = [
"LJSPEECH"
,
"LJSPEECH"
,
"GTZAN"
,
"GTZAN"
,
"CMUARCTIC"
,
"CMUARCTIC"
,
"CMUDict"
,
"LIBRITTS"
,
"LIBRITTS"
,
"diskcache_iterator"
,
"diskcache_iterator"
,
"bg_iterator"
,
"bg_iterator"
,
...
...
torchaudio/datasets/cmudict.py
0 → 100644
View file @
077a5f4a
import
os
import
re
from
pathlib
import
Path
from
typing
import
Iterable
,
Tuple
,
Union
,
List
from
torch.utils.data
import
Dataset
from
torchaudio.datasets.utils
import
download_url
_CHECKSUMS
=
{
"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b"
:
"825f4ebd9183f2417df9f067a9cabe86"
,
"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols"
:
"385e490aabc71b48e772118e3d02923e"
,
}
_PUNCTUATIONS
=
set
([
"!EXCLAMATION-POINT"
,
"
\"
CLOSE-QUOTE"
,
"
\"
DOUBLE-QUOTE"
,
"
\"
END-OF-QUOTE"
,
"
\"
END-QUOTE"
,
"
\"
IN-QUOTES"
,
"
\"
QUOTE"
,
"
\"
UNQUOTE"
,
"#HASH-MARK"
,
"#POUND-SIGN"
,
"#SHARP-SIGN"
,
"%PERCENT"
,
"&ERSAND"
,
"'END-INNER-QUOTE"
,
"'END-QUOTE"
,
"'INNER-QUOTE"
,
"'QUOTE"
,
"'SINGLE-QUOTE"
,
"(BEGIN-PARENS"
,
"(IN-PARENTHESES"
,
"(LEFT-PAREN"
,
"(OPEN-PARENTHESES"
,
"(PAREN"
,
"(PARENS"
,
"(PARENTHESES"
,
")CLOSE-PAREN"
,
")CLOSE-PARENTHESES"
,
")END-PAREN"
,
")END-PARENS"
,
")END-PARENTHESES"
,
")END-THE-PAREN"
,
")PAREN"
,
")PARENS"
,
")RIGHT-PAREN"
,
")UN-PARENTHESES"
,
"+PLUS"
,
",COMMA"
,
"--DASH"
,
"-DASH"
,
"-HYPHEN"
,
"...ELLIPSIS"
,
".DECIMAL"
,
".DOT"
,
".FULL-STOP"
,
".PERIOD"
,
".POINT"
,
"/SLASH"
,
":COLON"
,
";SEMI-COLON"
,
";SEMI-COLON(1)"
,
"?QUESTION-MARK"
,
"{BRACE"
,
"{LEFT-BRACE"
,
"{OPEN-BRACE"
,
"}CLOSE-BRACE"
,
"}RIGHT-BRACE"
,
])
def
_parse_dictionary
(
lines
:
Iterable
[
str
],
exclude_punctuations
:
bool
)
->
List
[
str
]:
_alt_re
=
re
.
compile
(
r
'\([0-9]+\)'
)
cmudict
:
List
[
Tuple
[
str
,
List
[
str
]]]
=
list
()
for
line
in
lines
:
if
not
line
or
line
.
startswith
(
';;;'
):
# ignore comments
continue
word
,
phones
=
line
.
strip
().
split
(
' '
)
if
word
in
_PUNCTUATIONS
:
if
exclude_punctuations
:
continue
# !EXCLAMATION-POINT -> !
# --DASH -> --
# ...ELLIPSIS -> ...
if
word
.
startswith
(
"..."
):
word
=
"..."
elif
word
.
startswith
(
"--"
):
word
=
"--"
else
:
word
=
word
[
0
]
# if a word have multiple pronunciations, there will be (number) appended to it
# for example, DATAPOINTS and DATAPOINTS(1),
# the regular expression `_alt_re` removes the '(1)' and change the word DATAPOINTS(1) to DATAPOINTS
word
=
re
.
sub
(
_alt_re
,
''
,
word
)
phones
=
phones
.
split
(
" "
)
cmudict
.
append
((
word
,
phones
))
return
cmudict
class
CMUDict
(
Dataset
):
"""Create a Dataset for CMU Pronouncing Dictionary (CMUDict).
Args:
root (str or Path): Path to the directory where the dataset is found or downloaded.
url (str, optional):
The URL to download the dictionary from.
(default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b"``)
url_symbols (str, optional):
The URL to download the list of symbols from.
(default: ``"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols"``)
download (bool, optional):
Whether to download the dataset if it is not found at root path. (default: ``False``).
"""
def
__init__
(
self
,
root
:
Union
[
str
,
Path
],
url
:
str
=
"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b"
,
url_symbols
:
str
=
"http://svn.code.sf.net/p/cmusphinx/code/trunk/cmudict/cmudict-0.7b.symbols"
,
download
:
bool
=
False
,
exclude_punctuations
:
bool
=
True
)
->
None
:
self
.
exclude_punctuations
=
exclude_punctuations
root
=
Path
(
root
)
if
not
os
.
path
.
isdir
(
root
):
os
.
mkdir
(
root
)
if
download
:
if
os
.
path
.
isdir
(
root
):
checksum
=
_CHECKSUMS
.
get
(
url
,
None
)
download_url
(
url
,
root
,
hash_value
=
checksum
,
hash_type
=
"md5"
)
checksum
=
_CHECKSUMS
.
get
(
url_symbols
,
None
)
download_url
(
url_symbols
,
root
,
hash_value
=
checksum
,
hash_type
=
"md5"
)
else
:
RuntimeError
(
"The argument `root` must be a path to directory, "
f
"but '
{
root
}
' is passed in instead."
)
self
.
_root_path
=
root
basename
=
os
.
path
.
basename
(
url
)
basename_symbols
=
os
.
path
.
basename
(
url_symbols
)
with
open
(
os
.
path
.
join
(
self
.
_root_path
,
basename_symbols
),
"r"
)
as
text
:
self
.
_symbols
=
[
line
.
strip
()
for
line
in
text
.
readlines
()]
with
open
(
os
.
path
.
join
(
self
.
_root_path
,
basename
),
"r"
)
as
text
:
self
.
_dictionary
=
_parse_dictionary
(
text
.
readlines
(),
exclude_punctuations
=
self
.
exclude_punctuations
)
def
__getitem__
(
self
,
n
:
int
)
->
Tuple
[
str
,
List
[
str
]]:
"""Load the n-th sample from the dataset.
Args:
n (int): The index of the sample to be loaded.
Returns:
tuple: The corresponding word and phonemes ``(word, [phonemes])``.
"""
return
self
.
_dictionary
[
n
]
def
__len__
(
self
)
->
int
:
return
len
(
self
.
_dictionary
)
@
property
def
symbols
(
self
)
->
List
[
str
]:
return
self
.
_symbols
.
copy
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment