Commit 8c8498e8 authored by Casper Hansen's avatar Casper Hansen
Browse files

TinyChat working with new framework

parent 5c2c85fe
import torch
import argparse import argparse
import time
import numpy as np import numpy as np
import torch from awq.models import *
import torch.nn as nn from awq.models.auto import AutoAWQForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, modeling_utils
from attributedict.collections import AttributeDict from attributedict.collections import AttributeDict
from tinychat.stream_generators import StreamGenerator, FalconStreamGenerator
from tinychat.utils.load_quant import load_awq_model, load_awq_llama_fast
from tinychat.utils.prompt_templates import get_prompter, get_stop_token_ids from tinychat.utils.prompt_templates import get_prompter, get_stop_token_ids
from tinychat.stream_generators import StreamGenerator, FalconStreamGenerator
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, modeling_utils
import os import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["CUDA_VISIBLE_DEVICES"] = "0"
...@@ -75,15 +74,12 @@ def device_warmup(device:str): ...@@ -75,15 +74,12 @@ def device_warmup(device:str):
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model_type', type=str, default='LLaMa', help='type of the model')
parser.add_argument('--model_path', type=str, default='/data/llm/checkpoints/vicuna-hf/vicuna-7b', help='path to the model') parser.add_argument('--model_path', type=str, default='/data/llm/checkpoints/vicuna-hf/vicuna-7b', help='path to the model')
parser.add_argument('--quant_file', type=str, default='awq_model_w4_g128.pt', help='path to the model file')
parser.add_argument('--precision' , type=str, default='W4A16', help='compute precision') parser.add_argument('--precision' , type=str, default='W4A16', help='compute precision')
parser.add_argument('--device' , type=str, default='cuda') parser.add_argument('--device' , type=str, default='cuda')
parser.add_argument('--q_group_size', type=int, default=128)
parser.add_argument('--load_quant', type=str, default='/data/llm/checkpoints/vicuna-hf/vicuna-7b-awq-w4g128.pt', help='path to the pre-quanted 4-bit weights')
args = parser.parse_args() args = parser.parse_args()
assert args.model_type.lower() in ["llama", "falcon", "mpt"], "We only support llama & falcon & mpt now"
assert args.precision in ["W4A16", "W16A16"], "We only support W4A16/W16A16 now" assert args.precision in ["W4A16", "W16A16"], "We only support W4A16/W16A16 now"
gen_params.n_predict = 512 gen_params.n_predict = 512
...@@ -107,30 +103,28 @@ if __name__ == '__main__': ...@@ -107,30 +103,28 @@ if __name__ == '__main__':
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
if args.precision == "W4A16": if args.precision == "W4A16":
if args.model_type.lower() == "llama": model = AutoAWQForCausalLM.from_quantized(args.model_path, args.quant_file)
model = load_awq_llama_fast(model, args.load_quant, 4, args.q_group_size, args.device) assert model.model_type.lower() in ["llama", "refinedweb", "refinedwebmodel", "mpt"], "We only support llama & falcon & mpt now"
else:
model = load_awq_model(model, args.load_quant, 4, args.q_group_size, args.device)
else: else:
model = AutoModelForCausalLM.from_pretrained(args.model_path, config=config, torch_dtype=torch.float16, trust_remote_code=True).to(args.device) model = AutoModelForCausalLM.from_pretrained(args.model_path, config=config, torch_dtype=torch.float16, trust_remote_code=True).to(args.device)
# device warm up # device warm up
device_warmup(args.device) device_warmup(args.device)
if args.model_type.lower() == 'falcon': if isinstance(model, FalconAWQForCausalLM):
stream_generator = FalconStreamGenerator stream_generator = FalconStreamGenerator
else: else:
stream_generator = StreamGenerator stream_generator = StreamGenerator
# Optimize AWQ quantized model # Optimize AWQ quantized model
if args.precision == "W4A16" and args.model_type.lower() == 'llama': if args.precision == "W4A16" and isinstance(model, LlamaAWQForCausalLM):
from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp
make_quant_attn(model, args.device) make_quant_attn(model.model, args.device)
make_quant_norm(model) make_quant_norm(model.model)
make_fused_mlp(model) make_fused_mlp(model.model)
model_prompter = get_prompter(args.model_type, args.model_path) model_prompter = get_prompter(model, args.model_path)
stop_token_ids = get_stop_token_ids(args.model_type, args.model_path) stop_token_ids = get_stop_token_ids(model, args.model_path)
count = 0 count = 0
while True: while True:
# Get input from the user # Get input from the user
......
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from awq.quantize.quantizer import real_quantize_model_weight
from awq.quantize.qmodule import WQLinear
from tqdm import tqdm
def load_awq_model(model, checkpoint, w_bit, group_size, device):
q_config = {"zero_point": True, "q_group_size": group_size}
real_quantize_model_weight(model, w_bit, q_config, init_only = True)
pbar = tqdm(range(1))
pbar.set_description('Loading checkpoint')
for i in pbar:
if hasattr(model.config, "tie_encoder_decoder"):
model.config.tie_encoder_decoder = False
if hasattr(model.config, "tie_word_embeddings"):
model.config.tie_word_embeddings = False
model = load_checkpoint_and_dispatch(
model, checkpoint,
no_split_module_classes=[
"OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"]
).to(device)
return model
def make_quant_linear(module, names, w_bit, groupsize, device, name=''):
if isinstance(module, WQLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + '.' + attr if name != '' else attr
if name1 in names:
delattr(module, attr)
setattr(module, attr, WQLinear(w_bit, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None, device))
for name1, child in module.named_children():
make_quant_linear(child, names, w_bit, groupsize, device, name + '.' + name1 if name != '' else name1)
def find_layers(module, layers=[nn.Linear], name=''):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
return res
def load_awq_llama_fast(model, checkpoint, w_bit, group_size, device):
layers = find_layers(model)
for name in ['lm_head']:
if name in layers:
del layers[name]
make_quant_linear(model, layers, w_bit, group_size, device)
del layers
pbar = tqdm(range(1))
pbar.set_description('Loading checkpoint')
for i in pbar:
if checkpoint.endswith('.safetensors'):
from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint))
else:
model.load_state_dict(torch.load(checkpoint))
return model.to(device)
\ No newline at end of file
from typing import List from typing import List
from awq.models import *
class BasePrompter: class BasePrompter:
def __init__(self, system_inst, role1, role2, sen_spliter = "\n", qa_spliter = "\n", decorator: List[str] = None): def __init__(self, system_inst, role1, role2, sen_spliter = "\n", qa_spliter = "\n", decorator: List[str] = None):
...@@ -125,32 +126,31 @@ class MPTChatPrompter(BasePrompter): ...@@ -125,32 +126,31 @@ class MPTChatPrompter(BasePrompter):
def get_prompter(model_type, model_path = ""): def get_prompter(model, model_path = ""):
if model_type.lower() == "llama": if isinstance(model, LlamaAWQForCausalLM):
if "vicuna" in model_path: if "vicuna" in model_path:
return VicunaPrompter() return VicunaPrompter()
else: else:
return Llama2Prompter() return Llama2Prompter()
elif model_type.lower() == "falcon": elif isinstance(model, FalconAWQForCausalLM):
# return FalconPrompter()
return FalconSimplePrompter() return FalconSimplePrompter()
elif model_type.lower() == "mpt": elif isinstance(model, MptAWQForCausalLM):
if "mpt" and "chat" in model_path: if "mpt" and "chat" in model_path:
return MPTChatPrompter() return MPTChatPrompter()
else: else:
return MPTPrompter() return MPTPrompter()
else: else:
raise ValueError(f"model type {model_type} is not supported") raise ValueError(f"model type {model.model_type} is not supported")
def get_stop_token_ids(model_type, model_path = ""): def get_stop_token_ids(model, model_path = ""):
if model_type.lower() == "llama": if isinstance(model, LlamaAWQForCausalLM):
return [] return []
elif model_type.lower() == "falcon": elif isinstance(model, FalconAWQForCausalLM):
return [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] return [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
elif model_type.lower() == "mpt": elif isinstance(model, MptAWQForCausalLM):
if "mpt" and "chat" in model_path: if "mpt" and "chat" in model_path:
return [50278, 0] return [50278, 0]
else: else:
return [] return []
else: else:
raise ValueError(f"model type {model_type} is not supported") raise ValueError(f"model type {model.model_type} is not supported")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment