from collections import defaultdict import json import os import os.path as osp import cv2 import numpy as np from prettytable import PrettyTable import torch import tqdm from torchmetrics.multimodal import CLIPScore from torchmetrics.functional.multimodal.clip_score import _clip_score_update class P2CLIPScore(CLIPScore): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.category2scores = defaultdict(list) self.category2nprompts = defaultdict(int) self.category2nimages = defaultdict(int) def process(self, p2_images_dir): prompt_dirs = [] for cat_dir_name in os.listdir(p2_images_dir): for prompt_dir_name in os.listdir(osp.join(p2_images_dir, cat_dir_name)): prompt_dir = osp.join(p2_images_dir, cat_dir_name, prompt_dir_name) prompt_dirs.append(prompt_dir) print("Processing...") for prompt_dir in tqdm.tqdm(prompt_dirs): prompt_json = osp.join(prompt_dir, "prompt_info.json") with open(prompt_json, "r") as f: prompt_info = json.load(f) category = prompt_info["category"] cat_dir_name = prompt_dir.split("/")[-2] assert cat_dir_name == category.replace(" ", "").replace("&", "_") imgs = [] for file_name in os.listdir(prompt_dir): if not file_name.endswith(".png"): continue image_path = osp.join(prompt_dir, file_name) img = cv2.imread(image_path)[None, ...] imgs.append(img) assert len(imgs) >= 1 scores, _ = _clip_score_update( [prompt_info["prompt_text"]] * len(imgs), torch.from_numpy(np.concatenate(imgs, 0).transpose(0, 3, 1, 2)), self.model, self.processor ) # self.category2scores["All"].extend(scores.detach().numpy().tolist()) # self.category2scores[category].extend(scores.detach().numpy().tolist()) self.category2scores["All"].append(scores.max().item()) self.category2scores[category].append(scores.max().item()) self.category2nprompts["All"] += 1 self.category2nprompts[category] += 1 self.category2nimages["All"] += len(imgs) self.category2nimages[category] += len(imgs) def compute(self, output_json=None): pt = PrettyTable() pt.title = "Evaluation Results of PartiPrompts Dataset" pt.field_names = ["Category", "Num Prompts", "Num Images", "Mean CLIP Score"] for category, scores in self.category2scores.items(): num_prompts = self.category2nprompts[category] num_images = self.category2nimages[category] mean_score = sum(scores) / len(scores) pt.add_row([category, num_prompts, num_images, round(mean_score, 4)]) print(pt) if output_json is not None: with open(output_json, "w") as f: f.write(pt.get_json_string()) def main(): import argparse parser = argparse.ArgumentParser( "Evaluate text2image results of PartiPrompts dataset") parser.add_argument("-m", "--model-dir", type=str, required=True, help="The path to the model directory.") parser.add_argument("-d", "--data-dir", type=str, required=True, help="The path to the evaluation data directory.") parser.add_argument("-o", "--output-json", type=str, default=None, help="Output json file path.") args = parser.parse_args() p2_clip_score = P2CLIPScore(args.model_dir) p2_clip_score.process(args.data_dir) p2_clip_score.compute(args.output_json) if __name__ == "__main__": main()