import configparser
import os
import json
import requests
from loguru import logger
import argparse
import time
from PIL import Image


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--config_path',
        default='/home/practice/magic_pdf-main/magic_pdf/config.ini',
    )
    parser.add_argument(
        '--image_path',
        default='/home/wanglch/projects/Qwen2-VL/20240920-163701.png',
    )
    parser.add_argument(
        '--text',
        default="描述你在图片中看到的内容",
    )
    args = parser.parse_args()
    return args


def parse_text(text):
    lines = text.split("\n")
    lines = [line for line in lines if line.strip() != ""]  # 去除空行
    count = 0
    parsed_lines = []

    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split("`")
            if count % 2 == 1:
                # 开始代码块
                parsed_lines.append(f'<pre><code class="language-{items[-1]}">')
            else:
                # 结束代码块
                parsed_lines.append(f"</code></pre>")
        else:
            if i > 0 and count % 2 == 1:
                # 转义代码块内的特殊字符
                line = line.replace("`", r"\`")
                line = line.replace("<", "&lt;")
                line = line.replace(">", "&gt;")
                line = line.replace(" ", "&nbsp;")
                line = line.replace("*", "&ast;")
                line = line.replace("_", "&lowbar;")
                line = line.replace("-", "&#45;")
                line = line.replace(".", "&#46;")
                line = line.replace("!", "&#33;")
                line = line.replace("(", "&#40;")
                line = line.replace(")", "&#41;")
                line = line.replace("$", "&#36;")
            # 使用空格连接行
            if parsed_lines:
                parsed_lines[-1] += " " + line
            else:
                parsed_lines.append(line)

    text = "".join(parsed_lines)
    return text


def unparse_text(parsed_text):
    in_code_block = False
    lines = parsed_text.split("\n")
    unparsed_lines = []

    for line in lines:
        if "<pre><code" in line:
            in_code_block = True
            # 移除开始标签
            line = line.split(">", 1)[1]
        elif "</code></pre>" in line:
            in_code_block = False
            # 移除结束标签
            line = line.rsplit("<", 1)[0]

        # 反转 HTML 实体
        line = line.replace("&lt;", "<")
        line = line.replace("&gt;", ">")
        line = line.replace("&nbsp;", " ")
        line = line.replace("&ast;", "*")
        line = line.replace("&lowbar;", "_")
        line = line.replace("&#45;", "-")
        line = line.replace("&#46;", ".")
        line = line.replace("&#33;", "!")
        line = line.replace("&#40;", "(")
        line = line.replace("&#41;", ")")
        line = line.replace("&#36;", "$")

        # 如果在代码块内，还原反斜杠转义
        if in_code_block:
            line = line.replace(r"\`", "`")

        unparsed_lines.append(line)

    # 合并所有行
    unparsed_text = "\n".join(unparsed_lines)
    return unparsed_text


def compress_image(image_path, max_size=(1024, 1024)):
    img = Image.open(image_path)
    width, height = img.size
    aspect_ratio = width / height

    if width > max_size[0] or height > max_size[1]:
        if width > height:
            new_width = max_size[0]
            new_height = int(new_width / aspect_ratio)
        else:
            new_height = max_size[1]
            new_width = int(new_height * aspect_ratio)

        img = img.resize((new_width, new_height), Image.LANCZOS)
        img.save(image_path, optimize=True, quality=80)


class PredictClient:
    def __init__(self, api_url):
        self.api_url = api_url

    def check_health(self):
        health_check_url = f'{self.api_url}/health'
        try:
            response = requests.get(health_check_url)
            if response.status_code == 200:
                logger.info("Server is healthy and ready to process requests.")
                return True
            else:
                logger.error(f'Server health check failed with status code:{response.status_code}')
                return False
        except requests.exceptions.RequestException as e:
            logger.error(f'Health check request failed:{e}')
            return False


    def predict(self, image_path: str, text: str):
        payload = {
            "image_path": image_path,
            "text": text
        }
        headers = {'Content-Type': 'application/json'}
        response = requests.post(f"{self.api_url}/predict", json=payload, headers=headers)

        if response.status_code == 200:
            result = response.json()
            return result.get('Generated Text', '')
        else:
            raise Exception(f"Predict API request failed with status code {response.status_code}")


def main():
    args = parse_args()

    config = configparser.ConfigParser()
    config.read(args.config_path)
    ocr_server = config.get('server', 'ocr_server')
    client = PredictClient(ocr_server)
    try:
        start_time = time.time()  # 记录开始时间
        # 压缩图片
        #compress_image(args.image_path)

        generated_text = client.predict(args.image_path, parse_text(args.text))
        end_time = time.time()  # 记录结束时间
        elapsed_time = end_time - start_time  # 计算运行时间

        if generated_text:
            clean_text = unparse_text(generated_text)  # 解析生成的文本
            logger.info(f"Image Path: {args.image_path}")
            logger.info(f"Generated Text: {clean_text}")
            logger.info(f"耗时为: {elapsed_time}秒")  # 打印运行时间
        else:
            logger.warning("Received empty generated text.")
    except requests.exceptions.RequestException as e:
        logger.error(f"Error while making request to predict service: {e}")
    except Exception as e:
        logger.error(f"Unexpected error occurred: {e}")


if __name__ == "__main__":
    main()
