Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
0cf4b8a9
Unverified
Commit
0cf4b8a9
authored
Nov 14, 2020
by
Kshiteej K
Committed by
GitHub
Nov 13, 2020
Browse files
Add pathlib.Path support to `commonvoice` (#1027)
parent
f1142e65
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
5 deletions
+17
-5
test/torchaudio_unittest/datasets/commonvoice_test.py
test/torchaudio_unittest/datasets/commonvoice_test.py
+10
-2
torchaudio/datasets/commonvoice.py
torchaudio/datasets/commonvoice.py
+7
-3
No files found.
test/torchaudio_unittest/datasets/commonvoice_test.py
View file @
0cf4b8a9
import
os
import
os
import
csv
import
csv
import
random
import
random
from
pathlib
import
Path
from
torchaudio.datasets
import
commonvoice
from
torchaudio.datasets
import
commonvoice
from
torchaudio_unittest.common_utils
import
(
from
torchaudio_unittest.common_utils
import
(
...
@@ -59,8 +60,7 @@ class TestCommonVoice(TempDirMixin, TorchaudioTestCase):
...
@@ -59,8 +60,7 @@ class TestCommonVoice(TempDirMixin, TorchaudioTestCase):
# Append data entry
# Append data entry
cls
.
data
.
append
((
normalize_wav
(
data
),
cls
.
sample_rate
,
dict
(
zip
(
cls
.
_headers
,
content
))))
cls
.
data
.
append
((
normalize_wav
(
data
),
cls
.
sample_rate
,
dict
(
zip
(
cls
.
_headers
,
content
))))
def
test_commonvoice
(
self
):
def
_test_commonvoice
(
self
,
dataset
):
dataset
=
commonvoice
.
COMMONVOICE
(
self
.
root_dir
)
n_ite
=
0
n_ite
=
0
for
i
,
(
waveform
,
sample_rate
,
dictionary
)
in
enumerate
(
dataset
):
for
i
,
(
waveform
,
sample_rate
,
dictionary
)
in
enumerate
(
dataset
):
expected_dictionary
=
self
.
data
[
i
][
2
]
expected_dictionary
=
self
.
data
[
i
][
2
]
...
@@ -70,3 +70,11 @@ class TestCommonVoice(TempDirMixin, TorchaudioTestCase):
...
@@ -70,3 +70,11 @@ class TestCommonVoice(TempDirMixin, TorchaudioTestCase):
assert
dictionary
==
expected_dictionary
assert
dictionary
==
expected_dictionary
n_ite
+=
1
n_ite
+=
1
assert
n_ite
==
len
(
self
.
data
)
assert
n_ite
==
len
(
self
.
data
)
def
test_commonvoice_str
(
self
):
dataset
=
commonvoice
.
COMMONVOICE
(
self
.
root_dir
)
self
.
_test_commonvoice
(
dataset
)
def
test_commonvoice_path
(
self
):
dataset
=
commonvoice
.
COMMONVOICE
(
Path
(
self
.
root_dir
))
self
.
_test_commonvoice
(
dataset
)
torchaudio/datasets/commonvoice.py
View file @
0cf4b8a9
import
os
import
os
from
typing
import
List
,
Dict
,
Tuple
from
pathlib
import
Path
from
typing
import
List
,
Dict
,
Tuple
,
Union
import
torchaudio
import
torchaudio
from
torchaudio.datasets.utils
import
download_url
,
extract_archive
,
unicode_csv_reader
from
torchaudio.datasets.utils
import
download_url
,
extract_archive
,
unicode_csv_reader
...
@@ -103,7 +104,7 @@ class COMMONVOICE(Dataset):
...
@@ -103,7 +104,7 @@ class COMMONVOICE(Dataset):
"""Create a Dataset for CommonVoice.
"""Create a Dataset for CommonVoice.
Args:
Args:
root (str): Path to the directory where the dataset is found or downloaded.
root (str
or Path
): Path to the directory where the dataset is found or downloaded.
tsv (str, optional): The name of the tsv file used to construct the metadata.
tsv (str, optional): The name of the tsv file used to construct the metadata.
(default: ``"train.tsv"``)
(default: ``"train.tsv"``)
url (str, optional): The URL to download the dataset from, or the language of
url (str, optional): The URL to download the dataset from, or the language of
...
@@ -129,7 +130,7 @@ class COMMONVOICE(Dataset):
...
@@ -129,7 +130,7 @@ class COMMONVOICE(Dataset):
_folder_audio
=
"clips"
_folder_audio
=
"clips"
def
__init__
(
self
,
def
__init__
(
self
,
root
:
str
,
root
:
Union
[
str
,
Path
]
,
tsv
:
str
=
TSV
,
tsv
:
str
=
TSV
,
url
:
str
=
URL
,
url
:
str
=
URL
,
folder_in_archive
:
str
=
FOLDER_IN_ARCHIVE
,
folder_in_archive
:
str
=
FOLDER_IN_ARCHIVE
,
...
@@ -186,6 +187,9 @@ class COMMONVOICE(Dataset):
...
@@ -186,6 +187,9 @@ class COMMONVOICE(Dataset):
base_url
=
"https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com"
base_url
=
"https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com"
url
=
os
.
path
.
join
(
base_url
,
version
,
language
+
ext_archive
)
url
=
os
.
path
.
join
(
base_url
,
version
,
language
+
ext_archive
)
# Get string representation of 'root' in case Path object is passed
root
=
os
.
fspath
(
root
)
basename
=
os
.
path
.
basename
(
url
)
basename
=
os
.
path
.
basename
(
url
)
archive
=
os
.
path
.
join
(
root
,
basename
)
archive
=
os
.
path
.
join
(
root
,
basename
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment