{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from PIL import Image\n", "\n", "from lavis.models import load_model_and_preprocess\n", "from lavis.processors import load_processor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Load an example image and text" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "raw_image = Image.open(\"../docs/_static/merlion.png\").convert(\"RGB\")\n", "display(raw_image.resize((596, 437)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# setup device to use\n", "device = torch.device(\"cuda\") if torch.cuda.is_available() else \"cpu\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "caption = \"merlion in Singapore\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Load model and preprocessors" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model, vis_processors, text_processors = load_model_and_preprocess(\"blip2_image_text_matching\", \"pretrain\", device=device, is_eval=True)\n", "# model, vis_processors, text_processors = load_model_and_preprocess(\"blip2_image_text_matching\", \"coco\", device=device, is_eval=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Preprocess image and text inputs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "img = vis_processors[\"eval\"](raw_image).unsqueeze(0).to(device)\n", "txt = text_processors[\"eval\"](caption)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Compute image-text matching (ITM) score" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "itm_output = model({\"image\": img, \"text_input\": txt}, match_head=\"itm\")\n", "itm_scores = torch.nn.functional.softmax(itm_output, dim=1)\n", "print(f'The image and text are matched with a probability of {itm_scores[:, 1].item():.3%}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "itc_score = model({\"image\": img, \"text_input\": txt}, match_head='itc')\n", "print('The image feature and text feature has a cosine similarity of %.4f'%itc_score)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" }, "vscode": { "interpreter": { "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe" } } }, "nbformat": 4, "nbformat_minor": 2 }