import torch
import os
import argparse

from transformers import AutoModel, AutoTokenizer

os.environ["HIP_VISIBLE_DEVICES"] = '0'


parse = argparse.ArgumentParser()
parse.add_argument('--model_name_or_path', type=str, default='deepseek-ai/DeepSeek-OCR')
parse.add_argument('--image_file', type=str, default='./doc/test.png')
parse.add_argument('--output_path', type=str, default='./output/')
args = parse.parse_args()

if __name__ == '__main__':
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
    model = AutoModel.from_pretrained(args.model_name_or_path, _attn_implementation='flash_attention_2', trust_remote_code=True, use_safetensors=True)
    model = model.eval().cuda().to(torch.bfloat16)

    # prompt = "<image>\nFree OCR. "
    prompt = "<image>\n<|grounding|>Convert the document to markdown. "


    # infer(self, tokenizer, prompt='', image_file='', output_path = ' ', base_size = 1024, image_size = 640, crop_mode = True, test_compress = False, save_results = False):

    # Tiny: base_size = 512, image_size = 512, crop_mode = False
    # Small: base_size = 640, image_size = 640, crop_mode = False
    # Base: base_size = 1024, image_size = 1024, crop_mode = False
    # Large: base_size = 1280, image_size = 1280, crop_mode = False
    # Gundam: base_size = 1024, image_size = 640, crop_mode = True

    res = model.infer(tokenizer, prompt=prompt, image_file=args.image_file, output_path=args.output_path, base_size=1024, image_size=640, crop_mode=True, save_results=True, test_compress=True)
    print("process end, result saved to ", args.output_path)