{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# AWQ on LLaVA" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook, we use LLaVA model to demonstrate the performance of AWQ on multi-modal models. We implement AWQ real-INT4 inference kernels, which are wrapped as Pytorch modules and can be easily used by existing models. We also provide a simple example to show how to use AWQ to quantize a model and save/load the quantized model checkpoint." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "In order to run this notebook, you need to install the following packages:\n", "- [AWQ](https://github.com/mit-han-lab/llm-awq)\n", "- [Pytorch](https://pytorch.org/)\n", "- [Accelerate](https://github.com/huggingface/accelerate)\n", "- [LLaVA](https://github.com/haotian-liu/LLaVA)\n", "- [Transformers](https://github.com/huggingface/transformers)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/jilin/anaconda3/envs/llava/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import torch\n", "import requests\n", "from PIL import Image\n", "from io import BytesIO\n", "from accelerate import init_empty_weights, load_checkpoint_and_dispatch\n", "from transformers import AutoTokenizer, CLIPVisionModel, CLIPImageProcessor, logging\n", "logging.set_verbosity_error() # too many warnings\n", "from llava.conversation import conv_templates, SeparatorStyle\n", "from llava.utils import disable_torch_init\n", "from llava.model import *\n", "from llava.model.utils import KeywordsStoppingCriteria\n", "from awq.models.auto import AutoAWQForCausalLM\n", "import os\n", "import gc\n", "\n", "from awq.quantize.auto_clip import apply_clip\n", "from awq.quantize.auto_scale import apply_scale\n", "\n", "# This demo only support single GPU for now\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", "DEFAULT_IMAGE_TOKEN = \"\"\n", "DEFAULT_IMAGE_PATCH_TOKEN = \"\"\n", "DEFAULT_IM_START_TOKEN = \"\"\n", "DEFAULT_IM_END_TOKEN = \"\"\n", "\n", "def load_search_result_into_memory(model, search_path):\n", " awq_results = torch.load(search_path, map_location=\"cpu\")\n", " \n", " apply_scale(model, awq_results[\"scale\"])\n", " apply_clip(model, awq_results[\"clip\"])" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Please get the LLaVA model from [LLaVA](https://github.com/haotian-liu/LLaVA) and run the following cell to generate a quantized model checkpoint first (note that we only quantize the language decoder, which dominates the model parameters). " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:09<00:00, 3.14s/it]\n", "real weight quantization...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [09:07<00:00, 13.69s/it]\n" ] } ], "source": [ "model_path = \"liuhaotian/LLaVA-13b-delta-v0\"\n", "quant_path = \"LLaVA-13B-v0-awq\" # place to dump quant weights\n", "search_path = \"../awq_cache/llava-13b-v0-w4-g128.pt\" # place where you stored search results\n", "\n", "# Load model\n", "model = AutoAWQForCausalLM.from_pretrained(model_path)\n", "tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n", "quant_config = {\"zero_point\": True, \"q_group_size\": 128, \"w_bit\": 4}\n", "\n", "# Load model and search results\n", "load_search_result_into_memory(model.model, search_path)\n", "\n", "# Run actual weight quantization\n", "model.quantize(quant_config=quant_config, run_search=False, run_quant=True)\n", "\n", "# Save quantized model\n", "model.save_quantized(quant_path)\n", "\n", "del model" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Then input a image link and a question below.\n", "\n", "![](https://llava.hliu.cc/file=/nobackup/haotian/code/LLaVA/llava/serve/examples/extreme_ironing.jpg)\n", "\n", "## Q: What is unusual about this image?" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "query = \"What is unusual about this image?\"\n", "image_file = \"https://llava.hliu.cc/file=/nobackup/haotian/code/LLaVA/llava/serve/examples/extreme_ironing.jpg\" " ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We first load a empty model and replace all the linear layers with WQLinear layers. Then we load the quantized weights from the checkpoint. " ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:15<00:00, 5.17s/it]\n", "real weight quantization...(init only): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:37<00:00, 1.08it/s]\n" ] } ], "source": [ "disable_torch_init()\n", "\n", "# Load model\n", "model = AutoAWQForCausalLM.from_quantized(quant_path, quant_filename=\"awq_model_w4_g128.pt\")\n", "tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/jilin/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/generation/utils.py:1211: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation)\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "The unusual aspect of this image is that a man is standing on a portable ironing board in the middle of the road, ironing clothes while traffic, including a yellow taxi, moves around him. This is not a typical scene you would expect to see in a city, as ironing is usually done in a private setting like a home, and not on the street amidst traffic. It brings attention to the unconventional and unexpected nature of the situation.\n" ] } ], "source": [ "def load_image(image_file):\n", " if image_file.startswith('http') or image_file.startswith('https'):\n", " response = requests.get(image_file)\n", " image = Image.open(BytesIO(response.content)).convert('RGB')\n", " else:\n", " image = Image.open(image_file).convert('RGB')\n", " return image\n", "\n", "image_processor = CLIPImageProcessor.from_pretrained(model.model.config.mm_vision_tower, torch_dtype=torch.float16)\n", "\n", "mm_use_im_start_end = getattr(model.config, \"mm_use_im_start_end\", False)\n", "tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n", "if mm_use_im_start_end:\n", " tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n", "\n", "vision_tower = model.model.get_model().vision_tower[0]\n", "if vision_tower.device.type == 'meta':\n", " vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True).cuda()\n", " model.model.get_model().vision_tower[0] = vision_tower\n", "else:\n", " vision_tower.to(device='cuda', dtype=torch.float16)\n", "vision_config = vision_tower.config\n", "vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]\n", "vision_config.use_im_start_end = mm_use_im_start_end\n", "if mm_use_im_start_end:\n", " vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])\n", "image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2\n", "\n", "qs = query\n", "if mm_use_im_start_end:\n", " qs = qs + '\\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN\n", "else:\n", " qs = qs + '\\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len\n", "\n", "conv_mode = \"multimodal\"\n", "\n", "conv = conv_templates[conv_mode].copy()\n", "conv.append_message(conv.roles[0], qs)\n", "conv.append_message(conv.roles[1], None)\n", "prompt = conv.get_prompt()\n", "inputs = tokenizer([prompt])\n", "\n", "image = load_image(image_file)\n", "image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n", "\n", "input_ids = torch.as_tensor(inputs.input_ids).cuda()\n", "\n", "stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n", "keywords = [stop_str]\n", "stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n", "\n", "with torch.inference_mode():\n", " output_ids = model.generate(\n", " input_ids,\n", " images=image_tensor.unsqueeze(0).half().cuda(),\n", " do_sample=True,\n", " temperature=0.2,\n", " max_new_tokens=1024,\n", " stopping_criteria=[stopping_criteria])\n", "\n", "input_token_len = input_ids.shape[1]\n", "n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n", "if n_diff_input_output > 0:\n", " print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n", "outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n", "outputs = outputs.strip()\n", "if outputs.endswith(stop_str):\n", " outputs = outputs[:-len(stop_str)]\n", "outputs = outputs.strip()\n", "print(outputs)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.11" } }, "nbformat": 4, "nbformat_minor": 2 }