Commit cc26cd81 authored by panning's avatar panning
Browse files

merge v0.16.0

parents f78f29f5 fbb4cc54
......@@ -177,7 +177,7 @@ class Cityscapes(VisionDataset):
index (int): Index
Returns:
tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
than one item. Otherwise, target is a json object if target_type="polygon", else the image segmentation.
"""
image = Image.open(self.images[index]).convert("RGB")
......
......@@ -11,7 +11,7 @@ class Country211(ImageFolder):
This dataset was built by filtering the images from the YFCC100m dataset
that have GPS coordinate corresponding to a ISO-3166 country code. The
dataset is balanced by sampling 150 train images, 50 validation images, and
100 test images images for each country.
100 test images for each country.
Args:
root (string): Root directory of the dataset.
......
import os
import pathlib
from typing import Callable, Optional
from typing import Any, Callable, Optional, Tuple
import PIL.Image
......@@ -76,7 +76,7 @@ class DTD(VisionDataset):
def __len__(self) -> int:
return len(self._image_files)
def __getitem__(self, idx):
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")
......
......@@ -90,7 +90,7 @@ class FGVCAircraft(VisionDataset):
def __len__(self) -> int:
return len(self._image_files)
def __getitem__(self, idx) -> Tuple[Any, Any]:
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")
......
......@@ -76,7 +76,7 @@ class Flowers102(VisionDataset):
def __len__(self) -> int:
return len(self._image_files)
def __getitem__(self, idx) -> Tuple[Any, Any]:
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")
......
......@@ -69,7 +69,7 @@ class Food101(VisionDataset):
def __len__(self) -> int:
return len(self._image_files)
def __getitem__(self, idx) -> Tuple[Any, Any]:
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")
......
......@@ -102,7 +102,7 @@ class HMDB51(VisionDataset):
output_format=output_format,
)
# we bookkeep the full version of video clips because we want to be able
# to return the meta data of full version rather than the subset version of
# to return the metadata of full version rather than the subset version of
# video clips
self.full_video_clips = video_clips
self.fold = fold
......
......@@ -21,6 +21,12 @@ META_FILE = "meta.bin"
class ImageNet(ImageFolder):
"""`ImageNet <http://image-net.org/>`_ 2012 Classification Dataset.
.. note::
Before using this class, it is required to download ImageNet 2012 dataset from
`here <https://image-net.org/challenges/LSVRC/2012/2012-downloads.php>`_ and
place the files ``ILSVRC2012_devkit_t12.tar.gz`` and ``ILSVRC2012_img_train.tar``
or ``ILSVRC2012_img_val.tar`` based on ``split`` in the root directory.
Args:
root (string): Root directory of the ImageNet Dataset.
split (string, optional): The dataset split, supports ``train``, or ``val``.
......
import os
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from PIL import Image
......@@ -38,7 +38,7 @@ class _LFW(VisionDataset):
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
):
) -> None:
super().__init__(os.path.join(root, self.base_folder), transform=transform, target_transform=target_transform)
self.image_set = verify_str_arg(image_set.lower(), "image_set", self.file_dict.keys())
......@@ -62,7 +62,7 @@ class _LFW(VisionDataset):
img = Image.open(f)
return img.convert("RGB")
def _check_integrity(self):
def _check_integrity(self) -> bool:
st1 = check_integrity(os.path.join(self.root, self.filename), self.md5)
st2 = check_integrity(os.path.join(self.root, self.labels_file), self.checksums[self.labels_file])
if not st1 or not st2:
......@@ -71,7 +71,7 @@ class _LFW(VisionDataset):
return check_integrity(os.path.join(self.root, self.names), self.checksums[self.names])
return True
def download(self):
def download(self) -> None:
if self._check_integrity():
print("Files already downloaded and verified")
return
......@@ -81,13 +81,13 @@ class _LFW(VisionDataset):
if self.view == "people":
download_url(f"{self.download_url_prefix}{self.names}", self.root)
def _get_path(self, identity, no):
def _get_path(self, identity: str, no: Union[int, str]) -> str:
return os.path.join(self.images_dir, identity, f"{identity}_{int(no):04d}.jpg")
def extra_repr(self) -> str:
return f"Alignment: {self.image_set}\nSplit: {self.split}"
def __len__(self):
def __len__(self) -> int:
return len(self.data)
......@@ -119,13 +119,13 @@ class LFWPeople(_LFW):
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
):
) -> None:
super().__init__(root, split, image_set, "people", transform, target_transform, download)
self.class_to_idx = self._get_classes()
self.data, self.targets = self._get_people()
def _get_people(self):
def _get_people(self) -> Tuple[List[str], List[int]]:
data, targets = [], []
with open(os.path.join(self.root, self.labels_file)) as f:
lines = f.readlines()
......@@ -143,7 +143,7 @@ class LFWPeople(_LFW):
return data, targets
def _get_classes(self):
def _get_classes(self) -> Dict[str, int]:
with open(os.path.join(self.root, self.names)) as f:
lines = f.readlines()
names = [line.strip().split()[0] for line in lines]
......@@ -201,12 +201,12 @@ class LFWPairs(_LFW):
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
):
) -> None:
super().__init__(root, split, image_set, "pairs", transform, target_transform, download)
self.pair_names, self.data, self.targets = self._get_pairs(self.images_dir)
def _get_pairs(self, images_dir):
def _get_pairs(self, images_dir: str) -> Tuple[List[Tuple[str, str]], List[Tuple[str, str]], List[int]]:
pair_names, data, targets = [], [], []
with open(os.path.join(self.root, self.labels_file)) as f:
lines = f.readlines()
......
......@@ -12,7 +12,7 @@ import numpy as np
import torch
from PIL import Image
from .utils import check_integrity, download_and_extract_archive, extract_archive, verify_str_arg
from .utils import _flip_byte_order, check_integrity, download_and_extract_archive, extract_archive, verify_str_arg
from .vision import VisionDataset
......@@ -366,7 +366,7 @@ class QMNIST(MNIST):
that takes in the target and transforms it.
train (bool,optional,compatibility): When argument 'what' is
not specified, this boolean decides whether to load the
training set ot the testing set. Default: True.
training set or the testing set. Default: True.
"""
subsets = {"train": "train", "test": "test", "test10k": "test", "test50k": "test", "nist": "nist"}
......@@ -519,13 +519,12 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso
torch_type = SN3_PASCALVINCENT_TYPEMAP[ty]
s = [get_int(data[4 * (i + 1) : 4 * (i + 2)]) for i in range(nd)]
num_bytes_per_value = torch.iinfo(torch_type).bits // 8
# The MNIST format uses the big endian byte order. If the system uses little endian byte order by default,
# we need to reverse the bytes before we can read them with torch.frombuffer().
needs_byte_reversal = sys.byteorder == "little" and num_bytes_per_value > 1
parsed = torch.frombuffer(bytearray(data), dtype=torch_type, offset=(4 * (nd + 1)))
if needs_byte_reversal:
parsed = parsed.flip(0)
# The MNIST format uses the big endian byte order, while `torch.frombuffer` uses whatever the system uses. In case
# that is little endian and the dtype has more than one byte, we need to flip them.
if sys.byteorder == "little" and parsed.element_size() > 1:
parsed = _flip_byte_order(parsed)
assert parsed.shape[0] == np.prod(s) or not strict
return parsed.view(*s)
......
import os.path
from typing import Callable, Optional
import numpy as np
import torch
from torchvision.datasets.utils import download_url, verify_str_arg
from torchvision.datasets.vision import VisionDataset
class MovingMNIST(VisionDataset):
"""`MovingMNIST <http://www.cs.toronto.edu/~nitish/unsupervised_video/>`_ Dataset.
Args:
root (string): Root directory of dataset where ``MovingMNIST/mnist_test_seq.npy`` exists.
split (string, optional): The dataset split, supports ``None`` (default), ``"train"`` and ``"test"``.
If ``split=None``, the full data is returned.
split_ratio (int, optional): The split ratio of number of frames. If ``split="train"``, the first split
frames ``data[:, :split_ratio]`` is returned. If ``split="test"``, the last split frames ``data[:, split_ratio:]``
is returned. If ``split=None``, this parameter is ignored and the all frames data is returned.
transform (callable, optional): A function/transform that takes in an torch Tensor
and returns a transformed version. E.g, ``transforms.RandomCrop``
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
"""
_URL = "http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy"
def __init__(
self,
root: str,
split: Optional[str] = None,
split_ratio: int = 10,
download: bool = False,
transform: Optional[Callable] = None,
) -> None:
super().__init__(root, transform=transform)
self._base_folder = os.path.join(self.root, self.__class__.__name__)
self._filename = self._URL.split("/")[-1]
if split is not None:
verify_str_arg(split, "split", ("train", "test"))
self.split = split
if not isinstance(split_ratio, int):
raise TypeError(f"`split_ratio` should be an integer, but got {type(split_ratio)}")
elif not (1 <= split_ratio <= 19):
raise ValueError(f"`split_ratio` should be `1 <= split_ratio <= 19`, but got {split_ratio} instead.")
self.split_ratio = split_ratio
if download:
self.download()
if not self._check_exists():
raise RuntimeError("Dataset not found. You can use download=True to download it.")
data = torch.from_numpy(np.load(os.path.join(self._base_folder, self._filename)))
if self.split == "train":
data = data[: self.split_ratio]
elif self.split == "test":
data = data[self.split_ratio :]
self.data = data.transpose(0, 1).unsqueeze(2).contiguous()
def __getitem__(self, idx: int) -> torch.Tensor:
"""
Args:
index (int): Index
Returns:
torch.Tensor: Video frames (torch Tensor[T, C, H, W]). The `T` is the number of frames.
"""
data = self.data[idx]
if self.transform is not None:
data = self.transform(data)
return data
def __len__(self) -> int:
return len(self.data)
def _check_exists(self) -> bool:
return os.path.exists(os.path.join(self._base_folder, self._filename))
def download(self) -> None:
if self._check_exists():
return
download_url(
url=self._URL,
root=self._base_folder,
filename=self._filename,
md5="be083ec986bfe91a449d63653c411eb2",
)
......@@ -15,7 +15,7 @@ class Places365(VisionDataset):
root (string): Root directory of the Places365 dataset.
split (string, optional): The dataset split. Can be one of ``train-standard`` (default), ``train-challenge``,
``val``.
small (bool, optional): If ``True``, uses the small images, i. e. resized to 256 x 256 pixels, instead of the
small (bool, optional): If ``True``, uses the small images, i.e. resized to 256 x 256 pixels, instead of the
high resolution ones.
download (bool, optional): If ``True``, downloads the dataset components and places them in ``root``. Already
downloaded archives are not downloaded again.
......@@ -32,7 +32,7 @@ class Places365(VisionDataset):
targets (list): The class_index value for each image in the dataset
Raises:
RuntimeError: If ``download is False`` and the meta files, i. e. the devkit, are not present or corrupted.
RuntimeError: If ``download is False`` and the meta files, i.e. the devkit, are not present or corrupted.
RuntimeError: If ``download is True`` and the image archive is already extracted.
"""
_SPLITS = ("train-standard", "train-challenge", "val")
......
......@@ -59,7 +59,7 @@ class RenderedSST2(VisionDataset):
def __len__(self) -> int:
return len(self._samples)
def __getitem__(self, idx) -> Tuple[Any, Any]:
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
image_file, label = self._samples[idx]
image = PIL.Image.open(image_file).convert("RGB")
......
......@@ -3,7 +3,7 @@ from typing import Any, Callable, Optional, Tuple
from PIL import Image
from .utils import check_integrity, download_url
from .utils import check_integrity, download_and_extract_archive, download_url
from .vision import VisionDataset
......@@ -90,17 +90,12 @@ class SBU(VisionDataset):
def download(self) -> None:
"""Download and extract the tarball, and download each individual photo."""
import tarfile
if self._check_integrity():
print("Files already downloaded and verified")
return
download_url(self.url, self.root, self.filename, self.md5_checksum)
# Extract file
with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
tar.extractall(path=self.root)
download_and_extract_archive(self.url, self.root, self.root, self.filename, self.md5_checksum)
# Download individual photos
with open(os.path.join(self.root, "dataset", "SBU_captioned_photo_dataset_urls.txt")) as fh:
......
......@@ -15,7 +15,7 @@ class STL10(VisionDataset):
root (string): Root directory of dataset where directory
``stl10_binary`` exists.
split (string): One of {'train', 'test', 'unlabeled', 'train+unlabeled'}.
Accordingly dataset is selected.
Accordingly, dataset is selected.
folds (int, optional): One of {0-9} or None.
For training, loads one of the 10 pre-defined folds of 1k samples for the
standard evaluation procedure. If no value is passed, loads the 5k samples.
......
......@@ -55,7 +55,7 @@ class SUN397(VisionDataset):
def __len__(self) -> int:
return len(self._image_files)
def __getitem__(self, idx) -> Tuple[Any, Any]:
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
image_file, label = self._image_files[idx], self._labels[idx]
image = PIL.Image.open(image_file).convert("RGB")
......
......@@ -78,7 +78,7 @@ class SVHN(VisionDataset):
loaded_mat = sio.loadmat(os.path.join(self.root, self.filename))
self.data = loaded_mat["X"]
# loading from the .mat file gives an np array of type np.uint8
# loading from the .mat file gives an np.ndarray of type np.uint8
# converting to np.int64, so that we have a LongTensor after
# the conversion from the numpy array
# the squeeze is needed to obtain a 1D tensor
......
......@@ -93,7 +93,7 @@ class UCF101(VisionDataset):
output_format=output_format,
)
# we bookkeep the full version of video clips because we want to be able
# to return the meta data of full version rather than the subset version of
# to return the metadata of full version rather than the subset version of
# video clips
self.full_video_clips = video_clips
self.indices = self._select_fold(video_list, annotation_path, fold, train)
......
......@@ -48,19 +48,6 @@ def _urlretrieve(url: str, filename: str, chunk_size: int = 1024 * 32) -> None:
_save_response_content(iter(lambda: response.read(chunk_size), b""), filename, length=response.length)
def gen_bar_updater() -> Callable[[int, int, int], None]:
warnings.warn("The function `gen_bar_update` is deprecated since 0.13 and will be removed in 0.15.")
pbar = tqdm(total=None)
def bar_update(count, block_size, total_size):
if pbar.total is None and total_size:
pbar.total = total_size
progress_bytes = count * block_size
pbar.update(progress_bytes - pbar.n)
return bar_update
def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
# Setting the `usedforsecurity` flag does not change anything about the functionality, but indicates that we are
# not using the MD5 checksum for cryptography. This enables its usage in restricted environments like FIPS. Without
......@@ -70,7 +57,7 @@ def calculate_md5(fpath: str, chunk_size: int = 1024 * 1024) -> str:
else:
md5 = hashlib.md5()
with open(fpath, "rb") as f:
for chunk in iter(lambda: f.read(chunk_size), b""):
while chunk := f.read(chunk_size):
md5.update(chunk)
return md5.hexdigest()
......@@ -464,7 +451,7 @@ def verify_str_arg(
valid_values: Optional[Iterable[T]] = None,
custom_msg: Optional[str] = None,
) -> T:
if not isinstance(value, torch._six.string_classes):
if not isinstance(value, str):
if arg is None:
msg = "Expected type str, but got type {type}."
else:
......@@ -520,3 +507,9 @@ def _read_pfm(file_name: str, slice_channels: int = 2) -> np.ndarray:
data = np.flip(data, axis=1) # flip on h dimension
data = data[:slice_channels, :, :]
return data.astype(np.float32)
def _flip_byte_order(t: torch.Tensor) -> torch.Tensor:
return (
t.contiguous().view(torch.uint8).view(*t.shape, t.element_size()).flip(-1).view(*t.shape[:-1], -1).view(t.dtype)
)
......@@ -49,7 +49,7 @@ class _VideoTimestampsDataset:
Dataset used to parallelize the reading of the timestamps
of a list of videos, given their paths in the filesystem.
Used in VideoClips and defined at top level so it can be
Used in VideoClips and defined at top level, so it can be
pickled when forking.
"""
......@@ -187,9 +187,9 @@ class VideoClips:
}
return type(self)(
video_paths,
self.num_frames,
self.step,
self.frame_rate,
clip_length_in_frames=self.num_frames,
frames_between_clips=self.step,
frame_rate=self.frame_rate,
_precomputed_metadata=metadata,
num_workers=self.num_workers,
_video_width=self._video_width,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment