eval_open_ended.py 5.05 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import argparse
import torch
import os
import random
import json
from tqdm import tqdm
from ultralytics.nn.text_model import MobileCLIP
import numpy as np
import argparse
from transformers import logging
import torch.multiprocessing as mp
from ultralytics.utils import yaml_load

import clip
import copy
from lvis import LVIS, LVISEval
from pathlib import Path

logging.set_verbosity_error()

class LVISOpenEndedEval():
    def __init__(self, anno_path, pred_path, name_path):
        self.anno_path = anno_path
        self.pred_path = pred_path
        
        categories = yaml_load("ultralytics/cfg/datasets/lvis.yaml")["names"].values()
        self.categories = [c.split("/")[0] for c in categories]

        with open(name_path, 'r') as f:
            self.names = [x.strip() for x in f.readlines()]

        self.topk_for_mapping = 1

    def setup_clip_matching(self):
        print("Loading CLIP model for matching")
        self.clip_model = MobileCLIP("blt", device="cuda")
        
        tokens = clip.tokenize(self.categories).cuda()
        text_features = self.clip_model.encode_text(tokens)
        
        self.lvis_embed = text_features.transpose(-1, -2)

    @torch.inference_mode()
    def match(self, total, rank, verbose=False):
        self.setup_clip_matching()
        
        predictions = []
        data_all = json.load(open(self.pred_path))
        data_all = np.array_split(data_all, total)[rank].tolist()

        batch = 512
        chunk_num = (len(data_all) + batch - 1) // batch
        data_all = np.array_split(data_all, chunk_num)
        
        for batch_data in tqdm(data_all):
            text = []
            for data in batch_data:
                description = self.names[data['category_id'] - 1]
                text.append(description)

            tokens = clip.tokenize(text).cuda()
            text_features = self.clip_model.encode_text(tokens)
            similarity = (100.0 * text_features @ self.lvis_embed).softmax(dim=-1)
            sim_values, sim_indices = similarity.topk(self.topk_for_mapping)
            category_ids = sim_indices.cpu().numpy() + 1
            scores = sim_values.cpu().numpy()
            for i, data in enumerate(batch_data):
                for cat_id, score in zip(category_ids[i], scores[i]):
                    d = copy.deepcopy(data)
                    d['category_id'] = int(cat_id)
                    predictions.append(d)
                    if verbose:
                        print(f"[pred | gt] ({score:.2f}): {text[i]} | {self.categories[cat_id-1]}")
            
        return predictions

def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def worker(args):
    json, pred, names, total, rank, device, seed, verbose = args
    setup_seed(seed)
    torch.cuda.set_device(int(device))
    
    lvis_eval = LVISOpenEndedEval(json, pred, names)
    match_result = lvis_eval.match(total, rank, verbose=(rank == 0) and verbose)
    return match_result

def main(args):
    pred_file = Path(args.pred).with_suffix('.mt.json')
    devices = args.devices.split(",")
    ranks = list(range(len(devices)))
    total = len(devices)
    mp.set_start_method('spawn')
    with mp.Pool(total) as pool:
        results = pool.map(worker,
                        [(args.json, args.pred, args.names,
                            total, rank, device, args.seed, args.verbose)
                            for rank, device in zip(ranks, devices)])

    predictions = []
    for result in results:
        predictions.extend(result)
    print(f"Total predictions: {len(predictions)}")
    predictions = sorted(predictions, key=lambda x: (x['image_id'], -x['score']))
    with open(pred_file, 'w') as f:
        json.dump(predictions, f)
    print(f"Saved predictions to {pred_file}")

    if args.fixed:
        print("Evaluating LVIS fixed mAP...")
        os.system(f"python ./tools/eval_fixed_ap.py {args.json} {pred_file}")
    else:
        print("Evaluating LVIS standard mAP...")
        anno = LVIS(args.json)
        pred = anno._load_json(pred_file)
        lvis_eval = LVISEval(anno, pred, "bbox")
        lvis_eval.evaluate()
        lvis_eval.accumulate()
        lvis_eval.summarize()
        lvis_eval.print_results()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='osprey demo', formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('--json', help='path to lvis minival json file', default='./data/lvis_v1_minival.json')
    parser.add_argument('--pred', help='path to pred json file', default='./data/predictions.json')
    parser.add_argument('--names', help='path to vocab names', default='./tools/ram_tag_list.txt')
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--devices", type=str, default="0,1,2,3,4,5,6,7")
    parser.add_argument('--verbose', help='verbose', action='store_true')
    parser.add_argument('--fixed', help='evaluate by fixed ap', action='store_true')
    args = parser.parse_args()

    main(args)