import argparse

from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate

from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest


parse = argparse.ArgumentParser()
parse.add_argument("--user_prompt", type=str, default="Explain Machine Learning to me in a nutshell.")
parse.add_argument("--model_name_or_path", type=str, default="mistralai/Mistral-7B-Instruct-v0.3")

args = parse.parse_args()


tokenizer = MistralTokenizer.from_file(f"{args.model_name_or_path}/tokenizer.model.v3")
model = Transformer.from_folder(args.model_name_or_path)

completion_request = ChatCompletionRequest(messages=[UserMessage(content=args.user_prompt)])

tokens = tokenizer.encode_chat_completion(completion_request).tokens

out_tokens, _ = generate([tokens], model, max_tokens=64, temperature=0.0, eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
result = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])

print(result)
