"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "c46a00c2370abbb42cc5a42c0f98e498787ad9e5"
Unverified Commit 46b7e271 authored by Akira Noda's avatar Akira Noda Committed by GitHub
Browse files

Add MovingMNIST dataset (#7042)



* add moving mnist dataset
Signed-off-by: default avatartsugumi-sys <tidemark0105@gmail.com>

* remove unused modules
Signed-off-by: default avatartsugumi-sys <tidemark0105@gmail.com>

* modify docstring
Signed-off-by: default avatartsugumi-sys <tidemark0105@gmail.com>

* modify docstring and docs
Signed-off-by: default avatartsugumi-sys <tidemark0105@gmail.com>

* add split and split ratio kwargs
Signed-off-by: default avatartsugumi-sys <tidemark0105@gmail.com>

* fix checking split argument
Signed-off-by: default avatartsugumi-sys <tidemark0105@gmail.com>

* remove unused package

* delete lines
Signed-off-by: default avatartsugumi-sys <tidemark0105@gmail.com>

* fix filename property
Signed-off-by: default avatartsugumi-sys <tidemark0105@gmail.com>

* fix reviews
Signed-off-by: default avatartsugumi-sys <tidemark0105@gmail.com>

* modify docstrings
Signed-off-by: default avatartsugumi-sys <tidemark0105@gmail.com>

* add split tests and etc
Signed-off-by: default avatartsugumi-sys <tidemark0105@gmail.com>

* fix tests
Signed-off-by: default avatartsugumi-sys <tidemark0105@gmail.com>
Signed-off-by: default avatartsugumi-sys <tidemark0105@gmail.com>
parent 32d254bb
...@@ -149,6 +149,14 @@ Video classification ...@@ -149,6 +149,14 @@ Video classification
Kinetics Kinetics
UCF101 UCF101
Video prediction
~~~~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: generated/
:template: class_dataset.rst
MovingMNIST
.. _base_classes_datasets: .. _base_classes_datasets:
......
...@@ -1494,6 +1494,37 @@ class QMNISTTestCase(MNISTTestCase): ...@@ -1494,6 +1494,37 @@ class QMNISTTestCase(MNISTTestCase):
assert len(dataset) == info["num_examples"] - 10000 assert len(dataset) == info["num_examples"] - 10000
class MovingMNISTTestCase(datasets_utils.DatasetTestCase):
DATASET_CLASS = datasets.MovingMNIST
FEATURE_TYPES = (torch.Tensor,)
ADDITIONAL_CONFIGS = datasets_utils.combinations_grid(split=(None, "train", "test"), split_ratio=(10, 1, 19))
def inject_fake_data(self, tmpdir, config):
base_folder = os.path.join(tmpdir, self.DATASET_CLASS.__name__)
os.makedirs(base_folder, exist_ok=True)
num_samples = 20
data = np.concatenate(
[
np.zeros((config["split_ratio"], num_samples, 64, 64)),
np.ones((20 - config["split_ratio"], num_samples, 64, 64)),
]
)
np.save(os.path.join(base_folder, "mnist_test_seq.npy"), data)
return num_samples
@datasets_utils.test_all_configs
def test_split(self, config):
if config["split"] is None:
return
with self.create_dataset(config) as (dataset, info):
if config["split"] == "train":
assert (dataset.data == 0).all()
else:
assert (dataset.data == 1).all()
class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase): class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.DatasetFolder DATASET_CLASS = datasets.DatasetFolder
......
...@@ -296,6 +296,10 @@ def qmnist(): ...@@ -296,6 +296,10 @@ def qmnist():
) )
def moving_mnist():
return collect_download_configs(lambda: datasets.MovingMNIST(ROOT, download=True), name="MovingMNIST")
def omniglot(): def omniglot():
return itertools.chain( return itertools.chain(
*[ *[
......
...@@ -36,6 +36,7 @@ from .kitti import Kitti ...@@ -36,6 +36,7 @@ from .kitti import Kitti
from .lfw import LFWPairs, LFWPeople from .lfw import LFWPairs, LFWPeople
from .lsun import LSUN, LSUNClass from .lsun import LSUN, LSUNClass
from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST from .mnist import EMNIST, FashionMNIST, KMNIST, MNIST, QMNIST
from .moving_mnist import MovingMNIST
from .omniglot import Omniglot from .omniglot import Omniglot
from .oxford_iiit_pet import OxfordIIITPet from .oxford_iiit_pet import OxfordIIITPet
from .pcam import PCAM from .pcam import PCAM
......
import os.path
from typing import Callable, Optional
import numpy as np
import torch
from torchvision.datasets.utils import download_url, verify_str_arg
from torchvision.datasets.vision import VisionDataset
class MovingMNIST(VisionDataset):
"""`MovingMNIST <http://www.cs.toronto.edu/~nitish/unsupervised_video/>`_ Dataset.
Args:
root (string): Root directory of dataset where ``MovingMNIST/mnist_test_seq.npy`` exists.
split (string, optional): The dataset split, supports ``None`` (default), ``"train"`` and ``"test"``.
If ``split=None``, the full data is returned.
split_ratio (int, optional): The split ratio of number of frames. If ``split="train"``, the first split
frames ``data[:, :split_ratio]`` is returned. If ``split="test"``, the last split frames ``data[:, split_ratio:]``
is returned. If ``split=None``, this parameter is ignored and the all frames data is returned.
transform (callable, optional): A function/transform that takes in an torch Tensor
and returns a transformed version. E.g, ``transforms.RandomCrop``
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
_URL = "http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy"
def __init__(
self,
root: str,
split: Optional[str] = None,
split_ratio: int = 10,
download: bool = False,
transform: Optional[Callable] = None,
) -> None:
super().__init__(root, transform=transform)
self._base_folder = os.path.join(self.root, self.__class__.__name__)
self._filename = self._URL.split("/")[-1]
if split is not None:
verify_str_arg(split, "split", ("train", "test"))
self.split = split
if not isinstance(split_ratio, int):
raise TypeError(f"`split_ratio` should be an integer, but got {type(split_ratio)}")
elif not (1 <= split_ratio <= 19):
raise ValueError(f"`split_ratio` should be `1 <= split_ratio <= 19`, but got {split_ratio} instead.")
self.split_ratio = split_ratio
if download:
self.download()
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it.")
data = torch.from_numpy(np.load(os.path.join(self._base_folder, self._filename)))
if self.split == "train":
data = data[: self.split_ratio]
else:
data = data[self.split_ratio :]
self.data = data.transpose(0, 1).unsqueeze(2).contiguous()
def __getitem__(self, idx: int) -> torch.Tensor:
"""
Args:
index (int): Index
Returns:
torch.Tensor: Video frames (torch Tensor[T, C, H, W]). The `T` is the number of frames.
"""
data = self.data[idx]
if self.transform is not None:
data = self.transform(data)
return data
def __len__(self) -> int:
return len(self.data)
def _check_exists(self) -> bool:
return os.path.exists(os.path.join(self._base_folder, self._filename))
def download(self) -> None:
if self._check_exists():
return
download_url(
url=self._URL,
root=self._base_folder,
filename=self._filename,
md5="be083ec986bfe91a449d63653c411eb2",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment