import os import time import torch import argparse from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel from vllm import LLM, SamplingParams def infer_hf_chatglm(model_path, prompt): '''transformers 推理 chatglm2''' tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map="auto").half().cuda() model = model.eval() start_time = time.time() generated_text, _ = model.chat(tokenizer, prompt, history=[]) print("chat time ", time.time()- start_time) print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") return generated_text def infer_hf_llama3(model_path, prompt): '''transformers 推理 llama3''' input_query = {"role": "user", "content": prompt} tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained( model_path, torch_dtype="auto", device_map="auto") input_ids = tokenizer.apply_chat_template( [input_query,], add_generation_prompt=True, return_tensors="pt").to(model.device) outputs = model.generate( input_ids, max_new_tokens=512, do_sample=True, temperature=1, top_p=0.95, ) response = outputs[0][input_ids.shape[-1]:] generated_text = tokenizer.decode(response, skip_special_tokens=True) print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") return generated_text def infer_vllm_llama3(model_path, message, tp_size=1, max_model_len=1024): '''vllm 推理 llama3''' tokenizer = AutoTokenizer.from_pretrained(model_path) messages = [{"role": "user", "content": message}] print(f"Prompt: {messages!r}") sampling_params = SamplingParams(temperature=1, top_p=0.95, max_tokens=1024, stop_token_ids=[tokenizer.eos_token_id]) llm = LLM(model=model_path, max_model_len=max_model_len, trust_remote_code=True, enforce_eager=True, dtype="float16", tensor_parallel_size=tp_size) # generate answer start_time = time.time() prompt_token_ids = [tokenizer.apply_chat_template(messages, add_generation_prompt=True)] outputs = llm.generate(prompt_token_ids=prompt_token_ids, sampling_params=sampling_params) print("total infer time", time.time() - start_time) # results for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Generated text: {generated_text!r}") def infer_vllm_chatglm(model_path, message, tp_size=1): '''vllm 推理 chatglm2''' sampling_params = SamplingParams(temperature=1.0, top_p=0.9, max_tokens=1024) llm = LLM(model=model_path, trust_remote_code=True, enforce_eager=True, dtype="float16", tensor_parallel_size=tp_size) # generate answer print(f"chatglm2 Prompt: {message!r}") outputs = llm.generate(message, sampling_params=sampling_params) # results for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Generated text: {generated_text!r}") def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--model_path', default='') parser.add_argument('--query', default="DCU是什么?", help='提问的问题.') parser.add_argument('--use_hf', action='store_true') args = parser.parse_args() return args if __name__ == "__main__": args = parse_args() is_llama = True if "llama" in args.model_path else False print("Is llama", is_llama) if args.use_hf: # transformers if is_llama: infer_hf_llama3(args.model_path, args.query) else: infer_hf_chatglm(args.model_path, args.query) else: # vllm if is_llama: infer_vllm_llama3(args.model_path, args.query) else: infer_vllm_chatglm(args.model_path, args.query)