Unverified Commit bcbbee8c authored by Xiaomeng Zhao's avatar Xiaomeng Zhao Committed by GitHub
Browse files

Merge pull request #2622 from myhloli/dev

Dev
parents 3cc3f754 ced5a7b4
import socket
from api import create_app
from pathlib import Path
import yaml
def get_local_ip():
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.connect(('8.8.8.8', 80)) # Google DNS 服务器
ip_address = sock.getsockname()[0]
sock.close()
return ip_address
current_file_path = Path(__file__).resolve()
base_dir = current_file_path.parent
config_path = base_dir / "config/config.yaml"
class ConfigMap(dict):
__setattr__ = dict.__setitem__
__getattr__ = dict.__getitem__
with open(str(config_path), mode='r', encoding='utf-8') as fd:
data = yaml.load(fd, Loader=yaml.FullLoader)
_config = data.get(data.get("CurrentConfig", "DevelopmentConfig"))
config = ConfigMap()
for k, v in _config.items():
config[k] = v
config['base_dir'] = base_dir
database = _config.get("database")
if database:
if database.get("type") == "sqlite":
database_uri = f'sqlite:///{base_dir}/{database.get("path")}'
elif database.get("type") == "mysql":
database_uri = f'mysql+pymysql://{database.get("user")}:{database.get("password")}@{database.get("host")}:{database.get("port")}/{database.get("database")}?'
else:
database_uri = ''
config['SQLALCHEMY_DATABASE_URI'] = database_uri
ip_address = get_local_ip()
port = config.get("PORT", 5559)
# 配置 SERVER_NAME
config['SERVER_NAME'] = f'{ip_address}:{port}'
# 配置 APPLICATION_ROOT
config['APPLICATION_ROOT'] = '/'
# 配置 PREFERRED_URL_SCHEME
config['PREFERRED_URL_SCHEME'] = 'http'
app = create_app(config)
if __name__ == '__main__':
app.run(host="0.0.0.0", port=port, debug=config.get("DEBUG", False))
from flask import jsonify
class ResponseCode:
SUCCESS = 200
PARAM_WARING = 400
MESSAGE = "success"
def generate_response(data=None, code=ResponseCode.SUCCESS, msg=ResponseCode.MESSAGE, **kwargs):
"""
自定义响应
:param code:状态码
:param data:返回数据
:param msg:返回消息
:param kwargs:
:return:
"""
msg = msg or 'success' if code == 200 else msg or 'fail'
success = True if code == 200 else False
res = jsonify(dict(code=code, success=success, data=data, msg=msg, **kwargs))
res.status_code = 200
return res
import json
from flask import request
from werkzeug.exceptions import HTTPException
class ApiException(HTTPException):
"""API错误基类"""
code = 500
msg = 'Sorry, we made a mistake Σ(っ °Д °;)っ'
msgZH = ""
error_code = 999
def __init__(self, msg=None, msgZH=None, code=None, error_code=None, headers=None):
if code:
self.code = code
if msg:
self.msg = msg
if msgZH:
self.msgZH = msgZH
if error_code:
self.error_code = error_code
super(ApiException, self).__init__(msg, None)
@staticmethod
def get_error_url():
"""获取出错路由和请求方式"""
method = request.method
full_path = str(request.full_path)
main_path = full_path.split('?')[0]
res = method + ' ' + main_path
return res
def get_body(self, environ=None, scope=None):
"""异常返回信息"""
body = dict(
msg=self.msg,
error_code=self.error_code,
request=self.get_error_url()
)
text = json.dumps(body)
return text
def get_headers(self, environ=None, scope=None):
"""异常返回格式"""
return [("Content-Type", "application/json")]
\ No newline at end of file
import hashlib
import mimetypes
import urllib.parse
def is_pdf(filename, file):
"""
判断文件是否为PDF格式,支持中文名和特殊字符。
:param filename: 文件名
:param file: 文件对象
:return: 如果文件是PDF格式,则返回True,否则返回False
"""
try:
# 对文件名进行URL解码,处理特殊字符
decoded_filename = urllib.parse.unquote(filename)
# 检查MIME类型
mime_type, _ = mimetypes.guess_type(decoded_filename)
print(f"Detected MIME type: {mime_type}")
# 某些情况下mime_type可能为None,需要特殊处理
if mime_type is None:
# 只检查文件内容的PDF标识
file_start = file.read(5)
file.seek(0) # 重置文件指针
return file_start.startswith(b'%PDF-')
if mime_type != 'application/pdf':
return False
# 检查文件内容的PDF标识
file_start = file.read(5)
file.seek(0) # 重置文件指针
if not file_start.startswith(b'%PDF-'):
return False
return True
except Exception as e:
print(f"Error checking PDF format: {str(e)}")
# 发生错误时,仍然尝试通过文件头判断
try:
file_start = file.read(5)
file.seek(0)
return file_start.startswith(b'%PDF-')
except:
return False
def url_is_pdf(file):
"""
判断文件是否为PDF格式。
:param file: 文件对象
:return: 如果文件是PDF格式,则返回True,否则返回False
"""
# 检查文件内容
file_start = file.read(5)
file.seek(0)
if not file_start.startswith(b'%PDF-'):
return False
return True
def calculate_file_hash(file, algorithm='sha256'):
"""
计算给定文件的哈希值。
:param file: 文件对象
:param algorithm: 哈希算法的名字,如:'sha256', 'md5', 'sha1'等
:return: 文件的哈希值
"""
hash_func = getattr(hashlib, algorithm)()
block_size = 65536 # 64KB chunks
# with open(file_path, 'rb') as file:
buffer = file.read(block_size)
while len(buffer) > 0:
hash_func.update(buffer)
buffer = file.read(block_size)
file.seek(0)
return hash_func.hexdigest()
def singleton_func(cls):
instance = {}
def _singleton(*args, **kwargs):
if cls not in instance:
instance[cls] = cls(*args, **kwargs)
return instance[cls]
return _singleton
from api.analysis.models import *
\ No newline at end of file
import os
from loguru import logger
from pathlib import Path
from datetime import datetime
def setup_log(config):
"""
Setup logging
:param config: config file
:return:
"""
log_path = os.path.join(Path(__file__).parent.parent, "log")
if not Path(log_path).exists():
Path(log_path).mkdir(parents=True, exist_ok=True)
log_level = config.get("LOG_LEVEL")
log_name = f'log_{datetime.now().strftime("%Y-%m-%d")}.log'
log_file_path = os.path.join(log_path, log_name)
logger.add(str(log_file_path), rotation='00:00', encoding='utf-8', level=log_level, enqueue=True)
import os
import unicodedata
if not os.getenv("FTLANG_CACHE"):
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
root_dir = os.path.dirname(current_dir)
ftlang_cache_dir = os.path.join(root_dir, 'resources', 'fasttext-langdetect')
os.environ["FTLANG_CACHE"] = str(ftlang_cache_dir)
# print(os.getenv("FTLANG_CACHE"))
from fast_langdetect import detect_language
def detect_lang(text: str) -> str:
if len(text) == 0:
return ""
try:
lang_upper = detect_language(text)
except:
html_no_ctrl_chars = ''.join([l for l in text if unicodedata.category(l)[0] not in ['C', ]])
lang_upper = detect_language(html_no_ctrl_chars)
try:
lang = lang_upper.lower()
except:
lang = ""
return lang
if __name__ == '__main__':
print(os.getenv("FTLANG_CACHE"))
print(detect_lang("This is a test."))
print(detect_lang("<html>This is a test</html>"))
print(detect_lang("这个是中文测试。"))
print(detect_lang("<html>这个是中文测试。</html>"))
import re
def escape_special_markdown_char(pymu_blocks):
"""
转义正文里对markdown语法有特殊意义的字符
"""
special_chars = ["*", "`", "~", "$"]
for blk in pymu_blocks:
for line in blk['lines']:
for span in line['spans']:
for char in special_chars:
span_text = span['text']
span_type = span.get("_type", None)
if span_type in ['inline-equation', 'interline-equation']:
continue
elif span_text:
span['text'] = span['text'].replace(char, "\\" + char)
return pymu_blocks
def ocr_escape_special_markdown_char(content):
"""
转义正文里对markdown语法有特殊意义的字符
"""
special_chars = ["*", "`", "~", "$"]
for char in special_chars:
content = content.replace(char, "\\" + char)
return content
class ContentType:
Image = 'image'
Table = 'table'
Text = 'text'
InlineEquation = 'inline_equation'
InterlineEquation = 'interline_equation'
class BlockType:
Image = 'image'
ImageBody = 'image_body'
ImageCaption = 'image_caption'
ImageFootnote = 'image_footnote'
Table = 'table'
TableBody = 'table_body'
TableCaption = 'table_caption'
TableFootnote = 'table_footnote'
Text = 'text'
Title = 'title'
InterlineEquation = 'interline_equation'
Footnote = 'footnote'
Discarded = 'discarded'
class CategoryId:
Title = 0
Text = 1
Abandon = 2
ImageBody = 3
ImageCaption = 4
TableBody = 5
TableCaption = 6
TableFootnote = 7
InterlineEquation_Layout = 8
InlineEquation = 13
InterlineEquation_YOLO = 14
OcrText = 15
ImageFootnote = 101
import re
import wordninja
from .libs.language import detect_lang
from .libs.markdown_utils import ocr_escape_special_markdown_char
from .libs.ocr_content_type import BlockType, ContentType
def __is_hyphen_at_line_end(line):
"""
Check if a line ends with one or more letters followed by a hyphen.
Args:
line (str): The line of text to check.
Returns:
bool: True if the line ends with one or more letters followed by a hyphen, False otherwise.
"""
# Use regex to check if the line ends with one or more letters followed by a hyphen
return bool(re.search(r'[A-Za-z]+-\s*$', line))
def split_long_words(text):
segments = text.split(' ')
for i in range(len(segments)):
words = re.findall(r'\w+|[^\w]', segments[i], re.UNICODE)
for j in range(len(words)):
if len(words[j]) > 10:
words[j] = ' '.join(wordninja.split(words[j]))
segments[i] = ''.join(words)
return ' '.join(segments)
def join_path(*args):
return ''.join(str(s).rstrip('/') for s in args)
def ocr_mk_mm_markdown_with_para_and_pagination(pdf_info_dict: list,
img_buket_path):
markdown_with_para_and_pagination = []
page_no = 0
for page_info in pdf_info_dict:
paras_of_layout = page_info.get('para_blocks')
if not paras_of_layout:
continue
page_markdown = ocr_mk_markdown_with_para_core_v2(
paras_of_layout, 'mm', img_buket_path)
markdown_with_para_and_pagination.append({
'page_no':
page_no,
'md_content':
'\n\n'.join(page_markdown)
})
page_no += 1
return markdown_with_para_and_pagination
def merge_para_with_text(para_block):
def detect_language(text):
en_pattern = r'[a-zA-Z]+'
en_matches = re.findall(en_pattern, text)
en_length = sum(len(match) for match in en_matches)
if len(text) > 0:
if en_length / len(text) >= 0.5:
return 'en'
else:
return 'unknown'
else:
return 'empty'
para_text = ''
for line in para_block['lines']:
line_text = ''
line_lang = ''
for span in line['spans']:
span_type = span['type']
if span_type == ContentType.Text:
line_text += span['content'].strip()
if line_text != '':
line_lang = detect_lang(line_text)
for span in line['spans']:
span_type = span['type']
content = ''
if span_type == ContentType.Text:
content = span['content']
# language = detect_lang(content)
language = detect_language(content)
if language == 'en': # 只对英文长词进行分词处理,中文分词会丢失文本
content = ocr_escape_special_markdown_char(
split_long_words(content))
else:
content = ocr_escape_special_markdown_char(content)
elif span_type == ContentType.InlineEquation:
content = f" ${span['content']}$ "
elif span_type == ContentType.InterlineEquation:
content = f"\n$$\n{span['content']}\n$$\n"
if content != '':
langs = ['zh', 'ja', 'ko']
if line_lang in langs: # 遇到一些一个字一个span的文档,这种单字语言判断不准,需要用整行文本判断
para_text += content # 中文/日语/韩文语境下,content间不需要空格分隔
elif line_lang == 'en':
# 如果是前一行带有-连字符,那么末尾不应该加空格
if __is_hyphen_at_line_end(content):
para_text += content[:-1]
else:
para_text += content + ' '
else:
para_text += content + ' ' # 西方文本语境下 content间需要空格分隔
return para_text
def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
mode,
img_buket_path=''):
page_markdown = []
for para_block in paras_of_layout:
para_text = ''
para_type = para_block['type']
if para_type == BlockType.Text:
para_text = merge_para_with_text(para_block)
elif para_type == BlockType.Title:
para_text = f'# {merge_para_with_text(para_block)}'
elif para_type == BlockType.InterlineEquation:
para_text = merge_para_with_text(para_block)
elif para_type == BlockType.Image:
if mode == 'nlp':
continue
elif mode == 'mm':
for block in para_block['blocks']: # 1st.拼image_body
if block['type'] == BlockType.ImageBody:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.Image:
para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n"
for block in para_block['blocks']: # 2nd.拼image_caption
if block['type'] == BlockType.ImageCaption:
para_text += merge_para_with_text(block)
for block in para_block['blocks']: # 2nd.拼image_caption
if block['type'] == BlockType.ImageFootnote:
para_text += merge_para_with_text(block)
elif para_type == BlockType.Table:
if mode == 'nlp':
continue
elif mode == 'mm':
for block in para_block['blocks']: # 1st.拼table_caption
if block['type'] == BlockType.TableCaption:
para_text += merge_para_with_text(block)
for block in para_block['blocks']: # 2nd.拼table_body
if block['type'] == BlockType.TableBody:
for line in block['lines']:
for span in line['spans']:
if span['type'] == ContentType.Table:
# if processed by table model
if span.get('latex', ''):
para_text += f"\n\n$\n {span['latex']}\n$\n\n"
elif span.get('html', ''):
para_text += f"\n\n{span['html']}\n\n"
else:
para_text += f"\n![]({join_path(img_buket_path, span['image_path'])}) \n"
for block in para_block['blocks']: # 3rd.拼table_footnote
if block['type'] == BlockType.TableFootnote:
para_text += merge_para_with_text(block)
if para_text.strip() == '':
continue
else:
page_markdown.append(para_text.strip() + ' ')
return page_markdown
def before_request():
return None
def after_request(response):
response.headers.add('Access-Control-Allow-Origin', '*')
response.headers.add('Access-Control-Allow-Headers', 'Content-Type,Authorization')
return response
# 基本配置
BaseConfig: &base
DEBUG: false
PORT: 5559
LOG_LEVEL: "DEBUG"
SQLALCHEMY_TRACK_MODIFICATIONS: true
SQLALCHEMY_DATABASE_URI: ""
PROPAGATE_EXCEPTIONS: true
SECRET_KEY: "#$%^&**$##*(*^%%$**((&"
JWT_SECRET_KEY: "#$%^&**$##*(*^%%$**((&"
JWT_ACCESS_TOKEN_EXPIRES: 3600
PDF_UPLOAD_FOLDER: "upload_pdf"
PDF_ANALYSIS_FOLDER: "analysis_pdf"
# 前端项目打包的路径
REACT_APP_DIST: "../../web/dist/"
# 文件访问路径
FILE_API: "/api/v2/analysis/pdf_img?as_attachment=False"
# 开发配置
DevelopmentConfig:
<<: *base
database:
type: sqlite
path: config/mineru_web.db
# 生产配置
ProductionConfig:
<<: *base
# 测试配置
TestingConfig:
<<: *base
# 当前使用配置
CurrentConfig: "DevelopmentConfig"
import json
import shutil
import os
import requests
from modelscope import snapshot_download
def download_json(url):
# 下载JSON文件
response = requests.get(url)
response.raise_for_status() # 检查请求是否成功
return response.json()
def download_and_modify_json(url, local_filename, modifications):
if os.path.exists(local_filename):
data = json.load(open(local_filename))
config_version = data.get('config_version', '0.0.0')
if config_version < '1.2.0':
data = download_json(url)
else:
data = download_json(url)
# 修改内容
for key, value in modifications.items():
data[key] = value
# 保存修改后的内容
with open(local_filename, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
if __name__ == '__main__':
mineru_patterns = [
# "models/Layout/LayoutLMv3/*",
"models/Layout/YOLO/*",
"models/MFD/YOLO/*",
"models/MFR/unimernet_hf_small_2503/*",
"models/OCR/paddleocr_torch/*",
# "models/TabRec/TableMaster/*",
# "models/TabRec/StructEqTable/*",
]
model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader')
model_dir = model_dir + '/models'
print(f'model_dir is: {model_dir}')
print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
# paddleocr_model_dir = model_dir + '/OCR/paddleocr'
# user_paddleocr_dir = os.path.expanduser('~/.paddleocr')
# if os.path.exists(user_paddleocr_dir):
# shutil.rmtree(user_paddleocr_dir)
# shutil.copytree(paddleocr_model_dir, user_paddleocr_dir)
json_url = 'https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/magic-pdf.template.json'
config_file_name = 'magic-pdf.json'
home_dir = os.path.expanduser('~')
config_file = os.path.join(home_dir, config_file_name)
json_mods = {
'models-dir': model_dir,
'layoutreader-model-dir': layoutreader_model_dir,
}
download_and_modify_json(json_url, config_file, json_mods)
print(f'The configuration file has been configured successfully, the path is: {config_file}')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment