Commit 1e82667b authored by Casper's avatar Casper
Browse files

Merge remote-tracking branch 'origin/refactor-models'

parents 05abe7d6 46750ff9
import torch.nn as nn
def get_named_linears(module):
return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}
def get_op_by_name(module, op_name): def get_op_by_name(module, op_name):
# get the op by its name relative to the module # get the op by its name relative to the module
......
...@@ -31,10 +31,8 @@ ...@@ -31,10 +31,8 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"import torch\n", "from awq.models.auto import AutoAWQForCausalLM\n",
"from accelerate import init_empty_weights, load_checkpoint_and_dispatch\n", "from transformers import AutoTokenizer\n",
"from awq.quantize.quantizer import real_quantize_model_weight\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig\n",
"from tinychat.demo import gen_params, stream_output\n", "from tinychat.demo import gen_params, stream_output\n",
"from tinychat.stream_generators import StreamGenerator\n", "from tinychat.stream_generators import StreamGenerator\n",
"from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp\n", "from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp\n",
...@@ -44,98 +42,32 @@ ...@@ -44,98 +42,32 @@
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"" "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\""
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Please get the Vicuna model from [FastChat](https://github.com/lm-sys/FastChat) and run the following command to generate a quantized model checkpoint first.\n",
"\n",
"```bash\n",
"mkdir quant_cache\n",
"python -m awq.entry --model_path [vicuna-7b_model_path] \\\n",
" --w_bit 4 --q_group_size 128 \\\n",
" --load_awq awq_cache/vicuna-7b-w4-g128.pt \\\n",
" --q_backend real --dump_quant quant_cache/vicuna-7b-w4-g128-awq.pt\n",
"```"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [],
"source": [
"# model_path = \"\" # the path of vicuna-7b model\n",
"# load_quant_path = \"quant_cache/vicuna-7b-w4-g128-awq.pt\"\n",
"model_path = \"/data/llm/checkpoints/vicuna-hf/vicuna-7b\"\n",
"load_quant_path = \"/data/llm/checkpoints/vicuna-hf/vicuna-7b-awq-w4g128.pt\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We first load a empty model and replace all the linear layers with WQLinear layers. Then we load the quantized weights from the checkpoint. "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [ "outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8b79a82b73ab4d9191ba54f5d0f8cb86",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{ {
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"real weight quantization...(init only): 100%|███████████████████| 32/32 [00:11<00:00, 2.69it/s]\n", "Replacing layers...: 100%|██████████| 32/32 [00:02<00:00, 11.85it/s]\n"
"The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n",
"The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n"
] ]
} }
], ],
"source": [ "source": [
"config = AutoConfig.from_pretrained(model_path)\n", "model_path = 'vicuna-7b-v1.5-awq'\n",
"tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n", "quant_file = 'awq_model_w4_g128.pt'\n",
"with init_empty_weights():\n",
" model = AutoModelForCausalLM.from_pretrained(model_path, config=config,\n",
" torch_dtype=torch.float16)\n",
"q_config = {\"zero_point\": True, \"q_group_size\": 128}\n",
"real_quantize_model_weight(\n",
" model, w_bit=4, q_config=q_config, init_only=True)\n",
"\n", "\n",
"model = load_checkpoint_and_dispatch(\n", "tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n",
" model, load_quant_path,\n", "model = AutoAWQForCausalLM.from_quantized(model_path, quant_file)"
" device_map=\"auto\",\n",
" no_split_module_classes=[\"LlamaDecoderLayer\"]\n",
")"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 3,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[Warning] Calling a fake MLP fusion. But still faster than Huggingface Implimentation.\n"
]
},
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
...@@ -162,69 +94,47 @@ ...@@ -162,69 +94,47 @@
")" ")"
] ]
}, },
"execution_count": 4, "execution_count": 3,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
], ],
"source": [ "source": [
"make_quant_attn(model, \"cuda:0\")\n", "make_quant_attn(model.model, \"cuda:0\")\n",
"make_quant_norm(model)\n", "make_quant_norm(model.model)\n",
"make_fused_mlp(model)" "make_fused_mlp(model.model)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 4,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{
"name": "stdin",
"output_type": "stream",
"text": [
"USER: Show me some attractions in Boston.\n"
]
},
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"ASSISTANT: 1. Boston Public Library\n", "ASSISTANT: Sure! Here are some popular tourist attractions in Boston:\n",
"2. Fenway Park\n", "\n",
"3. Harvard Square\n", "1. Freedom Trail - a 2.5-mile walking trail that takes you through some of the most important historical sites in Boston, including Paul Revere's House, the Old North Church, and the site of the Boston Massacre.\n",
"4. Boston Common\n", "2. Fenway Park - home to the Boston Red Sox baseball team, this historic ballpark is one of the oldest in Major League Baseball.\n",
"5. Freedom Trail\n", "3. Museum of Fine Arts - one of the largest art museums in the country, with a collection of over 450,000 works of art from around the world.\n",
"6. Museum of Fine Arts\n", "4. Boston Harbor Islands National Recreation Area - a group of islands located just offshore from downtown Boston that offer stunning views of the city skyline and easy access to outdoor recreational activities like hiking and kayaking.\n",
"7. Isabella Stewart Gardner Museum\n", "5. New England Aquarium - one of the oldest and largest aquariums in the United States, featuring a wide variety of marine life, including giant whales and colorful fish.\n",
"8. Paul Revere House\n", "6. The USS Constitution Museum - located on board the USS Constitution, a historic ship that played a key role in the War of 1812 and is still in active service today.\n",
"9. New England Aquarium\n", "7. Bunker Hill Monument - a 221-foot-tall obelisk located in Charlestown that commemorates the Battle of Bunker Hill during the Revolutionary War.\n",
"10. Museum of Science\n", "8. The Hancock Building - a historic building in the heart of Boston that offers panoramic views of the city from its observation deck.\n",
"==================================================\n", "==================================================\n",
"Speed of Inference\n", "Speed of Inference\n",
"--------------------------------------------------\n", "--------------------------------------------------\n",
"Context Stage : 7.18 ms/token\n", "Generation Stage : 10.13 ms/token\n",
"Generation Stage : 9.49 ms/token\n", "==================================================\n",
"Average Speed : 8.53 ms/token\n",
"==================================================\n"
]
},
{
"name": "stdin",
"output_type": "stream",
"text": [
"USER: \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"EXIT...\n" "EXIT...\n"
] ]
} }
], ],
"source": [ "source": [
"model_prompter = get_prompter(\"llama\", model_path)\n", "model_prompter = get_prompter(model, model_path)\n",
"stream_generator = StreamGenerator\n", "stream_generator = StreamGenerator\n",
"count = 0\n", "count = 0\n",
"while True:\n", "while True:\n",
...@@ -239,20 +149,13 @@ ...@@ -239,20 +149,13 @@
" model_prompter.update_template(outputs)\n", " model_prompter.update_template(outputs)\n",
" count += 1" " count += 1"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python (awq)", "display_name": "Python 3",
"language": "python", "language": "python",
"name": "awq" "name": "python3"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {
...@@ -264,7 +167,7 @@ ...@@ -264,7 +167,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.11" "version": "3.10.6"
} }
}, },
"nbformat": 4, "nbformat": 4,
......
...@@ -55,17 +55,25 @@ ...@@ -55,17 +55,25 @@
"from llava.utils import disable_torch_init\n", "from llava.utils import disable_torch_init\n",
"from llava.model import *\n", "from llava.model import *\n",
"from llava.model.utils import KeywordsStoppingCriteria\n", "from llava.model.utils import KeywordsStoppingCriteria\n",
"from awq.quantize.pre_quant import apply_awq\n", "from awq.models.auto import AutoAWQForCausalLM\n",
"from awq.quantize.quantizer import real_quantize_model_weight\n",
"import os\n", "import os\n",
"import gc\n", "import gc\n",
"\n", "\n",
"from awq.quantize.auto_clip import apply_clip\n",
"from awq.quantize.auto_scale import apply_scale\n",
"\n",
"# This demo only support single GPU for now\n", "# This demo only support single GPU for now\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
"DEFAULT_IMAGE_TOKEN = \"<image>\"\n", "DEFAULT_IMAGE_TOKEN = \"<image>\"\n",
"DEFAULT_IMAGE_PATCH_TOKEN = \"<im_patch>\"\n", "DEFAULT_IMAGE_PATCH_TOKEN = \"<im_patch>\"\n",
"DEFAULT_IM_START_TOKEN = \"<im_start>\"\n", "DEFAULT_IM_START_TOKEN = \"<im_start>\"\n",
"DEFAULT_IM_END_TOKEN = \"<im_end>\"" "DEFAULT_IM_END_TOKEN = \"<im_end>\"\n",
"\n",
"def load_search_result_into_memory(model, search_path):\n",
" awq_results = torch.load(search_path, map_location=\"cpu\")\n",
" \n",
" apply_scale(model, awq_results[\"scale\"])\n",
" apply_clip(model, awq_results[\"clip\"])"
] ]
}, },
{ {
...@@ -91,20 +99,25 @@ ...@@ -91,20 +99,25 @@
} }
], ],
"source": [ "source": [
"model_path = \"/dataset/llava/LLaVA-13B-v0\" # Please change here \n", "model_path = \"liuhaotian/LLaVA-13b-delta-v0\"\n",
"quant_path = \"../quant_cache/LLaVA-13B-v0-w4-g128-awq.pt\" # place to dump quant weights\n", "quant_path = \"LLaVA-13B-v0-awq\" # place to dump quant weights\n",
"search_path = \"../awq_cache/llava-13b-v0-w4-g128.pt\" # place where you stored search results\n",
"\n", "\n",
"model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()\n", "# Load model\n",
"model = AutoAWQForCausalLM.from_pretrained(model_path)\n",
"tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)\n",
"quant_config = {\"zero_point\": True, \"q_group_size\": 128, \"w_bit\": 4}\n",
"\n", "\n",
"awq_results = torch.load(\"../awq_cache/llava-13b-v0-w4-g128.pt\", map_location=\"cpu\")\n", "# Load model and search results\n",
"apply_awq(model, awq_results)\n", "load_search_result_into_memory(model.model, search_path)\n",
"\n", "\n",
"real_quantize_model_weight(model, w_bit=4, q_config={\"zero_point\": True, \"q_group_size\": 128})\n", "# Run actual weight quantization\n",
"torch.save(model.cpu().state_dict(), quant_path)\n", "model.quantize(quant_config=quant_config, run_search=False, run_quant=True)\n",
"\n", "\n",
"del model\n", "# Save quantized model\n",
"gc.collect()\n", "model.save_quantized(quant_path)\n",
"torch.cuda.empty_cache()" "\n",
"del model"
] ]
}, },
{ {
...@@ -152,20 +165,11 @@ ...@@ -152,20 +165,11 @@
} }
], ],
"source": [ "source": [
"\n",
"disable_torch_init()\n", "disable_torch_init()\n",
"tokenizer = AutoTokenizer.from_pretrained(model_path)\n",
"config = LlavaConfig.from_pretrained(model_path)\n",
"with init_empty_weights():\n",
" model = LlavaLlamaForCausalLM.from_pretrained(model_path, config=config,\n",
" torch_dtype=torch.float16, device_map=\"auto\")\n",
"q_config = {\"zero_point\": True, \"q_group_size\": 128}\n",
"real_quantize_model_weight(\n",
" model, w_bit=4, q_config=q_config, init_only=True)\n",
"\n", "\n",
"model = load_checkpoint_and_dispatch(\n", "# Load model\n",
" model, quant_path, device_map=\"auto\"\n", "model = AutoAWQForCausalLM.from_quantized(quant_path, quant_filename=\"awq_model_w4_g128.pt\")\n",
")" "tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)"
] ]
}, },
{ {
...@@ -198,17 +202,17 @@ ...@@ -198,17 +202,17 @@
" image = Image.open(image_file).convert('RGB')\n", " image = Image.open(image_file).convert('RGB')\n",
" return image\n", " return image\n",
"\n", "\n",
"image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16)\n", "image_processor = CLIPImageProcessor.from_pretrained(model.model.config.mm_vision_tower, torch_dtype=torch.float16)\n",
"\n", "\n",
"mm_use_im_start_end = getattr(model.config, \"mm_use_im_start_end\", False)\n", "mm_use_im_start_end = getattr(model.config, \"mm_use_im_start_end\", False)\n",
"tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n", "tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)\n",
"if mm_use_im_start_end:\n", "if mm_use_im_start_end:\n",
" tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n", " tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)\n",
"\n", "\n",
"vision_tower = model.get_model().vision_tower[0]\n", "vision_tower = model.model.get_model().vision_tower[0]\n",
"if vision_tower.device.type == 'meta':\n", "if vision_tower.device.type == 'meta':\n",
" vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True).cuda()\n", " vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True).cuda()\n",
" model.get_model().vision_tower[0] = vision_tower\n", " model.model.get_model().vision_tower[0] = vision_tower\n",
"else:\n", "else:\n",
" vision_tower.to(device='cuda', dtype=torch.float16)\n", " vision_tower.to(device='cuda', dtype=torch.float16)\n",
"vision_config = vision_tower.config\n", "vision_config = vision_tower.config\n",
......
MODEL=llama-7b
# run AWQ search (optional; we provided the pre-computed results)
python -m awq.entry --model_path /dataset/llama-hf/$MODEL \
--w_bit 4 --q_group_size 128 \
--run_awq --dump_awq awq_cache/$MODEL-w4-g128.pt
# evaluate the AWQ quantize model (simulated pseudo quantization)
python -m awq.entry --model_path /dataset/llama-hf/$MODEL \
--tasks wikitext \
--w_bit 4 --q_group_size 128 \
--load_awq awq_cache/$MODEL-w4-g128.pt \
--q_backend fake
# generate real quantized weights (w4)
python -m awq.entry --model_path /dataset/llama-hf/$MODEL \
--w_bit 4 --q_group_size 128 \
--load_awq awq_cache/$MODEL-w4-g128.pt \
--q_backend real --dump_quant quant_cache/$MODEL-w4-g128-awq.pt
# load and evaluate the real quantized model (smaller gpu memory usage)
python -m awq.entry --model_path /dataset/llama-hf/$MODEL \
--tasks wikitext \
--w_bit 4 --q_group_size 128 \
--load_quant quant_cache/$MODEL-w4-g128-awq.pt
\ No newline at end of file
MODEL=opt-6.7b MODEL=facebook/opt-6.7b
# run AWQ search (optional; we provided the pre-computed results) # run AWQ search (optional; we provided the pre-computed results)
python -m awq.entry --model_path /dataset/opt/$MODEL \ python -m awq.entry --entry_type search \
--w_bit 4 --q_group_size 128 \ --model_path $MODEL \
--run_awq --dump_awq awq_cache/$MODEL-w4-g128.pt --search_path $MODEL-awq
# evaluate the AWQ quantize model (simulated pseudo quantization)
python -m awq.entry --model_path /dataset/opt/$MODEL \
--tasks wikitext \
--w_bit 4 --q_group_size 128 \
--load_awq awq_cache/$MODEL-w4-g128.pt \
--q_backend fake
# generate real quantized weights (w4) # generate real quantized weights (w4)
python -m awq.entry --model_path /dataset/opt/$MODEL \ python -m awq.entry --entry_type quant \
--w_bit 4 --q_group_size 128 \ --model_path $MODEL \
--load_awq awq_cache/$MODEL-w4-g128.pt \ --search_path $MODEL-awq/awq_model_search_result.pt \
--q_backend real --dump_quant quant_cache/$MODEL-w4-g128-awq.pt --quant_path $MODEL-awq
# load and evaluate the real quantized model (smaller gpu memory usage) # load and evaluate the real quantized model (smaller gpu memory usage)
python -m awq.entry --model_path /dataset/opt/$MODEL \ python -m awq.entry --entry_type perplexity \
--tasks wikitext \ --quant_path $MODEL-awq \
--w_bit 4 --q_group_size 128 \ --quant_file awq_model_w4_g128.pt
--load_quant quant_cache/$MODEL-w4-g128-awq.pt \ No newline at end of file
\ No newline at end of file
MODEL=lmsys/vicuna-7b-v1.5
# run AWQ search (optional; we provided the pre-computed results)
python -m awq.entry --entry_type search \
--model_path $MODEL \
--search_path $MODEL-awq
# generate real quantized weights (w4)
python -m awq.entry --entry_type quant \
--model_path $MODEL \
--search_path $MODEL-awq/awq_model_search_result.pt \
--quant_path $MODEL-awq
# load and evaluate the real quantized model (smaller gpu memory usage)
python -m awq.entry --entry_type perplexity \
--quant_path $MODEL-awq \
--quant_file awq_model_w4_g128.pt
\ No newline at end of file
...@@ -9,7 +9,7 @@ torch_is_prebuilt = os.environ.get('TORCH_IS_PREBUILT', '0') == '1' ...@@ -9,7 +9,7 @@ torch_is_prebuilt = os.environ.get('TORCH_IS_PREBUILT', '0') == '1'
# Define dependencies # Define dependencies
dependencies = [ dependencies = [
"accelerate", "sentencepiece", "tokenizers>=0.12.1", "accelerate", "sentencepiece", "tokenizers>=0.12.1",
"transformers>=4.31.0", "transformers>=4.32.0",
"lm_eval", "texttable", "lm_eval", "texttable",
"toml", "attributedict", "toml", "attributedict",
"protobuf" "protobuf"
......
...@@ -7,23 +7,16 @@ We introduce TinyChat, a cutting-edge chatbot interface designed for lightweight ...@@ -7,23 +7,16 @@ We introduce TinyChat, a cutting-edge chatbot interface designed for lightweight
The current release supports: The current release supports:
- LLaMA-2-7B/13B-chat; - LLaMA-2-7B/13B-chat;
- Vicuna; - Vicuna;
- MPT-chat; - MPT-chat;
- Falcon-instruct. - Falcon-instruct.
## Contents ## Contents
- [Examples](#examples) - [Examples](#examples)
- [Benchmarks](#benchmarks) - [Benchmarks](#benchmarks)
- [Usage](#usage) - [Usage](#usage)
- [Reference](#reference) - [Reference](#reference)
...@@ -91,73 +84,27 @@ The latency reported in all tables are per-token latency for the generation stag ...@@ -91,73 +84,27 @@ The latency reported in all tables are per-token latency for the generation stag
2. Download the pretrained instruction-tuned LLMs: 2. Download the pretrained instruction-tuned LLMs:
- For LLaMA-2-chat, please refer to [this link](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf); - For LLaMA-2-chat, please refer to [this link](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf);
- For Vicuna, please refer to [this link](https://huggingface.co/lmsys/); - For Vicuna, please refer to [this link](https://huggingface.co/lmsys/);
- For MPT-chat, please refer to [this link](https://huggingface.co/mosaicml/mpt-7b-chat); - For MPT-chat, please refer to [this link](https://huggingface.co/mosaicml/mpt-7b-chat);
- For Falcon-instruct, please refer to [this link](https://huggingface.co/tiiuae/falcon-7b-instruct). - For Falcon-instruct, please refer to [this link](https://huggingface.co/tiiuae/falcon-7b-instruct).
3. Quantize instruction-tuned LLMs with AWQ (see [usage in README](../README.md#usage)).
3. Quantize instruction-tuned LLMs with AWQ:
- We provide pre-computed AWQ search results for multiple model families, including LLaMA, OPT, Vicuna, and LLaVA. To get the pre-computed AWQ search results, run:
```bash
# git lfs install # install git lfs if not already
git clone https://huggingface.co/datasets/mit-han-lab/awq-model-zoo awq_cache
```
- You may run a one-line starter below:
```bash
./scripts/llama2_demo.sh
```
Alternatively, you may go through the process step by step. We will demonstrate the quantization process with LLaMA-2. For all other models except Falcon, one only needs to change the `model_path` and saving locations. For Falcon-7B, we also need to change `q_group_size` from 128 to 64.
- Perform AWQ search and save search results (we already did it for you):
```bash
mkdir awq_cache
python -m awq.entry --model_path /PATH/TO/LLAMA2/llama-2-7b-chat \
--w_bit 4 --q_group_size 128 \
--run_awq --dump_awq awq_cache/llama-2-7b-chat-w4-g128.pt
```
- Generate real quantized weights (INT4):
```bash
mkdir quant_cache
python -m awq.entry --model_path /PATH/TO/LLAMA2/llama-2-7b-chat \
--w_bit 4 --q_group_size 128 \
--load_awq awq_cache/llama-2-7b-chat-w4-g128.pt \
--q_backend real --dump_quant quant_cache/llama-2-7b-chat-w4-g128-awq.pt
```
4. Run the TinyChat demo: 4. Run the TinyChat demo:
Here, we use Vicuna as an example and assume that you have already quantized the model.
```bash ```bash
cd tinychat cd tinychat
python demo.py --model_type llama \ python demo.py --model_path vicuna-7b-v1.5-awq
--model_path /PATH/TO/LLAMA2/llama-2-7b-chat \
--q_group_size 128 --load_quant quant_cache/llama-2-7b-chat-w4-g128-awq.pt \
    --precision W4A16
``` ```
Note: if you use Falcon-7B-instruct, please remember to also change `q_group_size` to 64. You may also run the following command to execute the chatbot in FP16 to compare the speed and quality of language generation: You may also run the following command to execute the chatbot in FP16 to compare the speed and quality of language generation:
```bash ```bash
python demo.py --model_type llama \ python demo.py --model_path lmsys/vicuna-7b-v1.5 --precision W16A16
--model_path /PATH/TO/LLAMA2/llama-2-7b-chat \
--precision W16A16
``` ```
## Reference ## Reference
TinyChat is inspired by the following open-source projects: [FasterTransformer](https://github.com/NVIDIA/FasterTransformer), [vLLM](https://github.com/vllm-project/vllm), [FastChat](https://github.com/lm-sys/FastChat). TinyChat is inspired by the following open-source projects: [FasterTransformer](https://github.com/NVIDIA/FasterTransformer), [vLLM](https://github.com/vllm-project/vllm), [FastChat](https://github.com/lm-sys/FastChat).
import torch
import argparse import argparse
import time
import numpy as np import numpy as np
import torch from awq.models import *
import torch.nn as nn from awq.models.auto import AutoAWQForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, modeling_utils
from attributedict.collections import AttributeDict from attributedict.collections import AttributeDict
from tinychat.stream_generators import StreamGenerator, FalconStreamGenerator
from tinychat.utils.load_quant import load_awq_model, load_awq_llama_fast
from tinychat.utils.prompt_templates import get_prompter, get_stop_token_ids from tinychat.utils.prompt_templates import get_prompter, get_stop_token_ids
from tinychat.stream_generators import StreamGenerator, FalconStreamGenerator
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, modeling_utils
import os import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["CUDA_VISIBLE_DEVICES"] = "0"
...@@ -75,15 +74,12 @@ def device_warmup(device:str): ...@@ -75,15 +74,12 @@ def device_warmup(device:str):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model_type', type=str, default='LLaMa', help='type of the model')
parser.add_argument('--model_path', type=str, default='/data/llm/checkpoints/vicuna-hf/vicuna-7b', help='path to the model') parser.add_argument('--model_path', type=str, default='/data/llm/checkpoints/vicuna-hf/vicuna-7b', help='path to the model')
parser.add_argument('--quant_file', type=str, default='awq_model_w4_g128.pt', help='path to the model file')
parser.add_argument('--precision' , type=str, default='W4A16', help='compute precision') parser.add_argument('--precision' , type=str, default='W4A16', help='compute precision')
parser.add_argument('--device' , type=str, default='cuda') parser.add_argument('--device' , type=str, default='cuda')
parser.add_argument('--q_group_size', type=int, default=128)
parser.add_argument('--load_quant', type=str, default='/data/llm/checkpoints/vicuna-hf/vicuna-7b-awq-w4g128.pt', help='path to the pre-quanted 4-bit weights')
args = parser.parse_args() args = parser.parse_args()
assert args.model_type.lower() in ["llama", "falcon", "mpt"], "We only support llama & falcon & mpt now"
assert args.precision in ["W4A16", "W16A16"], "We only support W4A16/W16A16 now" assert args.precision in ["W4A16", "W16A16"], "We only support W4A16/W16A16 now"
gen_params.n_predict = 512 gen_params.n_predict = 512
...@@ -107,30 +103,21 @@ if __name__ == '__main__': ...@@ -107,30 +103,21 @@ if __name__ == '__main__':
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
if args.precision == "W4A16": if args.precision == "W4A16":
if args.model_type.lower() == "llama": model = AutoAWQForCausalLM.from_quantized(args.model_path, args.quant_file)
model = load_awq_llama_fast(model, args.load_quant, 4, args.q_group_size, args.device) assert model.model_type.lower() in ["llama", "refinedweb", "refinedwebmodel", "mpt"], "We only support llama & falcon & mpt now"
else:
model = load_awq_model(model, args.load_quant, 4, args.q_group_size, args.device)
else: else:
model = AutoModelForCausalLM.from_pretrained(args.model_path, config=config, torch_dtype=torch.float16, trust_remote_code=True).to(args.device) model = AutoModelForCausalLM.from_pretrained(args.model_path, config=config, torch_dtype=torch.float16, trust_remote_code=True).to(args.device)
# device warm up # device warm up
device_warmup(args.device) device_warmup(args.device)
if args.model_type.lower() == 'falcon': if isinstance(model, FalconAWQForCausalLM):
stream_generator = FalconStreamGenerator stream_generator = FalconStreamGenerator
else: else:
stream_generator = StreamGenerator stream_generator = StreamGenerator
# Optimize AWQ quantized model model_prompter = get_prompter(model, args.model_path)
if args.precision == "W4A16" and args.model_type.lower() == 'llama': stop_token_ids = get_stop_token_ids(model, args.model_path)
from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp
make_quant_attn(model, args.device)
make_quant_norm(model)
make_fused_mlp(model)
model_prompter = get_prompter(args.model_type, args.model_path)
stop_token_ids = get_stop_token_ids(args.model_type, args.model_path)
count = 0 count = 0
while True: while True:
# Get input from the user # Get input from the user
......
MODEL_PATH=/data/llm/checkpoints/llama2-hf
MODEL_NAME=llama-2-7b-chat
# # Perform AWQ search and save search results (we already did it for you):
# mkdir -p awq_cache
# python -m awq.entry --model_path $MODEL_PATH/$MODEL_NAME \
# --w_bit 4 --q_group_size 128 \
# --run_awq --dump_awq awq_cache/llama-2-7b-chat-w4-g128.pt
# Generate real quantized weights (INT4):
mkdir -p quant_cache
python -m awq.entry --model_path $MODEL_PATH/$MODEL_NAME \
--w_bit 4 --q_group_size 128 \
--load_awq awq_cache/llama-2-7b-chat-w4-g128.pt \
--q_backend real --dump_quant quant_cache/llama-2-7b-chat-w4-g128-awq.pt
# Run the TinyChat demo:
python demo.py --model_type llama \
--model_path $MODEL_PATH/$MODEL_NAME \
--q_group_size 128 --load_quant quant_cache/llama-2-7b-chat-w4-g128-awq.pt \
--precision W4A16
...@@ -30,6 +30,22 @@ def prepare_logits_processor( ...@@ -30,6 +30,22 @@ def prepare_logits_processor(
processor_list.append(TopKLogitsWarper(top_k)) processor_list.append(TopKLogitsWarper(top_k))
return processor_list return processor_list
def sanitize_tensor(tensor: torch.Tensor):
if tensor.dtype == torch.float16:
replacement_value = 65504
elif tensor.dtype == torch.float32:
replacement_value = 1e20
else:
return tensor
# Replace positive infinity with a large finite number
tensor[tensor == float('inf')] = replacement_value
# Replace negative infinity with a small finite number
tensor[tensor == float('-inf')] = -replacement_value
# Replace NaNs with zero
tensor[torch.isnan(tensor)] = 0.0
return tensor
@torch.inference_mode() @torch.inference_mode()
def StreamGenerator(model, def StreamGenerator(model,
...@@ -82,6 +98,7 @@ def StreamGenerator(model, ...@@ -82,6 +98,7 @@ def StreamGenerator(model,
else: else:
tmp_output_ids = None tmp_output_ids = None
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
last_token_logits = sanitize_tensor(last_token_logits)
else: else:
last_token_logits = logits[0, -1, :] last_token_logits = logits[0, -1, :]
if gen_params.temp < 1e-5 or gen_params.top_p < 1e-8: # greedy if gen_params.temp < 1e-5 or gen_params.top_p < 1e-8: # greedy
......
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from awq.quantize.quantizer import real_quantize_model_weight
from awq.quantize.qmodule import WQLinear
from tqdm import tqdm
def load_awq_model(model, checkpoint, w_bit, group_size, device):
q_config = {"zero_point": True, "q_group_size": group_size}
real_quantize_model_weight(model, w_bit, q_config, init_only = True)
pbar = tqdm(range(1))
pbar.set_description('Loading checkpoint')
for i in pbar:
if hasattr(model.config, "tie_encoder_decoder"):
model.config.tie_encoder_decoder = False
if hasattr(model.config, "tie_word_embeddings"):
model.config.tie_word_embeddings = False
model = load_checkpoint_and_dispatch(
model, checkpoint,
no_split_module_classes=[
"OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"]
).to(device)
return model
def make_quant_linear(module, names, w_bit, groupsize, device, name=''):
if isinstance(module, WQLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + '.' + attr if name != '' else attr
if name1 in names:
delattr(module, attr)
setattr(module, attr, WQLinear(w_bit, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None, device))
for name1, child in module.named_children():
make_quant_linear(child, names, w_bit, groupsize, device, name + '.' + name1 if name != '' else name1)
def find_layers(module, layers=[nn.Linear], name=''):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
return res
def load_awq_llama_fast(model, checkpoint, w_bit, group_size, device):
layers = find_layers(model)
for name in ['lm_head']:
if name in layers:
del layers[name]
make_quant_linear(model, layers, w_bit, group_size, device)
del layers
pbar = tqdm(range(1))
pbar.set_description('Loading checkpoint')
for i in pbar:
if checkpoint.endswith('.safetensors'):
from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint))
else:
model.load_state_dict(torch.load(checkpoint))
return model.to(device)
\ No newline at end of file
from typing import List from typing import List
from awq.models import *
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.falcon.modeling_falcon import FalconForCausalLM
class BasePrompter: class BasePrompter:
def __init__(self, system_inst, role1, role2, sen_spliter = "\n", qa_spliter = "\n", decorator: List[str] = None): def __init__(self, system_inst, role1, role2, sen_spliter = "\n", qa_spliter = "\n", decorator: List[str] = None):
...@@ -125,32 +128,32 @@ class MPTChatPrompter(BasePrompter): ...@@ -125,32 +128,32 @@ class MPTChatPrompter(BasePrompter):
def get_prompter(model_type, model_path = ""): def get_prompter(model, model_path = ""):
if model_type.lower() == "llama": if isinstance(model, LlamaAWQForCausalLM) or isinstance(model, LlamaForCausalLM):
if "vicuna" in model_path: if "vicuna" in model_path:
return VicunaPrompter() return VicunaPrompter()
else: else:
return Llama2Prompter() return Llama2Prompter()
elif model_type.lower() == "falcon": elif isinstance(model, FalconAWQForCausalLM) or isinstance(model, FalconForCausalLM):
# return FalconPrompter()
return FalconSimplePrompter() return FalconSimplePrompter()
elif model_type.lower() == "mpt": elif isinstance(model, MptAWQForCausalLM) or "mpt" in str(model.__class__).lower():
if "mpt" and "chat" in model_path: if "mpt" and "chat" in model_path:
return MPTChatPrompter() return MPTChatPrompter()
else: else:
return MPTPrompter() return MPTPrompter()
else: else:
raise ValueError(f"model type {model_type} is not supported") raise ValueError(f"model type {model.model_type} is not supported")
def get_stop_token_ids(model_type, model_path = ""): def get_stop_token_ids(model, model_path = ""):
if model_type.lower() == "llama": if isinstance(model, LlamaAWQForCausalLM) or isinstance(model, LlamaForCausalLM):
return [] return []
elif model_type.lower() == "falcon": elif isinstance(model, FalconAWQForCausalLM) or isinstance(model, FalconForCausalLM):
return [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] return [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
elif model_type.lower() == "mpt": elif isinstance(model, MptAWQForCausalLM) or "mpt" in str(model.__class__).lower():
if "mpt" and "chat" in model_path: if "mpt" and "chat" in model_path:
return [50278, 0] return [50278, 0]
else: else:
return [] return []
else: else:
model_type = str(model.__class__).lower()
raise ValueError(f"model type {model_type} is not supported") raise ValueError(f"model type {model_type} is not supported")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment