evaluate.py 3.91 KB
Newer Older
wangwf's avatar
init  
wangwf committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from collections import defaultdict
import json
import os
import os.path as osp

import cv2
import numpy as np
from prettytable import PrettyTable
import torch
import tqdm
from torchmetrics.multimodal import CLIPScore
from torchmetrics.functional.multimodal.clip_score import _clip_score_update


class P2CLIPScore(CLIPScore):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.category2scores = defaultdict(list)
        self.category2nprompts = defaultdict(int)
        self.category2nimages = defaultdict(int)

    def process(self, p2_images_dir):
        prompt_dirs = []
        for cat_dir_name in os.listdir(p2_images_dir): 
            for prompt_dir_name in os.listdir(osp.join(p2_images_dir, cat_dir_name)):
                prompt_dir = osp.join(p2_images_dir, cat_dir_name, prompt_dir_name)
                prompt_dirs.append(prompt_dir)

        print("Processing...")
        for prompt_dir in tqdm.tqdm(prompt_dirs):
            prompt_json = osp.join(prompt_dir, "prompt_info.json")
            with open(prompt_json, "r") as f:
                prompt_info = json.load(f)
            category = prompt_info["category"]
            cat_dir_name = prompt_dir.split("/")[-2]
            assert cat_dir_name == category.replace(" ", "").replace("&", "_")

            imgs = []
            for file_name in os.listdir(prompt_dir):
                if not file_name.endswith(".png"):
                    continue
                image_path = osp.join(prompt_dir, file_name)
                img = cv2.imread(image_path)[None, ...]
                imgs.append(img)
            assert len(imgs) >= 1

            scores, _ = _clip_score_update(
                [prompt_info["prompt_text"]] * len(imgs),
                torch.from_numpy(np.concatenate(imgs, 0).transpose(0, 3, 1, 2)), 
                self.model, 
                self.processor
            )

            # self.category2scores["All"].extend(scores.detach().numpy().tolist())
            # self.category2scores[category].extend(scores.detach().numpy().tolist())
            self.category2scores["All"].append(scores.max().item())
            self.category2scores[category].append(scores.max().item())
            self.category2nprompts["All"] += 1
            self.category2nprompts[category] += 1
            self.category2nimages["All"] += len(imgs)
            self.category2nimages[category] += len(imgs)

    def compute(self, output_json=None):
        pt = PrettyTable()
        pt.title = "Evaluation Results of PartiPrompts Dataset"
        pt.field_names = ["Category", "Num Prompts", "Num Images", "Mean CLIP Score"]
        for category, scores in self.category2scores.items():
            num_prompts = self.category2nprompts[category]
            num_images = self.category2nimages[category]
            mean_score = sum(scores) / len(scores)
            pt.add_row([category, num_prompts, num_images, round(mean_score, 4)])
        print(pt)

        if output_json is not None:
            with open(output_json, "w") as f:
                f.write(pt.get_json_string())

def main():
    import argparse
    parser = argparse.ArgumentParser(
        "Evaluate text2image results of PartiPrompts dataset")
    parser.add_argument("-m", "--model-dir", 
                        type=str, 
                        required=True, 
                        help="The path to the model directory.")
    parser.add_argument("-d", "--data-dir", 
                        type=str, 
                        required=True, 
                        help="The path to the evaluation data directory.")
    parser.add_argument("-o", "--output-json", 
                        type=str, 
                        default=None, 
                        help="Output json file path.")
    args = parser.parse_args()

    p2_clip_score = P2CLIPScore(args.model_dir)
    p2_clip_score.process(args.data_dir)
    p2_clip_score.compute(args.output_json)


if __name__ == "__main__":
    main()