Unverified Commit 01f07eeb authored by Drishti Bhasin's avatar Drishti Bhasin Committed by GitHub
Browse files

add Country211 prototype dataset (#5506)



* add country211

* remove unused import

* map val to valid and use path comparator

* remove unused import

* resolve keyerror

* map split names in dataset mock
Co-authored-by: default avatarDbhasin1 <drishti_b@me.iitr.c.in>
Co-authored-by: default avatarNicolas Hug <nicolashug@fb.com>
parent 71d2bb0b
......@@ -878,6 +878,34 @@ def celeba(info, root, config):
return CelebAMockData.generate(root)[config.split]
@register_mock
def country211(info, root, config):
split_name_mapper = {
"train": "train",
"val": "valid",
"test": "test",
}
split_folder = pathlib.Path(root, "country211", split_name_mapper[config["split"]])
split_folder.mkdir(parents=True, exist_ok=True)
num_examples = {
"train": 3,
"val": 4,
"test": 5,
}[config["split"]]
classes = ("AD", "BS", "GR")
for cls in classes:
create_image_folder(
split_folder,
name=cls,
file_name_fn=lambda idx: f"{idx}.jpg",
num_examples=num_examples,
)
make_tar(root, f"{split_folder.parent.name}.tgz", split_folder.parent, compression="gz")
return num_examples * len(classes)
@register_mock
def dtd(info, root, config):
data_folder = root / "dtd"
......
......@@ -3,6 +3,7 @@ from .celeba import CelebA
from .cifar import Cifar10, Cifar100
from .clevr import CLEVR
from .coco import Coco
from .country211 import Country211
from .cub200 import CUB200
from .dtd import DTD
from .fer2013 import FER2013
......
AD
AE
AF
AG
AI
AL
AM
AO
AQ
AR
AT
AU
AW
AX
AZ
BA
BB
BD
BE
BF
BG
BH
BJ
BM
BN
BO
BQ
BR
BS
BT
BW
BY
BZ
CA
CD
CF
CH
CI
CK
CL
CM
CN
CO
CR
CU
CV
CW
CY
CZ
DE
DK
DM
DO
DZ
EC
EE
EG
ES
ET
FI
FJ
FK
FO
FR
GA
GB
GD
GE
GF
GG
GH
GI
GL
GM
GP
GR
GS
GT
GU
GY
HK
HN
HR
HT
HU
ID
IE
IL
IM
IN
IQ
IR
IS
IT
JE
JM
JO
JP
KE
KG
KH
KN
KP
KR
KW
KY
KZ
LA
LB
LC
LI
LK
LR
LT
LU
LV
LY
MA
MC
MD
ME
MF
MG
MK
ML
MM
MN
MO
MQ
MR
MT
MU
MV
MW
MX
MY
MZ
NA
NC
NG
NI
NL
NO
NP
NZ
OM
PA
PE
PF
PG
PH
PK
PL
PR
PS
PT
PW
PY
QA
RE
RO
RS
RU
RW
SA
SB
SC
SD
SE
SG
SH
SI
SJ
SK
SL
SM
SN
SO
SS
SV
SX
SY
SZ
TG
TH
TJ
TL
TM
TN
TO
TR
TT
TW
TZ
UA
UG
US
UY
UZ
VA
VE
VG
VI
VN
VU
WS
XK
YE
ZA
ZM
ZW
import pathlib
from typing import Any, Dict, List, Tuple
from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter
from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource
from torchvision.prototype.datasets.utils._internal import path_comparator, hint_sharding, hint_shuffling
from torchvision.prototype.features import EncodedImage, Label
class Country211(Dataset):
def _make_info(self) -> DatasetInfo:
return DatasetInfo(
"country211",
homepage="https://github.com/openai/CLIP/blob/main/data/country211.md",
valid_options=dict(split=("train", "val", "test")),
)
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
return [
HttpResource(
"https://openaipublic.azureedge.net/clip/data/country211.tgz",
sha256="c011343cdc1296a8c31ff1d7129cf0b5e5b8605462cffd24f89266d6e6f4da3c",
)
]
_SPLIT_NAME_MAPPER = {
"train": "train",
"val": "valid",
"test": "test",
}
def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]:
path, buffer = data
category = pathlib.Path(path).parent.name
return dict(
label=Label.from_category(category, categories=self.categories),
path=path,
image=EncodedImage.from_file(buffer),
)
def _filter_split(self, data: Tuple[str, Any], *, split: str) -> bool:
return pathlib.Path(data[0]).parent.parent.name == split
def _make_datapipe(
self, resource_dps: List[IterDataPipe], *, config: DatasetConfig
) -> IterDataPipe[Dict[str, Any]]:
dp = resource_dps[0]
dp = Filter(dp, path_comparator("parent.parent.name", self._SPLIT_NAME_MAPPER[config.split]))
dp = hint_sharding(dp)
dp = hint_shuffling(dp)
return Mapper(dp, self._prepare_sample)
def _generate_categories(self, root: pathlib.Path) -> List[str]:
resources = self.resources(self.default_config)
dp = resources[0].load(root)
return sorted({pathlib.Path(path).parent.name for path, _ in dp})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment