Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
550b6a30
Unverified
Commit
550b6a30
authored
Nov 18, 2020
by
Bhargav Kathivarapu
Committed by
GitHub
Nov 17, 2020
Browse files
Add pathlib support for SPEECHCOMMANDS (#1039)
parent
619da1f2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
65 deletions
+24
-65
test/torchaudio_unittest/datasets/speechcommands_test.py
test/torchaudio_unittest/datasets/speechcommands_test.py
+17
-62
torchaudio/datasets/speechcommands.py
torchaudio/datasets/speechcommands.py
+7
-3
No files found.
test/torchaudio_unittest/datasets/speechcommands_test.py
View file @
550b6a30
import
os
from
pathlib
import
Path
from
torchaudio.datasets
import
speechcommands
...
...
@@ -105,85 +106,39 @@ class TestSpeechCommands(TempDirMixin, TorchaudioTestCase):
test
.
write
(
f
'
{
label
}
/
{
filename
}
\n
'
)
cls
.
test_samples
.
append
(
sample
)
def
testSpeechCommands
(
self
):
dataset
=
speechcommands
.
SPEECHCOMMANDS
(
self
.
root_dir
)
def
_testSpeechCommands
(
self
,
dataset
,
data_samples
):
num_samples
=
0
for
i
,
(
data
,
sample_rate
,
label
,
speaker_id
,
utterance_number
)
in
enumerate
(
dataset
):
self
.
assertEqual
(
data
,
self
.
samples
[
i
][
0
],
atol
=
5e-5
,
rtol
=
1e-8
)
assert
sample_rate
==
self
.
samples
[
i
][
1
]
assert
label
==
self
.
samples
[
i
][
2
]
assert
speaker_id
==
self
.
samples
[
i
][
3
]
assert
utterance_number
==
self
.
samples
[
i
][
4
]
self
.
assertEqual
(
data
,
data_
samples
[
i
][
0
],
atol
=
5e-5
,
rtol
=
1e-8
)
assert
sample_rate
==
data_
samples
[
i
][
1
]
assert
label
==
data_
samples
[
i
][
2
]
assert
speaker_id
==
data_
samples
[
i
][
3
]
assert
utterance_number
==
data_
samples
[
i
][
4
]
num_samples
+=
1
assert
num_samples
==
len
(
self
.
samples
)
assert
num_samples
==
len
(
data_
samples
)
def
testSpeechCommandsNone
(
self
):
dataset
=
speechcommands
.
SPEECHCOMMANDS
(
self
.
root_dir
,
subset
=
None
)
num_samples
=
0
for
i
,
(
data
,
sample_rate
,
label
,
speaker_id
,
utterance_number
)
in
enumerate
(
dataset
):
self
.
assertEqual
(
data
,
self
.
samples
[
i
][
0
],
atol
=
5e-5
,
rtol
=
1e-8
)
assert
sample_rate
==
self
.
samples
[
i
][
1
]
assert
label
==
self
.
samples
[
i
][
2
]
assert
speaker_id
==
self
.
samples
[
i
][
3
]
assert
utterance_number
==
self
.
samples
[
i
][
4
]
num_samples
+=
1
def
testSpeechCommands_str
(
self
):
dataset
=
speechcommands
.
SPEECHCOMMANDS
(
self
.
root_dir
)
self
.
_testSpeechCommands
(
dataset
,
self
.
samples
)
assert
num_samples
==
len
(
self
.
samples
)
def
testSpeechCommands_path
(
self
):
dataset
=
speechcommands
.
SPEECHCOMMANDS
(
Path
(
self
.
root_dir
))
self
.
_testSpeechCommands
(
dataset
,
self
.
samples
)
def
testSpeechCommandsSubsetTrain
(
self
):
dataset
=
speechcommands
.
SPEECHCOMMANDS
(
self
.
root_dir
,
subset
=
"training"
)
num_samples
=
0
for
i
,
(
data
,
sample_rate
,
label
,
speaker_id
,
utterance_number
)
in
enumerate
(
dataset
):
self
.
assertEqual
(
data
,
self
.
train_samples
[
i
][
0
],
atol
=
5e-5
,
rtol
=
1e-8
)
assert
sample_rate
==
self
.
train_samples
[
i
][
1
]
assert
label
==
self
.
train_samples
[
i
][
2
]
assert
speaker_id
==
self
.
train_samples
[
i
][
3
]
assert
utterance_number
==
self
.
train_samples
[
i
][
4
]
num_samples
+=
1
assert
num_samples
==
len
(
self
.
train_samples
)
self
.
_testSpeechCommands
(
dataset
,
self
.
train_samples
)
def
testSpeechCommandsSubsetValid
(
self
):
dataset
=
speechcommands
.
SPEECHCOMMANDS
(
self
.
root_dir
,
subset
=
"validation"
)
num_samples
=
0
for
i
,
(
data
,
sample_rate
,
label
,
speaker_id
,
utterance_number
)
in
enumerate
(
dataset
):
self
.
assertEqual
(
data
,
self
.
valid_samples
[
i
][
0
],
atol
=
5e-5
,
rtol
=
1e-8
)
assert
sample_rate
==
self
.
valid_samples
[
i
][
1
]
assert
label
==
self
.
valid_samples
[
i
][
2
]
assert
speaker_id
==
self
.
valid_samples
[
i
][
3
]
assert
utterance_number
==
self
.
valid_samples
[
i
][
4
]
num_samples
+=
1
assert
num_samples
==
len
(
self
.
valid_samples
)
self
.
_testSpeechCommands
(
dataset
,
self
.
valid_samples
)
def
testSpeechCommandsSubsetTest
(
self
):
dataset
=
speechcommands
.
SPEECHCOMMANDS
(
self
.
root_dir
,
subset
=
"testing"
)
num_samples
=
0
for
i
,
(
data
,
sample_rate
,
label
,
speaker_id
,
utterance_number
)
in
enumerate
(
dataset
):
self
.
assertEqual
(
data
,
self
.
test_samples
[
i
][
0
],
atol
=
5e-5
,
rtol
=
1e-8
)
assert
sample_rate
==
self
.
test_samples
[
i
][
1
]
assert
label
==
self
.
test_samples
[
i
][
2
]
assert
speaker_id
==
self
.
test_samples
[
i
][
3
]
assert
utterance_number
==
self
.
test_samples
[
i
][
4
]
num_samples
+=
1
assert
num_samples
==
len
(
self
.
test_samples
)
self
.
_testSpeechCommands
(
dataset
,
self
.
test_samples
)
def
testSpeechCommandsSum
(
self
):
dataset_all
=
speechcommands
.
SPEECHCOMMANDS
(
self
.
root_dir
)
...
...
torchaudio/datasets/speechcommands.py
View file @
550b6a30
import
os
from
typing
import
Tuple
,
Optional
from
typing
import
Tuple
,
Optional
,
Union
from
pathlib
import
Path
import
torchaudio
from
torch.utils.data
import
Dataset
...
...
@@ -48,7 +49,7 @@ class SPEECHCOMMANDS(Dataset):
"""Create a Dataset for Speech Commands.
Args:
root (str): Path to the directory where the dataset is found or downloaded.
root (str
or Path
): Path to the directory where the dataset is found or downloaded.
url (str, optional): The URL to download the dataset from,
or the type of the dataset to dowload.
Allowed type values are ``"speech_commands_v0.01"`` and ``"speech_commands_v0.02"``
...
...
@@ -64,7 +65,7 @@ class SPEECHCOMMANDS(Dataset):
"""
def
__init__
(
self
,
root
:
str
,
root
:
Union
[
str
,
Path
]
,
url
:
str
=
URL
,
folder_in_archive
:
str
=
FOLDER_IN_ARCHIVE
,
download
:
bool
=
False
,
...
...
@@ -85,6 +86,9 @@ class SPEECHCOMMANDS(Dataset):
url
=
os
.
path
.
join
(
base_url
,
url
+
ext_archive
)
# Get string representation of 'root' in case Path object is passed
root
=
os
.
fspath
(
root
)
basename
=
os
.
path
.
basename
(
url
)
archive
=
os
.
path
.
join
(
root
,
basename
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment