Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
dabb6d52
Unverified
Commit
dabb6d52
authored
Mar 23, 2023
by
Shu
Committed by
GitHub
Mar 23, 2023
Browse files
MovingMNIST split fix (#7449)
Co-authored-by:
Philip Meier
<
github.pmeier@posteo.de
>
parent
76144bad
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
8 deletions
+9
-8
test/test_datasets.py
test/test_datasets.py
+8
-7
torchvision/datasets/moving_mnist.py
torchvision/datasets/moving_mnist.py
+1
-1
No files found.
test/test_datasets.py
View file @
dabb6d52
...
...
@@ -1504,14 +1504,16 @@ class MovingMNISTTestCase(datasets_utils.DatasetTestCase):
ADDITIONAL_CONFIGS
=
combinations_grid
(
split
=
(
None
,
"train"
,
"test"
),
split_ratio
=
(
10
,
1
,
19
))
_NUM_FRAMES
=
20
def
inject_fake_data
(
self
,
tmpdir
,
config
):
base_folder
=
os
.
path
.
join
(
tmpdir
,
self
.
DATASET_CLASS
.
__name__
)
os
.
makedirs
(
base_folder
,
exist_ok
=
True
)
num_samples
=
20
num_samples
=
5
data
=
np
.
concatenate
(
[
np
.
zeros
((
config
[
"split_ratio"
],
num_samples
,
64
,
64
)),
np
.
ones
((
20
-
config
[
"split_ratio"
],
num_samples
,
64
,
64
)),
np
.
ones
((
self
.
_NUM_FRAMES
-
config
[
"split_ratio"
],
num_samples
,
64
,
64
)),
]
)
np
.
save
(
os
.
path
.
join
(
base_folder
,
"mnist_test_seq.npy"
),
data
)
...
...
@@ -1519,14 +1521,13 @@ class MovingMNISTTestCase(datasets_utils.DatasetTestCase):
@
datasets_utils
.
test_all_configs
def
test_split
(
self
,
config
):
if
config
[
"split"
]
is
None
:
return
with
self
.
create_dataset
(
config
)
as
(
dataset
,
info
):
with
self
.
create_dataset
(
config
)
as
(
dataset
,
_
):
if
config
[
"split"
]
==
"train"
:
assert
(
dataset
.
data
==
0
).
all
()
el
se
:
el
if
config
[
"split"
]
==
"test"
:
assert
(
dataset
.
data
==
1
).
all
()
else
:
assert
dataset
.
data
.
size
()[
1
]
==
self
.
_NUM_FRAMES
class
DatasetFolderTestCase
(
datasets_utils
.
ImageDatasetTestCase
):
...
...
torchvision/datasets/moving_mnist.py
View file @
dabb6d52
...
...
@@ -58,7 +58,7 @@ class MovingMNIST(VisionDataset):
data
=
torch
.
from_numpy
(
np
.
load
(
os
.
path
.
join
(
self
.
_base_folder
,
self
.
_filename
)))
if
self
.
split
==
"train"
:
data
=
data
[:
self
.
split_ratio
]
el
se
:
el
if
self
.
split
==
"test"
:
data
=
data
[
self
.
split_ratio
:]
self
.
data
=
data
.
transpose
(
0
,
1
).
unsqueeze
(
2
).
contiguous
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment