Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
hehl2
Torchaudio
Commits
37b4e136
"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "b756ec6e80b3d94c3ae7dc356bdbbdb426a05dca"
Unverified
Commit
37b4e136
authored
Nov 19, 2020
by
Bhargav Kathivarapu
Committed by
GitHub
Nov 18, 2020
Browse files
Add pathlib support for TEDLIUM (#1045)
parent
f3b9208f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
26 additions
and
38 deletions
+26
-38
test/torchaudio_unittest/datasets/tedlium_test.py
test/torchaudio_unittest/datasets/tedlium_test.py
+14
-35
torchaudio/datasets/tedlium.py
torchaudio/datasets/tedlium.py
+12
-3
No files found.
test/torchaudio_unittest/datasets/tedlium_test.py
View file @
37b4e136
import
os
import
os
import
platform
import
platform
import
unittest
import
unittest
from
pathlib
import
Path
from
torchaudio.datasets
import
tedlium
from
torchaudio.datasets
import
tedlium
...
@@ -93,9 +94,7 @@ class Tedlium(TempDirMixin):
...
@@ -93,9 +94,7 @@ class Tedlium(TempDirMixin):
cls
.
samples
[
release
].
append
(
sample
)
cls
.
samples
[
release
].
append
(
sample
)
seed
+=
1
seed
+=
1
def
test_tedlium_release1
(
self
):
def
_test_tedlium
(
self
,
dataset
,
release
):
release
=
"release1"
dataset
=
tedlium
.
TEDLIUM
(
self
.
root_dir
,
release
=
release
)
num_samples
=
0
num_samples
=
0
for
i
,
(
data
,
sample_rate
,
transcript
,
talk_id
,
speaker_id
,
identifier
)
in
enumerate
(
dataset
):
for
i
,
(
data
,
sample_rate
,
transcript
,
talk_id
,
speaker_id
,
identifier
)
in
enumerate
(
dataset
):
self
.
assertEqual
(
data
,
self
.
samples
[
release
][
i
][
0
],
atol
=
5e-5
,
rtol
=
1e-8
)
self
.
assertEqual
(
data
,
self
.
samples
[
release
][
i
][
0
],
atol
=
5e-5
,
rtol
=
1e-8
)
...
@@ -113,45 +112,25 @@ class Tedlium(TempDirMixin):
...
@@ -113,45 +112,25 @@ class Tedlium(TempDirMixin):
phoenemes
=
[
f
"
{
key
}
{
' '
.
join
(
value
)
}
"
for
key
,
value
in
phoneme_dict
.
items
()]
phoenemes
=
[
f
"
{
key
}
{
' '
.
join
(
value
)
}
"
for
key
,
value
in
phoneme_dict
.
items
()]
assert
phoenemes
==
PHONEME
assert
phoenemes
==
PHONEME
def
test_tedlium_release
2
(
self
):
def
test_tedlium_release
1_str
(
self
):
release
=
"release
2
"
release
=
"release
1
"
dataset
=
tedlium
.
TEDLIUM
(
self
.
root_dir
,
release
=
release
)
dataset
=
tedlium
.
TEDLIUM
(
self
.
root_dir
,
release
=
release
)
num_samples
=
0
self
.
_test_tedlium
(
dataset
,
release
)
for
i
,
(
data
,
sample_rate
,
transcript
,
talk_id
,
speaker_id
,
identifier
)
in
enumerate
(
dataset
):
self
.
assertEqual
(
data
,
self
.
samples
[
release
][
i
][
0
],
atol
=
5e-5
,
rtol
=
1e-8
)
assert
sample_rate
==
self
.
samples
[
release
][
i
][
1
]
assert
transcript
==
self
.
samples
[
release
][
i
][
2
]
assert
talk_id
==
self
.
samples
[
release
][
i
][
3
]
assert
speaker_id
==
self
.
samples
[
release
][
i
][
4
]
assert
identifier
==
self
.
samples
[
release
][
i
][
5
]
num_samples
+=
1
assert
num_samples
==
len
(
self
.
samples
[
release
])
def
test_tedlium_release1_path
(
self
):
release
=
"release1"
dataset
=
tedlium
.
TEDLIUM
(
Path
(
self
.
root_dir
),
release
=
release
)
self
.
_test_tedlium
(
dataset
,
release
)
d
ataset
.
_dict_path
=
os
.
path
.
join
(
dataset
.
_path
,
f
"
{
release
}
.dic"
)
d
ef
test_tedlium_release2
(
self
):
phoneme_dict
=
dataset
.
phoneme_dict
release
=
"release2"
phoenemes
=
[
f
"
{
key
}
{
' '
.
join
(
value
)
}
"
for
key
,
value
in
phoneme_dict
.
items
()]
dataset
=
tedlium
.
TEDLIUM
(
self
.
root_dir
,
release
=
release
)
assert
phoenemes
==
PHONEME
self
.
_test_tedlium
(
dataset
,
release
)
def
test_tedlium_release3
(
self
):
def
test_tedlium_release3
(
self
):
release
=
"release3"
release
=
"release3"
dataset
=
tedlium
.
TEDLIUM
(
self
.
root_dir
,
release
=
release
)
dataset
=
tedlium
.
TEDLIUM
(
self
.
root_dir
,
release
=
release
)
num_samples
=
0
self
.
_test_tedlium
(
dataset
,
release
)
for
i
,
(
data
,
sample_rate
,
transcript
,
talk_id
,
speaker_id
,
identifier
)
in
enumerate
(
dataset
):
self
.
assertEqual
(
data
,
self
.
samples
[
release
][
i
][
0
],
atol
=
5e-5
,
rtol
=
1e-8
)
assert
sample_rate
==
self
.
samples
[
release
][
i
][
1
]
assert
transcript
==
self
.
samples
[
release
][
i
][
2
]
assert
talk_id
==
self
.
samples
[
release
][
i
][
3
]
assert
speaker_id
==
self
.
samples
[
release
][
i
][
4
]
assert
identifier
==
self
.
samples
[
release
][
i
][
5
]
num_samples
+=
1
assert
num_samples
==
len
(
self
.
samples
[
release
])
dataset
.
_dict_path
=
os
.
path
.
join
(
dataset
.
_path
,
f
"
{
release
}
.dic"
)
phoneme_dict
=
dataset
.
phoneme_dict
phoenemes
=
[
f
"
{
key
}
{
' '
.
join
(
value
)
}
"
for
key
,
value
in
phoneme_dict
.
items
()]
assert
phoenemes
==
PHONEME
class
TestTedliumSoundfile
(
Tedlium
,
TorchaudioTestCase
):
class
TestTedliumSoundfile
(
Tedlium
,
TorchaudioTestCase
):
...
...
torchaudio/datasets/tedlium.py
View file @
37b4e136
import
os
import
os
from
typing
import
Tuple
from
typing
import
Tuple
,
Union
from
pathlib
import
Path
import
torchaudio
import
torchaudio
from
torch
import
Tensor
from
torch
import
Tensor
...
@@ -46,7 +47,7 @@ class TEDLIUM(Dataset):
...
@@ -46,7 +47,7 @@ class TEDLIUM(Dataset):
Create a Dataset for Tedlium. It supports releases 1,2 and 3.
Create a Dataset for Tedlium. It supports releases 1,2 and 3.
Args:
Args:
root (str): Path to the directory where the dataset is found or downloaded.
root (str
or Path
): Path to the directory where the dataset is found or downloaded.
release (str, optional): Release version.
release (str, optional): Release version.
Allowed values are ``"release1"``, ``"release2"`` or ``"release3"``.
Allowed values are ``"release1"``, ``"release2"`` or ``"release3"``.
(default: ``"release1"``).
(default: ``"release1"``).
...
@@ -56,7 +57,12 @@ class TEDLIUM(Dataset):
...
@@ -56,7 +57,12 @@ class TEDLIUM(Dataset):
Whether to download the dataset if it is not found at root path. (default: ``False``).
Whether to download the dataset if it is not found at root path. (default: ``False``).
"""
"""
def
__init__
(
def
__init__
(
self
,
root
:
str
,
release
:
str
=
"release1"
,
subset
:
str
=
None
,
download
:
bool
=
False
,
audio_ext
=
".sph"
self
,
root
:
Union
[
str
,
Path
],
release
:
str
=
"release1"
,
subset
:
str
=
None
,
download
:
bool
=
False
,
audio_ext
=
".sph"
)
->
None
:
)
->
None
:
self
.
_ext_audio
=
audio_ext
self
.
_ext_audio
=
audio_ext
if
release
in
_RELEASE_CONFIGS
.
keys
():
if
release
in
_RELEASE_CONFIGS
.
keys
():
...
@@ -78,6 +84,9 @@ class TEDLIUM(Dataset):
...
@@ -78,6 +84,9 @@ class TEDLIUM(Dataset):
)
)
)
)
# Get string representation of 'root' in case Path object is passed
root
=
os
.
fspath
(
root
)
basename
=
os
.
path
.
basename
(
url
)
basename
=
os
.
path
.
basename
(
url
)
archive
=
os
.
path
.
join
(
root
,
basename
)
archive
=
os
.
path
.
join
(
root
,
basename
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment