chat_demo.ipynb 8.43 KB
Newer Older
Ji Lin's avatar
Ji Lin committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# AWQ on Vicuna"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we use Vicuna model to demonstrate the performance of AWQ on instruction-tuned models. We implement AWQ real-INT4 inference kernels, which are wrapped as Pytorch modules and can be easily used by existing models. We also provide a simple example to show how to use AWQ to quantize a model and save/load the quantized model checkpoint."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In order to run this notebook, you need to install the following packages:\n",
Sakits's avatar
Sakits committed
22
    "- [AWQ](https://github.com/mit-han-lab/llm-awq)\n",
Ji Lin's avatar
Ji Lin committed
23
24
25
26
27
28
29
    "- [Pytorch](https://pytorch.org/)\n",
    "- [Accelerate](https://github.com/huggingface/accelerate)\n",
    "- [Transformers](https://github.com/huggingface/transformers)"
   ]
  },
  {
   "cell_type": "code",
Haotian Tang's avatar
Haotian Tang committed
30
   "execution_count": 1,
Ji Lin's avatar
Ji Lin committed
31
32
33
34
35
36
37
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from accelerate import init_empty_weights, load_checkpoint_and_dispatch\n",
    "from awq.quantize.quantizer import real_quantize_model_weight\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig\n",
Haotian Tang's avatar
Haotian Tang committed
38
39
40
41
    "from tinychat.demo import gen_params, stream_output\n",
    "from tinychat.stream_generators import StreamGenerator\n",
    "from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp\n",
    "from tinychat.utils.prompt_templates import get_prompter\n",
Ji Lin's avatar
Ji Lin committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    "import os\n",
    "# This demo only support single GPU for now\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Please get the Vicuna model from [FastChat](https://github.com/lm-sys/FastChat) and run the following command to generate a quantized model checkpoint first.\n",
    "\n",
    "```bash\n",
    "mkdir quant_cache\n",
    "python -m awq.entry --model_path [vicuna-7b_model_path] \\\n",
    "    --w_bit 4 --q_group_size 128 \\\n",
    "    --load_awq awq_cache/vicuna-7b-w4-g128.pt \\\n",
    "    --q_backend real --dump_quant quant_cache/vicuna-7b-w4-g128-awq.pt\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
Haotian Tang's avatar
Haotian Tang committed
68
69
70
71
    "# model_path = \"\" # the path of vicuna-7b model\n",
    "# load_quant_path = \"quant_cache/vicuna-7b-w4-g128-awq.pt\"\n",
    "model_path = \"/data/llm/checkpoints/vicuna-hf/vicuna-7b\"\n",
    "load_quant_path = \"/data/llm/checkpoints/vicuna-hf/vicuna-7b-awq-w4g128.pt\""
Ji Lin's avatar
Ji Lin committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We first load a empty model and replace all the linear layers with WQLinear layers. Then we load the quantized weights from the checkpoint. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
Haotian Tang's avatar
Haotian Tang committed
87
88
89
90
91
92
93
94
95
96
97
98
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8b79a82b73ab4d9191ba54f5d0f8cb86",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
Ji Lin's avatar
Ji Lin committed
99
100
101
102
103
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
Haotian Tang's avatar
Haotian Tang committed
104
105
106
      "real weight quantization...(init only): 100%|███████████████████| 32/32 [00:11<00:00,  2.69it/s]\n",
      "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n",
      "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n"
Ji Lin's avatar
Ji Lin committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
     ]
    }
   ],
   "source": [
    "config = AutoConfig.from_pretrained(model_path)\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n",
    "with init_empty_weights():\n",
    "    model = AutoModelForCausalLM.from_pretrained(model_path, config=config,\n",
    "                                                    torch_dtype=torch.float16)\n",
    "q_config = {\"zero_point\": True, \"q_group_size\": 128}\n",
    "real_quantize_model_weight(\n",
    "    model, w_bit=4, q_config=q_config, init_only=True)\n",
    "\n",
    "model = load_checkpoint_and_dispatch(\n",
    "    model, load_quant_path,\n",
    "    device_map=\"auto\",\n",
    "    no_split_module_classes=[\"LlamaDecoderLayer\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
Haotian Tang's avatar
Haotian Tang committed
136
      "[Warning] Calling a fake MLP fusion. But still faster than Huggingface Implimentation.\n"
Ji Lin's avatar
Ji Lin committed
137
     ]
Haotian Tang's avatar
Haotian Tang committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    },
    {
     "data": {
      "text/plain": [
       "LlamaForCausalLM(\n",
       "  (model): LlamaModel(\n",
       "    (embed_tokens): Embedding(32000, 4096, padding_idx=0)\n",
       "    (layers): ModuleList(\n",
       "      (0-31): 32 x LlamaDecoderLayer(\n",
       "        (self_attn): QuantLlamaAttention(\n",
       "          (qkv_proj): WQLinear(in_features=4096, out_features=12288, bias=False, w_bit=4, group_size=128)\n",
       "          (o_proj): WQLinear(in_features=4096, out_features=4096, bias=False, w_bit=4, group_size=128)\n",
       "          (rotary_emb): QuantLlamaRotaryEmbedding()\n",
       "        )\n",
       "        (mlp): QuantLlamaMLP(\n",
       "          (down_proj): WQLinear(in_features=11008, out_features=4096, bias=False, w_bit=4, group_size=128)\n",
       "        )\n",
       "        (input_layernorm): FTLlamaRMSNorm()\n",
       "        (post_attention_layernorm): FTLlamaRMSNorm()\n",
       "      )\n",
       "    )\n",
       "    (norm): FTLlamaRMSNorm()\n",
       "  )\n",
       "  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
Ji Lin's avatar
Ji Lin committed
168
169
170
    }
   ],
   "source": [
Haotian Tang's avatar
Haotian Tang committed
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
    "make_quant_attn(model, \"cuda:0\")\n",
    "make_quant_norm(model)\n",
    "make_fused_mlp(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "USER:  Show me some attractions in Boston.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ASSISTANT: 1. Boston Public Library\n",
      "2. Fenway Park\n",
      "3. Harvard Square\n",
      "4. Boston Common\n",
      "5. Freedom Trail\n",
      "6. Museum of Fine Arts\n",
      "7. Isabella Stewart Gardner Museum\n",
      "8. Paul Revere House\n",
      "9. New England Aquarium\n",
      "10. Museum of Science\n",
      "==================================================\n",
      "Speed of Inference\n",
      "--------------------------------------------------\n",
      "Context Stage    : 7.18 ms/token\n",
      "Generation Stage : 9.49 ms/token\n",
      "Average Speed    : 8.53 ms/token\n",
      "==================================================\n"
     ]
    },
    {
     "name": "stdin",
     "output_type": "stream",
     "text": [
      "USER:  \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "EXIT...\n"
     ]
    }
   ],
   "source": [
    "model_prompter = get_prompter(\"llama\", model_path)\n",
    "stream_generator = StreamGenerator\n",
    "count = 0\n",
Ji Lin's avatar
Ji Lin committed
230
    "while True:\n",
Haotian Tang's avatar
Haotian Tang committed
231
232
233
234
    "    # Get input from the user\n",
    "    input_prompt = input(\"USER: \")\n",
    "    if input_prompt == \"\":\n",
    "        print(\"EXIT...\")\n",
Ji Lin's avatar
Ji Lin committed
235
    "        break\n",
Haotian Tang's avatar
Haotian Tang committed
236
237
238
239
240
    "    model_prompter.insert_prompt(input_prompt)\n",
    "    output_stream = stream_generator(model, tokenizer, model_prompter.model_input, gen_params, device=\"cuda:0\")\n",
    "    outputs = stream_output(output_stream)    \n",
    "    model_prompter.update_template(outputs)\n",
    "    count += 1"
Ji Lin's avatar
Ji Lin committed
241
   ]
Haotian Tang's avatar
Haotian Tang committed
242
243
244
245
246
247
248
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
Ji Lin's avatar
Ji Lin committed
249
250
251
252
  }
 ],
 "metadata": {
  "kernelspec": {
Haotian Tang's avatar
Haotian Tang committed
253
   "display_name": "Python (awq)",
Ji Lin's avatar
Ji Lin committed
254
   "language": "python",
Haotian Tang's avatar
Haotian Tang committed
255
   "name": "awq"
Ji Lin's avatar
Ji Lin committed
256
257
258
259
260
261
262
263
264
265
266
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
Haotian Tang's avatar
Haotian Tang committed
267
268
   "version": "3.10.11"
  }
Ji Lin's avatar
Ji Lin committed
269
270
 },
 "nbformat": 4,
Haotian Tang's avatar
Haotian Tang committed
271
 "nbformat_minor": 4
Ji Lin's avatar
Ji Lin committed
272
}