data.py 1.22 KB
Newer Older
1
2
import torch
from datasets import load_dataset
3
4
from torch.utils.data import Dataset

5
6

class BeansDataset(Dataset):
7
8

    def __init__(self, image_processor, tp_size=1, split='train'):
9
10
11
12
13

        super().__init__()
        self.image_processor = image_processor
        self.ds = load_dataset('beans')[split]
        self.label_names = self.ds.features['labels'].names
14
15
16
        while len(self.label_names) % tp_size != 0:
            # ensure that the number of labels is multiple of tp_size
            self.label_names.append(f"pad_label_{len(self.label_names)}")
17
18
19
20
        self.num_labels = len(self.label_names)
        self.inputs = []
        for example in self.ds:
            self.inputs.append(self.process_example(example))
21

22
23
24
25
26
    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx]
27

28
29
30
31
    def process_example(self, example):
        input = self.image_processor(example['image'], return_tensors='pt')
        input['labels'] = example['labels']
        return input
32

33
34

def beans_collator(batch):
35
36
37
38
    return {
        'pixel_values': torch.cat([data['pixel_values'] for data in batch], dim=0),
        'labels': torch.tensor([data['labels'] for data in batch], dtype=torch.int64)
    }