data.py 1.22 KB
Newer Older
1
2
import torch
from datasets import load_dataset
3
4
from torch.utils.data import Dataset

5
6

class BeansDataset(Dataset):
7
    def __init__(self, image_processor, tp_size=1, split="train"):
8
9
        super().__init__()
        self.image_processor = image_processor
10
11
        self.ds = load_dataset("beans")[split]
        self.label_names = self.ds.features["labels"].names
12
13
14
        while len(self.label_names) % tp_size != 0:
            # ensure that the number of labels is multiple of tp_size
            self.label_names.append(f"pad_label_{len(self.label_names)}")
15
16
17
18
        self.num_labels = len(self.label_names)
        self.inputs = []
        for example in self.ds:
            self.inputs.append(self.process_example(example))
19

20
21
22
23
24
    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx]
25

26
    def process_example(self, example):
27
28
        input = self.image_processor(example["image"], return_tensors="pt")
        input["labels"] = example["labels"]
29
        return input
30

31
32

def beans_collator(batch):
33
    return {
34
35
        "pixel_values": torch.cat([data["pixel_values"] for data in batch], dim=0),
        "labels": torch.tensor([data["labels"] for data in batch], dtype=torch.int64),
36
    }