Commit 93799c9d authored by Casper's avatar Casper
Browse files

Remove tinychat

parent 6924d106
# TinyChat: Efficient and Lightweight Chatbot with AWQ
We introduce TinyChat, a cutting-edge chatbot interface designed for lightweight resource consumption and fast inference speed on GPU platforms. It allows for seamless deployment on consumer-level GPUs such as 3090/4090 and low-power edge devices like the NVIDIA Jetson Orin, empowering users with a responsive conversational experience like never before.
The current release supports:
- LLaMA-2-7B/13B-chat;
- Vicuna;
- MPT-chat;
- Falcon-instruct.
## Contents
- [Examples](#examples)
- [Benchmarks](#benchmarks)
- [Usage](#usage)
- [Reference](#reference)
## Examples
Thanks to AWQ, TinyChat can now deliver more prompt responses through 4-bit inference. The following examples showcase that TinyChat's W4A16 generation is 2.3x faster on RTX 4090 and 1.4x faster on Jetson Orin, compared to the FP16 baselines. (Tested with [LLaMA-2-7b]( https://huggingface.co/meta-llama/Llama-2-7b-chat-hf ) model.)
* TinyChat on RTX 4090:
![TinyChat on RTX 4090: W4A16 is 2.3x faster than FP16](./figures/4090_example.gif)
* TinyChat on Jetson Orin:
![TinyChat on Jetson Orin: W4A16 is 1.4x faster than FP16](./figures/orin_example.gif)
## Benchmarks
We benchmark TinyChat on A6000 (server-class GPU), 4090 (desktop GPU) and Orin (edge GPU).
We use the default implementation from Huggingface for the FP16 baseline. The INT4 implementation applies AWQ and utilizes our fast W4A16 GPU kernel. Please notice that the end-to-end runtime for INT4 TinyChat could be further improved if we reduce the framework overhead from Huggingface (e.g. utilizing the implementation from TGI). We are working on a new release with even faster inference performance, please stay tuned!
The latency reported in all tables are per-token latency for the generation stage.
### A6000 Results
| Model | FP16 latency (ms) | INT4 latency (ms) | Speedup |
| ----------- |:-----------------:|:-----------------:|:-------:|
| LLaMA-2-7B | 27.14 | 12.44 | 2.18x |
| LLaMA-2-13B | 47.28 | 20.28 | 2.33x |
| Vicuna-7B | 26.06 | 12.43 | 2.10x |
| Vicuna-13B | 44.91 | 17.30 | 2.60x |
| MPT-7B | 22.79 | 16.87 | 1.35x |
| MPT-30B | OOM | 31.57 | -- |
| Falcon-7B | 39.44 | 27.34 | 1.44x |
### 4090 Results
| Model | FP16 latency (ms) | INT4 latency (ms) | Speedup |
| ----------- |:-----------------:|:-----------------:|:-------:|
| LLaMA-2-7B | 19.97 | 8.66 | 2.31x |
| LLaMA-2-13B | OOM | 13.54 | -- |
| Vicuna-7B | 19.09 | 8.61 | 2.22x |
| Vicuna-13B | OOM | 12.17 | -- |
| MPT-7B | 17.09 | 12.58 | 1.36x |
| MPT-30B | OOM | 23.54 | -- |
| Falcon-7B | 29.91 | 19.84 | 1.51x |
### Orin Results
| Model | FP16 latency (ms) | INT4 latency (ms) | Speedup |
| ----------- |:-----------------:|:-----------------:|:-------:|
| LLaMA-2-7B | 104.71 | 75.11 | 1.39x |
| LLaMA-2-13B | OOM | 136.81 | -- |
| Vicuna-7B | 93.12 | 65.34 | 1.43x |
| Vicuna-13B | OOM | 115.4 | -- |
| MPT-7B | 89.85 | 67.36 | 1.33x |
| Falcon-7B | 147.84 | 102.74 | 1.44x |
## Usage
1. Please follow the [AWQ installation guidance](https://github.com/mit-han-lab/llm-awq#readme) to install AWQ and its dependencies.
2. Download the pretrained instruction-tuned LLMs:
- For LLaMA-2-chat, please refer to [this link](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf);
- For Vicuna, please refer to [this link](https://huggingface.co/lmsys/);
- For MPT-chat, please refer to [this link](https://huggingface.co/mosaicml/mpt-7b-chat);
- For Falcon-instruct, please refer to [this link](https://huggingface.co/tiiuae/falcon-7b-instruct).
3. Quantize instruction-tuned LLMs with AWQ (see [usage in README](../README.md#usage)).
4. Run the TinyChat demo:
Here, we use Vicuna as an example and assume that you have already quantized the model.
```bash
cd tinychat
python demo.py --model_path vicuna-7b-v1.5-awq
```
You may also run the following command to execute the chatbot in FP16 to compare the speed and quality of language generation:
```bash
python demo.py --model_path lmsys/vicuna-7b-v1.5 --precision W16A16
```
## Reference
TinyChat is inspired by the following open-source projects: [FasterTransformer](https://github.com/NVIDIA/FasterTransformer), [vLLM](https://github.com/vllm-project/vllm), [FastChat](https://github.com/lm-sys/FastChat).
import torch
import argparse
import numpy as np
from awq.models import *
from awq.models.auto import AutoAWQForCausalLM
from attributedict.collections import AttributeDict
from tinychat.utils.prompt_templates import get_prompter, get_stop_token_ids
from tinychat.stream_generators import StreamGenerator, FalconStreamGenerator
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, modeling_utils
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# opt_params in TinyLLMEngine
gen_params = AttributeDict([
("seed", -1), # RNG seed
("n_threads", 1), # TODO: fix this
("n_predict", 512), # new tokens to predict
("n_parts", -1), # amount of model parts (-1: determine from model dimensions)
("n_ctx", 512), # context size
("n_batch", 512), # batch size for prompt processing (must be >=32 to use BLAS)
("n_keep", 0), # number of tokens to keep from initial prompt
("n_vocab", 50272), # vocabulary size
# sampling parameters
("logit_bias", dict()), # logit bias for specific tokens: <int, float>
("top_k", 40), # <= 0 to use vocab size
("top_p", 0.95), # 1.0 = disabled
("tfs_z", 1.00), # 1.0 = disabled
("typical_p", 1.00), # 1.0 = disabled
("temp", 0.70), # 1.0 = disabled
("repeat_penalty", 1.10), # 1.0 = disabled
("repeat_last_n", 64), # last n tokens to penalize (0 = disable penalty, -1 = context size)
("frequency_penalty", 0.00),# 0.0 = disabled
("presence_penalty", 0.00), # 0.0 = disabled
("mirostat", 0), # 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
("mirostat_tau", 5.00), # target entropy
("mirostat_eta", 0.10), # learning rate
])
def stream_output(output_stream):
print(f"ASSISTANT: ", end="", flush=True)
pre = 0
for outputs in output_stream:
output_text = outputs["text"]
output_text = output_text.strip().split(" ")
now = len(output_text) - 1
if now > pre:
print(" ".join(output_text[pre:now]), end=" ", flush=True)
pre = now
print(" ".join(output_text[pre:]), flush=True)
if "timing" in outputs and outputs["timing"] is not None:
timing = outputs["timing"]
context_tokens = timing["context_tokens"]
context_time = timing["context_time"]
total_tokens = timing["total_tokens"]
generation_time_list = timing["generation_time_list"]
generation_tokens = len(generation_time_list)
average_speed = (context_time + np.sum(generation_time_list)) / (context_tokens + generation_tokens)
print("=" * 50)
print("Speed of Inference")
print("-" * 50)
# print(f"Context Stage : {context_time/context_tokens * 1000:.2f} ms/token")
print(f"Generation Stage : {np.average(generation_time_list) * 1000:.2f} ms/token")
# print(f"Average Speed : {average_speed * 1000:.2f} ms/token")
print("=" * 50)
# print("token num:", total_tokens)
# print("Model total Time = ", (context_time + np.sum(generation_time_list))*1000, "ms" )
return " ".join(output_text)
def device_warmup(device:str):
warm_up = torch.randn((4096,4096)).to(device)
torch.mm(warm_up,warm_up)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default='/data/llm/checkpoints/vicuna-hf/vicuna-7b', help='path to the model')
parser.add_argument('--quant_file', type=str, default='awq_model_w4_g128.pt', help='path to the model file')
parser.add_argument('--precision' , type=str, default='W4A16', help='compute precision')
parser.add_argument('--device' , type=str, default='cuda')
args = parser.parse_args()
assert args.precision in ["W4A16", "W16A16"], "We only support W4A16/W16A16 now"
gen_params.n_predict = 512
gen_params.n_vocab = 32000
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.kaiming_normal_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True)
if "mpt" in config.__class__.__name__.lower():
# config.init_device="meta"
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, trust_remote_code=True)
else:
tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False, trust_remote_code=True)
modeling_utils._init_weights = False
torch.set_default_dtype(torch.half)
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
if args.precision == "W4A16":
model = AutoAWQForCausalLM.from_quantized(args.model_path, args.quant_file)
assert model.model_type.lower() in ["llama", "refinedweb", "refinedwebmodel", "mpt"], "We only support llama & falcon & mpt now"
else:
model = AutoModelForCausalLM.from_pretrained(args.model_path, config=config, torch_dtype=torch.float16, trust_remote_code=True).to(args.device)
# device warm up
device_warmup(args.device)
if isinstance(model, FalconAWQForCausalLM):
stream_generator = FalconStreamGenerator
else:
stream_generator = StreamGenerator
model_prompter = get_prompter(model, args.model_path)
stop_token_ids = get_stop_token_ids(model, args.model_path)
count = 0
while True:
# Get input from the user
input_prompt = input("USER: ")
if input_prompt == "":
print("EXIT...")
break
model_prompter.insert_prompt(input_prompt)
output_stream = stream_generator(model, tokenizer, model_prompter.model_input, gen_params, device=args.device, stop_token_ids = stop_token_ids)
outputs = stream_output(output_stream)
model_prompter.update_template(outputs)
count += 1
from .falcon_stream_gen import *
from .stream_gen import *
\ No newline at end of file
import gc
from threading import Thread
from typing import Iterable
import torch
import transformers
from transformers import TextIteratorStreamer, GenerationConfig
transformers.logging.set_verbosity_error()
def is_partial_stop(output: str, stop_str: str):
"""Check whether the output contains a partial stop str."""
for i in range(0, min(len(output), len(stop_str))):
if stop_str.startswith(output[-i:]):
return True
return False
@torch.inference_mode()
def FalconStreamGenerator(
model,
tokenizer,
input : str,
gen_params : dict,
device: str = "cuda:0",
context_len = 2048,
stream_interval = 2,
judge_sent_end = False,
echo: bool = False,
stop_str: str = "\nUser",
stop_token_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
):
prompt = input
len_prompt = len(prompt)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
if gen_params.top_k <= 0:
top_k = gen_params.n_vocab
else:
top_k = gen_params.top_k
max_new_tokens = gen_params.n_predict
max_src_len = context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:] # truncate the input prompt
attention_mask = attention_mask[-max_src_len:] # truncate the input prompt
input_echo_len = len(input_ids)
stop_token_ids.append(tokenizer.eos_token_id)
decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config)
generation_config = GenerationConfig(
max_new_tokens = max_new_tokens,
do_sample = gen_params.temp >= 1e-5,
temperature = gen_params.temp,
repetition_penalty = gen_params.repeat_penalty,
no_repeat_ngram_size = 10,
top_p = gen_params.top_p,
top_k = top_k,
eos_token_id = stop_token_ids,
)
generation_kwargs = dict(
inputs=input_ids,
attention_mask=attention_mask,
streamer=streamer,
generation_config=generation_config,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
if echo:
# means keep the prompt
output = prompt
else:
output = ""
for i, new_text in enumerate(streamer):
output += new_text
if i % stream_interval == 0:
if echo:
rfind_start = len_prompt
else:
rfind_start = 0
partially_stopped = False
if stop_str:
if isinstance(stop_str, str):
pos = output.rfind(stop_str, rfind_start)
if pos != -1:
output = output[:pos]
else:
partially_stopped = is_partial_stop(output, stop_str)
elif isinstance(stop_str, Iterable):
for each_stop in stop_str:
pos = output.rfind(each_stop, rfind_start)
if pos != -1:
output = output[:pos]
break
else:
partially_stopped = is_partial_stop(output, each_stop)
if partially_stopped:
break
else:
raise ValueError("Invalid stop field type.")
# prevent yielding partial stop sequence
if not partially_stopped:
yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": None,
}
output = output.strip()
# finish stream event, which contains finish reason
if i == max_new_tokens - 1:
finish_reason = "length"
elif partially_stopped:
finish_reason = None
else:
finish_reason = "stop"
yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": finish_reason,
}
# clean
gc.collect()
torch.cuda.empty_cache()
\ No newline at end of file
import torch
import gc
import time
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
context_tokens = 0
context_time = 0.0
total_tokens = 0
generation_time_list = []
def prepare_logits_processor(
temperature: float, repetition_penalty: float, top_p: float, top_k: int
) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
if temperature >= 1e-5 and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
if repetition_penalty > 1.0:
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
if 1e-8 <= top_p < 1.0:
processor_list.append(TopPLogitsWarper(top_p))
if top_k > 0:
processor_list.append(TopKLogitsWarper(top_k))
return processor_list
def sanitize_tensor(tensor: torch.Tensor):
if tensor.dtype == torch.float16:
replacement_value = 65504
elif tensor.dtype == torch.float32:
replacement_value = 1e20
else:
return tensor
# Replace positive infinity with a large finite number
tensor[tensor == float('inf')] = replacement_value
# Replace negative infinity with a small finite number
tensor[tensor == float('-inf')] = -replacement_value
# Replace NaNs with zero
tensor[torch.isnan(tensor)] = 0.0
return tensor
@torch.inference_mode()
def StreamGenerator(model,
tokenizer,
input : str,
gen_params : dict,
device: str = "cuda:0",
stream_interval: int = 2,
echo: bool = False,
stop_token_ids = []
):
input_ids = tokenizer(input).input_ids
input_echo_len = len(input_ids)
# print(input_ids)
output_ids = list(input_ids)
len_input = len(input)
if gen_params.top_k <= 0:
top_k = gen_params.n_vocab
else:
top_k = gen_params.top_k
logits_processor = prepare_logits_processor(
gen_params.temp, gen_params.repeat_penalty, gen_params.top_p, top_k
)
past_key_values = out = None
stop_token_ids.append(tokenizer.eos_token_id)
max_new_tokens = gen_params.n_predict
for i in range(max_new_tokens):
torch.cuda.synchronize()
t_st = time.time()
if i == 0: # Context Stage
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else:
out = model(
input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
past_key_values=past_key_values,
)
logits = out.logits
past_key_values = out.past_key_values
# Processing the logits
if logits_processor:
if gen_params.repeat_penalty > 1.0:
tmp_output_ids = torch.as_tensor([output_ids], device=logits.device)
else:
tmp_output_ids = None
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
last_token_logits = sanitize_tensor(last_token_logits)
else:
last_token_logits = logits[0, -1, :]
if gen_params.temp < 1e-5 or gen_params.top_p < 1e-8: # greedy
token = int(torch.argmax(last_token_logits))
else:
probs = torch.softmax(last_token_logits, dim=-1)
token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token)
torch.cuda.synchronize()
t_ed = time.time()
global context_time
global context_tokens
global total_tokens
global generation_time_list
if i == 0:
context_time = t_ed - t_st
context_tokens = logits.shape[1]
generation_time_list = []
else:
generation_time_list.append(t_ed-t_st)
if token in stop_token_ids:
stopped = True
else:
stopped = False
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
if echo:
tmp_output_ids = output_ids
rfind_start = len_input
else:
tmp_output_ids = output_ids[input_echo_len:]
rfind_start = 0
output = tokenizer.decode(
tmp_output_ids,
skip_special_tokens=True,
spaces_between_special_tokens=False,
)
partially_stopped = False
# prevent yielding partial stop sequence
if not partially_stopped:
yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": None,
"timing": None,
}
if stopped:
break
# finish stream event, which contains finish reason
if i == max_new_tokens - 1:
finish_reason = "length"
elif stopped:
finish_reason = "stop"
else:
finish_reason = None
total_tokens = (context_tokens + len(generation_time_list))
yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": finish_reason,
"timing":{
"context_tokens": context_tokens,
"context_time": context_time,
"total_tokens": total_tokens,
"generation_time_list": generation_time_list,
}
}
del past_key_values, out
gc.collect()
torch.cuda.empty_cache()
# return context_tokens, context_time, total_tokens, generation_time_list
\ No newline at end of file
from typing import List
from awq.models import *
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.falcon.modeling_falcon import FalconForCausalLM
class BasePrompter:
def __init__(self, system_inst, role1, role2, sen_spliter = "\n", qa_spliter = "\n", decorator: List[str] = None):
self.system_inst = system_inst # System Instruction
self.role1 = role1 # The name of USER
self.role2 = role2 # The name of AI-Assistant
self.sen_spliter = sen_spliter # How to split system/user/assistant outputs
self.qa_spliter = qa_spliter # How to split Q&A rounds
self.decorator = decorator
if self.decorator == None:
self.starter = ""
self.stopper = ""
else:
self.starter = self.decorator[0]
self.stopper = self.decorator[1]
if self.system_inst == None:
self.template = self.starter + self.role1 + ": {prompt}" + self.stopper + self.sen_spliter \
+ self.starter + self.role2 + ":"
else:
self.template = self.starter + self.system_inst + self.stopper + self.sen_spliter \
+ self.starter + self.role1 + ": {prompt}" + self.stopper + self.sen_spliter \
+ self.starter + self.role2 + ":"
self.model_input = None
def insert_prompt(self, input_prompt):
self.model_input = self.template.format(prompt=input_prompt)
def update_template(self, outputs):
self.template = self.model_input + " " + outputs.strip() + self.stopper + self.qa_spliter \
+ self.starter + self.role1 + ": {prompt}" + self.stopper + self.sen_spliter \
+ self.starter + self.role2 + ":"
self.model_input = None
class OneShotBasePrompter(BasePrompter):
def __init__(self,
oneshot_example: List[str], # User prompt + Assistant responce
system_inst, role1, role2, sen_spliter = "\n", qa_spliter = "\n", decorator: List[str] = None):
super().__init__(system_inst, role1, role2, sen_spliter, qa_spliter)
assert len(oneshot_example) == 2, "One-shot example must be a List of 2 strs."
self.user_example = oneshot_example[0]
self.assistant_example = oneshot_example[1]
self.insert_prompt(self.user_example)
self.update_template(self.assistant_example)
class VicunaPrompter(BasePrompter):
def __init__(self):
system_inst = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
role1 = "USER"
role2 = "ASSISTANT"
sen_spliter = " "
qa_spliter = "</s>"
super().__init__(system_inst, role1, role2, sen_spliter, qa_spliter)
class Llama2Prompter(OneShotBasePrompter):
def __init__(self):
system_inst = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
role1 = "### Human"
role2 = "### Assistant"
sen_spliter = "\n"
qa_spliter = "</s>"
user_example="Got any creative ideas for a 10 year old's birthday?"
assistant_example = "Of course! Here are some creative ideas for a 10-year-old's birthday party:\n" \
+ "1. Treasure Hunt: Organize a treasure hunt in your backyard or nearby park. Create clues and riddles for the kids to solve, leading them to hidden treasures and surprises.\n" \
+ "2. Science Party: Plan a science-themed party where kids can engage in fun and interactive experiments. You can set up different stations with activities like making slime, erupting volcanoes, or creating simple chemical reactions.\n" \
+ "3. Outdoor Movie Night: Set up a backyard movie night with a projector and a large screen or white sheet. Create a cozy seating area with blankets and pillows, and serve popcorn and snacks while the kids enjoy a favorite movie under the stars.\n" \
+ "4. DIY Crafts Party: Arrange a craft party where kids can unleash their creativity. Provide a variety of craft supplies like beads, paints, and fabrics, and let them create their own unique masterpieces to take home as party favors.\n" \
+ "5. Sports Olympics: Host a mini Olympics event with various sports and games. Set up different stations for activities like sack races, relay races, basketball shooting, and obstacle courses. Give out medals or certificates to the participants.\n" \
+ "6. Cooking Party: Have a cooking-themed party where the kids can prepare their own mini pizzas, cupcakes, or cookies. Provide toppings, frosting, and decorating supplies, and let them get hands-on in the kitchen.\n" \
+ "7. Superhero Training Camp: Create a superhero-themed party where the kids can engage in fun training activities. Set up an obstacle course, have them design their own superhero capes or masks, and organize superhero-themed games and challenges.\n" \
+ "8. Outdoor Adventure: Plan an outdoor adventure party at a local park or nature reserve. Arrange activities like hiking, nature scavenger hunts, or a picnic with games. Encourage exploration and appreciation for the outdoors.\n" \
+ "Remember to tailor the activities to the birthday child's interests and preferences. Have a great celebration!"
oneshot_example = [user_example, assistant_example]
super().__init__(oneshot_example, system_inst, role1, role2, sen_spliter, qa_spliter)
class FalconSimplePrompter(BasePrompter):
def __init__(self):
system_inst = None
role1 = "User"
role2 = "Assistant"
sen_spliter = "\n\n"
qa_spliter = "\n\n"
super().__init__(system_inst, role1, role2, sen_spliter, qa_spliter)
class FalconPrompter(BasePrompter):
def __init__(self):
system_inst = "The following is a conversation between a highly knowledgeable and intelligent AI assistant, called Falcon, "\
+ "and a human user, called User. In the following interactions, User and Falcon will converse in natural language, "\
+ "and Falcon will answer User's questions. Falcon was built to be respectful, polite and inclusive. "\
+ "Falcon was built by the Technology Innovation Institute in Abu Dhabi. "\
+ "Falcon will never decline to answer a question, and always attempts to give an answer that User would be satisfied with. "\
+ "It knows a lot, and always tells the truth. The conversation begins."
role1 = "User"
role2 = "Assistant"
sen_spliter = "\n"
qa_spliter = "\n"
super().__init__(system_inst, role1, role2, sen_spliter, qa_spliter)
class MPTPrompter(BasePrompter):
def __init__(self):
system_inst = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
role1 = "### Human"
role2 = "### Assistant"
sen_spliter = "\n"
qa_spliter = "\n"
super().__init__(system_inst, role1, role2, sen_spliter, qa_spliter)
class MPTChatPrompter(BasePrompter):
def __init__(self):
system_inst = "system\n" \
+ "- You are a helpful assistant chatbot trained by MosaicML.\n" \
+ "- You answer questions.\n" \
+ "- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n" \
+ "- You are more than just an information source, you are also able to write poetry, short stories, and make jokes."
role1 = "user"
role2 = "assistant"
sen_spliter = "\n"
qa_spliter = "\n"
decorator = ["<|im_start|>", "<|im_end|>"]
super().__init__(system_inst, role1, role2, sen_spliter, qa_spliter, decorator)
def get_prompter(model, model_path = ""):
if isinstance(model, LlamaAWQForCausalLM) or isinstance(model, LlamaForCausalLM):
if "vicuna" in model_path:
return VicunaPrompter()
else:
return Llama2Prompter()
elif isinstance(model, FalconAWQForCausalLM) or isinstance(model, FalconForCausalLM):
return FalconSimplePrompter()
elif isinstance(model, MptAWQForCausalLM) or "mpt" in str(model.__class__).lower():
if "mpt" and "chat" in model_path:
return MPTChatPrompter()
else:
return MPTPrompter()
else:
raise ValueError(f"model type {model.model_type} is not supported")
def get_stop_token_ids(model, model_path = ""):
if isinstance(model, LlamaAWQForCausalLM) or isinstance(model, LlamaForCausalLM):
return []
elif isinstance(model, FalconAWQForCausalLM) or isinstance(model, FalconForCausalLM):
return [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
elif isinstance(model, MptAWQForCausalLM) or "mpt" in str(model.__class__).lower():
if "mpt" and "chat" in model_path:
return [50278, 0]
else:
return []
else:
model_type = str(model.__class__).lower()
raise ValueError(f"model type {model_type} is not supported")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment