{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/export/share/anasawadalla/miniconda3/envs/xgenmm-release-clone/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import os\n", "\n", "from omegaconf import OmegaConf\n", "from functools import partial\n", "from PIL import Image\n", "import torch\n", "\n", "from open_flamingo import create_model_and_transforms \n", "from open_flamingo.train.any_res_data_utils import process_images" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference code" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n", "Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00, 1.60it/s]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "xgenmm_v1 model initialized with 3,931,031,619 trainable parameters\n", "==========Trainable Parameters\n", "Vision encoder: 0 trainable parameters\n", "Vision tokenizer: 109,901,568 trainable parameters\n", "Language model: 3,821,130,051 trainable parameters\n", "==========Total Parameters\n", "Vision encoder: 428,225,600 parameters\n", "Vision tokenizer: 109,901,568 parameters\n", "Language model: 3,821,130,051 parameters\n", "==========\n" ] } ], "source": [ "# Set model configs.\n", "model_ckpt=\"path/to/your/local/checkpoint.pt\"\n", "cfg = dict(\n", " model_family = 'xgenmm_v1',\n", " lm_path = 'microsoft/Phi-3-mini-4k-instruct',\n", " vision_encoder_path = 'google/siglip-so400m-patch14-384',\n", " vision_encoder_pretrained = 'google',\n", " num_vision_tokens = 128,\n", " image_aspect_ratio = 'anyres',\n", " anyres_patch_sampling = True,\n", " anyres_grids = [(1,2),(2,1),(2,2),(3,1),(1,3)],\n", " ckpt_pth = model_ckpt,\n", ")\n", "cfg = OmegaConf.create(cfg)\n", "\n", "additional_kwargs = {\n", " \"num_vision_tokens\": cfg.num_vision_tokens,\n", " \"image_aspect_ratio\": cfg.image_aspect_ratio,\n", " \"anyres_patch_sampling\": cfg.anyres_patch_sampling,\n", "}\n", "\n", "# Initialize the model.\n", "model, image_processor, tokenizer = create_model_and_transforms(\n", " clip_vision_encoder_path=cfg.vision_encoder_path,\n", " clip_vision_encoder_pretrained=cfg.vision_encoder_pretrained,\n", " lang_model_path=cfg.lm_path,\n", " tokenizer_path=cfg.lm_path,\n", " model_family=cfg.model_family,\n", " **additional_kwargs)\n", "\n", "ckpt = torch.load(cfg.ckpt_pth)[\"model_state_dict\"]\n", "model.load_state_dict(ckpt, strict=True)\n", "torch.cuda.empty_cache()\n", "model = model.eval().cuda()\n", "\n", "base_img_size = model.base_img_size\n", "anyres_grids = []\n", "for (m,n) in cfg.anyres_grids:\n", " anyres_grids.append([base_img_size*m, base_img_size*n])\n", "model.anyres_grids = anyres_grids" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# Preprocessing utils.\n", "\n", "image_proc = partial(process_images, image_processor=image_processor, model_cfg=cfg)\n", "\n", "def apply_prompt_template(prompt, cfg):\n", " if 'Phi-3' in cfg.lm_path:\n", " s = (\n", " '<|system|>\\nA chat between a curious user and an artificial intelligence assistant. '\n", " \"The assistant gives helpful, detailed, and polite answers to the user's questions.<|end|>\\n\"\n", " f'<|user|>\\n{prompt}<|end|>\\n<|assistant|>\\n'\n", " )\n", " else:\n", " raise NotImplementedError\n", " return s" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# Prep image input.\n", "image_path_1 = 'example_images/image-1.jpeg'\n", "image_path_2 = 'example_images/image-2.jpeg'\n", "\n", "image_1 = Image.open(image_path_1).convert('RGB')\n", "image_2 = Image.open(image_path_2).convert('RGB')\n", "images = [image_1, image_2]\n", "image_size = [image_1.size, image_2.size]\n", "image_size = [image_size]\n", "vision_x = [image_proc([img]) for img in images]\n", "vision_x = [vision_x]" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "# Prep language input.\n", "prompt = \"Look at this image and this image . What is in the second image?\"\n", "prompt = apply_prompt_template(prompt, cfg)\n", "lang_x = tokenizer([prompt], return_tensors=\"pt\")" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/export/share/anasawadalla/miniconda3/envs/xgenmm-release-clone/lib/python3.10/site-packages/transformers/generation/configuration_utils.py:515: UserWarning: `do_sample` is set to `False`. However, `temperature` is set to `0` -- this flag is only used in sample-based generation modes. You should set `do_sample=True` or unset `temperature`.\n", " warnings.warn(\n", "You are not running the flash-attention implementation, expect numerical differences.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "A black and white cat. \n" ] } ], "source": [ "# Run inference.\n", "kwargs_default = dict(do_sample=False, temperature=0, max_new_tokens=1024, top_p=None, num_beams=1)\n", "\n", "generated_text = model.generate(\n", " vision_x=vision_x, \n", " lang_x=lang_x['input_ids'].to(torch.device('cuda:0')), \n", " image_size=image_size,\n", " attention_mask=lang_x['attention_mask'].to(torch.device('cuda:0')), \n", " **kwargs_default)\n", " \n", "generated_text = tokenizer.decode(generated_text[0], skip_special_tokens=True)\n", "if 'Phi-3' in cfg.lm_path:\n", " text = generated_text.split('<|end|>')[0]\n", "else:\n", " text=generated_text\n", "\n", "print(text)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }