dataset.py 1.68 KB
Newer Older
yangzhong's avatar
yangzhong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
"""
dataset related classes and methods
"""

# pylint: disable=unused-argument,missing-docstring

import logging
import sys
import time

import cv2
import numpy as np


logging.basicConfig(level=logging.INFO)
log = logging.getLogger("dataset")


class Dataset:
    def __init__(self):
        self.arrival = None
        self.image_list = []
        self.caption_list = []
        self.items_inmemory = {}
        self.last_loaded = -1

    def preprocess(self, use_cache=True):
        raise NotImplementedError("Dataset:preprocess")

    def get_item_count(self):
        return len(self.image_list)

    def get_list(self):
        raise NotImplementedError("Dataset:get_list")

    def load_query_samples(self, sample_list):
        self.items_inmemory = {}
        for sample in sample_list:
            self.items_inmemory[sample] = self.get_item(sample)
        self.last_loaded = time.time()

    def unload_query_samples(self, sample_list):
        if sample_list:
            for sample in sample_list:
                if sample in self.items_inmemory:
                    del self.items_inmemory[sample]
        else:
            self.items_inmemory = {}

    def get_samples(self, id_list):
        data = [
            {
                "input_tokens": self.items_inmemory[id]["input_tokens"],
                "input_tokens_2": self.items_inmemory[id]["input_tokens_2"],
                "latents": self.items_inmemory[id]["latents"],
            }
            for id in id_list
        ]
        images = [self.items_inmemory[id]["file_name"] for id in id_list]
        return data, images

    def get_item(self, id):
        raise NotImplementedError("Dataset:get_item")


def preprocess(img):
    return img