{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from PIL import Image\n", "\n", "from lavis.models import load_model_and_preprocess" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Load an example image" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "raw_image = Image.open(\"../docs/_static/merlion.png\").convert(\"RGB\")\n", "caption = \"a large fountain spewing water into the air\"\n", "\n", "display(raw_image.resize((596, 437)))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# setup device to use\n", "device = torch.device(\"cuda\") if torch.cuda.is_available() else \"cpu\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model, vis_processors, txt_processors = load_model_and_preprocess(name=\"blip2_feature_extractor\", model_type=\"pretrain\", is_eval=True, device=device)\n", "image = vis_processors[\"eval\"](raw_image).unsqueeze(0).to(device)\n", "text_input = txt_processors[\"eval\"](caption)\n", "sample = {\"image\": image, \"text_input\": [text_input]}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Multimodal features" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "features_multimodal = model.extract_features(sample)\n", "print(features_multimodal.multimodal_embeds.shape)\n", "# torch.Size([1, 32, 768]), 32 is the number of queries" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Unimodal features" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "features_image = model.extract_features(sample, mode=\"image\")\n", "features_text = model.extract_features(sample, mode=\"text\")\n", "print(features_image.image_embeds.shape)\n", "# torch.Size([1, 32, 768])\n", "print(features_text.text_embeds.shape)\n", "# torch.Size([1, 12, 768])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Normalized low-dimensional unimodal features" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# low-dimensional projected features\n", "print(features_image.image_embeds_proj.shape)\n", "# torch.Size([1, 32, 256])\n", "print(features_text.text_embeds_proj.shape)\n", "# torch.Size([1, 12, 256])\n", "similarity = (features_image.image_embeds_proj @ features_text.text_embeds_proj[:,0,:].t()).max()\n", "print(similarity)\n", "# tensor([[0.3642]])" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" }, "vscode": { "interpreter": { "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe" } } }, "nbformat": 4, "nbformat_minor": 2 }