import torch
import clip
import os

from PIL import Image


if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser()
    
    parser.add_argument("--pt", type=str, help="模型名称")
    
    args = parser.parse_args()
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    if ".pt" in args.pt:
        model, preprocess = clip.load(f"pretrained_models/{args.pt}", device=device)
    else:
        model, preprocess = clip.load(f"{args.pt}", device=device)
    
    image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
    text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)
        
        logits_per_image, logits_per_text = model(image, text)
        probs = logits_per_image.softmax(dim=-1).cpu().numpy()
        
        print(probs)