"""Inference for FastChat models.""" import abc import torch from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer, AutoModel from fastchat.conversation import conv_templates, SeparatorStyle from fastchat.serve.compression import compress_module from fastchat.serve.monkey_patch_non_inplace import replace_llama_attn_with_non_inplace_operations from fastchat.serve.serve_chatglm import chatglm_generate_stream def load_model(model_name, device, num_gpus, load_8bit=False, debug=False): if device == "cpu": kwargs = {} elif device == "cuda": kwargs = {"torch_dtype": torch.float16} if num_gpus == "auto": kwargs["device_map"] = "auto" else: num_gpus = int(num_gpus) if num_gpus != 1: kwargs.update({ "device_map": "auto", "max_memory": {i: "13GiB" for i in range(num_gpus)}, }) elif device == "mps": kwargs = {"torch_dtype": torch.float16} # Avoid bugs in mps backend by not using in-place operations. replace_llama_attn_with_non_inplace_operations() else: raise ValueError(f"Invalid device: {device}") if "chatglm" in model_name: tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModel.from_pretrained(model_name, trust_remote_code=True).half().cuda() else: tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) model = AutoModelForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, **kwargs) if load_8bit: compress_module(model, device) if (device == "cuda" and num_gpus == 1) or device == "mps": model.to(device) if debug: print(model) return model, tokenizer @torch.inference_mode() def generate_stream(model, tokenizer, params, device, context_len=2048, stream_interval=2): prompt = params["prompt"] l_prompt = len(prompt) temperature = float(params.get("temperature", 1.0)) max_new_tokens = int(params.get("max_new_tokens", 256)) stop_str = params.get("stop", None) input_ids = tokenizer(prompt).input_ids output_ids = list(input_ids) max_src_len = context_len - max_new_tokens - 8 input_ids = input_ids[-max_src_len:] for i in range(max_new_tokens): if i == 0: out = model( torch.as_tensor([input_ids], device=device), use_cache=True) logits = out.logits past_key_values = out.past_key_values else: attention_mask = torch.ones( 1, past_key_values[0][0].shape[-2] + 1, device=device) out = model(input_ids=torch.as_tensor([[token]], device=device), use_cache=True, attention_mask=attention_mask, past_key_values=past_key_values) logits = out.logits past_key_values = out.past_key_values last_token_logits = logits[0][-1] if device == "mps": # Switch to CPU by avoiding some bugs in mps backend. last_token_logits = last_token_logits.float().to("cpu") if temperature < 1e-4: token = int(torch.argmax(last_token_logits)) else: probs = torch.softmax(last_token_logits / temperature, dim=-1) token = int(torch.multinomial(probs, num_samples=1)) output_ids.append(token) if token == tokenizer.eos_token_id: stopped = True else: stopped = False if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: output = tokenizer.decode(output_ids, skip_special_tokens=True) pos = output.rfind(stop_str, l_prompt) if pos != -1: output = output[:pos] stopped = True yield output if stopped: break del past_key_values class ChatIO(abc.ABC): @abc.abstractmethod def prompt_for_input(self, role: str) -> str: """Prompt for input from a role.""" @abc.abstractmethod def prompt_for_output(self, role: str): """Prompt for output from a role.""" @abc.abstractmethod def stream_output(self, output_stream, skip_echo_len: int): """Stream output.""" def chat_loop(model_name: str, device: str, num_gpus: str, load_8bit: bool, conv_template: str, temperature: float, max_new_tokens: int, chatio: ChatIO, debug: bool): # Model model, tokenizer = load_model(model_name, device, num_gpus, load_8bit, debug) is_chatglm = "chatglm" in str(type(model)).lower() # Chat conv = conv_templates[conv_template].copy() while True: try: inp = chatio.prompt_for_input(conv.roles[0]) except EOFError: inp = "" if not inp: print("exit...") break conv.append_message(conv.roles[0], inp) conv.append_message(conv.roles[1], None) if is_chatglm: prompt = conv.messages[conv.offset:] generate_stream_func = chatglm_generate_stream skip_echo_len = len(conv.messages[-2][1]) + 1 else: generate_stream_func = generate_stream prompt = conv.get_prompt() skip_echo_len = len(prompt) + 1 params = { "model": model_name, "prompt": prompt, "temperature": temperature, "max_new_tokens": max_new_tokens, "stop": conv.sep if conv.sep_style == SeparatorStyle.SINGLE else conv.sep2, } chatio.prompt_for_output(conv.roles[1]) output_stream = generate_stream_func(model, tokenizer, params, device) outputs = chatio.stream_output(output_stream, skip_echo_len) conv.messages[-1][-1] = " ".join(outputs) if debug: print("\n", {"prompt": prompt, "outputs": outputs}, "\n")