"docs/git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "b7d44d97af5778012817bce06da7eec08ec2ffc3"
Unverified Commit dabb6d52 authored by Shu's avatar Shu Committed by GitHub
Browse files

MovingMNIST split fix (#7449)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 76144bad
...@@ -1504,14 +1504,16 @@ class MovingMNISTTestCase(datasets_utils.DatasetTestCase): ...@@ -1504,14 +1504,16 @@ class MovingMNISTTestCase(datasets_utils.DatasetTestCase):
ADDITIONAL_CONFIGS = combinations_grid(split=(None, "train", "test"), split_ratio=(10, 1, 19)) ADDITIONAL_CONFIGS = combinations_grid(split=(None, "train", "test"), split_ratio=(10, 1, 19))
_NUM_FRAMES = 20
def inject_fake_data(self, tmpdir, config): def inject_fake_data(self, tmpdir, config):
base_folder = os.path.join(tmpdir, self.DATASET_CLASS.__name__) base_folder = os.path.join(tmpdir, self.DATASET_CLASS.__name__)
os.makedirs(base_folder, exist_ok=True) os.makedirs(base_folder, exist_ok=True)
num_samples = 20 num_samples = 5
data = np.concatenate( data = np.concatenate(
[ [
np.zeros((config["split_ratio"], num_samples, 64, 64)), np.zeros((config["split_ratio"], num_samples, 64, 64)),
np.ones((20 - config["split_ratio"], num_samples, 64, 64)), np.ones((self._NUM_FRAMES - config["split_ratio"], num_samples, 64, 64)),
] ]
) )
np.save(os.path.join(base_folder, "mnist_test_seq.npy"), data) np.save(os.path.join(base_folder, "mnist_test_seq.npy"), data)
...@@ -1519,14 +1521,13 @@ class MovingMNISTTestCase(datasets_utils.DatasetTestCase): ...@@ -1519,14 +1521,13 @@ class MovingMNISTTestCase(datasets_utils.DatasetTestCase):
@datasets_utils.test_all_configs @datasets_utils.test_all_configs
def test_split(self, config): def test_split(self, config):
if config["split"] is None: with self.create_dataset(config) as (dataset, _):
return
with self.create_dataset(config) as (dataset, info):
if config["split"] == "train": if config["split"] == "train":
assert (dataset.data == 0).all() assert (dataset.data == 0).all()
else: elif config["split"] == "test":
assert (dataset.data == 1).all() assert (dataset.data == 1).all()
else:
assert dataset.data.size()[1] == self._NUM_FRAMES
class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase): class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
......
...@@ -58,7 +58,7 @@ class MovingMNIST(VisionDataset): ...@@ -58,7 +58,7 @@ class MovingMNIST(VisionDataset):
data = torch.from_numpy(np.load(os.path.join(self._base_folder, self._filename))) data = torch.from_numpy(np.load(os.path.join(self._base_folder, self._filename)))
if self.split == "train": if self.split == "train":
data = data[: self.split_ratio] data = data[: self.split_ratio]
else: elif self.split == "test":
data = data[self.split_ratio :] data = data[self.split_ratio :]
self.data = data.transpose(0, 1).unsqueeze(2).contiguous() self.data = data.transpose(0, 1).unsqueeze(2).contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment