Commit 2d1a0f2c authored by myhloli's avatar myhloli
Browse files

fix(mfr): optimize LaTeX formula repair functionality

- Improve \left and \right command handling in LaTeX formulas
- Enhance environment type matching for array, matrix, and other structures
- Refactor code for better readability and maintainability
parent c8747cff
...@@ -58,6 +58,12 @@ class TokenizerWrapper: ...@@ -58,6 +58,12 @@ class TokenizerWrapper:
return toks return toks
LEFT_PATTERN = re.compile(r'(\\left)(\S*)')
RIGHT_PATTERN = re.compile(r'(\\right)(\S*)')
LEFT_COUNT_PATTERN = re.compile(r'\\left(?![a-zA-Z])')
RIGHT_COUNT_PATTERN = re.compile(r'\\right(?![a-zA-Z])')
LEFT_RIGHT_REMOVE_PATTERN = re.compile(r'\\left\.?|\\right\.?')
def fix_latex_left_right(s): def fix_latex_left_right(s):
""" """
修复LaTeX中的\left和\right命令 修复LaTeX中的\left和\right命令
...@@ -71,31 +77,22 @@ def fix_latex_left_right(s): ...@@ -71,31 +77,22 @@ def fix_latex_left_right(s):
r'\Uparrow', r'\Downarrow', r'\|', r'\.'] r'\Uparrow', r'\Downarrow', r'\|', r'\.']
# 为\left后缺失有效分隔符的情况添加点 # 为\left后缺失有效分隔符的情况添加点
def fix_left_delim(match): def fix_delim(match, is_left=True):
cmd = match.group(1) # \left cmd = match.group(1) # \left 或 \right
rest = match.group(2) if len(match.groups()) > 1 else ""
# if not rest or not re.match(f"^({valid_delims})", rest):
if not rest or rest not in valid_delims_list:
return cmd + "."
return match.group(0)
# 为\right后缺失有效分隔符的情况添加点
def fix_right_delim(match):
cmd = match.group(1) # \right
rest = match.group(2) if len(match.groups()) > 1 else "" rest = match.group(2) if len(match.groups()) > 1 else ""
# if not rest or not re.match(f"^({valid_delims})", rest):
if not rest or rest not in valid_delims_list: if not rest or rest not in valid_delims_list:
return cmd + "." return cmd + "."
return match.group(0) return match.group(0)
# 使用更精确的模式匹配\left和\right命令 # 使用更精确的模式匹配\left和\right命令
# 确保它们是独立的命令,不是其他命令的一部分 # 确保它们是独立的命令,不是其他命令的一部分
s = re.sub(r'(\\left)(\S*)', fix_left_delim, s) # 使用预编译正则和统一回调函数
s = re.sub(r'(\\right)(\S*)', fix_right_delim, s) s = LEFT_PATTERN.sub(lambda m: fix_delim(m, True), s)
s = RIGHT_PATTERN.sub(lambda m: fix_delim(m, False), s)
# 更精确地计算\left和\right的数量 # 更精确地计算\left和\right的数量
left_count = len(re.findall(r'\\left(?![a-zA-Z])', s)) # 不匹配\lefteqn等 left_count = len(LEFT_COUNT_PATTERN.findall(s)) # 不匹配\lefteqn等
right_count = len(re.findall(r'\\right(?![a-zA-Z])', s)) # 不匹配\rightarrow等 right_count = len(RIGHT_COUNT_PATTERN.findall(s)) # 不匹配\rightarrow等
if left_count == right_count: if left_count == right_count:
# 如果数量相等,检查是否在同一组 # 如果数量相等,检查是否在同一组
...@@ -104,7 +101,7 @@ def fix_latex_left_right(s): ...@@ -104,7 +101,7 @@ def fix_latex_left_right(s):
# 如果数量不等,移除所有\left和\right # 如果数量不等,移除所有\left和\right
# logger.debug(f"latex:{s}") # logger.debug(f"latex:{s}")
# logger.warning(f"left_count: {left_count}, right_count: {right_count}") # logger.warning(f"left_count: {left_count}, right_count: {right_count}")
return re.sub(r'\\left\.?|\\right\.?', '', s) return LEFT_RIGHT_REMOVE_PATTERN.sub('', s)
def fix_left_right_pairs(latex_formula): def fix_left_right_pairs(latex_formula):
...@@ -302,6 +299,12 @@ def process_latex(input_string): ...@@ -302,6 +299,12 @@ def process_latex(input_string):
return re.sub(pattern, replace_func, input_string) return re.sub(pattern, replace_func, input_string)
# 常见的在KaTeX/MathJax中可用的数学环境
ENV_TYPES = ['array', 'matrix', 'pmatrix', 'bmatrix', 'vmatrix',
'Bmatrix', 'Vmatrix', 'cases', 'aligned', 'gathered']
ENV_BEGIN_PATTERNS = {env: re.compile(r'\\begin\{' + env + r'\}') for env in ENV_TYPES}
ENV_END_PATTERNS = {env: re.compile(r'\\end\{' + env + r'\}') for env in ENV_TYPES}
ENV_FORMAT_PATTERNS = {env: re.compile(r'\\begin\{' + env + r'\}\{([^}]*)\}') for env in ENV_TYPES}
def fix_latex_environments(s): def fix_latex_environments(s):
""" """
...@@ -309,45 +312,22 @@ def fix_latex_environments(s): ...@@ -309,45 +312,22 @@ def fix_latex_environments(s):
1. 如果缺少\begin标签则在开头添加 1. 如果缺少\begin标签则在开头添加
2. 如果缺少\end标签则在末尾添加 2. 如果缺少\end标签则在末尾添加
""" """
# 常见的在KaTeX/MathJax中可用的数学环境 for env in ENV_TYPES:
env_types = [ begin_count = len(ENV_BEGIN_PATTERNS[env].findall(s))
'array', 'matrix', 'pmatrix', 'bmatrix', 'vmatrix', end_count = len(ENV_END_PATTERNS[env].findall(s))
'Bmatrix', 'Vmatrix', 'cases', 'aligned', 'gathered'
]
for env in env_types:
# 计算\begin{env}和\end{env}的数量
begin_pattern = r'\\begin\{' + env + r'\}'
end_pattern = r'\\end\{' + env + r'\}'
begin_count = len(re.findall(begin_pattern, s))
end_count = len(re.findall(end_pattern, s))
# 处理两种不匹配情况
if begin_count != end_count: if begin_count != end_count:
# 情况1:缺少\begin - 在开头添加缺失的\begin{env}
if end_count > begin_count: if end_count > begin_count:
# 尝试从现有的\begin{env}中提取格式 format_match = ENV_FORMAT_PATTERNS[env].search(s)
format_match = re.search(r'\\begin\{' + env + r'\}\{([^}]*)\}', s) default_format = '{c}' if env == 'array' else ''
# 默认格式,对于array需要列格式
default_format = ''
if env == 'array':
default_format = '{c}' # 默认单列居中
format_str = '{' + format_match.group(1) + '}' if format_match else default_format format_str = '{' + format_match.group(1) + '}' if format_match else default_format
# 添加缺失的\begin{env}
missing_count = end_count - begin_count missing_count = end_count - begin_count
begin_command = '\\begin{' + env + '}' + format_str + ' ' begin_command = '\\begin{' + env + '}' + format_str + ' '
s = begin_command * missing_count + s s = begin_command * missing_count + s
else:
# 情况2:缺少\end - 在末尾添加缺失的\end{env}
elif begin_count > end_count:
# 添加缺失的\end{env}
missing_count = begin_count - end_count missing_count = begin_count - end_count
end_command = ' \\end{' + env + '}' s = s + (' \\end{' + env + '}') * missing_count
s = s + end_command * missing_count
return s return s
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment