import os import platform import signal from transformers import AutoTokenizer, AutoModel import readline base_model = '../FinGPT/FinGPT_mt_chatglm2-6b-merged' tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True) model = AutoModel.from_pretrained(base_model, trust_remote_code=True, device_map = "auto") # 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量 # from utils import load_model_on_gpus # model = load_model_on_gpus("THUDM/chatglm2-6b", num_gpus=2) model = model.eval() os_name = platform.system() clear_command = 'cls' if os_name == 'Windows' else 'clear' stop_stream = False def build_prompt(history): prompt = "欢迎使用 FinGPT-ChatGLM2-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序" for query, response in history: prompt += f"\n\n用户:{query}" prompt += f"\n\nFinGPT:{response}" return prompt def signal_handler(signal, frame): global stop_stream stop_stream = True def main(): os.system(clear_command) past_key_values, history = None, [] global stop_stream print("欢迎使用由中科曙光智能与计算产业事业部开发的FinGPT-ChatGLM2-6B金融大模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") while True: query = input("\n用户:") if query.strip() == "stop": break if query.strip() == "clear": past_key_values, history = None, [] os.system(clear_command) print("欢迎使用由中科曙光智能与计算产业事业部开发的FinGPT-ChatGLM2-6B金融大模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") continue print("\nFinGPT:", end="") current_length = 0 for response, history, past_key_values in model.stream_chat(tokenizer, query, history=history, past_key_values=past_key_values, return_past_key_values=True, temperature=0.8, ): if stop_stream: stop_stream = False break else: print(response[current_length:], end="", flush=True) current_length = len(response) print("") if __name__ == "__main__": main()