Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
f5aced81
Unverified
Commit
f5aced81
authored
Jan 19, 2021
by
Krishna Kalyan
Committed by
GitHub
Jan 19, 2021
Browse files
Refactor YesNo dataset (#1127)
Co-authored-by:
krishnakalyan3
<
skalyan@cloudera.com
>
parent
e43a8e76
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
36 deletions
+29
-36
torchaudio/datasets/yesno.py
torchaudio/datasets/yesno.py
+29
-36
No files found.
torchaudio/datasets/yesno.py
View file @
f5aced81
...
...
@@ -11,23 +11,14 @@ from torchaudio.datasets.utils import (
extract_archive
,
)
URL
=
"http://www.openslr.org/resources/1/waves_yesno.tar.gz"
FOLDER_IN_ARCHIVE
=
"waves_yesno"
_CHECKSUMS
=
{
"http://www.openslr.org/resources/1/waves_yesno.tar.gz"
:
"962ff6e904d2df1126132ecec6978786"
}
def
load_yesno_item
(
fileid
:
str
,
path
:
str
,
ext_audio
:
str
)
->
Tuple
[
Tensor
,
int
,
List
[
int
]]:
# Read label
labels
=
[
int
(
c
)
for
c
in
fileid
.
split
(
"_"
)]
# Read wav
file_audio
=
os
.
path
.
join
(
path
,
fileid
+
ext_audio
)
waveform
,
sample_rate
=
torchaudio
.
load
(
file_audio
)
return
waveform
,
sample_rate
,
labels
_RELEASE_CONFIGS
=
{
"release1"
:
{
"folder_in_archive"
:
"waves_yesno"
,
"url"
:
"http://www.openslr.org/resources/1/waves_yesno.tar.gz"
,
"checksum"
:
"30301975fd8c5cac4040c261c0852f57cfa8adbbad2ce78e77e4986957445f27"
,
}
}
class
YESNO
(
Dataset
):
...
...
@@ -43,25 +34,26 @@ class YESNO(Dataset):
Whether to download the dataset if it is not found at root path. (default: ``False``).
"""
_ext_audio
=
".wav"
def
__init__
(
self
,
root
:
Union
[
str
,
Path
],
url
:
str
=
URL
,
folder_in_archive
:
str
=
FOLDER_IN_ARCHIVE
,
download
:
bool
=
False
)
->
None
:
def
__init__
(
self
,
root
:
Union
[
str
,
Path
]
,
url
:
str
=
_RELEASE_CONFIGS
[
"release1"
][
"url"
],
folder_in_archive
:
str
=
_RELEASE_CONFIGS
[
"release1"
][
"folder_in_archive"
]
,
download
:
bool
=
False
)
->
None
:
# Get string representation of 'root' in case Path object is passed
root
=
os
.
fspath
(
root
)
self
.
_parse_filesystem
(
root
,
url
,
folder_in_archive
,
download
)
def
_parse_filesystem
(
self
,
root
:
str
,
url
:
str
,
folder_in_archive
:
str
,
download
:
bool
)
->
None
:
root
=
Path
(
root
)
archive
=
os
.
path
.
basename
(
url
)
archive
=
os
.
path
.
join
(
root
,
archive
)
self
.
_path
=
os
.
path
.
join
(
root
,
folder_in_archive
)
archive
=
root
/
archive
self
.
_path
=
root
/
folder_in_archive
if
download
:
if
not
os
.
path
.
isdir
(
self
.
_path
):
if
not
os
.
path
.
isfile
(
archive
):
checksum
=
_
CHECKSUMS
.
get
(
url
,
None
)
checksum
=
_
RELEASE_CONFIGS
[
"release1"
][
"checksum"
]
download_url
(
url
,
root
,
hash_value
=
checksum
,
hash_type
=
"md5"
)
extract_archive
(
archive
)
...
...
@@ -70,7 +62,13 @@ class YESNO(Dataset):
"Dataset not found. Please use `download=True` to download it."
)
self
.
_walker
=
sorted
(
str
(
p
.
stem
)
for
p
in
Path
(
self
.
_path
).
glob
(
'*'
+
self
.
_ext_audio
))
self
.
_walker
=
sorted
(
str
(
p
.
stem
)
for
p
in
Path
(
self
.
_path
).
glob
(
"*.wav"
))
def
_load_item
(
self
,
fileid
:
str
,
path
:
str
):
labels
=
[
int
(
c
)
for
c
in
fileid
.
split
(
"_"
)]
file_audio
=
os
.
path
.
join
(
path
,
fileid
+
".wav"
)
waveform
,
sample_rate
=
torchaudio
.
load
(
file_audio
)
return
waveform
,
sample_rate
,
labels
def
__getitem__
(
self
,
n
:
int
)
->
Tuple
[
Tensor
,
int
,
List
[
int
]]:
"""Load the n-th sample from the dataset.
...
...
@@ -82,13 +80,8 @@ class YESNO(Dataset):
tuple: ``(waveform, sample_rate, labels)``
"""
fileid
=
self
.
_walker
[
n
]
item
=
load_yesno_item
(
fileid
,
self
.
_path
,
self
.
_ext_audio
)
# TODO Upon deprecation, uncomment line below and remove following code
# return item
waveform
,
sample_rate
,
labels
=
item
return
waveform
,
sample_rate
,
labels
item
=
self
.
_load_item
(
fileid
,
self
.
_path
)
return
item
def
__len__
(
self
)
->
int
:
return
len
(
self
.
_walker
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment