"examples/custom_diffusion/train_custom_diffusion.py" did not exist on "46def7265fe43a7d3d67bac1872593acbdb0b61f"
Unverified Commit 31e503f1 authored by Yassine Alouini's avatar Yassine Alouini Committed by GitHub
Browse files

Food101 new dataset api (#5584)



* [FEAT] Start implementing Food101 using the new datasets API. WIP.

* [FEAT] Generate Food101 categories and start the test mock.

* [FEAT] food101 dataset code seems to work now.

* [TEST] food101 mock update.

* [FIX] Some fixes thanks to running food101 tests.

* [FIX] Fix mypy checks for the food101 file.

* [FIX] Remove unused numpy.

* [FIX] Some changes thanks to code review.

* [ENH] More idomatic dataset code thanks to code review.

* [FIX] Remove unused cast.

* [ENH] Set decompress and extract to True for some performance gains.

* [FEAT] Use the preprocess=decompress keyword.

* [ENH] Use the train and test.txt file instead of the .json variants and simplify code + update mock data.

* [ENH] Better food101 mock data generation.

* [FIX] Remove a useless print.
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 79e4985a
...@@ -911,6 +911,44 @@ def country211(info, root, config): ...@@ -911,6 +911,44 @@ def country211(info, root, config):
return num_examples * len(classes) return num_examples * len(classes)
@register_mock
def food101(info, root, config):
data_folder = root / "food-101"
num_images_per_class = 3
image_folder = data_folder / "images"
categories = ["apple_pie", "baby_back_ribs", "waffles"]
image_ids = []
for category in categories:
image_files = create_image_folder(
image_folder,
category,
file_name_fn=lambda idx: f"{idx:04d}.jpg",
num_examples=num_images_per_class,
)
image_ids.extend(path.relative_to(path.parents[1]).with_suffix("").as_posix() for path in image_files)
meta_folder = data_folder / "meta"
meta_folder.mkdir()
with open(meta_folder / "classes.txt", "w") as file:
for category in categories:
file.write(f"{category}\n")
splits = ["train", "test"]
num_samples_map = {}
for offset, split in enumerate(splits):
image_ids_in_split = image_ids[offset :: len(splits)]
num_samples_map[split] = len(image_ids_in_split)
with open(meta_folder / f"{split}.txt", "w") as file:
for image_id in image_ids_in_split:
file.write(f"{image_id}\n")
make_tar(root, f"{data_folder.name}.tar.gz", compression="gz")
return num_samples_map[config.split]
@register_mock @register_mock
def dtd(info, root, config): def dtd(info, root, config):
data_folder = root / "dtd" data_folder = root / "dtd"
......
...@@ -8,6 +8,7 @@ from .cub200 import CUB200 ...@@ -8,6 +8,7 @@ from .cub200 import CUB200
from .dtd import DTD from .dtd import DTD
from .eurosat import EuroSAT from .eurosat import EuroSAT
from .fer2013 import FER2013 from .fer2013 import FER2013
from .food101 import Food101
from .gtsrb import GTSRB from .gtsrb import GTSRB
from .imagenet import ImageNet from .imagenet import ImageNet
from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST
......
apple_pie
baby_back_ribs
baklava
beef_carpaccio
beef_tartare
beet_salad
beignets
bibimbap
bread_pudding
breakfast_burrito
bruschetta
caesar_salad
cannoli
caprese_salad
carrot_cake
ceviche
cheesecake
cheese_plate
chicken_curry
chicken_quesadilla
chicken_wings
chocolate_cake
chocolate_mousse
churros
clam_chowder
club_sandwich
crab_cakes
creme_brulee
croque_madame
cup_cakes
deviled_eggs
donuts
dumplings
edamame
eggs_benedict
escargots
falafel
filet_mignon
fish_and_chips
foie_gras
french_fries
french_onion_soup
french_toast
fried_calamari
fried_rice
frozen_yogurt
garlic_bread
gnocchi
greek_salad
grilled_cheese_sandwich
grilled_salmon
guacamole
gyoza
hamburger
hot_and_sour_soup
hot_dog
huevos_rancheros
hummus
ice_cream
lasagna
lobster_bisque
lobster_roll_sandwich
macaroni_and_cheese
macarons
miso_soup
mussels
nachos
omelette
onion_rings
oysters
pad_thai
paella
pancakes
panna_cotta
peking_duck
pho
pizza
pork_chop
poutine
prime_rib
pulled_pork_sandwich
ramen
ravioli
red_velvet_cake
risotto
samosa
sashimi
scallops
seaweed_salad
shrimp_and_grits
spaghetti_bolognese
spaghetti_carbonara
spring_rolls
steak
strawberry_shortcake
sushi
tacos
takoyaki
tiramisu
tuna_tartare
waffles
from pathlib import Path
from typing import Any, Tuple, List, Dict, Optional, BinaryIO
from torchdata.datapipes.iter import (
IterDataPipe,
Filter,
Mapper,
LineReader,
Demultiplexer,
IterKeyZipper,
)
from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import (
hint_shuffling,
hint_sharding,
path_comparator,
getitem,
INFINITE_BUFFER_SIZE,
)
from torchvision.prototype.features import Label, EncodedImage
class Food101(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"food101",
homepage="https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101",
valid_options=dict(split=("train", "test")),
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
return [
HttpResource(
url="http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz",
sha256="d97d15e438b7f4498f96086a4f7e2fa42a32f2712e87d3295441b2b6314053a4",
preprocess="decompress",
)
]
def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]:
path = Path(data[0])
if path.parents[1].name == "images":
return 0
elif path.parents[0].name == "meta":
return 1
else:
return None
def _prepare_sample(self, data: Tuple[str, Tuple[str, BinaryIO]]) -> Dict[str, Any]:
id, (path, buffer) = data
return dict(
label=Label.from_category(id.split("/", 1)[0], categories=self.categories),
path=path,
image=EncodedImage.from_file(buffer),
)
def _image_key(self, data: Tuple[str, Any]) -> str:
path = Path(data[0])
return path.relative_to(path.parents[1]).with_suffix("").as_posix()
def _make_datapipe(
self,
resource_dps: List[IterDataPipe],
*,
config: DatasetConfig,
) -> IterDataPipe[Dict[str, Any]]:
archive_dp = resource_dps[0]
images_dp, split_dp = Demultiplexer(
archive_dp, 2, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
)
split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt"))
split_dp = LineReader(split_dp, decode=True, return_path=False)
split_dp = hint_sharding(split_dp)
split_dp = hint_shuffling(split_dp)
dp = IterKeyZipper(
split_dp,
images_dp,
key_fn=getitem(),
ref_key_fn=self._image_key,
buffer_size=INFINITE_BUFFER_SIZE,
)
return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: Path) -> List[str]:
resources = self.resources(self.default_config)
dp = resources[0].load(root)
dp = Filter(dp, path_comparator("name", "classes.txt"))
dp = LineReader(dp, decode=True, return_path=False)
return list(dp)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment