import time
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

torch.random.manual_seed(0)

def infer_hf(model_path, messages):
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map="cuda",
        torch_dtype="auto",
        trust_remote_code=True,
    )

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
    )

    generation_args = {
        "max_new_tokens": 600,
        "return_full_text": False,
        "temperature": 0.3,
        "do_sample": False,
    }

    start_time = time.time()
    output = pipe(messages, **generation_args)
    print("total infer time", time.time() - start_time)
    print("output", output[0]['generated_text'])


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', default="/home/checkpoints/Phi-3-mini-128k-instruct/")
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()
    messages = [
        {"role": "system", "content": "You are a helpful AI assistant."},
        {"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"},
        {"role": "assistant", "content": "Sure! Here are some ways to eat bananas and dragonfruits together: 1. Banana and dragonfruit smoothie: Blend bananas and dragonfruits together with some milk and honey. 2. Banana and dragonfruit salad: Mix sliced bananas and dragonfruits together with some lemon juice and honey."},
        {"role": "user", "content": "What about solving an 2x + 3 = 7 equation?"},
    ]

    infer_hf(args.model_path, messages)
