Unverified Commit dabb6d52 authored by Shu's avatar Shu Committed by GitHub
Browse files

MovingMNIST split fix (#7449)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 76144bad
......@@ -1504,14 +1504,16 @@ class MovingMNISTTestCase(datasets_utils.DatasetTestCase):
ADDITIONAL_CONFIGS = combinations_grid(split=(None, "train", "test"), split_ratio=(10, 1, 19))
_NUM_FRAMES = 20
def inject_fake_data(self, tmpdir, config):
base_folder = os.path.join(tmpdir, self.DATASET_CLASS.__name__)
os.makedirs(base_folder, exist_ok=True)
num_samples = 20
num_samples = 5
data = np.concatenate(
[
np.zeros((config["split_ratio"], num_samples, 64, 64)),
np.ones((20 - config["split_ratio"], num_samples, 64, 64)),
np.ones((self._NUM_FRAMES - config["split_ratio"], num_samples, 64, 64)),
]
)
np.save(os.path.join(base_folder, "mnist_test_seq.npy"), data)
......@@ -1519,14 +1521,13 @@ class MovingMNISTTestCase(datasets_utils.DatasetTestCase):
@datasets_utils.test_all_configs
def test_split(self, config):
if config["split"] is None:
return
with self.create_dataset(config) as (dataset, info):
with self.create_dataset(config) as (dataset, _):
if config["split"] == "train":
assert (dataset.data == 0).all()
else:
elif config["split"] == "test":
assert (dataset.data == 1).all()
else:
assert dataset.data.size()[1] == self._NUM_FRAMES
class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
......
......@@ -58,7 +58,7 @@ class MovingMNIST(VisionDataset):
data = torch.from_numpy(np.load(os.path.join(self._base_folder, self._filename)))
if self.split == "train":
data = data[: self.split_ratio]
else:
elif self.split == "test":
data = data[self.split_ratio :]
self.data = data.transpose(0, 1).unsqueeze(2).contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment