Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
55175003
Unverified
Commit
55175003
authored
Nov 16, 2020
by
Bhargav Kathivarapu
Committed by
GitHub
Nov 16, 2020
Browse files
Pathlib support for VCTK and LJSPEECH (#1028)
parent
0cf4b8a9
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
34 additions
and
10 deletions
+34
-10
test/torchaudio_unittest/datasets/ljspeech_test.py
test/torchaudio_unittest/datasets/ljspeech_test.py
+10
-2
test/torchaudio_unittest/datasets/vctk_test.py
test/torchaudio_unittest/datasets/vctk_test.py
+10
-2
torchaudio/datasets/ljspeech.py
torchaudio/datasets/ljspeech.py
+7
-3
torchaudio/datasets/vctk.py
torchaudio/datasets/vctk.py
+7
-3
No files found.
test/torchaudio_unittest/datasets/ljspeech_test.py
View file @
55175003
import
csv
import
csv
import
os
import
os
from
pathlib
import
Path
from
torchaudio.datasets
import
ljspeech
from
torchaudio.datasets
import
ljspeech
...
@@ -57,8 +58,7 @@ class TestLJSpeech(TempDirMixin, TorchaudioTestCase):
...
@@ -57,8 +58,7 @@ class TestLJSpeech(TempDirMixin, TorchaudioTestCase):
save_wav
(
path
,
data
,
sample_rate
)
save_wav
(
path
,
data
,
sample_rate
)
cls
.
data
.
append
(
normalize_wav
(
data
))
cls
.
data
.
append
(
normalize_wav
(
data
))
def
test_ljspeech
(
self
):
def
_test_ljspeech
(
self
,
dataset
):
dataset
=
ljspeech
.
LJSPEECH
(
self
.
root_dir
)
n_ite
=
0
n_ite
=
0
for
i
,
(
waveform
,
sample_rate
,
transcript
,
normalized_transcript
)
in
enumerate
(
for
i
,
(
waveform
,
sample_rate
,
transcript
,
normalized_transcript
)
in
enumerate
(
dataset
dataset
...
@@ -72,3 +72,11 @@ class TestLJSpeech(TempDirMixin, TorchaudioTestCase):
...
@@ -72,3 +72,11 @@ class TestLJSpeech(TempDirMixin, TorchaudioTestCase):
assert
normalized_transcript
==
expected_normalized_transcript
assert
normalized_transcript
==
expected_normalized_transcript
n_ite
+=
1
n_ite
+=
1
assert
n_ite
==
len
(
self
.
data
)
assert
n_ite
==
len
(
self
.
data
)
def
test_ljspeech_str
(
self
):
dataset
=
ljspeech
.
LJSPEECH
(
self
.
root_dir
)
self
.
_test_ljspeech
(
dataset
)
def
test_ljspeech_path
(
self
):
dataset
=
ljspeech
.
LJSPEECH
(
Path
(
self
.
root_dir
))
self
.
_test_ljspeech
(
dataset
)
test/torchaudio_unittest/datasets/vctk_test.py
View file @
55175003
import
os
import
os
from
pathlib
import
Path
from
torchaudio.datasets
import
vctk
from
torchaudio.datasets
import
vctk
...
@@ -77,8 +78,7 @@ class TestVCTK(TempDirMixin, TorchaudioTestCase):
...
@@ -77,8 +78,7 @@ class TestVCTK(TempDirMixin, TorchaudioTestCase):
seed
+=
1
seed
+=
1
def
test_vctk
(
self
):
def
_test_vctk
(
self
,
dataset
):
dataset
=
vctk
.
VCTK_092
(
self
.
root_dir
,
audio_ext
=
".wav"
)
num_samples
=
0
num_samples
=
0
for
i
,
(
data
,
sample_rate
,
utterance
,
speaker_id
,
utterance_id
)
in
enumerate
(
dataset
):
for
i
,
(
data
,
sample_rate
,
utterance
,
speaker_id
,
utterance_id
)
in
enumerate
(
dataset
):
self
.
assertEqual
(
data
,
self
.
samples
[
i
][
0
],
atol
=
5e-5
,
rtol
=
1e-8
)
self
.
assertEqual
(
data
,
self
.
samples
[
i
][
0
],
atol
=
5e-5
,
rtol
=
1e-8
)
...
@@ -89,3 +89,11 @@ class TestVCTK(TempDirMixin, TorchaudioTestCase):
...
@@ -89,3 +89,11 @@ class TestVCTK(TempDirMixin, TorchaudioTestCase):
num_samples
+=
1
num_samples
+=
1
assert
num_samples
==
len
(
self
.
samples
)
assert
num_samples
==
len
(
self
.
samples
)
def
test_vctk_str
(
self
):
dataset
=
vctk
.
VCTK_092
(
self
.
root_dir
,
audio_ext
=
".wav"
)
self
.
_test_vctk
(
dataset
)
def
test_vctk_path
(
self
):
dataset
=
vctk
.
VCTK_092
(
Path
(
self
.
root_dir
),
audio_ext
=
".wav"
)
self
.
_test_vctk
(
dataset
)
torchaudio/datasets/ljspeech.py
View file @
55175003
import
os
import
os
import
csv
import
csv
from
typing
import
List
,
Tuple
from
typing
import
List
,
Tuple
,
Union
from
pathlib
import
Path
import
torchaudio
import
torchaudio
from
torchaudio.datasets.utils
import
download_url
,
extract_archive
,
unicode_csv_reader
from
torchaudio.datasets.utils
import
download_url
,
extract_archive
,
unicode_csv_reader
...
@@ -36,7 +37,7 @@ class LJSPEECH(Dataset):
...
@@ -36,7 +37,7 @@ class LJSPEECH(Dataset):
"""Create a Dataset for LJSpeech-1.1.
"""Create a Dataset for LJSpeech-1.1.
Args:
Args:
root (str): Path to the directory where the dataset is found or downloaded.
root (str
or Path
): Path to the directory where the dataset is found or downloaded.
url (str, optional): The URL to download the dataset from.
url (str, optional): The URL to download the dataset from.
(default: ``"https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"``)
(default: ``"https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2"``)
folder_in_archive (str, optional):
folder_in_archive (str, optional):
...
@@ -49,11 +50,14 @@ class LJSPEECH(Dataset):
...
@@ -49,11 +50,14 @@ class LJSPEECH(Dataset):
_ext_archive
=
'.tar.bz2'
_ext_archive
=
'.tar.bz2'
def
__init__
(
self
,
def
__init__
(
self
,
root
:
str
,
root
:
Union
[
str
,
Path
]
,
url
:
str
=
URL
,
url
:
str
=
URL
,
folder_in_archive
:
str
=
FOLDER_IN_ARCHIVE
,
folder_in_archive
:
str
=
FOLDER_IN_ARCHIVE
,
download
:
bool
=
False
)
->
None
:
download
:
bool
=
False
)
->
None
:
# Get string representation of 'root' in case Path object is passed
root
=
os
.
fspath
(
root
)
basename
=
os
.
path
.
basename
(
url
)
basename
=
os
.
path
.
basename
(
url
)
archive
=
os
.
path
.
join
(
root
,
basename
)
archive
=
os
.
path
.
join
(
root
,
basename
)
...
...
torchaudio/datasets/vctk.py
View file @
55175003
import
os
import
os
import
warnings
import
warnings
from
typing
import
Any
,
Tuple
from
typing
import
Any
,
Tuple
,
Union
from
pathlib
import
Path
import
torchaudio
import
torchaudio
from
torch
import
Tensor
from
torch
import
Tensor
...
@@ -57,7 +58,7 @@ class VCTK(Dataset):
...
@@ -57,7 +58,7 @@ class VCTK(Dataset):
For more information about the dataset visit: https://datashare.is.ed.ac.uk/handle/10283/3443
For more information about the dataset visit: https://datashare.is.ed.ac.uk/handle/10283/3443
Args:
Args:
root (str): Path to the directory where the dataset is found or downloaded.
root (str
or Path
): Path to the directory where the dataset is found or downloaded.
url (str, optional): Not used as the dataset is no longer publicly available.
url (str, optional): Not used as the dataset is no longer publicly available.
folder_in_archive (str, optional):
folder_in_archive (str, optional):
The top-level directory of the dataset. (default: ``"VCTK-Corpus"``)
The top-level directory of the dataset. (default: ``"VCTK-Corpus"``)
...
@@ -77,7 +78,7 @@ class VCTK(Dataset):
...
@@ -77,7 +78,7 @@ class VCTK(Dataset):
_except_folder
=
"p315"
_except_folder
=
"p315"
def
__init__
(
self
,
def
__init__
(
self
,
root
:
str
,
root
:
Union
[
str
,
Path
]
,
url
:
str
=
URL
,
url
:
str
=
URL
,
folder_in_archive
:
str
=
FOLDER_IN_ARCHIVE
,
folder_in_archive
:
str
=
FOLDER_IN_ARCHIVE
,
download
:
bool
=
False
,
download
:
bool
=
False
,
...
@@ -103,6 +104,9 @@ class VCTK(Dataset):
...
@@ -103,6 +104,9 @@ class VCTK(Dataset):
self
.
transform
=
transform
self
.
transform
=
transform
self
.
target_transform
=
target_transform
self
.
target_transform
=
target_transform
# Get string representation of 'root' in case Path object is passed
root
=
os
.
fspath
(
root
)
archive
=
os
.
path
.
basename
(
url
)
archive
=
os
.
path
.
basename
(
url
)
archive
=
os
.
path
.
join
(
root
,
archive
)
archive
=
os
.
path
.
join
(
root
,
archive
)
self
.
_path
=
os
.
path
.
join
(
root
,
folder_in_archive
)
self
.
_path
=
os
.
path
.
join
(
root
,
folder_in_archive
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment