llava_demo.ipynb 12.4 KB
Newer Older
Ji Lin's avatar
Ji Lin committed
1
2
3
{
 "cells": [
  {
Sakits's avatar
Sakits committed
4
   "attachments": {},
Ji Lin's avatar
Ji Lin committed
5
6
7
8
9
10
11
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# AWQ on LLaVA"
   ]
  },
  {
Sakits's avatar
Sakits committed
12
   "attachments": {},
Ji Lin's avatar
Ji Lin committed
13
14
15
16
17
18
19
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we use LLaVA model to demonstrate the performance of AWQ on multi-modal models. We implement AWQ real-INT4 inference kernels, which are wrapped as Pytorch modules and can be easily used by existing models. We also provide a simple example to show how to use AWQ to quantize a model and save/load the quantized model checkpoint."
   ]
  },
  {
Sakits's avatar
Sakits committed
20
   "attachments": {},
Ji Lin's avatar
Ji Lin committed
21
22
23
24
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In order to run this notebook, you need to install the following packages:\n",
Sakits's avatar
Sakits committed
25
    "- [AWQ](https://github.com/mit-han-lab/llm-awq)\n",
Ji Lin's avatar
Ji Lin committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    "- [Pytorch](https://pytorch.org/)\n",
    "- [Accelerate](https://github.com/huggingface/accelerate)\n",
    "- [LLaVA](https://github.com/haotian-liu/LLaVA)\n",
    "- [Transformers](https://github.com/huggingface/transformers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jilin/anaconda3/envs/llava/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import requests\n",
    "from PIL import Image\n",
    "from io import BytesIO\n",
    "from accelerate import init_empty_weights, load_checkpoint_and_dispatch\n",
    "from transformers import AutoTokenizer, CLIPVisionModel, CLIPImageProcessor, logging\n",
    "logging.set_verbosity_error()  # too many warnings\n",
    "from llava.conversation import conv_templates, SeparatorStyle\n",
    "from llava.utils import disable_torch_init\n",
    "from llava.model import *\n",
    "from llava.model.utils import KeywordsStoppingCriteria\n",
Casper Hansen's avatar
Casper Hansen committed
58
    "from awq.models.auto import AutoAWQForCausalLM\n",
Ji Lin's avatar
Ji Lin committed
59
60
61
    "import os\n",
    "import gc\n",
    "\n",
Casper Hansen's avatar
Casper Hansen committed
62
63
64
    "from awq.quantize.auto_clip import apply_clip\n",
    "from awq.quantize.auto_scale import apply_scale\n",
    "\n",
Ji Lin's avatar
Ji Lin committed
65
66
67
68
69
    "# This demo only support single GPU for now\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
    "DEFAULT_IMAGE_TOKEN = \"<image>\"\n",
    "DEFAULT_IMAGE_PATCH_TOKEN = \"<im_patch>\"\n",
    "DEFAULT_IM_START_TOKEN = \"<im_start>\"\n",
Casper Hansen's avatar
Casper Hansen committed
70
71
72
73
74
75
76
    "DEFAULT_IM_END_TOKEN = \"<im_end>\"\n",
    "\n",
    "def load_search_result_into_memory(model, search_path):\n",
    "    awq_results = torch.load(search_path, map_location=\"cpu\")\n",
    "            \n",
    "    apply_scale(model, awq_results[\"scale\"])\n",
    "    apply_clip(model, awq_results[\"clip\"])"
Ji Lin's avatar
Ji Lin committed
77
78
79
   ]
  },
  {
Sakits's avatar
Sakits committed
80
   "attachments": {},
Ji Lin's avatar
Ji Lin committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Please get the LLaVA model from [LLaVA](https://github.com/haotian-liu/LLaVA) and run the following cell to generate a quantized model checkpoint first (note that we only quantize the language decoder, which dominates the model parameters). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:09<00:00,  3.14s/it]\n",
      "real weight quantization...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [09:07<00:00, 13.69s/it]\n"
     ]
    }
   ],
   "source": [
Casper Hansen's avatar
Casper Hansen committed
102
103
104
105
106
107
108
109
    "model_path = \"liuhaotian/LLaVA-13b-delta-v0\"\n",
    "quant_path = \"LLaVA-13B-v0-awq\" # place to dump quant weights\n",
    "search_path = \"../awq_cache/llava-13b-v0-w4-g128.pt\" # place where you stored search results\n",
    "\n",
    "# Load model\n",
    "model = AutoAWQForCausalLM.from_pretrained(model_path)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n",
    "quant_config = {\"zero_point\": True, \"q_group_size\": 128, \"w_bit\": 4}\n",
Ji Lin's avatar
Ji Lin committed
110
    "\n",
Casper Hansen's avatar
Casper Hansen committed
111
112
    "# Load model and search results\n",
    "load_search_result_into_memory(model.model, search_path)\n",
Ji Lin's avatar
Ji Lin committed
113
    "\n",
Casper Hansen's avatar
Casper Hansen committed
114
115
    "# Run actual weight quantization\n",
    "model.quantize(quant_config=quant_config, run_search=False, run_quant=True)\n",
Ji Lin's avatar
Ji Lin committed
116
    "\n",
Casper Hansen's avatar
Casper Hansen committed
117
118
    "# Save quantized model\n",
    "model.save_quantized(quant_path)\n",
Ji Lin's avatar
Ji Lin committed
119
    "\n",
Casper Hansen's avatar
Casper Hansen committed
120
    "del model"
Ji Lin's avatar
Ji Lin committed
121
122
123
   ]
  },
  {
Sakits's avatar
Sakits committed
124
   "attachments": {},
Ji Lin's avatar
Ji Lin committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then input a image link and a question below.\n",
    "\n",
    "![](https://llava.hliu.cc/file=/nobackup/haotian/code/LLaVA/llava/serve/examples/extreme_ironing.jpg)\n",
    "\n",
    "## Q: What is unusual about this image?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "query = \"What is unusual about this image?\"\n",
    "image_file = \"https://llava.hliu.cc/file=/nobackup/haotian/code/LLaVA/llava/serve/examples/extreme_ironing.jpg\" "
   ]
  },
  {
Sakits's avatar
Sakits committed
146
   "attachments": {},
Ji Lin's avatar
Ji Lin committed
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We first load a empty model and replace all the linear layers with WQLinear layers. Then we load the quantized weights from the checkpoint. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:15<00:00,  5.17s/it]\n",
      "real weight quantization...(init only): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:37<00:00,  1.08it/s]\n"
     ]
    }
   ],
   "source": [
    "disable_torch_init()\n",
Casper Hansen's avatar
Casper Hansen committed
169
170
171
172
    "\n",
    "# Load model\n",
    "model = AutoAWQForCausalLM.from_quantized(quant_path, quant_filename=\"awq_model_w4_g128.pt\")\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)"
Ji Lin's avatar
Ji Lin committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jilin/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/generation/utils.py:1211: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation)\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The unusual aspect of this image is that a man is standing on a portable ironing board in the middle of the road, ironing clothes while traffic, including a yellow taxi, moves around him. This is not a typical scene you would expect to see in a city, as ironing is usually done in a private setting like a home, and not on the street amidst traffic. It brings attention to the unconventional and unexpected nature of the situation.\n"
     ]
    }
   ],
   "source": [
    "def load_image(image_file):\n",
    "    if image_file.startswith('http') or image_file.startswith('https'):\n",
    "        response = requests.get(image_file)\n",
    "        image = Image.open(BytesIO(response.content)).convert('RGB')\n",
    "    else:\n",
    "        image = Image.open(image_file).convert('RGB')\n",
    "    return image\n",
    "\n",
Casper Hansen's avatar
Casper Hansen committed
205
    "image_processor = CLIPImageProcessor.from_pretrained(model.model.config.mm_vision_tower, torch_dtype=torch.float16)\n",
Ji Lin's avatar
Ji Lin committed
206
207
208
209
210
211
    "\n",
    "mm_use_im_start_end = getattr(model.config, \"mm_use_im_start_end\", False)\n",
    "tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n",
    "if mm_use_im_start_end:\n",
    "    tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n",
    "\n",
Casper Hansen's avatar
Casper Hansen committed
212
    "vision_tower = model.model.get_model().vision_tower[0]\n",
Ji Lin's avatar
Ji Lin committed
213
214
    "if vision_tower.device.type == 'meta':\n",
    "    vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True).cuda()\n",
Casper Hansen's avatar
Casper Hansen committed
215
    "    model.model.get_model().vision_tower[0] = vision_tower\n",
Ji Lin's avatar
Ji Lin committed
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
    "else:\n",
    "    vision_tower.to(device='cuda', dtype=torch.float16)\n",
    "vision_config = vision_tower.config\n",
    "vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]\n",
    "vision_config.use_im_start_end = mm_use_im_start_end\n",
    "if mm_use_im_start_end:\n",
    "    vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])\n",
    "image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2\n",
    "\n",
    "qs = query\n",
    "if mm_use_im_start_end:\n",
    "    qs = qs + '\\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN\n",
    "else:\n",
    "    qs = qs + '\\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len\n",
    "\n",
    "conv_mode = \"multimodal\"\n",
    "\n",
    "conv = conv_templates[conv_mode].copy()\n",
    "conv.append_message(conv.roles[0], qs)\n",
    "conv.append_message(conv.roles[1], None)\n",
    "prompt = conv.get_prompt()\n",
    "inputs = tokenizer([prompt])\n",
    "\n",
    "image = load_image(image_file)\n",
    "image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n",
    "\n",
    "input_ids = torch.as_tensor(inputs.input_ids).cuda()\n",
    "\n",
    "stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n",
    "keywords = [stop_str]\n",
    "stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n",
    "\n",
    "with torch.inference_mode():\n",
    "    output_ids = model.generate(\n",
    "        input_ids,\n",
    "        images=image_tensor.unsqueeze(0).half().cuda(),\n",
    "        do_sample=True,\n",
    "        temperature=0.2,\n",
    "        max_new_tokens=1024,\n",
    "        stopping_criteria=[stopping_criteria])\n",
    "\n",
    "input_token_len = input_ids.shape[1]\n",
    "n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n",
    "if n_diff_input_output > 0:\n",
    "    print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n",
    "outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n",
    "outputs = outputs.strip()\n",
    "if outputs.endswith(stop_str):\n",
    "    outputs = outputs[:-len(stop_str)]\n",
    "outputs = outputs.strip()\n",
    "print(outputs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}