Unverified Commit 6d9a42c3 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add test split for imagenet (#4866)

* add test split for imagenet

* add infinite buffer size to shuffler
parent f093d082
......@@ -452,11 +452,7 @@ def caltech256(info, root, config):
@dataset_mocks.register_mock_data_fn
def imagenet(info, root, config):
devkit_root = root / "ILSVRC2012_devkit_t12"
devkit_root.mkdir()
wnids = tuple(info.extra.wnid_to_category.keys())
if config.split == "train":
images_root = root / "ILSVRC2012_img_train"
......@@ -470,7 +466,7 @@ def imagenet(info, root, config):
num_examples=1,
)
make_tar(images_root, f"{wnid}.tar", files[0].parent)
else:
elif config.split == "val":
num_samples = 3
files = create_image_folder(
root=root,
......@@ -479,14 +475,26 @@ def imagenet(info, root, config):
num_examples=num_samples,
)
images_root = files[0].parent
else: # config.split == "test"
images_root = root / "ILSVRC2012_img_test_v10102019"
data_root = devkit_root / "data"
data_root.mkdir()
with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file:
for label in torch.randint(0, len(wnids), (num_samples,)).tolist():
file.write(f"{label}\n")
num_samples = 3
create_image_folder(
root=images_root,
name="test",
file_name_fn=lambda image_idx: f"ILSVRC2012_test_{image_idx + 1:08d}.JPEG",
num_examples=num_samples,
)
make_tar(root, f"{images_root.name}.tar", images_root)
devkit_root = root / "ILSVRC2012_devkit_t12"
devkit_root.mkdir()
data_root = devkit_root / "data"
data_root.mkdir()
with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file:
for label in torch.randint(0, len(wnids), (num_samples,)).tolist():
file.write(f"{label}\n")
make_tar(root, f"{devkit_root}.tar.gz", devkit_root, compression="gz")
return num_samples
......@@ -34,11 +34,17 @@ class ImageNet(Dataset):
type=DatasetType.IMAGE,
categories=categories,
homepage="https://www.image-net.org/",
valid_options=dict(split=("train", "val")),
valid_options=dict(split=("train", "val", "test")),
extra=dict(
wnid_to_category=FrozenMapping(zip(wnids, categories)),
category_to_wnid=FrozenMapping(zip(categories, wnids)),
sizes=FrozenMapping([(DatasetConfig(split="train"), 1281167), (DatasetConfig(split="val"), 50000)]),
sizes=FrozenMapping(
[
(DatasetConfig(split="train"), 1_281_167),
(DatasetConfig(split="val"), 50_000),
(DatasetConfig(split="test"), 100_000),
]
),
),
)
......@@ -53,17 +59,15 @@ class ImageNet(Dataset):
def wnid_to_category(self) -> Dict[str, str]:
return cast(Dict[str, str], self.info.extra.wnid_to_category)
_IMAGES_CHECKSUMS = {
"train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb",
"val": "c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0",
"test_v10102019": "9cf7f8249639510f17d3d8a0deb47cd22a435886ba8e29e2b3223e65a4079eb4",
}
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
if config.split == "train":
images = HttpResource(
"ILSVRC2012_img_train.tar",
sha256="b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb",
)
else: # config.split == "val"
images = HttpResource(
"ILSVRC2012_img_val.tar",
sha256="c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0",
)
name = "test_v10102019" if config.split == "test" else config.split
images = HttpResource(f"ILSVRC2012_img_{name}.tar", sha256=self._IMAGES_CHECKSUMS[name])
devkit = HttpResource(
"ILSVRC2012_devkit_t12.tar.gz",
......@@ -81,11 +85,11 @@ class ImageNet(Dataset):
label = self.categories.index(category)
return (label, category, wnid), data
_VAL_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_val_(?P<id>\d{8})[.]JPEG")
_VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG")
def _val_image_key(self, data: Tuple[str, Any]) -> int:
def _val_test_image_key(self, data: Tuple[str, Any]) -> int:
path = pathlib.Path(data[0])
return int(self._VAL_IMAGE_NAME_PATTERN.match(path.name).group("id")) # type: ignore[union-attr]
return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name).group("id")) # type: ignore[union-attr]
def _collate_val_data(
self, data: Tuple[Tuple[int, int], Tuple[str, io.IOBase]]
......@@ -96,9 +100,12 @@ class ImageNet(Dataset):
wnid = self.category_to_wnid[category]
return (label, category, wnid), image_data
def _collate_test_data(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[None, None, None], Tuple[str, io.IOBase]]:
return (None, None, None), data
def _collate_and_decode_sample(
self,
data: Tuple[Tuple[int, str, str], Tuple[str, io.IOBase]],
data: Tuple[Tuple[Optional[int], Optional[str], Optional[str]], Tuple[str, io.IOBase]],
*,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]:
......@@ -108,7 +115,7 @@ class ImageNet(Dataset):
return dict(
path=path,
image=decoder(buffer) if decoder else buffer,
label=torch.tensor(label),
label=label,
category=category,
wnid=wnid,
)
......@@ -129,7 +136,7 @@ class ImageNet(Dataset):
dp = TarArchiveReader(images_dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = Mapper(dp, self._collate_train_data)
else:
elif config.split == "val":
devkit_dp = TarArchiveReader(devkit_dp)
devkit_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt"))
devkit_dp = LineReader(devkit_dp, return_path=False)
......@@ -141,10 +148,13 @@ class ImageNet(Dataset):
devkit_dp,
images_dp,
key_fn=getitem(0),
ref_key_fn=self._val_image_key,
ref_key_fn=self._val_test_image_key,
buffer_size=INFINITE_BUFFER_SIZE,
)
dp = Mapper(dp, self._collate_val_data)
else: # config.split == "test"
dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = Mapper(dp, self._collate_test_data)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment