# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree
import configparser
import copy
import os
import re
import gc
import time

import torch
from argparse import ArgumentParser
from threading import Thread
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, TextIteratorStreamer

from fastapi import FastAPI
from pydantic import BaseModel
from typing import Optional
from loguru import logger

app = FastAPI()

DEFAULT_CKPT_PATH = '/home/practice/model/Qwen2-VL-7B-Instruct'
REVISION = 'v1.0.4'
BOX_TAG_PATTERN = r"<box>([\s\S]*?)</box>"
PUNCTUATION = "！？。＂＃＄％＆＇（）＊＋，－／：；＜＝＞＠［＼］＾＿｀｛｜｝～｟｠｢｣､、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."

logger.add("parse.log", rotation="10 MB", level="INFO",
           format="{time} {level} {message}", encoding='utf-8', enqueue=True)

def _get_args():
    parser = ArgumentParser()

    parser.add_argument('-c', '--checkpoint_path', type=str, default=DEFAULT_CKPT_PATH,
                        help='Checkpoint name or path, default to %(default)r')
    parser.add_argument('--cpu_only', action='store_true', help='Run demo with CPU only')
    parser.add_argument('--flash_attn2', action='store_true', default=False,
                        help='Enable flash_attention_2 when loading the model.')
    parser.add_argument('--share', action='store_true', default=False,
                        help='Create a publicly shareable link for the interface.')
    parser.add_argument('--inbrowser', action='store_true', default=False,
                        help='Automatically launch the interface in a new tab on the default browser.')
    parser.add_argument('--dcu_id', type=str, default='0', help='Specify the GPU ID to load the model onto.')
    parser.add_argument(
        '--config_path',
        default='./magic_pdf/config.ini',
        )
    args = parser.parse_args()
    return args


def _load_model_processor(args):
    if args.cpu_only:
        device_map = 'cpu'
    else:
        if args.dcu_id is not None:
            device_map = {'': f'cuda:{args.dcu_id}'}
            os.environ['CUDA_VISIBLE_DEVICES'] = args.dcu_id
            print('使用DCU推理:', f'cuda:{args.dcu_id}')
        else:
            device_map = 'auto'

    if args.flash_attn2:
        model = Qwen2VLForConditionalGeneration.from_pretrained(
            args.checkpoint_path,
            torch_dtype=torch.float16,
            attn_implementation='flash_attention_2',
            device_map=device_map
        )
    else:
        model = Qwen2VLForConditionalGeneration.from_pretrained(
            args.checkpoint_path,
            torch_dtype=torch.float16,
            device_map=device_map
        )

    processor = AutoProcessor.from_pretrained(args.checkpoint_path)
    return model, processor


def _parse_text(text):
    lines = text.split("\n")
    lines = [line for line in lines if line.strip() != ""]  # 去除空行
    count = 0
    parsed_lines = []

    for i, line in enumerate(lines):
        if "```" in line:
            count += 1
            items = line.split("`")
            if count % 2 == 1:
                # 开始代码块
                parsed_lines.append(f'<pre><code class="language-{items[-1]}">')
            else:
                # 结束代码块
                parsed_lines.append(f"</code></pre>")
        else:
            if i > 0 and count % 2 == 1:
                # 转义代码块内的特殊字符
                line = line.replace("`", r"\`")
                line = line.replace("<", "&lt;")
                line = line.replace(">", "&gt;")
                line = line.replace(" ", "&nbsp;")
                line = line.replace("*", "&ast;")
                line = line.replace("_", "&lowbar;")
                line = line.replace("-", "&#45;")
                line = line.replace(".", "&#46;")
                line = line.replace("!", "&#33;")
                line = line.replace("(", "&#40;")
                line = line.replace(")", "&#41;")
                line = line.replace("$", "&#36;")
            # 使用空格连接行
            if parsed_lines:
                parsed_lines[-1] += " " + line
            else:
                parsed_lines.append(line)

    text = "".join(parsed_lines)
    return text


def _remove_image_special(text):
    text = text.replace('<ref>', '').replace('</ref>', '')
    return re.sub(r'<box>.*?(</box>|$)', '', text)


def _is_video_file(filename):
    video_extensions = ['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg']
    return any(filename.lower().endswith(ext) for ext in video_extensions)


def _transform_messages(original_messages):
    transformed_messages = []
    for message in original_messages:
        new_content = []
        for item in message['content']:
            if 'image' in item:
                new_item = {'type': 'image', 'image': item['image']}
            elif 'text' in item:
                new_item = {'type': 'text', 'text': item['text']}
            elif 'video' in item:
                new_item = {'type': 'video', 'video': item['video']}
            else:
                continue
            new_content.append(new_item)

        new_message = {'role': message['role'], 'content': new_content}
        transformed_messages.append(new_message)

    return transformed_messages


def _gc():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


def call_local_model(model, processor, messages):
    messages = _transform_messages(messages)

    text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(messages)
    inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors='pt')
    inputs = inputs.to(model.device)

    tokenizer = processor.tokenizer
    streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)

    gen_kwargs = {'max_new_tokens': 512, 'streamer': streamer, **inputs}

    thread = Thread(target=model.generate, kwargs=gen_kwargs)
    thread.start()

    generated_text = ''
    for new_text in streamer:
        generated_text += new_text
        yield _parse_text(generated_text)


def create_predict_fn(model, processor):
    def predict(_chatbot, task_history):
        chat_query = _chatbot[-1][0]
        query = task_history[-1][0]
        if len(chat_query) == 0:
            _chatbot.pop()
            task_history.pop()
            return _chatbot
        print('User: ' + _parse_text(query))
        history_cp = copy.deepcopy(task_history)
        full_response = ''
        messages = []
        content = []
        for q, a in history_cp:
            if isinstance(q, (tuple, list)):
                if _is_video_file(q[0]):
                    content.append({'video': f'file://{q[0]}'})
                else:
                    content.append({'image': f'file://{q[0]}'})
            else:
                content.append({'text': q})
                messages.append({'role': 'user', 'content': content})
                messages.append({'role': 'assistant', 'content': [{'text': a}]})
                content = []
        messages.pop()

        for response in call_local_model(model, processor, messages):
            _chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response)))

            yield _chatbot
            full_response = _parse_text(response)

        task_history[-1] = (query, full_response)
        print('Qwen-VL-Chat: ' + _parse_text(full_response))
        yield _chatbot

    return predict


# 启用加载模型
args = _get_args()
model, processor = _load_model_processor(args)


class Item(BaseModel):
    image_path: str
    text: str

@app.get("/health")
async def health_check():
    return {"status": "healthy"}


@app.post("/predict")
async def predict(item: Item):
    messages = [
        {
            'role': 'user',
            'content': [
                {'image': item.image_path},
                {'text': item.text}
            ]
        }
    ]
    start = time.time()
    generated_text = ''
    for response in call_local_model(model, processor, messages):
        generated_text = _parse_text(response)

    _gc()
    end = time.time()
    logger.info(f'【{item.image_path}】解析的结果是：{generated_text},耗时为：{end-start}')
    return {"Generated Text": generated_text}


if __name__ == "__main__":
    import uvicorn

    args = _get_args()
    config = configparser.ConfigParser()
    config.read(args.config_path)
    # host = config.get('server', 'ocr_host')
    host, port = config.get('server', 'ocr_server').split('://')[1].split(':')[0], int(
        config.get('server', 'ocr_server').split('://')[1].split(':')[1])
    # port = int(config.get('server', 'ocr_port'))
    uvicorn.run(app, host=host, port=port)



