import os import argparse from transformers import AutoModelForCausalLM, AutoTokenizer def infer_hf(model_path, input_text, max_new_token=32): ''' transformers 推理 falcon''' tokenizer = AutoTokenizer.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained( model_path, device_map="auto", ) input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") outputs = model.generate(**input_ids, max_new_tokens=max_new_token) print(tokenizer.decode(outputs[0])) def parse_args(): parser = argparse.ArgumentParser(description='falcon inference') parser.add_argument('--input_text', default='Write me a poem about Machine Learning.', help='') parser.add_argument('--model_path', default='/path/of/gemma2') parser.add_argument('--max_new_tokens', default=32, type=int) return parser.parse_args() if __name__ == '__main__': args = parse_args() infer_hf(args.model_path, args.input_text, args.max_new_tokens)