import numpy as np import pandas as pd import torch import sys from tqdm import tqdm import open_clip import argparse, os import torch.nn.functional as F from PIL import Image from torch.utils.data import Dataset from transformers import CLIPTokenizer class TextImagePairDataset_all(Dataset): def __init__(self, text_file, image_dir, tokenizer, transform=None): self.image_dir = image_dir self.text_file = text_file self.tokenizer = tokenizer self.transform = transform df = pd.read_csv(text_file, sep='\t') self.image_paths = [os.path.join(image_dir, f"{f:05}.png") for f in range(len(df))] # df = pd.read_csv(text_file, sep='\t') self.prompts = df['Prompt'] assert len(self.image_paths) == len(self.prompts), "The number of images and texts must be the same." def __len__(self): return len(self.image_paths) def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() image_path = self.image_paths[idx] text = self.prompts[idx] image = Image.open(image_path).convert('RGB') image = self.transform(image) #.unsqueeze(0) tokens = self.tokenizer(text) return tokens, image class TextImagePairDataset(Dataset): def __init__(self, text_file, image_dir, tokenizer, transform=None): self.image_dir = image_dir self.text_file = text_file self.tokenizer = tokenizer self.transform = transform self.image_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir) if f.endswith(('.png','.jpg','.jpeg','.tiff','.bmp','.gif'))] self.texts = [] with open(text_file, 'r') as f: for line in f: self.texts.append(line.strip()) assert len(self.image_paths) == len(self.texts), "The number of images and texts must be the same." def __len__(self): return len(self.image_paths) def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() image_path = self.image_paths[idx] text = self.texts[idx] image = Image.open(image_path) # .convert('RGB') image = self.transform(image) #.unsqueeze(0) tokens = self.tokenizer(text) return tokens, image def calculate_clip_score(texts_file, images_dir, batch_size, device, num_workers, output): model_clip, _, preprocess_clip = open_clip.create_model_and_transforms('ViT-H-14', device=device, pretrained='laion2b_s32b_b79k') tokenizer = open_clip.get_tokenizer('ViT-H-14') dataset = TextImagePairDataset_all(texts_file, images_dir, tokenizer=tokenizer, transform=preprocess_clip) dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers) all_scores = [] all_scores_cpu = [] all_similarity = [] print(len(dataloader)) for texts, imgs in tqdm(dataloader): texts = texts.reshape(texts.shape[1],texts.shape[2]).to(device) texts = texts.to(device) imgs = imgs.to(device) with torch.no_grad(): img_fts = model_clip.encode_image(imgs) text_fts = model_clip.encode_text(texts) scores = F.cosine_similarity(img_fts, text_fts).squeeze() all_scores.append(scores) results_name = f"{output}.txt" if os.path.exists(results_name): os.remove(results_name) print("delete old results") for i in range(len(all_scores)): with open(results_name, 'a') as f: f.write(str(all_scores[i].cpu().numpy()) + '\n') all_scores_cpu.append(all_scores[i].cpu().numpy()) average_score = np.mean(all_scores_cpu) return average_score def main(): parser = argparse.ArgumentParser() parser.add_argument( "--texts", type=str, nargs="?", default="./PartiPrompts.tsv", #数据集路径 # default="texts/text.txt", ) parser.add_argument( "--images", type=str, nargs="?", default="./DPM-sample" #保存图片的文件夹路径 ) parser.add_argument( "--output", type=str, nargs="?", default="./DMP_all_scores" #保存图片的文件夹路径 ) parser.add_argument( "--batch_size", type=int, default=1 ) parser.add_argument( "--num_workers", type=int, default=1 ) parser.add_argument( "--device", type=str, default="cuda", ) args = parser.parse_args() if args.device is None: device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu') else: device = torch.device(args.device) clip_score = calculate_clip_score(args.texts, args.images, args.batch_size, device, args.num_workers, args.output) print('CLIP-score: ', clip_score) if __name__ == '__main__': main()