Commit 4f3e977c authored by Haotian Tang's avatar Haotian Tang
Browse files

[Major] Add TinyChat and demo.

parent 79048993
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding, apply_rotary_pos_emb
from awq.quantize.qmodule import WQLinear
import awq_inference_engine
class QuantLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
# self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
# self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
positions: torch.Tensor,
):
# Apply rotary embedding to the query and key before passing them
# to the attention op.
# print(positions.shape, query.shape, key.shape, self.cos_sin_cache.shape)
query = query.contiguous()
key = key.contiguous()
awq_inference_engine.rotary_embedding_neox(
positions,
query,
key,
self.dim,
self.cos_sin_cache,
)
return query, key
class QuantLlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(
self,
hidden_size,
num_heads,
qkv_proj,
o_proj,
dev
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
if (self.head_dim * num_heads) != self.hidden_size:
raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {num_heads}).")
self.qkv_proj = qkv_proj
self.o_proj = o_proj
self.rotary_emb = QuantLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=2048, device = dev)
def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False):
"""Input shape: Batch x Time x Channel"""
bsz, q_len, _ = hidden_states.size()
qkv_states = self.qkv_proj(hidden_states)
qkv_states = qkv_states.view(bsz, q_len, 3, self.num_heads, self.head_dim)
# This updates the query and key states in-place, saving VRAM.
query_states, key_states, value_states = torch.split(qkv_states, 1, dim=2)
query_states, key_states = self.rotary_emb(query_states, key_states, position_ids)
del qkv_states
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
is_causal = past_key_value is None
kv_seq_len = q_len
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
value_states = value_states.to("cuda:0")
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
if use_cache:
# Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
key_states = key_states.contiguous()
value_states = value_states.contiguous()
query_states = query_states.contiguous()
past_key_value = (key_states, value_states) if use_cache else None
# with torch.backends.cuda.sdp_kernel(enable_math=False):
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=is_causal)
del query_states, key_states, value_states
attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def make_quant_attn(model, dev):
"""
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
"""
for name, m in model.named_modules():
if not isinstance(m, LlamaAttention):
continue
q_proj = m.q_proj
k_proj = m.k_proj
v_proj = m.v_proj
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
# g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
g_idx = None
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
qkv_layer = WQLinear(q_proj.w_bit, q_proj.group_size, q_proj.in_features, q_proj.out_features + k_proj.out_features + v_proj.out_features, q_proj.bias is not None, q_proj.qweight.device)
qkv_layer.qweight = qweights
qkv_layer.qzeros = qzeros
qkv_layer.scales = scales
qkv_layer.bias = bias
# We're dropping the rotary embedding layer m.rotary_emb here. We don't need it in the triton branch.
attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, dev)
if '.' in name:
parent_name = name.rsplit('.', 1)[0]
child_name = name[len(parent_name) + 1:]
parent = model.get_submodule(parent_name)
else:
parent_name = ''
parent = model
child_name = name
#print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}")
setattr(parent, child_name, attn)
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import custom_bwd, custom_fwd
from transformers.models.llama.modeling_llama import LlamaMLP
import awq_inference_engine
class QuantLlamaMLP(nn.Module):
def __init__(
self,
gate_proj,
down_proj,
up_proj,
):
super().__init__()
self.register_buffer('gate_proj_qweight', gate_proj.qweight)
self.register_buffer('gate_proj_scales', gate_proj.scales)
self.register_buffer('gate_proj_qzeros', gate_proj.qzeros)
self.register_buffer('up_proj_qweight', up_proj.qweight)
self.register_buffer('up_proj_scales', up_proj.scales)
self.register_buffer('up_proj_qzeros', up_proj.qzeros)
self.in_features = gate_proj.in_features
self.intermediate_size = gate_proj.out_features
self.out_features = down_proj.out_features
self.w_bit = gate_proj.w_bit
self.down_proj = down_proj
def forward(self, x):
return self.down_proj(self.our_llama_mlp(x))
def our_llama_mlp(self, x):
out_shape = x.shape[:-1] + (self.intermediate_size, )
x = x.reshape(-1, x.shape[-1])
gate_output = awq_inference_engine.gemm_forward_cuda(
x, self.gate_proj_qweight, self.gate_proj_scales, self.gate_proj_qzeros, 8
)
gate_output = F.silu(gate_output)
up_output = awq_inference_engine.gemm_forward_cuda(
x, self.up_proj_qweight, self.up_proj_scales, self.up_proj_qzeros, 8
)
c = gate_output * up_output
c = c.reshape(out_shape)
return c
def make_fused_mlp(m, parent_name=''):
if not hasattr(make_fused_mlp, "called"):
# print("[Warning] Calling a fake MLP fusion. But still faster than Huggingface Implimentation.")
make_fused_mlp.called = True
"""
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
"""
if isinstance(m, LlamaMLP):
return QuantLlamaMLP(m.gate_proj, m.down_proj, m.up_proj)
for name, child in m.named_children():
child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}")
if isinstance(child, QuantLlamaMLP):
setattr(m, name, child)
return m
import torch
from torch import nn
from transformers.models.llama.modeling_llama import LlamaRMSNorm
import awq_inference_engine
class FTLlamaRMSNorm(nn.Module):
def __init__(self, weight, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = weight
self.variance_epsilon = eps
def forward(self, x):
output = torch.empty_like(x)
awq_inference_engine.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon)
return output
def make_quant_norm(model):
"""
Replace all LlamaRMSNorm modules with FTLlamaRMSNorm modules
"""
for name, m in model.named_modules():
if not isinstance(m, LlamaRMSNorm):
continue
norm = FTLlamaRMSNorm(m.weight, m.variance_epsilon)
if '.' in name:
parent_name = name.rsplit('.', 1)[0]
child_name = name[len(parent_name) + 1:]
parent = model.get_submodule(parent_name)
else:
parent_name = ''
parent = model
child_name = name
#print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}")
setattr(parent, child_name, norm)
MODEL_PATH=/data/llm/checkpoints/llama2-hf
MODEL_NAME=llama-2-7b-chat
# # Perform AWQ search and save search results (we already did it for you):
# mkdir awq_cache
# python -m awq.entry --model_path $MODEL_PATH/$MODEL_NAME \
# --w_bit 4 --q_group_size 128 \
# --run_awq --dump_awq awq_cache/llama-2-7b-chat-w4-g128.pt
# Generate real quantized weights (INT4):
mkdir quant_cache
python -m awq.entry --model_path $MODEL_PATH/$MODEL_NAME \
--w_bit 4 --q_group_size 128 \
--load_awq awq_cache/llama-2-7b-chat-w4-g128.pt \
--q_backend real --dump_quant quant_cache/llama-2-7b-chat-w4-g128-awq.pt
# Run the TinyChat demo:
python demo.py --model_type llama \
--model_path $MODEL_PATH/$MODEL_NAME \
--q_group_size 128 --load_quant quant_cache/llama-2-7b-chat-w4-g128-awq.pt \
--precision W4A16
from .falcon_stream_gen import *
from .stream_gen import *
\ No newline at end of file
import gc
from threading import Thread
from typing import Iterable
import torch
import transformers
from transformers import TextIteratorStreamer, GenerationConfig
transformers.logging.set_verbosity_error()
def is_partial_stop(output: str, stop_str: str):
"""Check whether the output contains a partial stop str."""
for i in range(0, min(len(output), len(stop_str))):
if stop_str.startswith(output[-i:]):
return True
return False
@torch.inference_mode()
def FalconStreamGenerator(
model,
tokenizer,
input : str,
gen_params : dict,
device: str = "cuda:0",
context_len = 2048,
stream_interval = 2,
judge_sent_end = False,
echo: bool = False,
stop_str: str = "\nUser",
stop_token_ids = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
):
prompt = input
len_prompt = len(prompt)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
if gen_params.top_k <= 0:
top_k = gen_params.n_vocab
else:
top_k = gen_params.top_k
max_new_tokens = gen_params.n_predict
max_src_len = context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:] # truncate the input prompt
attention_mask = attention_mask[-max_src_len:] # truncate the input prompt
input_echo_len = len(input_ids)
stop_token_ids.append(tokenizer.eos_token_id)
decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, **decode_config)
generation_config = GenerationConfig(
max_new_tokens = max_new_tokens,
do_sample = gen_params.temp >= 1e-5,
temperature = gen_params.temp,
repetition_penalty = gen_params.repeat_penalty,
no_repeat_ngram_size = 10,
top_p = gen_params.top_p,
top_k = top_k,
eos_token_id = stop_token_ids,
)
generation_kwargs = dict(
inputs=input_ids,
attention_mask=attention_mask,
streamer=streamer,
generation_config=generation_config,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
if echo:
# means keep the prompt
output = prompt
else:
output = ""
for i, new_text in enumerate(streamer):
output += new_text
if i % stream_interval == 0:
if echo:
rfind_start = len_prompt
else:
rfind_start = 0
partially_stopped = False
if stop_str:
if isinstance(stop_str, str):
pos = output.rfind(stop_str, rfind_start)
if pos != -1:
output = output[:pos]
else:
partially_stopped = is_partial_stop(output, stop_str)
elif isinstance(stop_str, Iterable):
for each_stop in stop_str:
pos = output.rfind(each_stop, rfind_start)
if pos != -1:
output = output[:pos]
break
else:
partially_stopped = is_partial_stop(output, each_stop)
if partially_stopped:
break
else:
raise ValueError("Invalid stop field type.")
# prevent yielding partial stop sequence
if not partially_stopped:
yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": None,
}
output = output.strip()
# finish stream event, which contains finish reason
if i == max_new_tokens - 1:
finish_reason = "length"
elif partially_stopped:
finish_reason = None
else:
finish_reason = "stop"
yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": finish_reason,
}
# clean
gc.collect()
torch.cuda.empty_cache()
\ No newline at end of file
import torch
import gc
import time
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
context_tokens = 0
context_time = 0.0
total_tokens = 0
generation_time_list = []
def prepare_logits_processor(
temperature: float, repetition_penalty: float, top_p: float, top_k: int
) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
if temperature >= 1e-5 and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
if repetition_penalty > 1.0:
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
if 1e-8 <= top_p < 1.0:
processor_list.append(TopPLogitsWarper(top_p))
if top_k > 0:
processor_list.append(TopKLogitsWarper(top_k))
return processor_list
@torch.inference_mode()
def StreamGenerator(model,
tokenizer,
input : str,
gen_params : dict,
device: str = "cuda:0",
stream_interval: int = 2,
echo: bool = False,
stop_token_ids = []
):
input_ids = tokenizer(input).input_ids
input_echo_len = len(input_ids)
# print(input_ids)
output_ids = list(input_ids)
len_input = len(input)
if gen_params.top_k <= 0:
top_k = gen_params.n_vocab
else:
top_k = gen_params.top_k
logits_processor = prepare_logits_processor(
gen_params.temp, gen_params.repeat_penalty, gen_params.top_p, top_k
)
past_key_values = out = None
stop_token_ids.append(tokenizer.eos_token_id)
max_new_tokens = gen_params.n_predict
for i in range(max_new_tokens):
torch.cuda.synchronize()
t_st = time.time()
if i == 0: # Context Stage
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else:
out = model(
input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
past_key_values=past_key_values,
)
logits = out.logits
past_key_values = out.past_key_values
# Processing the logits
if logits_processor:
if gen_params.repeat_penalty > 1.0:
tmp_output_ids = torch.as_tensor([output_ids], device=logits.device)
else:
tmp_output_ids = None
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
else:
last_token_logits = logits[0, -1, :]
if gen_params.temp < 1e-5 or gen_params.top_p < 1e-8: # greedy
token = int(torch.argmax(last_token_logits))
else:
probs = torch.softmax(last_token_logits, dim=-1)
token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token)
torch.cuda.synchronize()
t_ed = time.time()
global context_time
global context_tokens
global total_tokens
global generation_time_list
if i == 0:
context_time = t_ed - t_st
context_tokens = logits.shape[1]
generation_time_list = []
else:
generation_time_list.append(t_ed-t_st)
if token in stop_token_ids:
stopped = True
else:
stopped = False
if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
if echo:
tmp_output_ids = output_ids
rfind_start = len_input
else:
tmp_output_ids = output_ids[input_echo_len:]
rfind_start = 0
output = tokenizer.decode(
tmp_output_ids,
skip_special_tokens=True,
spaces_between_special_tokens=False,
)
partially_stopped = False
# prevent yielding partial stop sequence
if not partially_stopped:
yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": None,
"timing": None,
}
if stopped:
break
# finish stream event, which contains finish reason
if i == max_new_tokens - 1:
finish_reason = "length"
elif stopped:
finish_reason = "stop"
else:
finish_reason = None
total_tokens = (context_tokens + len(generation_time_list))
yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": finish_reason,
"timing":{
"context_tokens": context_tokens,
"context_time": context_time,
"total_tokens": total_tokens,
"generation_time_list": generation_time_list,
}
}
del past_key_values, out
gc.collect()
torch.cuda.empty_cache()
# return context_tokens, context_time, total_tokens, generation_time_list
\ No newline at end of file
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from awq.quantize.quantizer import real_quantize_model_weight
from awq.quantize.qmodule import WQLinear
from tqdm import tqdm
def load_awq_model(model, checkpoint, w_bit, group_size, device):
q_config = {"zero_point": True, "q_group_size": group_size}
real_quantize_model_weight(model, w_bit, q_config, init_only = True)
pbar = tqdm(range(1))
pbar.set_description('Loading checkpoint')
for i in pbar:
if hasattr(model.config, "tie_encoder_decoder"):
model.config.tie_encoder_decoder = False
if hasattr(model.config, "tie_word_embeddings"):
model.config.tie_word_embeddings = False
model = load_checkpoint_and_dispatch(
model, checkpoint,
no_split_module_classes=[
"OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"]
).to(device)
return model
def make_quant_linear(module, names, w_bit, groupsize, device, name=''):
if isinstance(module, WQLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + '.' + attr if name != '' else attr
if name1 in names:
delattr(module, attr)
setattr(module, attr, WQLinear(w_bit, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None, device))
for name1, child in module.named_children():
make_quant_linear(child, names, w_bit, groupsize, device, name + '.' + name1 if name != '' else name1)
def find_layers(module, layers=[nn.Linear], name=''):
if type(module) in layers:
return {name: module}
res = {}
for name1, child in module.named_children():
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
return res
def load_awq_llama_fast(model, checkpoint, w_bit, group_size, device):
layers = find_layers(model)
for name in ['lm_head']:
if name in layers:
del layers[name]
make_quant_linear(model, layers, w_bit, group_size, device)
del layers
pbar = tqdm(range(1))
pbar.set_description('Loading checkpoint')
for i in pbar:
if checkpoint.endswith('.safetensors'):
from safetensors.torch import load_file as safe_load
model.load_state_dict(safe_load(checkpoint))
else:
model.load_state_dict(torch.load(checkpoint))
return model.to(device)
\ No newline at end of file
from typing import List
class BasePrompter:
def __init__(self, system_inst, role1, role2, sen_spliter = "\n", qa_spliter = "\n", decorator: List[str] = None):
self.system_inst = system_inst # System Instruction
self.role1 = role1 # The name of USER
self.role2 = role2 # The name of AI-Assistant
self.sen_spliter = sen_spliter # How to split system/user/assistant outputs
self.qa_spliter = qa_spliter # How to split Q&A rounds
self.decorator = decorator
if self.decorator == None:
self.starter = ""
self.stopper = ""
else:
self.starter = self.decorator[0]
self.stopper = self.decorator[1]
if self.system_inst == None:
self.template = self.starter + self.role1 + ": {prompt}" + self.stopper + self.sen_spliter \
+ self.starter + self.role2 + ":"
else:
self.template = self.starter + self.system_inst + self.stopper + self.sen_spliter \
+ self.starter + self.role1 + ": {prompt}" + self.stopper + self.sen_spliter \
+ self.starter + self.role2 + ":"
self.model_input = None
def insert_prompt(self, input_prompt):
self.model_input = self.template.format(prompt=input_prompt)
def update_template(self, outputs):
self.template = self.model_input + " " + outputs.strip() + self.stopper + self.qa_spliter \
+ self.starter + self.role1 + ": {prompt}" + self.stopper + self.sen_spliter \
+ self.starter + self.role2 + ":"
self.model_input = None
class OneShotBasePrompter(BasePrompter):
def __init__(self,
oneshot_example: List[str], # User prompt + Assistant responce
system_inst, role1, role2, sen_spliter = "\n", qa_spliter = "\n", decorator: List[str] = None):
super().__init__(system_inst, role1, role2, sen_spliter, qa_spliter)
assert len(oneshot_example) == 2, "One-shot example must be a List of 2 strs."
self.user_example = oneshot_example[0]
self.assistant_example = oneshot_example[1]
self.insert_prompt(self.user_example)
self.update_template(self.assistant_example)
class VicunaPrompter(BasePrompter):
def __init__(self):
system_inst = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."
role1 = "USER"
role2 = "ASSISTANT"
sen_spliter = " "
qa_spliter = "</s>"
super().__init__(system_inst, role1, role2, sen_spliter, qa_spliter)
class Llama2Prompter(OneShotBasePrompter):
def __init__(self):
system_inst = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
role1 = "### Human"
role2 = "### Assistant"
sen_spliter = "\n"
qa_spliter = "</s>"
user_example="Got any creative ideas for a 10 year old's birthday?"
assistant_example = "Of course! Here are some creative ideas for a 10-year-old's birthday party:\n" \
+ "1. Treasure Hunt: Organize a treasure hunt in your backyard or nearby park. Create clues and riddles for the kids to solve, leading them to hidden treasures and surprises.\n" \
+ "2. Science Party: Plan a science-themed party where kids can engage in fun and interactive experiments. You can set up different stations with activities like making slime, erupting volcanoes, or creating simple chemical reactions.\n" \
+ "3. Outdoor Movie Night: Set up a backyard movie night with a projector and a large screen or white sheet. Create a cozy seating area with blankets and pillows, and serve popcorn and snacks while the kids enjoy a favorite movie under the stars.\n" \
+ "4. DIY Crafts Party: Arrange a craft party where kids can unleash their creativity. Provide a variety of craft supplies like beads, paints, and fabrics, and let them create their own unique masterpieces to take home as party favors.\n" \
+ "5. Sports Olympics: Host a mini Olympics event with various sports and games. Set up different stations for activities like sack races, relay races, basketball shooting, and obstacle courses. Give out medals or certificates to the participants.\n" \
+ "6. Cooking Party: Have a cooking-themed party where the kids can prepare their own mini pizzas, cupcakes, or cookies. Provide toppings, frosting, and decorating supplies, and let them get hands-on in the kitchen.\n" \
+ "7. Superhero Training Camp: Create a superhero-themed party where the kids can engage in fun training activities. Set up an obstacle course, have them design their own superhero capes or masks, and organize superhero-themed games and challenges.\n" \
+ "8. Outdoor Adventure: Plan an outdoor adventure party at a local park or nature reserve. Arrange activities like hiking, nature scavenger hunts, or a picnic with games. Encourage exploration and appreciation for the outdoors.\n" \
+ "Remember to tailor the activities to the birthday child's interests and preferences. Have a great celebration!"
oneshot_example = [user_example, assistant_example]
super().__init__(oneshot_example, system_inst, role1, role2, sen_spliter, qa_spliter)
class FalconSimplePrompter(BasePrompter):
def __init__(self):
system_inst = None
role1 = "User"
role2 = "Assistant"
sen_spliter = "\n\n"
qa_spliter = "\n\n"
super().__init__(system_inst, role1, role2, sen_spliter, qa_spliter)
class FalconPrompter(BasePrompter):
def __init__(self):
system_inst = "The following is a conversation between a highly knowledgeable and intelligent AI assistant, called Falcon, "\
+ "and a human user, called User. In the following interactions, User and Falcon will converse in natural language, "\
+ "and Falcon will answer User's questions. Falcon was built to be respectful, polite and inclusive. "\
+ "Falcon was built by the Technology Innovation Institute in Abu Dhabi. "\
+ "Falcon will never decline to answer a question, and always attempts to give an answer that User would be satisfied with. "\
+ "It knows a lot, and always tells the truth. The conversation begins."
role1 = "User"
role2 = "Assistant"
sen_spliter = "\n"
qa_spliter = "\n"
super().__init__(system_inst, role1, role2, sen_spliter, qa_spliter)
class MPTPrompter(BasePrompter):
def __init__(self):
system_inst = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
role1 = "### Human"
role2 = "### Assistant"
sen_spliter = "\n"
qa_spliter = "\n"
super().__init__(system_inst, role1, role2, sen_spliter, qa_spliter)
class MPTChatPrompter(BasePrompter):
def __init__(self):
system_inst = "system\n" \
+ "- You are a helpful assistant chatbot trained by MosaicML.\n" \
+ "- You answer questions.\n" \
+ "- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.\n" \
+ "- You are more than just an information source, you are also able to write poetry, short stories, and make jokes."
role1 = "user"
role2 = "assistant"
sen_spliter = "\n"
qa_spliter = "\n"
decorator = ["<|im_start|>", "<|im_end|>"]
super().__init__(system_inst, role1, role2, sen_spliter, qa_spliter, decorator)
def get_prompter(model_type, model_path = ""):
if model_type.lower() == "llama":
if "vicuna" in model_path:
return VicunaPrompter()
else:
return Llama2Prompter()
elif model_type.lower() == "falcon":
# return FalconPrompter()
return FalconSimplePrompter()
elif model_type.lower() == "mpt":
if any(name in model_path for name in ["mpt-7b-chat", "mpt-30b-chat"]):
return MPTChatPrompter()
else:
return MPTPrompter()
else:
raise ValueError(f"model type {model_type} is not supported")
def get_stop_token_ids(model_type, model_path = ""):
if model_type.lower() == "llama":
return []
elif model_type.lower() == "falcon":
return [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
elif model_type.lower() == "mpt":
if any(name in model_path for name in ["mpt-7b-chat", "mpt-30b-chat"]):
return [50278, 0]
else:
return []
else:
raise ValueError(f"model type {model_type} is not supported")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment