{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "6401fb4e-c559-49e5-bc88-874a74dd54c4", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/DATA/DATANAS2/bhuang/link/gitlab/vtimellm/docs/..\n" ] } ], "source": [ "import os\n", "os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n", "root_dir = os.path.join(os.getcwd(), \"..\")\n", "print(root_dir)\n", "import sys\n", "sys.path.append(root_dir)\n", "\n", "from vtimellm.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN\n", "from vtimellm.conversation import conv_templates, SeparatorStyle\n", "from vtimellm.model.builder import load_pretrained_model, load_lora\n", "from vtimellm.train.dataset import preprocess_glm\n", "from vtimellm.utils import disable_torch_init\n", "from vtimellm.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria, VideoExtractor\n", "from PIL import Image\n", "import requests\n", "from io import BytesIO\n", "from transformers import TextStreamer\n", "from easydict import EasyDict as edict\n", "try:\n", " from torchvision.transforms import InterpolationMode\n", " BICUBIC = InterpolationMode.BICUBIC\n", "except ImportError:\n", " from PIL import Image\n", " BICUBIC = Image.BICUBIC\n", "from torchvision.transforms import Compose, Resize, CenterCrop, Normalize\n", "import numpy as np\n", "import clip\n", "import torch" ] }, { "cell_type": "code", "execution_count": 3, "id": "a05c755d-ae5c-4e93-ae6f-4e4a6e795b14", "metadata": {}, "outputs": [], "source": [ "model_version = 'chatglm3-6b' # vicuna-v1-5-7b\n", "args = edict()\n", "args.model_base = \"/DATA/DATANAS2/bhuang/data/vicuna-7b-v1.5\"\n", "if model_version == 'chatglm3-6b':\n", " args.model_base = os.path.join(root_dir, 'checkpoints/' + model_version)\n", "args.clip_path = os.path.join(root_dir, \"checkpoints/clip/ViT-L-14.pt\")\n", "args.pretrain_mm_mlp_adapter = os.path.join(root_dir, f\"checkpoints/vtimellm-{model_version}-stage1/mm_projector.bin\")\n", "args.stage2 = os.path.join(root_dir, f\"checkpoints/vtimellm-{model_version}-stage2\")\n", "args.stage3 = os.path.join(root_dir, f\"checkpoints/vtimellm-{model_version}-stage3\")\n", "args.temperature = 0.05" ] }, { "cell_type": "code", "execution_count": 4, "id": "7cb39bb4", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "You are using a model of type chatglm to instantiate a model of type VTimeLLM_ChatGLM. This is not supported for all configurations of models and can yield errors.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loading VTimeLLM from base model...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "20832fd2d2a44b7f8d540c32069df157", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/7 [00:00\n", "images = transform(images / 255.0)\n", "images = images.to(torch.float16)\n", "with torch.no_grad():\n", " features = clip_model.encode_image(images.to('cuda'))" ] }, { "cell_type": "code", "execution_count": 14, "id": "7b42d546", "metadata": {}, "outputs": [], "source": [ "def inference(model, tokenizer, context_len, image, args):\n", " source = []\n", " first = True\n", " while True:\n", " try:\n", " inp = input(f\"USER: \")\n", " except EOFError:\n", " inp = \"\"\n", " if not inp:\n", " print(\"exit...\")\n", " break\n", "\n", " print(f\"ASSISTANT:\", end=\"\")\n", "\n", " if first:\n", " # first message\n", " inp = DEFAULT_IMAGE_TOKEN + '\\n' + inp\n", " first = False\n", " \n", " source.append({\n", " 'from': \"human\",\n", " 'value': inp\n", " })\n", " input_ids = preprocess_glm([source], tokenizer)['input_ids'].cuda()\n", " input_ids[0][-1] = tokenizer.get_command(\"<|assistant|>\")\n", " streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n", "\n", " with torch.inference_mode():\n", " output_ids = model.generate(\n", " input_ids,\n", " images=image[None,].cuda(),\n", " do_sample=True,\n", " temperature=args.temperature,\n", " max_new_tokens=1024,\n", " streamer=streamer,\n", " use_cache=True,\n", " eos_token_id=[tokenizer.get_command(\"<|user|>\"), tokenizer.eos_token_id],\n", " )\n", "\n", " outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:-1]).strip()\n", " # print(outputs)\n", " source.append({\n", " 'from': \"gpt\",\n", " 'value': outputs\n", " })" ] }, { "cell_type": "code", "execution_count": 19, "id": "010a986d-6121-43f2-9971-619472c0732c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ASSISTANT:视频中,一名男子在黑暗的房间里,手里拿着一个装满东西的盒子。他打开盒子,里面装满了各种物品。然后,该男子爬上一座高高的建筑物,并从窗户跳入水中。<|user|>\n", "exit...\n" ] } ], "source": [ "inference(model, tokenizer, context_len, features, args)" ] }, { "cell_type": "code", "execution_count": null, "id": "ef73e662-cab9-46fa-9df8-b45e2b1d9f5b", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }