Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
edaeda4f
Unverified
Commit
edaeda4f
authored
Nov 16, 2020
by
Kshiteej K
Committed by
GitHub
Nov 16, 2020
Browse files
Add pathlib.Path support to `gtzan` (#1032)
Co-authored-by:
Vincent QB
<
vincentqb@users.noreply.github.com
>
parent
55175003
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
12 deletions
+39
-12
test/torchaudio_unittest/datasets/gtzan_test.py
test/torchaudio_unittest/datasets/gtzan_test.py
+31
-9
torchaudio/datasets/gtzan.py
torchaudio/datasets/gtzan.py
+8
-3
No files found.
test/torchaudio_unittest/datasets/gtzan_test.py
View file @
edaeda4f
import
os
import
os
from
pathlib
import
Path
from
torchaudio.datasets
import
gtzan
from
torchaudio.datasets
import
gtzan
...
@@ -54,9 +55,7 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase):
...
@@ -54,9 +55,7 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase):
n_ite
+=
1
n_ite
+=
1
assert
n_ite
==
len
(
self
.
samples
)
assert
n_ite
==
len
(
self
.
samples
)
def
test_training
(
self
):
def
_test_training
(
self
,
dataset
):
dataset
=
gtzan
.
GTZAN
(
self
.
root_dir
,
subset
=
'training'
)
n_ite
=
0
n_ite
=
0
for
i
,
(
waveform
,
sample_rate
,
label
)
in
enumerate
(
dataset
):
for
i
,
(
waveform
,
sample_rate
,
label
)
in
enumerate
(
dataset
):
self
.
assertEqual
(
waveform
,
self
.
training
[
i
][
0
],
atol
=
5e-5
,
rtol
=
1e-8
)
self
.
assertEqual
(
waveform
,
self
.
training
[
i
][
0
],
atol
=
5e-5
,
rtol
=
1e-8
)
...
@@ -65,9 +64,7 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase):
...
@@ -65,9 +64,7 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase):
n_ite
+=
1
n_ite
+=
1
assert
n_ite
==
len
(
self
.
training
)
assert
n_ite
==
len
(
self
.
training
)
def
test_validation
(
self
):
def
_test_validation
(
self
,
dataset
):
dataset
=
gtzan
.
GTZAN
(
self
.
root_dir
,
subset
=
'validation'
)
n_ite
=
0
n_ite
=
0
for
i
,
(
waveform
,
sample_rate
,
label
)
in
enumerate
(
dataset
):
for
i
,
(
waveform
,
sample_rate
,
label
)
in
enumerate
(
dataset
):
self
.
assertEqual
(
waveform
,
self
.
validation
[
i
][
0
],
atol
=
5e-5
,
rtol
=
1e-8
)
self
.
assertEqual
(
waveform
,
self
.
validation
[
i
][
0
],
atol
=
5e-5
,
rtol
=
1e-8
)
...
@@ -76,9 +73,7 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase):
...
@@ -76,9 +73,7 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase):
n_ite
+=
1
n_ite
+=
1
assert
n_ite
==
len
(
self
.
validation
)
assert
n_ite
==
len
(
self
.
validation
)
def
test_testing
(
self
):
def
_test_testing
(
self
,
dataset
):
dataset
=
gtzan
.
GTZAN
(
self
.
root_dir
,
subset
=
'testing'
)
n_ite
=
0
n_ite
=
0
for
i
,
(
waveform
,
sample_rate
,
label
)
in
enumerate
(
dataset
):
for
i
,
(
waveform
,
sample_rate
,
label
)
in
enumerate
(
dataset
):
self
.
assertEqual
(
waveform
,
self
.
testing
[
i
][
0
],
atol
=
5e-5
,
rtol
=
1e-8
)
self
.
assertEqual
(
waveform
,
self
.
testing
[
i
][
0
],
atol
=
5e-5
,
rtol
=
1e-8
)
...
@@ -86,3 +81,30 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase):
...
@@ -86,3 +81,30 @@ class TestGTZAN(TempDirMixin, TorchaudioTestCase):
assert
label
==
self
.
testing
[
i
][
2
]
assert
label
==
self
.
testing
[
i
][
2
]
n_ite
+=
1
n_ite
+=
1
assert
n_ite
==
len
(
self
.
testing
)
assert
n_ite
==
len
(
self
.
testing
)
def
test_training_str
(
self
):
train_dataset
=
gtzan
.
GTZAN
(
self
.
root_dir
,
subset
=
'training'
)
self
.
_test_training
(
train_dataset
)
def
test_validation_str
(
self
):
val_dataset
=
gtzan
.
GTZAN
(
self
.
root_dir
,
subset
=
'validation'
)
self
.
_test_validation
(
val_dataset
)
def
test_testing_str
(
self
):
test_dataset
=
gtzan
.
GTZAN
(
self
.
root_dir
,
subset
=
'testing'
)
self
.
_test_testing
(
test_dataset
)
def
test_training_path
(
self
):
root_dir
=
Path
(
self
.
root_dir
)
train_dataset
=
gtzan
.
GTZAN
(
root_dir
,
subset
=
'training'
)
self
.
_test_training
(
train_dataset
)
def
test_validation_path
(
self
):
root_dir
=
Path
(
self
.
root_dir
)
val_dataset
=
gtzan
.
GTZAN
(
root_dir
,
subset
=
'validation'
)
self
.
_test_validation
(
val_dataset
)
def
test_testing_path
(
self
):
root_dir
=
Path
(
self
.
root_dir
)
test_dataset
=
gtzan
.
GTZAN
(
root_dir
,
subset
=
'testing'
)
self
.
_test_testing
(
test_dataset
)
torchaudio/datasets/gtzan.py
View file @
edaeda4f
import
os
import
os
import
warnings
import
warnings
from
typing
import
Any
,
Tuple
,
Optional
from
pathlib
import
Path
from
typing
import
Any
,
Tuple
,
Optional
,
Union
import
torchaudio
import
torchaudio
from
torch
import
Tensor
from
torch
import
Tensor
...
@@ -1005,7 +1006,7 @@ class GTZAN(Dataset):
...
@@ -1005,7 +1006,7 @@ class GTZAN(Dataset):
this dataset to publish results.
this dataset to publish results.
Args:
Args:
root (str): Path to the directory where the dataset is found or downloaded.
root (str
or Path
): Path to the directory where the dataset is found or downloaded.
url (str, optional): The URL to download the dataset from.
url (str, optional): The URL to download the dataset from.
(default: ``"http://opihi.cs.uvic.ca/sound/genres.tar.gz"``)
(default: ``"http://opihi.cs.uvic.ca/sound/genres.tar.gz"``)
folder_in_archive (str, optional): The top-level directory of the dataset.
folder_in_archive (str, optional): The top-level directory of the dataset.
...
@@ -1020,7 +1021,7 @@ class GTZAN(Dataset):
...
@@ -1020,7 +1021,7 @@ class GTZAN(Dataset):
def
__init__
(
def
__init__
(
self
,
self
,
root
:
str
,
root
:
Union
[
str
,
Path
]
,
url
:
str
=
URL
,
url
:
str
=
URL
,
folder_in_archive
:
str
=
FOLDER_IN_ARCHIVE
,
folder_in_archive
:
str
=
FOLDER_IN_ARCHIVE
,
download
:
bool
=
False
,
download
:
bool
=
False
,
...
@@ -1028,6 +1029,10 @@ class GTZAN(Dataset):
...
@@ -1028,6 +1029,10 @@ class GTZAN(Dataset):
)
->
None
:
)
->
None
:
# super(GTZAN, self).__init__()
# super(GTZAN, self).__init__()
# Get string representation of 'root' in case Path object is passed
root
=
os
.
fspath
(
root
)
self
.
root
=
root
self
.
root
=
root
self
.
url
=
url
self
.
url
=
url
self
.
folder_in_archive
=
folder_in_archive
self
.
folder_in_archive
=
folder_in_archive
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment