wikibk_dataset.py 568 Bytes
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from glob import glob
from .base_dataset import BaseDataset


class WikibkDataset(BaseDataset):
    def __init__(self, *args, split="", **kwargs):
        assert split in ["train", "val", "test"]
        if split == "test":
            split = "val"

        if split == "train":
            names = [f"wikibk_train_{i}" for i in range(50)]
        elif split == "val":
            names = ["wikibk_val_0"]

        super().__init__(*args, **kwargs, names=names, text_column_name="caption")

    def __getitem__(self, index):
        return self.get_text_suite(index)