Unverified Commit 6d9a42c3 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add test split for imagenet (#4866)

* add test split for imagenet

* add infinite buffer size to shuffler
parent f093d082
...@@ -452,11 +452,7 @@ def caltech256(info, root, config): ...@@ -452,11 +452,7 @@ def caltech256(info, root, config):
@dataset_mocks.register_mock_data_fn @dataset_mocks.register_mock_data_fn
def imagenet(info, root, config): def imagenet(info, root, config):
devkit_root = root / "ILSVRC2012_devkit_t12"
devkit_root.mkdir()
wnids = tuple(info.extra.wnid_to_category.keys()) wnids = tuple(info.extra.wnid_to_category.keys())
if config.split == "train": if config.split == "train":
images_root = root / "ILSVRC2012_img_train" images_root = root / "ILSVRC2012_img_train"
...@@ -470,7 +466,7 @@ def imagenet(info, root, config): ...@@ -470,7 +466,7 @@ def imagenet(info, root, config):
num_examples=1, num_examples=1,
) )
make_tar(images_root, f"{wnid}.tar", files[0].parent) make_tar(images_root, f"{wnid}.tar", files[0].parent)
else: elif config.split == "val":
num_samples = 3 num_samples = 3
files = create_image_folder( files = create_image_folder(
root=root, root=root,
...@@ -479,14 +475,26 @@ def imagenet(info, root, config): ...@@ -479,14 +475,26 @@ def imagenet(info, root, config):
num_examples=num_samples, num_examples=num_samples,
) )
images_root = files[0].parent images_root = files[0].parent
else: # config.split == "test"
images_root = root / "ILSVRC2012_img_test_v10102019"
data_root = devkit_root / "data" num_samples = 3
data_root.mkdir()
with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file:
for label in torch.randint(0, len(wnids), (num_samples,)).tolist():
file.write(f"{label}\n")
create_image_folder(
root=images_root,
name="test",
file_name_fn=lambda image_idx: f"ILSVRC2012_test_{image_idx + 1:08d}.JPEG",
num_examples=num_samples,
)
make_tar(root, f"{images_root.name}.tar", images_root) make_tar(root, f"{images_root.name}.tar", images_root)
devkit_root = root / "ILSVRC2012_devkit_t12"
devkit_root.mkdir()
data_root = devkit_root / "data"
data_root.mkdir()
with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file:
for label in torch.randint(0, len(wnids), (num_samples,)).tolist():
file.write(f"{label}\n")
make_tar(root, f"{devkit_root}.tar.gz", devkit_root, compression="gz") make_tar(root, f"{devkit_root}.tar.gz", devkit_root, compression="gz")
return num_samples return num_samples
...@@ -34,11 +34,17 @@ class ImageNet(Dataset): ...@@ -34,11 +34,17 @@ class ImageNet(Dataset):
type=DatasetType.IMAGE, type=DatasetType.IMAGE,
categories=categories, categories=categories,
homepage="https://www.image-net.org/", homepage="https://www.image-net.org/",
valid_options=dict(split=("train", "val")), valid_options=dict(split=("train", "val", "test")),
extra=dict( extra=dict(
wnid_to_category=FrozenMapping(zip(wnids, categories)), wnid_to_category=FrozenMapping(zip(wnids, categories)),
category_to_wnid=FrozenMapping(zip(categories, wnids)), category_to_wnid=FrozenMapping(zip(categories, wnids)),
sizes=FrozenMapping([(DatasetConfig(split="train"), 1281167), (DatasetConfig(split="val"), 50000)]), sizes=FrozenMapping(
[
(DatasetConfig(split="train"), 1_281_167),
(DatasetConfig(split="val"), 50_000),
(DatasetConfig(split="test"), 100_000),
]
),
), ),
) )
...@@ -53,17 +59,15 @@ class ImageNet(Dataset): ...@@ -53,17 +59,15 @@ class ImageNet(Dataset):
def wnid_to_category(self) -> Dict[str, str]: def wnid_to_category(self) -> Dict[str, str]:
return cast(Dict[str, str], self.info.extra.wnid_to_category) return cast(Dict[str, str], self.info.extra.wnid_to_category)
_IMAGES_CHECKSUMS = {
"train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb",
"val": "c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0",
"test_v10102019": "9cf7f8249639510f17d3d8a0deb47cd22a435886ba8e29e2b3223e65a4079eb4",
}
def resources(self, config: DatasetConfig) -> List[OnlineResource]: def resources(self, config: DatasetConfig) -> List[OnlineResource]:
if config.split == "train": name = "test_v10102019" if config.split == "test" else config.split
images = HttpResource( images = HttpResource(f"ILSVRC2012_img_{name}.tar", sha256=self._IMAGES_CHECKSUMS[name])
"ILSVRC2012_img_train.tar",
sha256="b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb",
)
else: # config.split == "val"
images = HttpResource(
"ILSVRC2012_img_val.tar",
sha256="c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0",
)
devkit = HttpResource( devkit = HttpResource(
"ILSVRC2012_devkit_t12.tar.gz", "ILSVRC2012_devkit_t12.tar.gz",
...@@ -81,11 +85,11 @@ class ImageNet(Dataset): ...@@ -81,11 +85,11 @@ class ImageNet(Dataset):
label = self.categories.index(category) label = self.categories.index(category)
return (label, category, wnid), data return (label, category, wnid), data
_VAL_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_val_(?P<id>\d{8})[.]JPEG") _VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG")
def _val_image_key(self, data: Tuple[str, Any]) -> int: def _val_test_image_key(self, data: Tuple[str, Any]) -> int:
path = pathlib.Path(data[0]) path = pathlib.Path(data[0])
return int(self._VAL_IMAGE_NAME_PATTERN.match(path.name).group("id")) # type: ignore[union-attr] return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name).group("id")) # type: ignore[union-attr]
def _collate_val_data( def _collate_val_data(
self, data: Tuple[Tuple[int, int], Tuple[str, io.IOBase]] self, data: Tuple[Tuple[int, int], Tuple[str, io.IOBase]]
...@@ -96,9 +100,12 @@ class ImageNet(Dataset): ...@@ -96,9 +100,12 @@ class ImageNet(Dataset):
wnid = self.category_to_wnid[category] wnid = self.category_to_wnid[category]
return (label, category, wnid), image_data return (label, category, wnid), image_data
def _collate_test_data(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[None, None, None], Tuple[str, io.IOBase]]:
return (None, None, None), data
def _collate_and_decode_sample( def _collate_and_decode_sample(
self, self,
data: Tuple[Tuple[int, str, str], Tuple[str, io.IOBase]], data: Tuple[Tuple[Optional[int], Optional[str], Optional[str]], Tuple[str, io.IOBase]],
*, *,
decoder: Optional[Callable[[io.IOBase], torch.Tensor]], decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
...@@ -108,7 +115,7 @@ class ImageNet(Dataset): ...@@ -108,7 +115,7 @@ class ImageNet(Dataset):
return dict( return dict(
path=path, path=path,
image=decoder(buffer) if decoder else buffer, image=decoder(buffer) if decoder else buffer,
label=torch.tensor(label), label=label,
category=category, category=category,
wnid=wnid, wnid=wnid,
) )
...@@ -129,7 +136,7 @@ class ImageNet(Dataset): ...@@ -129,7 +136,7 @@ class ImageNet(Dataset):
dp = TarArchiveReader(images_dp) dp = TarArchiveReader(images_dp)
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = Mapper(dp, self._collate_train_data) dp = Mapper(dp, self._collate_train_data)
else: elif config.split == "val":
devkit_dp = TarArchiveReader(devkit_dp) devkit_dp = TarArchiveReader(devkit_dp)
devkit_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt")) devkit_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt"))
devkit_dp = LineReader(devkit_dp, return_path=False) devkit_dp = LineReader(devkit_dp, return_path=False)
...@@ -141,10 +148,13 @@ class ImageNet(Dataset): ...@@ -141,10 +148,13 @@ class ImageNet(Dataset):
devkit_dp, devkit_dp,
images_dp, images_dp,
key_fn=getitem(0), key_fn=getitem(0),
ref_key_fn=self._val_image_key, ref_key_fn=self._val_test_image_key,
buffer_size=INFINITE_BUFFER_SIZE, buffer_size=INFINITE_BUFFER_SIZE,
) )
dp = Mapper(dp, self._collate_val_data) dp = Mapper(dp, self._collate_val_data)
else: # config.split == "test"
dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE)
dp = Mapper(dp, self._collate_test_data)
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder)) return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment