Commit 460d6d45 authored by Ji Lin's avatar Ji Lin
Browse files

first commit

parents
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# AWQ on Vicuna"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"In this notebook, we use Vicuna model to demonstrate the performance of AWQ on instruction-tuned models. We implement AWQ real-INT4 inference kernels, which are wrapped as Pytorch modules and can be easily used by existing models. We also provide a simple example to show how to use AWQ to quantize a model and save/load the quantized model checkpoint."
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to run this notebook, you need to install the following packages:\n",
"- [AWQ]()\n",
"- [Pytorch](https://pytorch.org/)\n",
"- [Accelerate](https://github.com/huggingface/accelerate)\n",
"- [FastChat](https://github.com/lm-sys/FastChat)\n",
"- [Transformers](https://github.com/huggingface/transformers)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from accelerate import init_empty_weights, load_checkpoint_and_dispatch\n",
"from awq.quantize.quantizer import real_quantize_model_weight\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig\n",
"from fastchat.serve.cli import SimpleChatIO\n",
"from fastchat.serve.inference import generate_stream \n",
"from fastchat.conversation import get_conv_template\n",
"import os\n",
"# This demo only support single GPU for now\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Please get the Vicuna model from [FastChat](https://github.com/lm-sys/FastChat) and run the following command to generate a quantized model checkpoint first.\n",
"\n",
"```bash\n",
"mkdir quant_cache\n",
"python -m awq.entry --model_path [vicuna-7b_model_path] \\\n",
" --w_bit 4 --q_group_size 128 \\\n",
" --load_awq awq_cache/vicuna-7b-w4-g128.pt \\\n",
" --q_backend real --dump_quant quant_cache/vicuna-7b-w4-g128-awq.pt\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"model_path = \"\" # the path of vicuna-7b model\n",
"load_quant_path = \"quant_cache/vicuna-7b-w4-g128-awq.pt\""
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"We first load a empty model and replace all the linear layers with WQLinear layers. Then we load the quantized weights from the checkpoint. "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00, 2.50s/it]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"* skipping lm_head\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"real weight quantization...: 100%|██████████| 224/224 [00:26<00:00, 8.40it/s]\n"
]
}
],
"source": [
"config = AutoConfig.from_pretrained(model_path)\n",
"tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n",
"with init_empty_weights():\n",
" model = AutoModelForCausalLM.from_pretrained(model_path, config=config,\n",
" torch_dtype=torch.float16)\n",
"q_config = {\"zero_point\": True, \"q_group_size\": 128}\n",
"real_quantize_model_weight(\n",
" model, w_bit=4, q_config=q_config, init_only=True)\n",
"\n",
"model = load_checkpoint_and_dispatch(\n",
" model, load_quant_path,\n",
" device_map=\"auto\",\n",
" no_split_module_classes=[\"LlamaDecoderLayer\"]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"User: How can I improve my time management skills?\n",
"ASSISTANT: Time management skills can be improved through a combination of techniques, such as setting clear goals, prioritizing tasks, and using time-saving tools and strategies. Here are some tips to help you improve your time management skills:\n",
"\n",
"1. Set clear goals: Establish clear and specific goals for what you want to achieve. This will help you prioritize your tasks and focus your efforts.\n",
"2. Prioritize tasks: Identify the most important tasks that need to be completed and prioritize them accordingly. Use the Eisenhower matrix to categorize tasks into urgent and important, important but not urgent, urgent but not important, and not urgent or important.\n",
"3. Use time-saving tools and strategies: Use tools like calendars, to-do lists, and time trackers to help you manage your time more effectively. Also, consider using time-saving strategies like batching, delegating, and automating tasks.\n",
"4. Practice time management techniques: Practice time management techniques like the Pomodoro technique, the 80/20 rule, and Parkinson's law to help you work more efficiently.\n",
"5. Learn to say no: Learn to say no to non-essential tasks and commitments to free up more time for what's important.\n",
"6. Take breaks: Take regular breaks throughout the day to recharge and refocus.\n",
"7. Review and adjust: Regularly review and adjust your time management strategies to ensure they are working for you.\n",
"\n",
"Remember, time management is a skill that takes time and practice to develop. Be patient with yourself and keep working on improving your time management skills.\n",
"exit...\n"
]
}
],
"source": [
"conv = get_conv_template(\"vicuna_v1.1\")\n",
"chatio = SimpleChatIO()\n",
"\n",
"inp = \"How can I improve my time management skills?\"\n",
"print(\"User:\", inp)\n",
"\n",
"while True:\n",
" if not inp:\n",
" try:\n",
" inp = chatio.prompt_for_input(conv.roles[0])\n",
" except EOFError:\n",
" inp = \"\"\n",
" if not inp:\n",
" print(\"exit...\")\n",
" break\n",
"\n",
" conv.append_message(conv.roles[0], inp)\n",
" conv.append_message(conv.roles[1], None)\n",
"\n",
" generate_stream_func = generate_stream\n",
" prompt = conv.get_prompt()\n",
"\n",
" gen_params = {\n",
" \"model\": model_path,\n",
" \"prompt\": prompt,\n",
" \"temperature\": 0.3,\n",
" \"repetition_penalty\": 1.0,\n",
" \"max_new_tokens\": 512,\n",
" \"stop\": conv.stop_str,\n",
" \"stop_token_ids\": conv.stop_token_ids,\n",
" \"echo\": False,\n",
" }\n",
"\n",
" chatio.prompt_for_output(conv.roles[1])\n",
" output_stream = generate_stream_func(model, tokenizer, gen_params, \"cuda\")\n",
" outputs = chatio.stream_output(output_stream)\n",
" conv.update_last_message(outputs.strip())\n",
" \n",
" inp = None"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "awq",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# AWQ on LLaVA"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this notebook, we use LLaVA model to demonstrate the performance of AWQ on multi-modal models. We implement AWQ real-INT4 inference kernels, which are wrapped as Pytorch modules and can be easily used by existing models. We also provide a simple example to show how to use AWQ to quantize a model and save/load the quantized model checkpoint."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to run this notebook, you need to install the following packages:\n",
"- [AWQ]()\n",
"- [Pytorch](https://pytorch.org/)\n",
"- [Accelerate](https://github.com/huggingface/accelerate)\n",
"- [LLaVA](https://github.com/haotian-liu/LLaVA)\n",
"- [Transformers](https://github.com/huggingface/transformers)"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/jilin/anaconda3/envs/llava/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"import torch\n",
"import requests\n",
"from PIL import Image\n",
"from io import BytesIO\n",
"from accelerate import init_empty_weights, load_checkpoint_and_dispatch\n",
"from transformers import AutoTokenizer, CLIPVisionModel, CLIPImageProcessor, logging\n",
"logging.set_verbosity_error() # too many warnings\n",
"from llava.conversation import conv_templates, SeparatorStyle\n",
"from llava.utils import disable_torch_init\n",
"from llava.model import *\n",
"from llava.model.utils import KeywordsStoppingCriteria\n",
"from awq.quantize.pre_quant import apply_awq\n",
"from awq.quantize.quantizer import real_quantize_model_weight\n",
"import os\n",
"import gc\n",
"\n",
"# This demo only support single GPU for now\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
"DEFAULT_IMAGE_TOKEN = \"<image>\"\n",
"DEFAULT_IMAGE_PATCH_TOKEN = \"<im_patch>\"\n",
"DEFAULT_IM_START_TOKEN = \"<im_start>\"\n",
"DEFAULT_IM_END_TOKEN = \"<im_end>\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Please get the LLaVA model from [LLaVA](https://github.com/haotian-liu/LLaVA) and run the following cell to generate a quantized model checkpoint first (note that we only quantize the language decoder, which dominates the model parameters). "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:09<00:00, 3.14s/it]\n",
"real weight quantization...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [09:07<00:00, 13.69s/it]\n"
]
}
],
"source": [
"model_path = \"/dataset/llava/LLaVA-13B-v0\" # Please change here \n",
"quant_path = \"../quant_cache/LLaVA-13B-v0-w4-g128-awq.pt\" # place to dump quant weights\n",
"\n",
"model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()\n",
"\n",
"awq_results = torch.load(\"../awq_cache/llava-13b-v0-w4-g128.pt\", map_location=\"cpu\")\n",
"apply_awq(model, awq_results)\n",
"\n",
"real_quantize_model_weight(model, w_bit=4, q_config={\"zero_point\": True, \"q_group_size\": 128})\n",
"torch.save(model.cpu().state_dict(), quant_path)\n",
"\n",
"del model\n",
"gc.collect()\n",
"torch.cuda.empty_cache()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then input a image link and a question below.\n",
"\n",
"![](https://llava.hliu.cc/file=/nobackup/haotian/code/LLaVA/llava/serve/examples/extreme_ironing.jpg)\n",
"\n",
"## Q: What is unusual about this image?"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"query = \"What is unusual about this image?\"\n",
"image_file = \"https://llava.hliu.cc/file=/nobackup/haotian/code/LLaVA/llava/serve/examples/extreme_ironing.jpg\" "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We first load a empty model and replace all the linear layers with WQLinear layers. Then we load the quantized weights from the checkpoint. "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:15<00:00, 5.17s/it]\n",
"real weight quantization...(init only): 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 40/40 [00:37<00:00, 1.08it/s]\n"
]
}
],
"source": [
"\n",
"disable_torch_init()\n",
"tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
"config = LlavaConfig.from_pretrained(model_path)\n",
"with init_empty_weights():\n",
" model = LlavaLlamaForCausalLM.from_pretrained(model_path, config=config,\n",
" torch_dtype=torch.float16, device_map=\"auto\")\n",
"q_config = {\"zero_point\": True, \"q_group_size\": 128}\n",
"real_quantize_model_weight(\n",
" model, w_bit=4, q_config=q_config, init_only=True)\n",
"\n",
"model = load_checkpoint_and_dispatch(\n",
" model, quant_path, device_map=\"auto\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/jilin/anaconda3/envs/llava/lib/python3.10/site-packages/transformers/generation/utils.py:1211: UserWarning: You have modified the pretrained model configuration to control generation. This is a deprecated strategy to control generation and will be removed soon, in a future version. Please use a generation configuration file (see https://huggingface.co/docs/transformers/main_classes/text_generation)\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"The unusual aspect of this image is that a man is standing on a portable ironing board in the middle of the road, ironing clothes while traffic, including a yellow taxi, moves around him. This is not a typical scene you would expect to see in a city, as ironing is usually done in a private setting like a home, and not on the street amidst traffic. It brings attention to the unconventional and unexpected nature of the situation.\n"
]
}
],
"source": [
"def load_image(image_file):\n",
" if image_file.startswith('http') or image_file.startswith('https'):\n",
" response = requests.get(image_file)\n",
" image = Image.open(BytesIO(response.content)).convert('RGB')\n",
" else:\n",
" image = Image.open(image_file).convert('RGB')\n",
" return image\n",
"\n",
"image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16)\n",
"\n",
"mm_use_im_start_end = getattr(model.config, \"mm_use_im_start_end\", False)\n",
"tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n",
"if mm_use_im_start_end:\n",
" tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n",
"\n",
"vision_tower = model.get_model().vision_tower[0]\n",
"if vision_tower.device.type == 'meta':\n",
" vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True).cuda()\n",
" model.get_model().vision_tower[0] = vision_tower\n",
"else:\n",
" vision_tower.to(device='cuda', dtype=torch.float16)\n",
"vision_config = vision_tower.config\n",
"vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]\n",
"vision_config.use_im_start_end = mm_use_im_start_end\n",
"if mm_use_im_start_end:\n",
" vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])\n",
"image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2\n",
"\n",
"qs = query\n",
"if mm_use_im_start_end:\n",
" qs = qs + '\\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN\n",
"else:\n",
" qs = qs + '\\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len\n",
"\n",
"conv_mode = \"multimodal\"\n",
"\n",
"conv = conv_templates[conv_mode].copy()\n",
"conv.append_message(conv.roles[0], qs)\n",
"conv.append_message(conv.roles[1], None)\n",
"prompt = conv.get_prompt()\n",
"inputs = tokenizer([prompt])\n",
"\n",
"image = load_image(image_file)\n",
"image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]\n",
"\n",
"input_ids = torch.as_tensor(inputs.input_ids).cuda()\n",
"\n",
"stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n",
"keywords = [stop_str]\n",
"stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n",
"\n",
"with torch.inference_mode():\n",
" output_ids = model.generate(\n",
" input_ids,\n",
" images=image_tensor.unsqueeze(0).half().cuda(),\n",
" do_sample=True,\n",
" temperature=0.2,\n",
" max_new_tokens=1024,\n",
" stopping_criteria=[stopping_criteria])\n",
"\n",
"input_token_len = input_ids.shape[1]\n",
"n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()\n",
"if n_diff_input_output > 0:\n",
" print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')\n",
"outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]\n",
"outputs = outputs.strip()\n",
"if outputs.endswith(stop_str):\n",
" outputs = outputs[:-len(stop_str)]\n",
"outputs = outputs.strip()\n",
"print(outputs)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "awq"
version = "0.1.0"
description = "An efficient and accurate low-bit weight quantization(INT3/4) method for LLMs."
readme = "README.md"
requires-python = ">=3.9"
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
]
dependencies = [
"accelerate", "sentencepiece", "tokenizers>=0.12.1",
"torch", "torchvision",
"transformers>=4.28.0",
"lm_eval"
]
[tool.setuptools.packages.find]
exclude = ["results*", "scripts*", "examples*"]
[tool.wheel]
exclude = ["results*", "scripts*", "examples*"]
\ No newline at end of file
MODEL=llama-7b
# run AWQ search (optional; we provided the pre-computed results)
python -m awq.entry --model_path /dataset/llama-hf/$MODEL \
--w_bit 4 --q_group_size 128 \
--run_awq --dump_awq awq_cache/$MODEL-w4-g128.pt
# evaluate the AWQ quantize model (simulated pseudo quantization)
python -m awq.entry --model_path /dataset/llama-hf/$MODEL \
--tasks wikitext \
--w_bit 4 --q_group_size 128 \
--load_awq awq_cache/$MODEL-w4-g128.pt \
--q_backend fake
# generate real quantized weights (w4)
python -m awq.entry --model_path /dataset/llama-hf/$MODEL \
--w_bit 4 --q_group_size 128 \
--load_awq awq_cache/$MODEL-w4-g128.pt \
--q_backend real --dump_quant quant_cache/$MODEL-w4-g128-awq.pt
# load and evaluate the real quantized model (smaller gpu memory usage)
python -m awq.entry --model_path /dataset/llama-hf/$MODEL \
--tasks wikitext \
--w_bit 4 --q_group_size 128 \
--load_quant quant_cache/$MODEL-w4-g128-awq.pt
\ No newline at end of file
MODEL=opt-6.7b
# run AWQ search (optional; we provided the pre-computed results)
python -m awq.entry --model_path /dataset/opt/$MODEL \
--w_bit 4 --q_group_size 128 \
--run_awq --dump_awq awq_cache/$MODEL-w4-g128.pt
# evaluate the AWQ quantize model (simulated pseudo quantization)
python -m awq.entry --model_path /dataset/opt/$MODEL \
--tasks wikitext \
--w_bit 4 --q_group_size 128 \
--load_awq awq_cache/$MODEL-w4-g128.pt \
--q_backend fake
# generate real quantized weights (w4)
python -m awq.entry --model_path /dataset/opt/$MODEL \
--w_bit 4 --q_group_size 128 \
--load_awq awq_cache/$MODEL-w4-g128.pt \
--q_backend real --dump_quant quant_cache/$MODEL-w4-g128-awq.pt
# load and evaluate the real quantized model (smaller gpu memory usage)
python -m awq.entry --model_path /dataset/opt/$MODEL \
--tasks wikitext \
--w_bit 4 --q_group_size 128 \
--load_quant quant_cache/$MODEL-w4-g128-awq.pt
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment