chat_demo.ipynb 7.75 KB
Newer Older
Ji Lin's avatar
Ji Lin committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# AWQ on Vicuna"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we use Vicuna model to demonstrate the performance of AWQ on instruction-tuned models. We implement AWQ real-INT4 inference kernels, which are wrapped as Pytorch modules and can be easily used by existing models. We also provide a simple example to show how to use AWQ to quantize a model and save/load the quantized model checkpoint."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In order to run this notebook, you need to install the following packages:\n",
Sakits's avatar
Sakits committed
25
    "- [AWQ](https://github.com/mit-han-lab/llm-awq)\n",
Ji Lin's avatar
Ji Lin committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    "- [Pytorch](https://pytorch.org/)\n",
    "- [Accelerate](https://github.com/huggingface/accelerate)\n",
    "- [FastChat](https://github.com/lm-sys/FastChat)\n",
    "- [Transformers](https://github.com/huggingface/transformers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from accelerate import init_empty_weights, load_checkpoint_and_dispatch\n",
    "from awq.quantize.quantizer import real_quantize_model_weight\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig\n",
    "from fastchat.serve.cli import SimpleChatIO\n",
    "from fastchat.serve.inference import generate_stream \n",
    "from fastchat.conversation import get_conv_template\n",
    "import os\n",
    "# This demo only support single GPU for now\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Please get the Vicuna model from [FastChat](https://github.com/lm-sys/FastChat) and run the following command to generate a quantized model checkpoint first.\n",
    "\n",
    "```bash\n",
    "mkdir quant_cache\n",
    "python -m awq.entry --model_path [vicuna-7b_model_path] \\\n",
    "    --w_bit 4 --q_group_size 128 \\\n",
    "    --load_awq awq_cache/vicuna-7b-w4-g128.pt \\\n",
    "    --q_backend real --dump_quant quant_cache/vicuna-7b-w4-g128-awq.pt\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = \"\" # the path of vicuna-7b model\n",
    "load_quant_path = \"quant_cache/vicuna-7b-w4-g128-awq.pt\""
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We first load a empty model and replace all the linear layers with WQLinear layers. Then we load the quantized weights from the checkpoint. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.50s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "* skipping lm_head\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "real weight quantization...: 100%|██████████| 224/224 [00:26<00:00,  8.40it/s]\n"
     ]
    }
   ],
   "source": [
    "config = AutoConfig.from_pretrained(model_path)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n",
    "with init_empty_weights():\n",
    "    model = AutoModelForCausalLM.from_pretrained(model_path, config=config,\n",
    "                                                    torch_dtype=torch.float16)\n",
    "q_config = {\"zero_point\": True, \"q_group_size\": 128}\n",
    "real_quantize_model_weight(\n",
    "    model, w_bit=4, q_config=q_config, init_only=True)\n",
    "\n",
    "model = load_checkpoint_and_dispatch(\n",
    "    model, load_quant_path,\n",
    "    device_map=\"auto\",\n",
    "    no_split_module_classes=[\"LlamaDecoderLayer\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "User: How can I improve my time management skills?\n",
      "ASSISTANT: Time management skills can be improved through a combination of techniques, such as setting clear goals, prioritizing tasks, and using time-saving tools and strategies. Here are some tips to help you improve your time management skills:\n",
      "\n",
      "1. Set clear goals: Establish clear and specific goals for what you want to achieve. This will help you prioritize your tasks and focus your efforts.\n",
      "2. Prioritize tasks: Identify the most important tasks that need to be completed and prioritize them accordingly. Use the Eisenhower matrix to categorize tasks into urgent and important, important but not urgent, urgent but not important, and not urgent or important.\n",
      "3. Use time-saving tools and strategies: Use tools like calendars, to-do lists, and time trackers to help you manage your time more effectively. Also, consider using time-saving strategies like batching, delegating, and automating tasks.\n",
      "4. Practice time management techniques: Practice time management techniques like the Pomodoro technique, the 80/20 rule, and Parkinson's law to help you work more efficiently.\n",
      "5. Learn to say no: Learn to say no to non-essential tasks and commitments to free up more time for what's important.\n",
      "6. Take breaks: Take regular breaks throughout the day to recharge and refocus.\n",
      "7. Review and adjust: Regularly review and adjust your time management strategies to ensure they are working for you.\n",
      "\n",
      "Remember, time management is a skill that takes time and practice to develop. Be patient with yourself and keep working on improving your time management skills.\n",
      "exit...\n"
     ]
    }
   ],
   "source": [
    "conv = get_conv_template(\"vicuna_v1.1\")\n",
    "chatio = SimpleChatIO()\n",
    "\n",
    "inp = \"How can I improve my time management skills?\"\n",
    "print(\"User:\", inp)\n",
    "\n",
    "while True:\n",
    "    if not inp:\n",
    "        try:\n",
    "            inp = chatio.prompt_for_input(conv.roles[0])\n",
    "        except EOFError:\n",
    "            inp = \"\"\n",
    "    if not inp:\n",
    "        print(\"exit...\")\n",
    "        break\n",
    "\n",
    "    conv.append_message(conv.roles[0], inp)\n",
    "    conv.append_message(conv.roles[1], None)\n",
    "\n",
    "    generate_stream_func = generate_stream\n",
    "    prompt = conv.get_prompt()\n",
    "\n",
    "    gen_params = {\n",
    "        \"model\": model_path,\n",
    "        \"prompt\": prompt,\n",
    "        \"temperature\": 0.3,\n",
    "        \"repetition_penalty\": 1.0,\n",
    "        \"max_new_tokens\": 512,\n",
    "        \"stop\": conv.stop_str,\n",
    "        \"stop_token_ids\": conv.stop_token_ids,\n",
    "        \"echo\": False,\n",
    "    }\n",
    "\n",
    "    chatio.prompt_for_output(conv.roles[1])\n",
    "    output_stream = generate_stream_func(model, tokenizer, gen_params, \"cuda\")\n",
    "    outputs = chatio.stream_output(output_stream)\n",
    "    conv.update_last_message(outputs.strip())\n",
    "    \n",
    "    inp = None"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "awq",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}