Commit 6c151151 authored by myhloli's avatar myhloli
Browse files

fix(mfr): improve LaTeX formula processing and repair

- Add functions to fix LaTeX left and right commands
- Implement brace matching and repair in LaTeX formulas
- Remove unnecessary whitespace and repair LaTeX code
- Replace specific LaTeX commands with appropriate alternatives
- Add logging for debugging purposes
parent bfb80cb2
...@@ -5,6 +5,7 @@ from typing import Optional ...@@ -5,6 +5,7 @@ from typing import Optional
import torch import torch
from ftfy import fix_text from ftfy import fix_text
from loguru import logger
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, PreTrainedModel from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig, PreTrainedModel
from transformers import VisionEncoderDecoderConfig, VisionEncoderDecoderModel from transformers import VisionEncoderDecoderConfig, VisionEncoderDecoderModel
...@@ -57,9 +58,296 @@ class TokenizerWrapper: ...@@ -57,9 +58,296 @@ class TokenizerWrapper:
return toks return toks
def fix_latex_left_right(s):
"""
修复LaTeX中的\left和\right命令
1. 确保它们后面跟有效分隔符
2. 平衡\left和\right的数量
"""
# 白名单分隔符
valid_delims = r'[()\[\]{}/|]|\\{|\\}|\\lceil|\\rceil|\\lfloor|\\rfloor|/|\\backslash|\\uparrow|\\downarrow|\\Uparrow|\\Downarrow|\\||\\.'
# 为\left后缺失有效分隔符的情况添加点
def fix_left_delim(match):
cmd = match.group(1) # \left
rest = match.group(2) if len(match.groups()) > 1 else ""
if not rest or not re.match(f"^({valid_delims})", rest):
return cmd + "."
return match.group(0)
# 为\right后缺失有效分隔符的情况添加点
def fix_right_delim(match):
cmd = match.group(1) # \right
rest = match.group(2) if len(match.groups()) > 1 else ""
if not rest or not re.match(f"^({valid_delims})", rest):
return cmd + "."
return match.group(0)
# 使用更精确的模式匹配\left和\right命令
# 确保它们是独立的命令,不是其他命令的一部分
s = re.sub(r'(\\left)(\S*)', fix_left_delim, s)
s = re.sub(r'(\\right)(\S*)', fix_right_delim, s)
# 更精确地计算\left和\right的数量
left_count = len(re.findall(r'\\left(?![a-zA-Z])', s)) # 不匹配\lefteqn等
right_count = len(re.findall(r'\\right(?![a-zA-Z])', s)) # 不匹配\rightarrow等
if left_count != right_count:
logger.debug(f"latex:{s}")
logger.warning(f"left_count: {left_count}, right_count: {right_count}")
if left_count > right_count:
s += ''.join(['\\right.' for _ in range(left_count - right_count)])
elif right_count > left_count:
# 不再在开头插入\left.,而是在第一个\right前插入
if '\\right' in s:
# 找出所有\right的位置
right_positions = [m.start() for m in re.finditer(r'\\right(?![a-zA-Z])', s)]
# 从前到后为每个缺失的\left处理
new_s = s
offset = 0
for i in range(min(right_count - left_count, len(right_positions))):
pos = right_positions[i] + offset
new_s = new_s[:pos] + '\\left.' + new_s[pos:]
offset += 6 # \left.的长度
s = new_s
return fix_left_right_pairs(s)
def fix_left_right_pairs(latex_formula):
"""
检测并修复LaTeX公式中\left和\right不在同一组的情况
Args:
latex_formula (str): 输入的LaTeX公式
Returns:
str: 修复后的LaTeX公式
"""
# 用于跟踪花括号嵌套层级
brace_stack = []
# 用于存储\left信息: (位置, 深度, 分隔符)
left_stack = []
# 存储需要调整的\right信息: (开始位置, 结束位置, 目标位置)
adjustments = []
i = 0
while i < len(latex_formula):
# 检查是否是转义字符
if i > 0 and latex_formula[i - 1] == '\\':
backslash_count = 0
j = i - 1
while j >= 0 and latex_formula[j] == '\\':
backslash_count += 1
j -= 1
if backslash_count % 2 == 1:
i += 1
continue
# 检测\left命令
if i + 5 < len(latex_formula) and latex_formula[i:i + 5] == "\\left" and i + 5 < len(latex_formula):
delimiter = latex_formula[i + 5]
left_stack.append((i, len(brace_stack), delimiter))
i += 6 # 跳过\left和分隔符
continue
# 检测\right命令
elif i + 6 < len(latex_formula) and latex_formula[i:i + 6] == "\\right" and i + 6 < len(latex_formula):
delimiter = latex_formula[i + 6]
if left_stack:
left_pos, left_depth, left_delim = left_stack.pop()
# 如果\left和\right不在同一花括号深度
if left_depth != len(brace_stack):
# 找到\left所在花括号组的结束位置
target_pos = find_group_end(latex_formula, left_pos, left_depth)
if target_pos != -1:
# 记录需要移动的\right
adjustments.append((i, i + 7, target_pos))
i += 7 # 跳过\right和分隔符
continue
# 处理花括号
if latex_formula[i] == '{':
brace_stack.append(i)
elif latex_formula[i] == '}':
if brace_stack:
brace_stack.pop()
i += 1
# 应用调整,从后向前处理以避免索引变化
if not adjustments:
return latex_formula
result = list(latex_formula)
adjustments.sort(reverse=True, key=lambda x: x[0])
for start, end, target in adjustments:
# 提取\right部分
right_part = result[start:end]
# 从原位置删除
del result[start:end]
# 在目标位置插入
result.insert(target, ''.join(right_part))
return ''.join(result)
def find_group_end(text, pos, depth):
"""查找特定深度的花括号组的结束位置"""
current_depth = depth
i = pos
while i < len(text):
if text[i] == '{' and (i == 0 or not is_escaped(text, i)):
current_depth += 1
elif text[i] == '}' and (i == 0 or not is_escaped(text, i)):
current_depth -= 1
if current_depth < depth:
return i
i += 1
return -1 # 未找到对应结束位置
def is_escaped(text, pos):
"""检查字符是否被转义"""
backslash_count = 0
j = pos - 1
while j >= 0 and text[j] == '\\':
backslash_count += 1
j -= 1
return backslash_count % 2 == 1
def fix_unbalanced_braces(latex_formula):
"""
检测LaTeX公式中的花括号是否闭合,并删除无法配对的花括号
Args:
latex_formula (str): 输入的LaTeX公式
Returns:
str: 删除无法配对的花括号后的LaTeX公式
"""
stack = [] # 存储左括号的索引
unmatched = set() # 存储不匹配括号的索引
i = 0
while i < len(latex_formula):
# 检查是否是转义的花括号
if latex_formula[i] in ['{', '}']:
# 计算前面连续的反斜杠数量
backslash_count = 0
j = i - 1
while j >= 0 and latex_formula[j] == '\\':
backslash_count += 1
j -= 1
# 如果前面有奇数个反斜杠,则该花括号是转义的,不参与匹配
if backslash_count % 2 == 1:
i += 1
continue
# 否则,该花括号参与匹配
if latex_formula[i] == '{':
stack.append(i)
else: # latex_formula[i] == '}'
if stack: # 有对应的左括号
stack.pop()
else: # 没有对应的左括号
unmatched.add(i)
i += 1
# 所有未匹配的左括号
unmatched.update(stack)
# 构建新字符串,删除不匹配的括号
return ''.join(char for i, char in enumerate(latex_formula) if i not in unmatched)
def process_latex(input_string):
"""
处理LaTeX公式中的反斜杠:
1. 如果\后跟特殊字符(#$%&~_^\\{})或空格,保持不变
2. 如果\后跟两个小写字母,保持不变
3. 其他情况,在\后添加空格
Args:
input_string (str): 输入的LaTeX公式
Returns:
str: 处理后的LaTeX公式
"""
def replace_func(match):
# 获取\后面的字符
next_char = match.group(1)
# 如果是特殊字符或空格,保持不变
if next_char in "#$%&~_^|\\{} \t\n\r\v\f":
return match.group(0)
# 如果是字母,检查下一个字符
if 'a' <= next_char <= 'z' or 'A' <= next_char <= 'Z':
pos = match.start() + 2 # \x后的位置
if pos < len(input_string) and ('a' <= input_string[pos] <= 'z' or 'A' <= input_string[pos] <= 'Z'):
# 下一个字符也是字母,保持不变
return match.group(0)
# 其他情况,在\后添加空格
return '\\' + ' ' + next_char
# 匹配\后面跟一个字符的情况
pattern = r'\\(.)'
return re.sub(pattern, replace_func, input_string)
def latex_rm_whitespace(s: str): def latex_rm_whitespace(s: str):
"""Remove unnecessary whitespace from LaTeX code. """Remove unnecessary whitespace from LaTeX code.
""" """
# logger.debug(f"latex_orig: {s}")
s = fix_unbalanced_braces(s)
# left right不匹配的情况(只考虑了不在同一个组里挪到同一个组里的逻辑,没有考虑right比left多的情况)
# 还有加一个\left或\right后至少要跟随一个符号,如果没符号就补.
s = fix_latex_left_right(s)
# s = fix_left_right_pairs(s)
# 用正则删除\left,\left.,\right,\right.
# s = re.sub(r'\\left\.?|\\right\.?', '', s)
# 替换\up命令
s = re.sub(r'\\up([a-zA-Z]+)',
lambda m: m.group(0) if m.group(1) in ["arrow", "downarrow", "lus", "silon",] else f"\\{m.group(1)}", s)
# 替换\underbar为underline
s = re.sub(r'\\underbar', r'\\underline', s)
# 删除\lefteqn
s = re.sub(r'\\lefteqn', r'', s)
# 删除\boldmath
s = re.sub(r'\\boldmath', r'', s)
# \Bar换成\hat
s = re.sub(r'\\Bar', r'\\hat', s)
# \后缺失空格的补空格
s = process_latex(s)
# \qquad后补空格
s = re.sub(r'\\qquad(?!\s)', r'\\qquad ', s)
# 先保存 "\ " 序列,防止被错误处理 # 先保存 "\ " 序列,防止被错误处理
s = re.sub(r'\\ ', r'\\SPACE', s) s = re.sub(r'\\ ', r'\\SPACE', s)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment