import os
from pathlib import Path

import click
from loguru import logger
from typing import List
from fastapi import FastAPI, HTTPException, Request
import magic_pdf.model as model_config
from magic_pdf.libs.version import __version__
from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
from magic_pdf.rw.AbsReaderWriter import AbsReaderWriter
from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter
from magic_pdf.tools.common import do_parse, parse_pdf_methods
from argparse import ArgumentParser
from pydantic import BaseModel
import uvicorn
import time
import configparser
from magic_pdf.dict2md.ocr_vllm_client import PredictClient,compress_image
# from magic_pdf.dict2md.ocr_client import PredictClient,compress_image
from magic_pdf.parse.pdf_client import ocrPdfClient
from magic_pdf.parse.ofd_parse import *
from magic_pdf.tools.ofd_parser import OFDParser


app = FastAPI()
method = 'auto'

logger.add("parse.log", rotation="10 MB", level="INFO",
           format="{time} {level} {message}", encoding='utf-8', enqueue=True)
config_path = None

ocr_status = None
custom_model = None

class ocrRequest(BaseModel):
    path: str
    output_dir: str
    config_path: str

class ocrResponse(BaseModel):
    status_code: int
    output_path: str


def parse_args():
    parser = ArgumentParser()
    parser.add_argument(
        '--dcu_id',
        default='0',
        help='设置DCU')
    parser.add_argument(
        '--method',
        type=parse_pdf_methods,
        help = """the method for parsing pdf.
        ocr: using ocr technique to extract information from pdf.
        txt: suitable for the text-based pdf only and outperform ocr.
        auto: automatically choose the best method for parsing pdf from ocr and txt.
        without method specified, auto will be used by default.""",
        default = 'auto',
        )
    parser.add_argument(
        '--debug',
        type=bool,
        help='Enables detailed debugging information during the execution of the CLI commands.',
        default=False,
    )
    parser.add_argument(
        '--config_path',
        default='/home/practice/magic_pdf-main/magic_pdf/config.ini')

    args = parser.parse_args()
    return args





def ocr_pdf_serve(args: str):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.dcu_id
    config = configparser.ConfigParser()
    config.read(args.config_path)
    # host = config.get('server', 'pdf_host')
    # port = int(config.get('server', 'pdf_port'))
    pdf_server = config.get('server', 'pdf_server').split('://')[1]
    host, port = pdf_server.split(':')[0], int(pdf_server.split(':')[1])
    global config_path
    config_path = args.config_path
    ocr_server = config.get('server', 'ocr_server')
    ocr_client = PredictClient(ocr_server)
    global ocr_status
    ocr_status = ocr_client.check_health()

    ocr = True
    show_log = False
    model_manager = ModelSingleton()
    global custom_model
    custom_model = model_manager.get_model(ocr, show_log)


    uvicorn.run(app, host=host, port=port)

@app.get("/health")
async def health_check():
    return {"status": "healthy"}

@app.post("/pdf_ocr")
# def cli(path, output_dir, method, debug_able, start_page_id, end_page_id):
async def pdf_ocr(request: ocrRequest):
    model_config.__use_inside_model__ = True
    model_config.__model_mode__ = 'full'
    output_dir = request.output_dir
    path = request.path
    #config_path = request.config_path
    os.makedirs(output_dir, exist_ok=True)
    debug_able = False
    start_page_id = 0
    end_page_id = None
    logger.info(f"正在处理文件: {path}")

    def read_fn(path):
        disk_rw = DiskReaderWriter(os.path.dirname(path))
        return disk_rw.read(os.path.basename(path), AbsReaderWriter.MODE_BIN)

    def parse_doc(doc_path: str, config_path: str):
        try:
            file_name = str(Path(doc_path).stem)
            pdf_data = read_fn(doc_path)
            output_path = do_parse(
                ocr_status,
                config_path,
                output_dir,
                file_name,
                pdf_data,
                [],
                method,
                debug_able,
                model=custom_model,
                start_page_id=start_page_id,
                end_page_id=end_page_id,
            )

            # logger.info(f'文件解析成功：{output_path}')
            return output_path

        except Exception as e:
            logger.exception(e)

    # logger.info(f'config_path：{config_path}')

    output_path = parse_doc(path,config_path)
    if output_path:
        logger.info(f'文件解析成功：{output_path}')
        return {"status_code": 200, "output_path": output_path}
    else:
        logger.error(f'文件解析失败，文件为：{path}')
        raise HTTPException(status_code=500)


@app.post("/ofd_ocr")
async def ofd_ocr(request: ocrRequest):
    try:
        # 读取配置文件
        config = configparser.ConfigParser()
        config.read(request.config_path)
        url = config.get('server', 'ocr_server')
        pdf_server = config.get('server', 'pdf_server')

        # 创建客户端
        client = PredictClient(url)
        pdf_ocr = ocrPdfClient(pdf_server)

        # 确保输出目录存在
        os.makedirs(request.output_dir, exist_ok=True)

        # 判断 OFD 是否为发票
        logger.info(f'正在判断ofd文件类型')
        check_res,ofd_imgs,pdfbytes = check_ofd(request.path,client,request.output_dir)

        text = '识别图片的内容，如果是发票就识别图中的文字信息，并以json格式返回'

        # 初始化变量
        ofd_txts = ''
        ofd_txt = ''

        # 判断 OFD 是否为发票
        if check_res:
            # 如果是发票，进行 OCR 识别
            for ofd_img in ofd_imgs:
                compress_image(ofd_img)
                res = client.predict(ofd_img, text)
                res = json_to_txt(res)
                ofd_txts += res + '\n'

            # 如果有识别文本，将其写入文件
            if ofd_txts:
                file_name = Path(request.path).stem
                ofd_txt = os.path.join(request.output_dir, f"{file_name}.txt")
                with open(ofd_txt, 'w', encoding='utf-8') as f:
                    f.write(ofd_txts)
        else:
            # 否则，将 OFD 转换为 PDF 进行 OCR
            ofd_pdf = ofd2pdf(request.path, request.output_dir, pdfbytes)
            ofd_txt = pdf_ocr.ocr_pdf_client(request.config_path,path=ofd_pdf, output_dir=request.output_dir)

        # 返回结果
        if ofd_txt:
            logger.info(f'文件解析成功：{ofd_txt}')
            return {"status_code": 200, "output_path": ofd_txt}
        else:
            logger.error(f'文件解析失败，文件为：{request.path}')
            raise HTTPException(status_code=500, detail="文件解析失败")

    except Exception as e:
        logger.exception(f"处理文件 {request.path} 时发生错误: {e}")
        raise HTTPException(status_code=500, detail="处理文件时发生错误")


# 基于关键词判断 OFD 是否为发票
def check_ofd_by_keywords(filepath):
    try:
        with open(filepath, "rb") as f:
            ofdb64 = str(base64.b64encode(f.read()), "utf-8")
        res = OFDParser(ofdb64)()  # 假设这是处理 OFD 文件的类
        invoice_keywords = ['发票代码', '发票号码', '发票', '开票日期']

        # 遍历所有页面并检查关键词
        for res_info in res:
            one_res = res_info['page_info']
            for _ in range(len(one_res)):
                # print(_['text_list'])
                # print(one_res[_]['text_list'])
                text_content = str(one_res[_].get('text_list', ''))
                if all(keyword in text_content for keyword in invoice_keywords):
                    # logger.info(f'关键字判断，是发票')
                    return True
        return False
    except Exception as e:
        logger.error(f"OFD 文件判断异常: {filepath}，报错：{e}")
        raise HTTPException(status_code=500, detail="判断ofd文件类型时发生错误")


# 基于深度学习模型（如 Qwen）判断 OFD 是否为发票
def check_ofd_by_qwen(filepath, client, text,output_dir):
    try:
        ofd_imgs, pdfbytes = ofd2img(filepath, output_dir)
        for ofd_img in ofd_imgs:
            compress_image(ofd_img)
            res = client.predict(ofd_img, text)
            if 'True' in res:  # 假设返回的结果包含 True 或 False 字符串
                return True,ofd_imgs, pdfbytes
        return False,ofd_imgs, pdfbytes
    except Exception as e:
        logger.error(f"基于 Qwen 判断 OFD 文件时异常: {filepath}，报错：{e}")
        raise HTTPException(status_code=500, detail="判断ofd文件类型时发生错误")



# 综合判断 OFD 是否为发票
def check_ofd(filepath,client,output_dir):
    # 首先通过关键词检查
    if check_ofd_by_keywords(filepath):
        # 如果包含所有关键词，进一步使用 Qwen 判断
        text = '请判断图片是否为发票,如果是发票，请返回"True"，否则返回"False"'
        res,ofd_imgs, pdfbytes =  check_ofd_by_qwen(filepath, client, text,output_dir)
    return res,ofd_imgs, pdfbytes


def main():
    args = parse_args()
    ocr_pdf_serve(args)



if __name__ == '__main__':
    main()






