{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# AWQ on Vicuna" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook, we use Vicuna model to demonstrate the performance of AWQ on instruction-tuned models. We implement AWQ real-INT4 inference kernels, which are wrapped as Pytorch modules and can be easily used by existing models. We also provide a simple example to show how to use AWQ to quantize a model and save/load the quantized model checkpoint." ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "In order to run this notebook, you need to install the following packages:\n", "- [AWQ](https://github.com/mit-han-lab/llm-awq)\n", "- [Pytorch](https://pytorch.org/)\n", "- [Accelerate](https://github.com/huggingface/accelerate)\n", "- [FastChat](https://github.com/lm-sys/FastChat)\n", "- [Transformers](https://github.com/huggingface/transformers)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from accelerate import init_empty_weights, load_checkpoint_and_dispatch\n", "from awq.quantize.quantizer import real_quantize_model_weight\n", "from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig\n", "from fastchat.serve.cli import SimpleChatIO\n", "from fastchat.serve.inference import generate_stream \n", "from fastchat.conversation import get_conv_template\n", "import os\n", "# This demo only support single GPU for now\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "Please get the Vicuna model from [FastChat](https://github.com/lm-sys/FastChat) and run the following command to generate a quantized model checkpoint first.\n", "\n", "```bash\n", "mkdir quant_cache\n", "python -m awq.entry --model_path [vicuna-7b_model_path] \\\n", " --w_bit 4 --q_group_size 128 \\\n", " --load_awq awq_cache/vicuna-7b-w4-g128.pt \\\n", " --q_backend real --dump_quant quant_cache/vicuna-7b-w4-g128-awq.pt\n", "```" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "model_path = \"\" # the path of vicuna-7b model\n", "load_quant_path = \"quant_cache/vicuna-7b-w4-g128-awq.pt\"" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "We first load a empty model and replace all the linear layers with WQLinear layers. Then we load the quantized weights from the checkpoint. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00, 2.50s/it]\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "* skipping lm_head\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "real weight quantization...: 100%|██████████| 224/224 [00:26<00:00, 8.40it/s]\n" ] } ], "source": [ "config = AutoConfig.from_pretrained(model_path)\n", "tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n", "with init_empty_weights():\n", " model = AutoModelForCausalLM.from_pretrained(model_path, config=config,\n", " torch_dtype=torch.float16)\n", "q_config = {\"zero_point\": True, \"q_group_size\": 128}\n", "real_quantize_model_weight(\n", " model, w_bit=4, q_config=q_config, init_only=True)\n", "\n", "model = load_checkpoint_and_dispatch(\n", " model, load_quant_path,\n", " device_map=\"auto\",\n", " no_split_module_classes=[\"LlamaDecoderLayer\"]\n", ")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "User: How can I improve my time management skills?\n", "ASSISTANT: Time management skills can be improved through a combination of techniques, such as setting clear goals, prioritizing tasks, and using time-saving tools and strategies. Here are some tips to help you improve your time management skills:\n", "\n", "1. Set clear goals: Establish clear and specific goals for what you want to achieve. This will help you prioritize your tasks and focus your efforts.\n", "2. Prioritize tasks: Identify the most important tasks that need to be completed and prioritize them accordingly. Use the Eisenhower matrix to categorize tasks into urgent and important, important but not urgent, urgent but not important, and not urgent or important.\n", "3. Use time-saving tools and strategies: Use tools like calendars, to-do lists, and time trackers to help you manage your time more effectively. Also, consider using time-saving strategies like batching, delegating, and automating tasks.\n", "4. Practice time management techniques: Practice time management techniques like the Pomodoro technique, the 80/20 rule, and Parkinson's law to help you work more efficiently.\n", "5. Learn to say no: Learn to say no to non-essential tasks and commitments to free up more time for what's important.\n", "6. Take breaks: Take regular breaks throughout the day to recharge and refocus.\n", "7. Review and adjust: Regularly review and adjust your time management strategies to ensure they are working for you.\n", "\n", "Remember, time management is a skill that takes time and practice to develop. Be patient with yourself and keep working on improving your time management skills.\n", "exit...\n" ] } ], "source": [ "conv = get_conv_template(\"vicuna_v1.1\")\n", "chatio = SimpleChatIO()\n", "\n", "inp = \"How can I improve my time management skills?\"\n", "print(\"User:\", inp)\n", "\n", "while True:\n", " if not inp:\n", " try:\n", " inp = chatio.prompt_for_input(conv.roles[0])\n", " except EOFError:\n", " inp = \"\"\n", " if not inp:\n", " print(\"exit...\")\n", " break\n", "\n", " conv.append_message(conv.roles[0], inp)\n", " conv.append_message(conv.roles[1], None)\n", "\n", " generate_stream_func = generate_stream\n", " prompt = conv.get_prompt()\n", "\n", " gen_params = {\n", " \"model\": model_path,\n", " \"prompt\": prompt,\n", " \"temperature\": 0.3,\n", " \"repetition_penalty\": 1.0,\n", " \"max_new_tokens\": 512,\n", " \"stop\": conv.stop_str,\n", " \"stop_token_ids\": conv.stop_token_ids,\n", " \"echo\": False,\n", " }\n", "\n", " chatio.prompt_for_output(conv.roles[1])\n", " output_stream = generate_stream_func(model, tokenizer, gen_params, \"cuda\")\n", " outputs = chatio.stream_output(output_stream)\n", " conv.update_last_message(outputs.strip())\n", " \n", " inp = None" ] } ], "metadata": { "kernelspec": { "display_name": "awq", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }