import os import json import requests from loguru import logger import argparse import time from PIL import Image import configparser def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( '--config_path', default='/home/practice/magic_pdf-main/magic_pdf/config.ini', ) parser.add_argument( '--image_path', default='/path/to/your/image.png', help='Path to the image file' ) parser.add_argument( '--text', default="描述你在图片中看到的内容", help='Text input for the model' ) args = parser.parse_args() return args def parse_text(text): lines = text.split("\n") lines = [line for line in lines if line.strip() != ""] # 去除空行 count = 0 parsed_lines = [] for i, line in enumerate(lines): if "```" in line: count += 1 items = line.split("`") if count % 2 == 1: # 开始代码块 parsed_lines.append(f'
')
            else:
                # 结束代码块
                parsed_lines.append(f"
") else: if i > 0 and count % 2 == 1: # 转义代码块内的特殊字符 line = line.replace("`", r"\`") line = line.replace("<", "<") line = line.replace(">", ">") line = line.replace(" ", " ") line = line.replace("*", "*") line = line.replace("_", "_") line = line.replace("-", "-") line = line.replace(".", ".") line = line.replace("!", "!") line = line.replace("(", "(") line = line.replace(")", ")") line = line.replace("$", "$") # 使用空格连接行 if parsed_lines: parsed_lines[-1] += " " + line else: parsed_lines.append(line) text = "".join(parsed_lines) return text def unparse_text(parsed_text): in_code_block = False lines = parsed_text.split("\n") unparsed_lines = [] for line in lines: if "
", 1)[1]
        elif "
" in line: in_code_block = False # 移除结束标签 line = line.rsplit("<", 1)[0] # 反转 HTML 实体 line = line.replace("<", "<") line = line.replace(">", ">") line = line.replace(" ", " ") line = line.replace("*", "*") line = line.replace("_", "_") line = line.replace("-", "-") line = line.replace(".", ".") line = line.replace("!", "!") line = line.replace("(", "(") line = line.replace(")", ")") line = line.replace("$", "$") # 如果在代码块内,还原反斜杠转义 if in_code_block: line = line.replace(r"\`", "`") unparsed_lines.append(line) # 合并所有行 unparsed_text = "\n".join(unparsed_lines) return unparsed_text def compress_image(image_path, max_size=(2048, 2048)): img = Image.open(image_path) width, height = img.size aspect_ratio = width / height if width > max_size[0] or height > max_size[1]: if width > height: new_width = max_size[0] new_height = int(new_width / aspect_ratio) else: new_height = max_size[1] new_width = int(new_height * aspect_ratio) img = img.resize((new_width, new_height), Image.LANCZOS) img.save(image_path, optimize=True, quality=80) class PredictClient: def __init__(self, api_url): self.api_url = api_url def check_health(self): health_check_url = f'{self.api_url}/health' try: response = requests.get(health_check_url) if response.status_code == 200: logger.info("Server is healthy and ready to process requests.") return True else: logger.error(f'Server health check failed with status code:{response.status_code}') return False except requests.exceptions.RequestException as e: logger.error(f'Health check request failed:{e}') return False def predict(self, image_path: str, text: str): payload = { "image_path": image_path, "text": text } headers = {'Content-Type': 'application/json'} response = requests.post(f"{self.api_url}/predict", json=payload, headers=headers) if response.status_code == 200: result = response.json() return result.get('Generated Text', '') else: raise Exception(f"Predict API request failed with status code {response.status_code}") def main(): args = parse_args() config = configparser.ConfigParser() config.read(args.config_path) ocr_server = config.get('server', 'ocr_server') client = PredictClient(ocr_server) try: start_time = time.time() # 记录开始时间 # 压缩图片 compress_image(args.image_path) generated_text = client.predict(args.image_path, parse_text(args.text)) end_time = time.time() # 记录结束时间 elapsed_time = end_time - start_time # 计算运行时间 if generated_text: clean_text = unparse_text(generated_text) # 解析生成的文本 logger.info(f"Image Path: {args.image_path}") logger.info(f"Generated Text: {clean_text}") logger.info(f"耗时为: {elapsed_time}秒") # 打印运行时间 else: logger.warning("Received empty generated text.") except requests.exceptions.RequestException as e: logger.error(f"Error while making request to predict service: {e}") except Exception as e: logger.error(f"Unexpected error occurred: {e}") if __name__ == "__main__": main()