data_item.py 5.9 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import json
import yaml
import torch
import random
import os
import glob
import pickle
from datasets import load_dataset
from .openx import OpenXDataItem
from tqdm import tqdm

class DataItem:
    """
    Curate data items from all data sources
    """
    def __init__(self, training_size=-1, local_run=False):
        self.training_size = training_size
        self.local_run = local_run

    def _get_dataset_tag(self, data_path):
        if "epic" in data_path.lower():
            return "epic"
        elif "open-x" in data_path or "openx" in data_path:
            if 'traces' in data_path:
                return "openx_magma"
            else:
                return "openx"
        elif "sthv2" in data_path.lower():
            return "sthv2"
        elif "exoego4d" in data_path.lower():
            return "exoego4d"
        elif 'ego4d' in data_path.lower():
            return "ego4d"
        elif 'aitw' in data_path.lower():
            return "aitw"
        elif 'seeclick' in data_path.lower() and 'ocr' in data_path.lower():
            return "seeclick_ocr"            
        elif 'seeclick' in data_path.lower():
            return "seeclick"
        elif 'mind2web' in data_path.lower():
            return "mind2web"
        elif 'vision2ui' in data_path.lower():
            return "vision2ui"
        elif 'llava' in data_path.lower():
            return "llava"
        elif 'magma' in data_path.lower():
            return "magma"
        elif 'sharegpt4v' in data_path.lower():
            return "sharegpt4v"
        else:
            raise ValueError(f"Dataset tag not found for {data_path}")
    
    def _get_items(self, data_path, image_folder=None, processor=None, conversation_lib=None):
        if data_path.endswith(".json"):
            list_data_dict = json.load(open(data_path, "r"))
        elif data_path.endswith(".jsonl"):
            list_data_dict = [json.loads(line) for line in open(data_path, "r")]
        elif data_path.endswith(".pth"):
            list_data_dict = torch.load(data_path, map_location="cpu")
            # random.shuffle(list_data_dict)
        else:
            if self._get_dataset_tag(data_path) == "openx":
                list_data_dict = OpenXDataItem()(data_path, image_folder, processor=processor, conversation_lib=conversation_lib, local_run=self.local_run)
            elif self._get_dataset_tag(data_path) == "pixelprose":
                # Load the dataset
                list_data_dict = load_dataset(
                    data_path, 
                    cache_dir=image_folder
                )
            else:
                data_folder = os.path.dirname(data_path)
                # get file name from data_path
                data_files = data_path.split('/')[-1].split('+')
                list_data_dict = []
                for file in data_files:
                    json_path = os.path.join(data_folder, file + '.json')      
                    list_data_dict.extend(json.load(open(json_path, "r")))                
        return list_data_dict
    
    def __call__(self, data_path, processor=None, conversation_lib=None, is_eval=False):
        assert data_path is not None, "Data path is not provided"
        if data_path.endswith(".yaml"):
            data_dict = yaml.load(open(data_path, "r"), Loader=yaml.FullLoader)    
            data_path_key = 'DATA_PATH' if not is_eval else 'DATA_PATH_VAL'
            image_folder_key = 'IMAGE_FOLDER' if not is_eval else 'IMAGE_FOLDER_VAL'
            assert len(data_dict[data_path_key]) == len(data_dict[image_folder_key]), "Data path and image folder mismatch"
            items = {}
            dataset_names = []
            dataset_folders = []
            for i, (data_path, image_folder) in enumerate(zip(data_dict[data_path_key], data_dict[image_folder_key])):
                items_temp = self._get_items(data_path, image_folder, processor, conversation_lib)                
                dataset_tag = self._get_dataset_tag(data_path)                
                if dataset_tag != "openx":
                    # if self.training_size > 0:
                    #     items_temp = items_temp[:self.training_size]             
                    if dataset_tag in ['sthv2', "ego4d", "exoego4d"]: 
                        for item in items_temp:
                            item['image_folder'] = image_folder
                            item['dataset_tag'] = dataset_tag
                            item['gpt_response'] = ''
                            item['global_instructions'] = item['annotations']
                    elif dataset_tag in ["openx_magma"]:
                        items_dict_temp = []
                        for item in items_temp:
                            items_dict_temp.append(
                                {
                                    'image': item.replace('traces', 'images').replace('.pth', '.jpg'),
                                    'trace': item,
                                    'image_folder': image_folder,
                                    'dataset_tag': dataset_tag
                                }
                            ) 
                        items_temp = items_dict_temp         
                    else:
                        # add image_foler to each item
                        for item in items_temp:
                            item['image_folder'] = image_folder
                        # add dataset tag to each item
                        for item in items_temp:
                            item['dataset_tag'] = dataset_tag
                if dataset_tag in items:
                    items[dataset_tag].extend(items_temp)
                else:
                    items[dataset_tag] = items_temp
                    dataset_names.append(dataset_tag)
                    dataset_folders.append(image_folder)
        else:
            items = self._get_items(data_path)
            dataset_names = None
            dataset_folders = None  
        return items, dataset_names, dataset_folders