Unverified Commit 0325fdd6 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Return labels for FER2013 if possible (#8452)

parent ab0b9a43
...@@ -2442,7 +2442,35 @@ class FER2013TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2442,7 +2442,35 @@ class FER2013TestCase(datasets_utils.ImageDatasetTestCase):
base_folder = os.path.join(tmpdir, "fer2013") base_folder = os.path.join(tmpdir, "fer2013")
os.makedirs(base_folder) os.makedirs(base_folder)
use_icml = config.pop("use_icml", False)
use_fer = config.pop("use_fer", False)
num_samples = 5 num_samples = 5
if use_icml or use_fer:
pixels_key, usage_key = (" pixels", " Usage") if use_icml else ("pixels", "Usage")
fieldnames = ("emotion", usage_key, pixels_key) if use_icml else ("emotion", pixels_key, usage_key)
filename = "icml_face_data.csv" if use_icml else "fer2013.csv"
with open(os.path.join(base_folder, filename), "w", newline="") as file:
writer = csv.DictWriter(
file,
fieldnames=fieldnames,
quoting=csv.QUOTE_NONNUMERIC,
quotechar='"',
)
writer.writeheader()
for i in range(num_samples):
row = {
"emotion": str(int(torch.randint(0, 7, ()))),
usage_key: "Training" if i % 2 else "PublicTest",
pixels_key: " ".join(
str(pixel)
for pixel in datasets_utils.create_image_or_video_tensor((48, 48)).view(-1).tolist()
),
}
writer.writerow(row)
else:
with open(os.path.join(base_folder, f"{config['split']}.csv"), "w", newline="") as file: with open(os.path.join(base_folder, f"{config['split']}.csv"), "w", newline="") as file:
writer = csv.DictWriter( writer = csv.DictWriter(
file, file,
...@@ -2454,7 +2482,8 @@ class FER2013TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2454,7 +2482,8 @@ class FER2013TestCase(datasets_utils.ImageDatasetTestCase):
for _ in range(num_samples): for _ in range(num_samples):
row = dict( row = dict(
pixels=" ".join( pixels=" ".join(
str(pixel) for pixel in datasets_utils.create_image_or_video_tensor((48, 48)).view(-1).tolist() str(pixel)
for pixel in datasets_utils.create_image_or_video_tensor((48, 48)).view(-1).tolist()
) )
) )
if config["split"] == "train": if config["split"] == "train":
...@@ -2464,6 +2493,17 @@ class FER2013TestCase(datasets_utils.ImageDatasetTestCase): ...@@ -2464,6 +2493,17 @@ class FER2013TestCase(datasets_utils.ImageDatasetTestCase):
return num_samples return num_samples
def test_icml_file(self):
config = {"split": "test"}
with self.create_dataset(config=config) as (dataset, _):
assert all(s[1] is None for s in dataset)
for split in ("train", "test"):
for d in ({"use_icml": True}, {"use_fer": True}):
config = {"split": split, **d}
with self.create_dataset(config=config) as (dataset, _):
assert all(s[1] is not None for s in dataset)
class GTSRBTestCase(datasets_utils.ImageDatasetTestCase): class GTSRBTestCase(datasets_utils.ImageDatasetTestCase):
DATASET_CLASS = datasets.GTSRB DATASET_CLASS = datasets.GTSRB
......
...@@ -13,9 +13,21 @@ class FER2013(VisionDataset): ...@@ -13,9 +13,21 @@ class FER2013(VisionDataset):
"""`FER2013 """`FER2013
<https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset. <https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge>`_ Dataset.
.. note::
This dataset can return test labels only if ``fer2013.csv`` OR
``icml_face_data.csv`` are present in ``root/fer2013/``. If only
``train.csv`` and ``test.csv`` are present, the test labels are set to
``None``.
Args: Args:
root (str or ``pathlib.Path``): Root directory of dataset where directory root (str or ``pathlib.Path``): Root directory of dataset where directory
``root/fer2013`` exists. ``root/fer2013`` exists. This directory may contain either
``fer2013.csv``, ``icml_face_data.csv``, or both ``train.csv`` and
``test.csv``. Precendence is given in that order, i.e. if
``fer2013.csv`` is present then the rest of the files will be
ignored. All these (combinations of) files contain the same data and
are supported for convenience, but only ``fer2013.csv`` and
``icml_face_data.csv`` are able to return non-None test labels.
split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``. split (string, optional): The dataset split, supports ``"train"`` (default), or ``"test"``.
transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
version. E.g, ``transforms.RandomCrop`` version. E.g, ``transforms.RandomCrop``
...@@ -25,6 +37,25 @@ class FER2013(VisionDataset): ...@@ -25,6 +37,25 @@ class FER2013(VisionDataset):
_RESOURCES = { _RESOURCES = {
"train": ("train.csv", "3f0dfb3d3fd99c811a1299cb947e3131"), "train": ("train.csv", "3f0dfb3d3fd99c811a1299cb947e3131"),
"test": ("test.csv", "b02c2298636a634e8c2faabbf3ea9a23"), "test": ("test.csv", "b02c2298636a634e8c2faabbf3ea9a23"),
# The fer2013.csv and icml_face_data.csv files contain both train and
# tests instances, and unlike test.csv they contain the labels for the
# test instances. We give these 2 files precedence over train.csv and
# test.csv. And yes, they both contain the same data, but with different
# column names (note the spaces) and ordering:
# $ head -n 1 fer2013.csv icml_face_data.csv train.csv test.csv
# ==> fer2013.csv <==
# emotion,pixels,Usage
#
# ==> icml_face_data.csv <==
# emotion, Usage, pixels
#
# ==> train.csv <==
# emotion,pixels
#
# ==> test.csv <==
# pixels
"fer": ("fer2013.csv", "f8428a1edbd21e88f42c73edd2a14f95"),
"icml": ("icml_face_data.csv", "b114b9e04e6949e5fe8b6a98b3892b1d"),
} }
def __init__( def __init__(
...@@ -34,11 +65,13 @@ class FER2013(VisionDataset): ...@@ -34,11 +65,13 @@ class FER2013(VisionDataset):
transform: Optional[Callable] = None, transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None,
) -> None: ) -> None:
self._split = verify_str_arg(split, "split", self._RESOURCES.keys()) self._split = verify_str_arg(split, "split", ("train", "test"))
super().__init__(root, transform=transform, target_transform=target_transform) super().__init__(root, transform=transform, target_transform=target_transform)
base_folder = pathlib.Path(self.root) / "fer2013" base_folder = pathlib.Path(self.root) / "fer2013"
file_name, md5 = self._RESOURCES[self._split] use_fer_file = (base_folder / self._RESOURCES["fer"][0]).exists()
use_icml_file = not use_fer_file and (base_folder / self._RESOURCES["icml"][0]).exists()
file_name, md5 = self._RESOURCES["fer" if use_fer_file else "icml" if use_icml_file else self._split]
data_file = base_folder / file_name data_file = base_folder / file_name
if not check_integrity(str(data_file), md5=md5): if not check_integrity(str(data_file), md5=md5):
raise RuntimeError( raise RuntimeError(
...@@ -47,14 +80,26 @@ class FER2013(VisionDataset): ...@@ -47,14 +80,26 @@ class FER2013(VisionDataset):
f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge" f"https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge"
) )
pixels_key = " pixels" if use_icml_file else "pixels"
usage_key = " Usage" if use_icml_file else "Usage"
def get_img(row):
return torch.tensor([int(idx) for idx in row[pixels_key].split()], dtype=torch.uint8).reshape(48, 48)
def get_label(row):
if use_fer_file or use_icml_file or self._split == "train":
return int(row["emotion"])
else:
return None
with open(data_file, "r", newline="") as file: with open(data_file, "r", newline="") as file:
self._samples = [ rows = (row for row in csv.DictReader(file))
(
torch.tensor([int(idx) for idx in row["pixels"].split()], dtype=torch.uint8).reshape(48, 48), if use_fer_file or use_icml_file:
int(row["emotion"]) if "emotion" in row else None, valid_keys = ("Training",) if self._split == "train" else ("PublicTest", "PrivateTest")
) rows = (row for row in rows if row[usage_key] in valid_keys)
for row in csv.DictReader(file)
] self._samples = [(get_img(row), get_label(row)) for row in rows]
def __len__(self) -> int: def __len__(self) -> int:
return len(self._samples) return len(self._samples)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment