Unverified Commit 3e4d062c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Expand tests for prototype datasets (#5187)

* refactor prototype datasets tests

* skip tests with insufficient third party dependencies

* cleanup

* add tests for SBD prototype dataset

* add tests for SEMEION prototype dataset

* add tests for VOC prototype dataset

* add tests for CelebA prototype dataset

* add tests for DTD prototype dataset

* add tests for FER2013 prototype dataset

* add tests for CLEVR prototype dataset

* add tests for oxford-iiit-pet prototype dataset

* enforce tests for new datasets

* add missing archive generation for oxford-iiit-pet tests

* add tests for CUB200 prototype datasets

* fix split generation

* add capability to mark parametrization and xfail cub200 traverse tests
parent bf073e78
import collections.abc
import contextlib import contextlib
import csv
import functools import functools
import gzip import gzip
import itertools import itertools
...@@ -6,14 +8,17 @@ import json ...@@ -6,14 +8,17 @@ import json
import lzma import lzma
import pathlib import pathlib
import pickle import pickle
import random
import tempfile import tempfile
from collections import defaultdict, UserList import xml.etree.ElementTree as ET
from collections import defaultdict, Counter, UserDict
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import pytest import pytest
import torch import torch
from datasets_utils import make_zip, make_tar, create_image_folder from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file
from torch.nn.functional import one_hot
from torch.testing import make_tensor as _make_tensor from torch.testing import make_tensor as _make_tensor
from torchvision.prototype import datasets from torchvision.prototype import datasets
from torchvision.prototype.datasets._api import DEFAULT_DECODER_MAP, DEFAULT_DECODER, find from torchvision.prototype.datasets._api import DEFAULT_DECODER_MAP, DEFAULT_DECODER, find
...@@ -129,10 +134,53 @@ class DatasetMock: ...@@ -129,10 +134,53 @@ class DatasetMock:
return datapipe, mock_info return datapipe, mock_info
class DatasetMocks(UserList): def config_id(name, config):
def append_named_callable(self, fn): parts = [name]
mock_data_fn = fn.__func__ if isinstance(fn, classmethod) else fn for name, value in config.items():
self.data.append(DatasetMock(mock_data_fn.__name__, mock_data_fn)) if isinstance(value, bool):
part = ("" if value else "no_") + name
else:
part = str(value)
parts.append(part)
return "-".join(parts)
def parametrize_dataset_mocks(*dataset_mocks, marks=None):
mocks = {}
for mock in dataset_mocks:
if isinstance(mock, DatasetMock):
mocks[mock.name] = mock
elif isinstance(mock, collections.abc.Sequence):
mocks.update({mock_.name: mock_ for mock_ in mock})
elif isinstance(mock, collections.abc.Mapping):
mocks.update(mock)
else:
raise pytest.UsageError(
f"The positional arguments passed to `parametrize_dataset_mocks` can either be a `DatasetMock`, "
f"a sequence of `DatasetMock`'s, or a mapping of names to `DatasetMock`'s, "
f"but got {mock} instead."
)
dataset_mocks = mocks
if marks is None:
marks = {}
elif not isinstance(marks, collections.abc.Mapping):
raise pytest.UsageError()
return pytest.mark.parametrize(
("dataset_mock", "config"),
[
pytest.param(dataset_mock, config, id=config_id(name, config), marks=marks.get(name, ()))
for name, dataset_mock in dataset_mocks.items()
for config in dataset_mock.configs
],
)
class DatasetMocks(UserDict):
def set_from_named_callable(self, fn):
name = fn.__name__.replace("_", "-")
self.data[name] = DatasetMock(name, fn)
return fn return fn
...@@ -214,7 +262,7 @@ class MNISTMockData: ...@@ -214,7 +262,7 @@ class MNISTMockData:
return num_samples return num_samples
@DATASET_MOCKS.append_named_callable @DATASET_MOCKS.set_from_named_callable
def mnist(info, root, config): def mnist(info, root, config):
train = config.split == "train" train = config.split == "train"
images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz" images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz"
...@@ -227,10 +275,10 @@ def mnist(info, root, config): ...@@ -227,10 +275,10 @@ def mnist(info, root, config):
) )
DATASET_MOCKS.extend([DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]]) DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]})
@DATASET_MOCKS.append_named_callable @DATASET_MOCKS.set_from_named_callable
def emnist(info, root, _): def emnist(info, root, _):
# The image sets that merge some lower case letters in their respective upper case variant, still use dense # The image sets that merge some lower case letters in their respective upper case variant, still use dense
# labels in the data files. Thus, num_categories != len(categories) there. # labels in the data files. Thus, num_categories != len(categories) there.
...@@ -259,7 +307,7 @@ def emnist(info, root, _): ...@@ -259,7 +307,7 @@ def emnist(info, root, _):
return mock_infos return mock_infos
@DATASET_MOCKS.append_named_callable @DATASET_MOCKS.set_from_named_callable
def qmnist(info, root, config): def qmnist(info, root, config):
num_categories = len(info.categories) num_categories = len(info.categories)
if config.split == "train": if config.split == "train":
...@@ -338,7 +386,7 @@ class CIFARMockData: ...@@ -338,7 +386,7 @@ class CIFARMockData:
make_tar(root, name, folder, compression="gz") make_tar(root, name, folder, compression="gz")
@DATASET_MOCKS.append_named_callable @DATASET_MOCKS.set_from_named_callable
def cifar10(info, root, config): def cifar10(info, root, config):
train_files = [f"data_batch_{idx}" for idx in range(1, 6)] train_files = [f"data_batch_{idx}" for idx in range(1, 6)]
test_files = ["test_batch"] test_files = ["test_batch"]
...@@ -356,7 +404,7 @@ def cifar10(info, root, config): ...@@ -356,7 +404,7 @@ def cifar10(info, root, config):
return len(train_files if config.split == "train" else test_files) return len(train_files if config.split == "train" else test_files)
@DATASET_MOCKS.append_named_callable @DATASET_MOCKS.set_from_named_callable
def cifar100(info, root, config): def cifar100(info, root, config):
train_files = ["train"] train_files = ["train"]
test_files = ["test"] test_files = ["test"]
...@@ -374,7 +422,7 @@ def cifar100(info, root, config): ...@@ -374,7 +422,7 @@ def cifar100(info, root, config):
return len(train_files if config.split == "train" else test_files) return len(train_files if config.split == "train" else test_files)
@DATASET_MOCKS.append_named_callable @DATASET_MOCKS.set_from_named_callable
def caltech101(info, root, config): def caltech101(info, root, config):
def create_ann_file(root, name): def create_ann_file(root, name):
import scipy.io import scipy.io
...@@ -424,7 +472,7 @@ def caltech101(info, root, config): ...@@ -424,7 +472,7 @@ def caltech101(info, root, config):
return num_images_per_category * len(info.categories) return num_images_per_category * len(info.categories)
@DATASET_MOCKS.append_named_callable @DATASET_MOCKS.set_from_named_callable
def caltech256(info, root, config): def caltech256(info, root, config):
dir = root / "256_ObjectCategories" dir = root / "256_ObjectCategories"
num_images_per_category = 2 num_images_per_category = 2
...@@ -444,7 +492,7 @@ def caltech256(info, root, config): ...@@ -444,7 +492,7 @@ def caltech256(info, root, config):
return num_images_per_category * len(info.categories) return num_images_per_category * len(info.categories)
@DATASET_MOCKS.append_named_callable @DATASET_MOCKS.set_from_named_callable
def imagenet(info, root, config): def imagenet(info, root, config):
wnids = tuple(info.extra.wnid_to_category.keys()) wnids = tuple(info.extra.wnid_to_category.keys())
if config.split == "train": if config.split == "train":
...@@ -599,7 +647,7 @@ class CocoMockData: ...@@ -599,7 +647,7 @@ class CocoMockData:
return num_samples return num_samples
@DATASET_MOCKS.append_named_callable @DATASET_MOCKS.set_from_named_callable
def coco(info, root, config): def coco(info, root, config):
return dict( return dict(
zip( zip(
...@@ -609,23 +657,644 @@ def coco(info, root, config): ...@@ -609,23 +657,644 @@ def coco(info, root, config):
) )
def config_id(name, config): class SBDMockData:
parts = [name] _NUM_CATEGORIES = 20
for name, value in config.items():
if isinstance(value, bool): @classmethod
part = ("" if value else "no_") + name def _make_split_files(cls, root_map):
ids_map = {
split: [f"2008_{idx:06d}" for idx in idcs]
for split, idcs in (
("train", [0, 1, 2]),
("train_noval", [0, 2]),
("val", [3]),
)
}
for split, ids in ids_map.items():
with open(root_map[split] / f"{split}.txt", "w") as fh:
fh.writelines(f"{id}\n" for id in ids)
return sorted(set(itertools.chain(*ids_map.values()))), {split: len(ids) for split, ids in ids_map.items()}
@classmethod
def _make_anns_folder(cls, root, name, ids):
from scipy.io import savemat
anns_folder = root / name
anns_folder.mkdir()
sizes = torch.randint(1, 9, size=(len(ids), 2)).tolist()
for id, size in zip(ids, sizes):
savemat(
anns_folder / f"{id}.mat",
{
"GTcls": {
"Boundaries": cls._make_boundaries(size),
"Segmentation": cls._make_segmentation(size),
}
},
)
return sizes
@classmethod
def _make_boundaries(cls, size):
from scipy.sparse import csc_matrix
return [
[csc_matrix(torch.randint(0, 2, size=size, dtype=torch.uint8).numpy())] for _ in range(cls._NUM_CATEGORIES)
]
@classmethod
def _make_segmentation(cls, size):
return torch.randint(0, cls._NUM_CATEGORIES + 1, size=size, dtype=torch.uint8).numpy()
@classmethod
def generate(cls, root):
archive_folder = root / "benchmark_RELEASE"
dataset_folder = archive_folder / "dataset"
dataset_folder.mkdir(parents=True, exist_ok=True)
ids, num_samples_map = cls._make_split_files(defaultdict(lambda: dataset_folder, {"train_noval": root}))
sizes = cls._make_anns_folder(dataset_folder, "cls", ids)
create_image_folder(
dataset_folder, "img", lambda idx: f"{ids[idx]}.jpg", num_examples=len(ids), size=lambda idx: sizes[idx]
)
make_tar(root, "benchmark.tgz", archive_folder, compression="gz")
return num_samples_map
@DATASET_MOCKS.set_from_named_callable
def sbd(info, root, _):
num_samples_map = SBDMockData.generate(root)
return {config: num_samples_map[config.split] for config in info._configs}
@DATASET_MOCKS.set_from_named_callable
def semeion(info, root, config):
num_samples = 3
images = torch.rand(num_samples, 256)
labels = one_hot(torch.randint(len(info.categories), size=(num_samples,)))
with open(root / "semeion.data", "w") as fh:
for image, one_hot_label in zip(images, labels):
image_columns = " ".join([f"{pixel.item():.4f}" for pixel in image])
labels_columns = " ".join([str(label.item()) for label in one_hot_label])
fh.write(f"{image_columns} {labels_columns}\n")
return num_samples
class VOCMockData:
_TRAIN_VAL_FILE_NAMES = {
"2007": "VOCtrainval_06-Nov-2007.tar",
"2008": "VOCtrainval_14-Jul-2008.tar",
"2009": "VOCtrainval_11-May-2009.tar",
"2010": "VOCtrainval_03-May-2010.tar",
"2011": "VOCtrainval_25-May-2011.tar",
"2012": "VOCtrainval_11-May-2012.tar",
}
_TEST_FILE_NAMES = {
"2007": "VOCtest_06-Nov-2007.tar",
}
@classmethod
def _make_split_files(cls, root, *, year, trainval):
split_folder = root / "ImageSets"
if trainval:
idcs_map = {
"train": [0, 1, 2],
"val": [3, 4],
}
idcs_map["trainval"] = [*idcs_map["train"], *idcs_map["val"]]
else: else:
part = str(value) idcs_map = {
parts.append(part) "test": [5],
return "-".join(parts) }
ids_map = {split: [f"{year}_{idx:06d}" for idx in idcs] for split, idcs in idcs_map.items()}
for task_sub_folder in ("Main", "Segmentation"):
task_folder = split_folder / task_sub_folder
task_folder.mkdir(parents=True, exist_ok=True)
for split, ids in ids_map.items():
with open(task_folder / f"{split}.txt", "w") as fh:
fh.writelines(f"{id}\n" for id in ids)
def parametrize_dataset_mocks(datasets_mocks): return sorted(set(itertools.chain(*ids_map.values()))), {split: len(ids) for split, ids in ids_map.items()}
return pytest.mark.parametrize(
("dataset_mock", "config"), @classmethod
def _make_detection_anns_folder(cls, root, name, *, file_name_fn, num_examples):
folder = root / name
folder.mkdir(parents=True, exist_ok=True)
for idx in range(num_examples):
cls._make_detection_ann_file(folder, file_name_fn(idx))
@classmethod
def _make_detection_ann_file(cls, root, name):
def add_child(parent, name, text=None):
child = ET.SubElement(parent, name)
child.text = text
return child
def add_name(obj, name="dog"):
add_child(obj, "name", name)
return name
def add_bndbox(obj, bndbox=None):
if bndbox is None:
bndbox = {"xmin": "1", "xmax": "2", "ymin": "3", "ymax": "4"}
obj = add_child(obj, "bndbox")
for name, text in bndbox.items():
add_child(obj, name, text)
return bndbox
annotation = ET.Element("annotation")
obj = add_child(annotation, "object")
data = dict(name=add_name(obj), bndbox=add_bndbox(obj))
with open(root / name, "wb") as fh:
fh.write(ET.tostring(annotation))
return data
@classmethod
def generate(cls, root, *, year, trainval):
archive_folder = root
if year == "2011":
archive_folder /= "TrainVal"
data_folder = archive_folder / "VOCdevkit" / f"VOC{year}"
data_folder.mkdir(parents=True, exist_ok=True)
ids, num_samples_map = cls._make_split_files(data_folder, year=year, trainval=trainval)
for make_folder_fn, name, suffix in [
(create_image_folder, "JPEGImages", ".jpg"),
(create_image_folder, "SegmentationClass", ".png"),
(cls._make_detection_anns_folder, "Annotations", ".xml"),
]:
make_folder_fn(data_folder, name, file_name_fn=lambda idx: ids[idx] + suffix, num_examples=len(ids))
make_tar(root, (cls._TRAIN_VAL_FILE_NAMES if trainval else cls._TEST_FILE_NAMES)[year], data_folder)
return num_samples_map
@DATASET_MOCKS.set_from_named_callable
def voc(info, root, config):
trainval = config.split != "test"
num_samples_map = VOCMockData.generate(root, year=config.year, trainval=trainval)
return {
config_: num_samples_map[config_.split]
for config_ in info._configs
if config_.year == config.year and ((config_.split == "test") ^ trainval)
}
class CelebAMockData:
@classmethod
def _make_ann_file(cls, root, name, data, *, field_names=None):
with open(root / name, "w") as file:
if field_names:
file.write(f"{len(data)}\r\n")
file.write(" ".join(field_names) + "\r\n")
file.writelines(" ".join(str(item) for item in row) + "\r\n" for row in data)
_SPLIT_TO_IDX = {
"train": 0,
"val": 1,
"test": 2,
}
@classmethod
def _make_split_file(cls, root):
num_samples_map = {"train": 4, "val": 3, "test": 2}
data = [
(f"{idx:06d}.jpg", cls._SPLIT_TO_IDX[split])
for split, num_samples in num_samples_map.items()
for idx in range(num_samples)
]
cls._make_ann_file(root, "list_eval_partition.txt", data)
image_file_names, _ = zip(*data)
return image_file_names, num_samples_map
@classmethod
def _make_identity_file(cls, root, image_file_names):
cls._make_ann_file(
root, "identity_CelebA.txt", [(name, int(make_scalar(low=1, dtype=torch.int))) for name in image_file_names]
)
@classmethod
def _make_attributes_file(cls, root, image_file_names):
field_names = ("5_o_Clock_Shadow", "Young")
data = [
[name, *[" 1" if attr else "-1" for attr in make_tensor((len(field_names),), dtype=torch.bool)]]
for name in image_file_names
]
cls._make_ann_file(root, "list_attr_celeba.txt", data, field_names=(*field_names, ""))
@classmethod
def _make_bounding_boxes_file(cls, root, image_file_names):
field_names = ("image_id", "x_1", "y_1", "width", "height")
data = [
[f"{name} ", *[f"{coord:3d}" for coord in make_tensor((4,), low=0, dtype=torch.int).tolist()]]
for name in image_file_names
]
cls._make_ann_file(root, "list_bbox_celeba.txt", data, field_names=field_names)
@classmethod
def _make_landmarks_file(cls, root, image_file_names):
field_names = ("lefteye_x", "lefteye_y", "rightmouth_x", "rightmouth_y")
data = [
[ [
pytest.param(dataset_mock, config, id=config_id(dataset_mock.name, config)) name,
for dataset_mock in datasets_mocks *[
for config in dataset_mock.configs f"{coord:4d}" if idx else coord
for idx, coord in enumerate(make_tensor((len(field_names),), low=0, dtype=torch.int).tolist())
], ],
]
for name in image_file_names
]
cls._make_ann_file(root, "list_landmarks_align_celeba.txt", data, field_names=field_names)
@classmethod
def generate(cls, root):
image_file_names, num_samples_map = cls._make_split_file(root)
image_files = create_image_folder(
root, "img_align_celeba", file_name_fn=lambda idx: image_file_names[idx], num_examples=len(image_file_names)
)
make_zip(root, image_files[0].parent.with_suffix(".zip").name)
for make_ann_file_fn in (
cls._make_identity_file,
cls._make_attributes_file,
cls._make_bounding_boxes_file,
cls._make_landmarks_file,
):
make_ann_file_fn(root, image_file_names)
return num_samples_map
@DATASET_MOCKS.set_from_named_callable
def celeba(info, root, _):
num_samples_map = CelebAMockData.generate(root)
return {config: num_samples_map[config.split] for config in info._configs}
@DATASET_MOCKS.set_from_named_callable
def dtd(info, root, _):
data_folder = root / "dtd"
num_images_per_class = 3
image_folder = data_folder / "images"
categories = {"banded", "marbled", "zigzagged"}
image_ids_per_category = {
category: [
str(path.relative_to(path.parents[1]).as_posix())
for path in create_image_folder(
image_folder,
category,
file_name_fn=lambda idx: f"{category}_{idx:04d}.jpg",
num_examples=num_images_per_class,
)
]
for category in categories
}
meta_folder = data_folder / "labels"
meta_folder.mkdir()
with open(meta_folder / "labels_joint_anno.txt", "w") as file:
for cls, image_ids in image_ids_per_category.items():
for image_id in image_ids:
joint_categories = random.choices(
list(categories - {cls}), k=int(torch.randint(len(categories) - 1, ()))
) )
file.write(" ".join([image_id, *sorted([cls, *joint_categories])]) + "\n")
image_ids = list(itertools.chain(*image_ids_per_category.values()))
splits = ("train", "val", "test")
num_samples_map = {}
for fold in range(1, 11):
random.shuffle(image_ids)
for offset, split in enumerate(splits):
image_ids_in_config = image_ids[offset :: len(splits)]
with open(meta_folder / f"{split}{fold}.txt", "w") as file:
file.write("\n".join(image_ids_in_config) + "\n")
num_samples_map[info.make_config(split=split, fold=str(fold))] = len(image_ids_in_config)
make_tar(root, "dtd-r1.0.1.tar.gz", data_folder, compression="gz")
return num_samples_map
@DATASET_MOCKS.set_from_named_callable
def fer2013(info, root, config):
num_samples = 5 if config.split == "train" else 3
path = root / f"{config.split}.txt"
with open(path, "w", newline="") as file:
field_names = ["emotion"] if config.split == "train" else []
field_names.append("pixels")
file.write(",".join(field_names) + "\n")
writer = csv.DictWriter(file, fieldnames=field_names, quotechar='"', quoting=csv.QUOTE_NONNUMERIC)
for _ in range(num_samples):
rowdict = {
"pixels": " ".join([str(int(pixel)) for pixel in torch.randint(256, (48 * 48,), dtype=torch.uint8)])
}
if config.split == "train":
rowdict["emotion"] = int(torch.randint(7, ()))
writer.writerow(rowdict)
make_zip(root, f"{path.name}.zip", path)
return num_samples
@DATASET_MOCKS.set_from_named_callable
def clevr(info, root, config):
data_folder = root / "CLEVR_v1.0"
num_samples_map = {
"train": 3,
"val": 2,
"test": 1,
}
images_folder = data_folder / "images"
image_files = {
split: create_image_folder(
images_folder,
split,
file_name_fn=lambda idx: f"CLEVR_{split}_{idx:06d}.jpg",
num_examples=num_samples,
)
for split, num_samples in num_samples_map.items()
}
scenes_folder = data_folder / "scenes"
scenes_folder.mkdir()
for split in ["train", "val"]:
with open(scenes_folder / f"CLEVR_{split}_scenes.json", "w") as file:
json.dump(
{
"scenes": [
{
"image_filename": image_file.name,
# We currently only return the number of objects in a scene.
# Thus, it is sufficient for now to only mock the number of elements.
"objects": [None] * int(torch.randint(1, 5, ())),
}
for image_file in image_files[split]
]
},
file,
)
make_zip(root, f"{data_folder.name}.zip")
return {config_: num_samples_map[config_.split] for config_ in info._configs}
class OxfordIIITPetMockData:
@classmethod
def _meta_to_split_and_classification_ann(cls, meta, idx):
image_id = "_".join(
[
*[(str.title if meta["species"] == "cat" else str.lower)(part) for part in meta["cls"].split()],
str(idx),
]
)
class_id = str(meta["label"] + 1)
species = "1" if meta["species"] == "cat" else "2"
breed_id = "-1"
return (image_id, class_id, species, breed_id)
@classmethod
def generate(self, root):
classification_anns_meta = (
dict(cls="Abyssinian", label=0, species="cat"),
dict(cls="Keeshond", label=18, species="dog"),
dict(cls="Yorkshire Terrier", label=36, species="dog"),
)
split_and_classification_anns = [
self._meta_to_split_and_classification_ann(meta, idx)
for meta, idx in itertools.product(classification_anns_meta, (1, 2, 10))
]
image_ids, *_ = zip(*split_and_classification_anns)
image_files = create_image_folder(
root, "images", file_name_fn=lambda idx: f"{image_ids[idx]}.jpg", num_examples=len(image_ids)
)
anns_folder = root / "annotations"
anns_folder.mkdir()
random.shuffle(split_and_classification_anns)
splits = ("trainval", "test")
num_samples_map = {}
for offset, split in enumerate(splits):
split_and_classification_anns_in_split = split_and_classification_anns[offset :: len(splits)]
with open(anns_folder / f"{split}.txt", "w") as file:
writer = csv.writer(file, delimiter=" ")
for split_and_classification_ann in split_and_classification_anns_in_split:
writer.writerow(split_and_classification_ann)
num_samples_map[split] = len(split_and_classification_anns_in_split)
segmentation_files = create_image_folder(
anns_folder, "trimaps", file_name_fn=lambda idx: f"{image_ids[idx]}.png", num_examples=len(image_ids)
)
# The dataset has some rogue files
for path in image_files[:3]:
path.with_suffix(".mat").touch()
for path in segmentation_files:
path.with_name(f".{path.name}").touch()
make_tar(root, "images.tar")
make_tar(root, anns_folder.with_suffix(".tar").name)
return num_samples_map
@DATASET_MOCKS.set_from_named_callable
def oxford_iiit_pet(info, root, config):
num_samples_map = OxfordIIITPetMockData.generate(root)
return {config_: num_samples_map[config_.split] for config_ in info._configs}
class _CUB200MockData:
@classmethod
def _category_folder(cls, category, idx):
return f"{idx:03d}.{category}"
@classmethod
def _file_stem(cls, category, idx):
return f"{category}_{idx:04d}"
@classmethod
def _make_images(cls, images_folder):
image_files = []
for category_idx, category in [
(1, "Black_footed_Albatross"),
(100, "Brown_Pelican"),
(200, "Common_Yellowthroat"),
]:
image_files.extend(
create_image_folder(
images_folder,
cls._category_folder(category, category_idx),
lambda image_idx: f"{cls._file_stem(category, image_idx)}.jpg",
num_examples=5,
)
)
return image_files
class CUB2002011MockData(_CUB200MockData):
@classmethod
def _make_archive(cls, root):
archive_folder = root / "CUB_200_2011"
images_folder = archive_folder / "images"
image_files = cls._make_images(images_folder)
image_ids = list(range(1, len(image_files) + 1))
with open(archive_folder / "images.txt", "w") as file:
file.write(
"\n".join(
f"{id} {path.relative_to(images_folder).as_posix()}" for id, path in zip(image_ids, image_files)
)
)
split_ids = torch.randint(2, (len(image_ids),)).tolist()
counts = Counter(split_ids)
num_samples_map = {"train": counts[1], "test": counts[0]}
with open(archive_folder / "train_test_split.txt", "w") as file:
file.write("\n".join(f"{image_id} {split_id}" for image_id, split_id in zip(image_ids, split_ids)))
with open(archive_folder / "bounding_boxes.txt", "w") as file:
file.write(
"\n".join(
" ".join(
str(item)
for item in [image_id, *make_tensor((4,), dtype=torch.int, low=0).to(torch.float).tolist()]
)
for image_id in image_ids
)
)
make_tar(root, archive_folder.with_suffix(".tgz").name, compression="gz")
return image_files, num_samples_map
@classmethod
def _make_segmentations(cls, root, image_files):
segmentations_folder = root / "segmentations"
for image_file in image_files:
folder = segmentations_folder.joinpath(image_file.relative_to(image_file.parents[1]))
folder.mkdir(exist_ok=True, parents=True)
create_image_file(
folder,
image_file.with_suffix(".png").name,
size=[1, *make_tensor((2,), low=3, dtype=torch.int).tolist()],
)
make_tar(root, segmentations_folder.with_suffix(".tgz").name)
@classmethod
def generate(cls, root):
image_files, num_samples_map = cls._make_archive(root)
cls._make_segmentations(root, image_files)
return num_samples_map
class CUB2002010MockData(_CUB200MockData):
@classmethod
def _make_hidden_rouge_file(cls, *files):
for file in files:
(file.parent / f"._{file.name}").touch()
@classmethod
def _make_splits(cls, root, image_files):
split_folder = root / "lists"
split_folder.mkdir()
random.shuffle(image_files)
splits = ("train", "test")
num_samples_map = {}
for offset, split in enumerate(splits):
image_files_in_split = image_files[offset :: len(splits)]
split_file = split_folder / f"{split}.txt"
with open(split_file, "w") as file:
file.write(
"\n".join(
sorted(
str(image_file.relative_to(image_file.parents[1]).as_posix())
for image_file in image_files_in_split
)
)
)
cls._make_hidden_rouge_file(split_file)
num_samples_map[split] = len(image_files_in_split)
make_tar(root, split_folder.with_suffix(".tgz").name, compression="gz")
return num_samples_map
@classmethod
def _make_anns(cls, root, image_files):
from scipy.io import savemat
anns_folder = root / "annotations-mat"
for image_file in image_files:
ann_file = anns_folder / image_file.with_suffix(".mat").relative_to(image_file.parents[1])
ann_file.parent.mkdir(parents=True, exist_ok=True)
savemat(
ann_file,
{
"seg": torch.randint(
256, make_tensor((2,), low=3, dtype=torch.int).tolist(), dtype=torch.uint8
).numpy(),
"bbox": dict(
zip(("left", "top", "right", "bottom"), make_tensor((4,), dtype=torch.uint8).tolist())
),
},
)
readme_file = anns_folder / "README.txt"
readme_file.touch()
cls._make_hidden_rouge_file(readme_file)
make_tar(root, "annotations.tgz", anns_folder, compression="gz")
@classmethod
def generate(cls, root):
images_folder = root / "images"
image_files = cls._make_images(images_folder)
cls._make_hidden_rouge_file(*image_files)
make_tar(root, images_folder.with_suffix(".tgz").name, compression="gz")
num_samples_map = cls._make_splits(root, image_files)
cls._make_anns(root, image_files)
return num_samples_map
@DATASET_MOCKS.set_from_named_callable
def cub200(info, root, config):
num_samples_map = (CUB2002011MockData if config.year == "2011" else CUB2002010MockData).generate(root)
return {config_: num_samples_map[config_.split] for config_ in info._configs if config_.year == config.year}
...@@ -6,17 +6,28 @@ from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS ...@@ -6,17 +6,28 @@ from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS
from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter
from torch.utils.data.graph import traverse from torch.utils.data.graph import traverse
from torchdata.datapipes.iter import IterDataPipe, Shuffler from torchdata.datapipes.iter import IterDataPipe, Shuffler
from torchvision.prototype import transforms from torchvision.prototype import transforms, datasets
from torchvision.prototype.utils._internal import sequence_to_str from torchvision.prototype.utils._internal import sequence_to_str
@parametrize_dataset_mocks(DATASET_MOCKS) def test_coverage():
untested_datasets = set(datasets.list()) - DATASET_MOCKS.keys()
if untested_datasets:
raise AssertionError(
f"The dataset(s) {sequence_to_str(sorted(untested_datasets), separate_last='and ')} "
f"are exposed through `torchvision.prototype.datasets.load()`, but are not tested. "
f"Please add mock data to `test/builtin_dataset_mocks.py`."
)
class TestCommon: class TestCommon:
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_smoke(self, dataset_mock, config): def test_smoke(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config) dataset, _ = dataset_mock.load(config)
if not isinstance(dataset, IterDataPipe): if not isinstance(dataset, IterDataPipe):
raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.") raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.")
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_sample(self, dataset_mock, config): def test_sample(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config) dataset, _ = dataset_mock.load(config)
...@@ -31,6 +42,7 @@ class TestCommon: ...@@ -31,6 +42,7 @@ class TestCommon:
if not sample: if not sample:
raise AssertionError("Sample dictionary is empty.") raise AssertionError("Sample dictionary is empty.")
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_num_samples(self, dataset_mock, config): def test_num_samples(self, dataset_mock, config):
dataset, mock_info = dataset_mock.load(config) dataset, mock_info = dataset_mock.load(config)
...@@ -40,6 +52,7 @@ class TestCommon: ...@@ -40,6 +52,7 @@ class TestCommon:
assert num_samples == mock_info["num_samples"] assert num_samples == mock_info["num_samples"]
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_decoding(self, dataset_mock, config): def test_decoding(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config) dataset, _ = dataset_mock.load(config)
...@@ -50,6 +63,7 @@ class TestCommon: ...@@ -50,6 +63,7 @@ class TestCommon:
f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded." f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded."
) )
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_no_vanilla_tensors(self, dataset_mock, config): def test_no_vanilla_tensors(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config) dataset, _ = dataset_mock.load(config)
...@@ -60,16 +74,33 @@ class TestCommon: ...@@ -60,16 +74,33 @@ class TestCommon:
f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors." f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors."
) )
@parametrize_dataset_mocks(DATASET_MOCKS)
def test_transformable(self, dataset_mock, config): def test_transformable(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config) dataset, _ = dataset_mock.load(config)
next(iter(dataset.map(transforms.Identity()))) next(iter(dataset.map(transforms.Identity())))
@parametrize_dataset_mocks(
DATASET_MOCKS,
marks={
"cub200": pytest.mark.xfail(
reason="See https://github.com/pytorch/vision/pull/5187#issuecomment-1015479165"
)
},
)
def test_traversable(self, dataset_mock, config): def test_traversable(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config) dataset, _ = dataset_mock.load(config)
traverse(dataset) traverse(dataset)
@parametrize_dataset_mocks(
DATASET_MOCKS,
marks={
"cub200": pytest.mark.xfail(
reason="See https://github.com/pytorch/vision/pull/5187#issuecomment-1015479165"
)
},
)
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter), ids=lambda type: type.__name__) @pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter), ids=lambda type: type.__name__)
def test_has_annotations(self, dataset_mock, config, annotation_dp_type): def test_has_annotations(self, dataset_mock, config, annotation_dp_type):
def scan(graph): def scan(graph):
...@@ -86,8 +117,8 @@ class TestCommon: ...@@ -86,8 +117,8 @@ class TestCommon:
raise AssertionError(f"The dataset doesn't comprise a {annotation_dp_type.__name__}() datapipe.") raise AssertionError(f"The dataset doesn't comprise a {annotation_dp_type.__name__}() datapipe.")
@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"])
class TestQMNIST: class TestQMNIST:
@parametrize_dataset_mocks([mock for mock in DATASET_MOCKS if mock.name == "qmnist"])
def test_extra_label(self, dataset_mock, config): def test_extra_label(self, dataset_mock, config):
dataset, _ = dataset_mock.load(config) dataset, _ = dataset_mock.load(config)
......
...@@ -26,6 +26,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -26,6 +26,7 @@ from torchvision.prototype.datasets.utils._internal import (
hint_sharding, hint_sharding,
hint_shuffling, hint_shuffling,
) )
from torchvision.prototype.features import Feature, Label, BoundingBox
csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True) csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True)
...@@ -67,6 +68,7 @@ class CelebA(Dataset): ...@@ -67,6 +68,7 @@ class CelebA(Dataset):
"celeba", "celeba",
type=DatasetType.IMAGE, type=DatasetType.IMAGE,
homepage="https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html", homepage="https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html",
valid_options=dict(split=("train", "val", "test")),
) )
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def resources(self, config: DatasetConfig) -> List[OnlineResource]:
...@@ -104,7 +106,7 @@ class CelebA(Dataset): ...@@ -104,7 +106,7 @@ class CelebA(Dataset):
_SPLIT_ID_TO_NAME = { _SPLIT_ID_TO_NAME = {
"0": "train", "0": "train",
"1": "valid", "1": "val",
"2": "test", "2": "test",
} }
...@@ -117,22 +119,22 @@ class CelebA(Dataset): ...@@ -117,22 +119,22 @@ class CelebA(Dataset):
def _collate_and_decode_sample( def _collate_and_decode_sample(
self, self,
data: Tuple[Tuple[str, Tuple[str, List[str]], Tuple[str, io.IOBase]], Tuple[str, Dict[str, Any]]], data: Tuple[Tuple[str, Tuple[Tuple[str, Dict[str, Any]], Tuple[str, io.IOBase]]], Tuple[str, Dict[str, Any]]],
*, *,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
split_and_image_data, ann_data = data split_and_image_data, ann_data = data
_, _, image_data = split_and_image_data _, (_, image_data) = split_and_image_data
path, buffer = image_data path, buffer = image_data
_, ann = ann_data _, ann = ann_data
image = decoder(buffer) if decoder else buffer image = decoder(buffer) if decoder else buffer
identity = int(ann["identity"]["identity"]) identity = Label(int(ann["identity"]["identity"]))
attributes = {attr: value == "1" for attr, value in ann["attributes"].items()} attributes = {attr: value == "1" for attr, value in ann["attributes"].items()}
bbox = torch.tensor([int(ann["bbox"][key]) for key in ("x_1", "y_1", "width", "height")]) bbox = BoundingBox([int(ann["bbox"][key]) for key in ("x_1", "y_1", "width", "height")])
landmarks = { landmarks = {
landmark: torch.tensor((int(ann["landmarks"][f"{landmark}_x"]), int(ann["landmarks"][f"{landmark}_y"]))) landmark: Feature((int(ann["landmarks"][f"{landmark}_x"]), int(ann["landmarks"][f"{landmark}_y"])))
for landmark in {key[:-2] for key in ann["landmarks"].keys()} for landmark in {key[:-2] for key in ann["landmarks"].keys()}
} }
......
...@@ -105,7 +105,7 @@ class CUB200(Dataset): ...@@ -105,7 +105,7 @@ class CUB200(Dataset):
path = pathlib.Path(data[0]) path = pathlib.Path(data[0])
return path.with_suffix(".jpg").name return path.with_suffix(".jpg").name
def _2011_decode_ann( def _2011_load_ann(
self, self,
data: Tuple[str, Tuple[List[str], Tuple[str, io.IOBase]]], data: Tuple[str, Tuple[List[str], Tuple[str, io.IOBase]]],
*, *,
...@@ -126,7 +126,7 @@ class CUB200(Dataset): ...@@ -126,7 +126,7 @@ class CUB200(Dataset):
path = pathlib.Path(data[0]) path = pathlib.Path(data[0])
return path.with_suffix(".jpg").name, data return path.with_suffix(".jpg").name, data
def _2010_decode_ann( def _2010_load_ann(
self, data: Tuple[str, Tuple[str, io.IOBase]], *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] self, data: Tuple[str, Tuple[str, io.IOBase]], *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]]
) -> Dict[str, Any]: ) -> Dict[str, Any]:
_, (path, buffer) = data _, (path, buffer) = data
...@@ -154,7 +154,7 @@ class CUB200(Dataset): ...@@ -154,7 +154,7 @@ class CUB200(Dataset):
label_str, category = dir_name.split(".") label_str, category = dir_name.split(".")
return dict( return dict(
(self._2011_decode_ann if year == "2011" else self._2010_decode_ann)(anns_data, decoder=decoder), (self._2011_load_ann if year == "2011" else self._2010_load_ann)(anns_data, decoder=decoder),
image=decoder(buffer) if decoder else buffer, image=decoder(buffer) if decoder else buffer,
label=Label(int(label_str), category=category), label=Label(int(label_str), category=category),
) )
...@@ -196,7 +196,7 @@ class CUB200(Dataset): ...@@ -196,7 +196,7 @@ class CUB200(Dataset):
else: # config.year == "2010" else: # config.year == "2010"
split_dp, images_dp, anns_dp = resource_dps split_dp, images_dp, anns_dp = resource_dps
split_dp = Filter(split_dp, path_comparator("stem", config.split)) split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
split_dp = LineReader(split_dp, decode=True, return_path=False) split_dp = LineReader(split_dp, decode=True, return_path=False)
split_dp = Mapper(split_dp, self._2010_split_key) split_dp = Mapper(split_dp, self._2010_split_key)
......
import enum import enum
import functools
import io import io
import pathlib import pathlib
from typing import Any, Callable, Dict, List, Optional, Tuple from typing import Any, Callable, Dict, List, Optional, Tuple
...@@ -126,7 +127,7 @@ class DTD(Dataset): ...@@ -126,7 +127,7 @@ class DTD(Dataset):
ref_key_fn=self._image_key_fn, ref_key_fn=self._image_key_fn,
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
) )
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, functools.partial(self._collate_and_decode_sample, decoder=decoder))
def _filter_images(self, data: Tuple[str, Any]) -> bool: def _filter_images(self, data: Tuple[str, Any]) -> bool:
return self._classify_archive(data) == DTDDemux.IMAGES return self._classify_archive(data) == DTDDemux.IMAGES
......
...@@ -31,6 +31,7 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -31,6 +31,7 @@ from torchvision.prototype.datasets.utils._internal import (
hint_sharding, hint_sharding,
hint_shuffling, hint_shuffling,
) )
from torchvision.prototype.features import Feature
class SBD(Dataset): class SBD(Dataset):
...@@ -83,11 +84,11 @@ class SBD(Dataset): ...@@ -83,11 +84,11 @@ class SBD(Dataset):
# the boundaries are stored in sparse CSC format, which is not supported by PyTorch # the boundaries are stored in sparse CSC format, which is not supported by PyTorch
boundaries = ( boundaries = (
torch.as_tensor(np.stack([raw_boundary[0].toarray() for raw_boundary in raw_boundaries])) Feature(np.stack([raw_boundary[0].toarray() for raw_boundary in raw_boundaries]))
if decode_boundaries if decode_boundaries
else None else None
) )
segmentation = torch.as_tensor(raw_segmentation) if decode_segmentation else None segmentation = Feature(raw_segmentation) if decode_segmentation else None
return boundaries, segmentation return boundaries, segmentation
...@@ -140,6 +141,7 @@ class SBD(Dataset): ...@@ -140,6 +141,7 @@ class SBD(Dataset):
if config.split == "train_noval": if config.split == "train_noval":
split_dp = extra_split_dp split_dp = extra_split_dp
split_dp = Filter(split_dp, path_comparator("stem", config.split))
split_dp = LineReader(split_dp, decode=True) split_dp = LineReader(split_dp, decode=True)
split_dp = hint_sharding(split_dp) split_dp = hint_sharding(split_dp)
split_dp = hint_shuffling(split_dp) split_dp = hint_shuffling(split_dp)
......
...@@ -18,6 +18,7 @@ from torchvision.prototype.datasets.utils import ( ...@@ -18,6 +18,7 @@ from torchvision.prototype.datasets.utils import (
DatasetType, DatasetType,
) )
from torchvision.prototype.datasets.utils._internal import image_buffer_from_array, hint_sharding, hint_shuffling from torchvision.prototype.datasets.utils._internal import image_buffer_from_array, hint_sharding, hint_shuffling
from torchvision.prototype.features import Image, Label
class SEMEION(Dataset): class SEMEION(Dataset):
...@@ -46,14 +47,13 @@ class SEMEION(Dataset): ...@@ -46,14 +47,13 @@ class SEMEION(Dataset):
label_data = [int(label) for label in data[256:] if label] label_data = [int(label) for label in data[256:] if label]
if decoder is raw: if decoder is raw:
image = image_data.unsqueeze(0) image = Image(image_data.unsqueeze(0))
else: else:
image_buffer = image_buffer_from_array(image_data.numpy()) image_buffer = image_buffer_from_array(image_data.numpy())
image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment] image = decoder(image_buffer) if decoder else image_buffer # type: ignore[assignment]
label = next((idx for idx, one_hot_label in enumerate(label_data) if one_hot_label)) label_idx = next((idx for idx, one_hot_label in enumerate(label_data) if one_hot_label))
category = self.info.categories[label] return dict(image=image, label=Label(label_idx, category=self.info.categories[label_idx]))
return dict(image=image, label=label, category=category)
def _make_datapipe( def _make_datapipe(
self, self,
......
...@@ -30,34 +30,50 @@ from torchvision.prototype.datasets.utils._internal import ( ...@@ -30,34 +30,50 @@ from torchvision.prototype.datasets.utils._internal import (
hint_sharding, hint_sharding,
hint_shuffling, hint_shuffling,
) )
from torchvision.prototype.features import BoundingBox
HERE = pathlib.Path(__file__).parent
class VOCDatasetInfo(DatasetInfo):
def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self._configs = tuple(config for config in self._configs if config.split != "test" or config.year == "2007")
def make_config(self, **options: Any) -> DatasetConfig:
config = super().make_config(**options)
if config.split == "test" and config.year != "2007":
raise ValueError("`split='test'` is only available for `year='2007'`")
return config
class VOC(Dataset): class VOC(Dataset):
def _make_info(self) -> DatasetInfo: def _make_info(self) -> DatasetInfo:
return DatasetInfo( return VOCDatasetInfo(
"voc", "voc",
type=DatasetType.IMAGE, type=DatasetType.IMAGE,
homepage="http://host.robots.ox.ac.uk/pascal/VOC/", homepage="http://host.robots.ox.ac.uk/pascal/VOC/",
valid_options=dict( valid_options=dict(
split=("train", "val", "test"), split=("train", "val", "trainval", "test"),
year=("2012",), year=("2012", "2007", "2008", "2009", "2010", "2011"),
task=("detection", "segmentation"), task=("detection", "segmentation"),
), ),
) )
_TRAIN_VAL_ARCHIVES = {
"2007": ("VOCtrainval_06-Nov-2007.tar", "7d8cd951101b0957ddfd7a530bdc8a94f06121cfc1e511bb5937e973020c7508"),
"2008": ("VOCtrainval_14-Jul-2008.tar", "7f0ca53c1b5a838fbe946965fc106c6e86832183240af5c88e3f6c306318d42e"),
"2009": ("VOCtrainval_11-May-2009.tar", "11cbe1741fb5bdadbbca3c08e9ec62cd95c14884845527d50847bc2cf57e7fd6"),
"2010": ("VOCtrainval_03-May-2010.tar", "1af4189cbe44323ab212bff7afbc7d0f55a267cc191eb3aac911037887e5c7d4"),
"2011": ("VOCtrainval_25-May-2011.tar", "0a7f5f5d154f7290ec65ec3f78b72ef72c6d93ff6d79acd40dc222a9ee5248ba"),
"2012": ("VOCtrainval_11-May-2012.tar", "e14f763270cf193d0b5f74b169f44157a4b0c6efa708f4dd0ff78ee691763bcb"),
}
_TEST_ARCHIVES = {
"2007": ("VOCtest_06-Nov-2007.tar", "6836888e2e01dca84577a849d339fa4f73e1e4f135d312430c4856b5609b4892")
}
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def resources(self, config: DatasetConfig) -> List[OnlineResource]:
if config.year == "2012": file_name, sha256 = (self._TEST_ARCHIVES if config.split == "test" else self._TRAIN_VAL_ARCHIVES)[config.year]
if config.split == "train": archive = HttpResource(f"http://host.robots.ox.ac.uk/pascal/VOC/voc{config.year}/{file_name}", sha256=sha256)
archive = HttpResource(
"http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar",
sha256="e14f763270cf193d0b5f74b169f44157a4b0c6efa708f4dd0ff78ee691763bcb",
)
else:
raise RuntimeError("FIXME")
else:
raise RuntimeError("FIXME")
return [archive] return [archive]
_ANNS_FOLDER = dict( _ANNS_FOLDER = dict(
...@@ -88,7 +104,7 @@ class VOC(Dataset): ...@@ -88,7 +104,7 @@ class VOC(Dataset):
objects = result["annotation"]["object"] objects = result["annotation"]["object"]
bboxes = [obj["bndbox"] for obj in objects] bboxes = [obj["bndbox"] for obj in objects]
bboxes = [[int(bbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] for bbox in bboxes] bboxes = [[int(bbox[part]) for part in ("xmin", "ymin", "xmax", "ymax")] for bbox in bboxes]
return torch.tensor(bboxes) return BoundingBox(bboxes)
def _collate_and_decode_sample( def _collate_and_decode_sample(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment