import os from typing import Optional, Union from transformers import AutoModel, AutoTokenizer, LogitsProcessorList MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b') TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH) tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH, trust_remote_code=True) model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval() def batch( model, tokenizer, prompts: Union[str, list[str]], max_length: int = 8192, num_beams: int = 1, do_sample: bool = True, top_p: float = 0.8, temperature: float = 0.8, logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), ): tokenizer.encode_special_tokens = True if isinstance(prompts, str): prompts = [prompts] batched_inputs = tokenizer(prompts, return_tensors="pt", padding="longest") batched_inputs = batched_inputs.to(model.device) eos_token_id = [ tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), tokenizer.get_command("<|assistant|>"), ] gen_kwargs = { "max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, "temperature": temperature, "logits_processor": logits_processor, "eos_token_id": eos_token_id, } batched_outputs = model.generate(**batched_inputs, **gen_kwargs) batched_response = [] for input_ids, output_ids in zip(batched_inputs.input_ids, batched_outputs): decoded_text = tokenizer.decode(output_ids[len(input_ids):]) batched_response.append(decoded_text.strip()) return batched_response def main(batch_queries): gen_kwargs = { "max_length": 2048, "do_sample": True, "top_p": 0.8, "temperature": 0.8, "num_beams": 1, } batch_responses = batch(model, tokenizer, batch_queries, **gen_kwargs) return batch_responses if __name__ == "__main__": batch_queries = [ "<|user|>\n讲个故事\n<|assistant|>", "<|user|>\n讲个爱情故事\n<|assistant|>", "<|user|>\n讲个开心故事\n<|assistant|>", "<|user|>\n讲个睡前故事\n<|assistant|>", "<|user|>\n讲个励志的故事\n<|assistant|>", "<|user|>\n讲个少壮不努力的故事\n<|assistant|>", "<|user|>\n讲个青春校园恋爱故事\n<|assistant|>", "<|user|>\n讲个工作故事\n<|assistant|>", "<|user|>\n讲个旅游的故事\n<|assistant|>", ] batch_responses = main(batch_queries) for response in batch_responses: print("=" * 10) print(response)