{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "6401fb4e-c559-49e5-bc88-874a74dd54c4", "metadata": {}, "outputs": [], "source": [ "import os\n", "root_dir = os.path.join(os.getcwd(), \"..\")\n", "import sys\n", "sys.path.append(root_dir)\n", "from vtimellm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN\n", "from vtimellm.conversation import conv_templates, SeparatorStyle\n", "from vtimellm.model.builder import load_pretrained_model, load_lora\n", "from vtimellm.utils import disable_torch_init\n", "from vtimellm.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria, VideoExtractor\n", "from PIL import Image\n", "import requests\n", "from io import BytesIO\n", "from transformers import TextStreamer\n", "from easydict import EasyDict as edict\n", "try:\n", " from torchvision.transforms import InterpolationMode\n", " BICUBIC = InterpolationMode.BICUBIC\n", "except ImportError:\n", " from PIL import Image\n", " BICUBIC = Image.BICUBIC\n", "from torchvision.transforms import Compose, Resize, CenterCrop, Normalize\n", "import numpy as np\n", "import clip\n", "import torch" ] }, { "cell_type": "code", "execution_count": 2, "id": "a05c755d-ae5c-4e93-ae6f-4e4a6e795b14", "metadata": {}, "outputs": [], "source": [ "args = edict()\n", "args.model_base = \"/path/to/vicuna-7b-v1.5\"\n", "args.clip_path = os.path.join(root_dir, \"checkpoints/clip/ViT-L-14.pt\")\n", "args.pretrain_mm_mlp_adapter = os.path.join(root_dir, \"checkpoints/vtimellm-vicuna-v1-5-7b-stage1/mm_projector.bin\")\n", "args.stage2 = os.path.join(root_dir, \"checkpoints/vtimellm-vicuna-v1-5-7b-stage2\")\n", "args.stage3 = os.path.join(root_dir, \"checkpoints/vtimellm-vicuna-v1-5-7b-stage3\")\n", "args.video_path = os.path.join(root_dir, \"images/demo.mp4\")\n", "args.temperature = 0.05" ] }, { "cell_type": "code", "execution_count": 3, "id": "d3815e10-d5a5-4fdf-a98f-ae9050f23b97", "metadata": {}, "outputs": [], "source": [ "def inference(model, tokenizer, context_len, image, args):\n", " conv = conv_templates['v1'].copy()\n", " roles = conv.roles\n", " first = True\n", " while True:\n", " try:\n", " inp = input(f\"{roles[0]}: \")\n", " except EOFError:\n", " inp = \"\"\n", " if not inp:\n", " print(\"exit...\")\n", " break\n", "\n", " print(f\"{roles[1]}: \", end=\"\")\n", "\n", " if first:\n", " # first message\n", " inp = DEFAULT_IMAGE_TOKEN + '\\n' + inp\n", " conv.append_message(conv.roles[0], inp)\n", " first = False\n", " else:\n", " # later messages\n", " conv.append_message(conv.roles[0], inp)\n", " conv.append_message(conv.roles[1], None)\n", " prompt = conv.get_prompt()\n", "\n", " input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n", " stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 # plain:sep(###) v1:sep2(None)\n", " keywords = [stop_str]\n", " stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n", " streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n", "\n", " with torch.inference_mode():\n", " output_ids = model.generate(\n", " input_ids,\n", " images=image[None,].cuda(),\n", " do_sample=True,\n", " temperature=args.temperature,\n", " max_new_tokens=1024,\n", " streamer=streamer,\n", " use_cache=True,\n", " stopping_criteria=[stopping_criteria]\n", " )\n", "\n", " outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()\n", " conv.messages[-1][-1] = outputs" ] }, { "cell_type": "code", "execution_count": 4, "id": "a29e94ff-620e-4c86-adee-47206b67b606", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "You are using a model of type llama to instantiate a model of type VTimeLLM. This is not supported for all configurations of models and can yield errors.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loading VTimeLLM from base model...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8c13771904c8424abf87e3e2507e4bd8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00\n", "images = transform(images / 255.0)\n", "images = images.to(torch.float16)\n", "with torch.no_grad():\n", " features = clip_model.encode_image(images.to('cuda'))" ] }, { "cell_type": "code", "execution_count": 8, "id": "010a986d-6121-43f2-9971-619472c0732c", "metadata": {}, "outputs": [ { "name": "stdin", "output_type": "stream", "text": [ "USER: Explain why this video is funny.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "ASSISTANT: The video is funny because the bear is dancing to the music and moving its arms and legs in a funny way. The bear's movements are exaggerated and comical, making it difficult for the person to keep up with the beat. The bear's facial expressions and body language add to the humor of the video.\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ "USER: Is it a real bear?\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "ASSISTANT: No, it is not a real bear. It is a costume worn by a person.\n" ] }, { "name": "stdin", "output_type": "stream", "text": [ "USER: \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "exit...\n" ] } ], "source": [ "inference(model, tokenizer, context_len, features, args)" ] }, { "cell_type": "code", "execution_count": null, "id": "ef73e662-cab9-46fa-9df8-b45e2b1d9f5b", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 5 }