import open_clip
import torch
from PIL import Image


if __name__ == "__main__":
    from argparse import ArgumentParser
    
    parser = ArgumentParser()
    
    parser.add_argument("--model_name", type=str, default="coca_ViT-L-14")
    
    parser.add_argument("--pretrained", type=str, default="mscoco_finetuned_laion2B-s13B-b90k")
    
    parser.add_argument("--image_path", type=str)
    
    args = parser.parse_args()

    model, _, transform = open_clip.create_model_and_transforms(
        model_name=args.model_name,
        pretrained=args.pretrained
    )

    model.cuda()
    
    im = Image.open(args.image_path).convert("RGB")
    im = transform(im).unsqueeze(0).cuda()

    with torch.no_grad(), torch.cuda.amp.autocast():
        generated = model.generate(im)

    print(open_clip.decode(generated[0]).split("<end_of_text>")[0].replace("<start_of_text>", ""))