Commit 89e60e48 authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
Pipeline #2484 canceled with stages
icon.png

61 KB

# 模型唯一标识
modelCode=1452
# 模型名称
modelName=olmocr_pytorch
# 模型描述
modelDescription=Ai2开发的基于阿里千问多模态OCR模型
# 应用场景
appScenario=推理,OCR,电商,教育,广媒,交通,政府
# 框架类型
frameType=pytorch
from .version import VERSION, VERSION_SHORT
# olmOCR-Bench
We develop olmOCR-Bench in order to automatically and effectively evaluate document-level
parsing and OCR of various tools.
olmOCR-Bench works by testing various "facts" or "properties" about document pages at the PDF-level.
We choose PDFs directly, because PDFs do preserve some digital metadata and information which is helpful
and commonly available. Almost any other format can be converted to a PDF, but not the reverse.
## Property classes
- Text presence/absence
- This task makes sure that a given small piece of text (ex. 1-3 sentence level) is present with high probability within
a parsed document. It looks at documents with ambiguity around headers, footers, and other ambiguous content. Text still
has a fuzzy matching allowed.
- Natural Reading Order
- This task ensures that blocks of text which are present have a defined order relative to one another. For example,
on a document that contains multiple news articles on one page, you'd want to see that the first sentence of the
first article appears after the heading of that article. But, you may be okay with swapping the order of those
two articles.
- Table Accuracy
- Pages with tables get parsed out and are checked for accuracy on a direct row/column/title basis.
- Formula Accuracy
- Extract formula from document, render it, and compare rendering using foundation model.
Table Format:
- pdf_filename
- Task ID
- Type: text_presence, text_absense, reading_order, table
- text_presence, text_absense: {text: str, fuzzy_threshold: float}
- reading_order: {target_text_presence: task_id, appears_before: task_id, appears_after: task_id}
- table: {table_index: int, needs to be fuzzy as well, ex. does row exist with column text X, does column exist with a row containing Y}
- formula: TODO
## Creation
We sampled documents from the same source as olmocrmix. We run them through two models, and see which ones have biggest
plain textual diffs, but still contain lots of good text, and aren't just tables/formula heavy for now.
Then, we will extract text presence/absense markers and verify using tinyhost UI.
Write those to JSON. Maybe do some embedding and grouping to try to get lots of variation, at least when
prioritizing manual review.
Later, we will repeat the same for tables and formulas.
Write the evalutor script which will output a nice templated tinyhostable results page.
## Running
We do not want to depend on a model having any specific format of its output.
Step 1. Download dataset with all pdfs (all will be single page) to /pdfs
Step 2. Run your extraction on it, point output to folder, ex. olmocr-v2_1/ where you expect pdf_page1.md for /pdfs/pdf_page1.pdf file
Step 3. Run the evaluation script
Step 4. Get results, and use tinyhost to view all failing examples
### Running existing scripts
```bash
pip install marker-pdf==1.5.4
python olmocr/bench/runners/run_marker.py olmocr/bench/sample_data/pdfs
pip install verovio torchvision
python olmocr/bench/runners/run_gotocr.py olmocr/bench/sample_data/pdfs
conda create -n MinerU python=3.10
conda activate MinerU
pip install -U magic-pdf[full]==1.1.0 --extra-index-url https://wheels.myhloli.com
pip install huggingface_hub
wget https://github.com/opendatalab/MinerU/raw/master/scripts/download_models_hf.py -O download_models_hf.py
python download_models_hf.py
python olmocr/bench/runners/run_mineru.py olmocr/bench/sample_data/pdfs
```
#!/usr/bin/env python3
"""
This script runs olmocr bench.
It will take as an argument a folder, and scan it for .jsonl files which contain the various rules and properties that we will check.
It will then validate the JSON files to make sure they are all valid.
Then, each other folder in there (besides /pdfs) represents a pipeline tool that we will evaluate.
We will validate that each one of those contains at least one .md file (or repeated generations, e.g. _pg{page}_repeat{repeat}.md)
corresponding to its parse for every .pdf in the /pdfs folder.
Then, we will read each one, and check if they pass against all the rules.
If a rule fails on some of the repeats, a short explanation is printed.
The final score is averaged over the repeated generations.
Statistical analysis including bootstrap confidence intervals are provided for the results.
Pairwise permutation tests are conducted between specific candidate pairs.
"""
import argparse
import glob
import os
import sys
import re
from typing import Dict, List, Tuple, Optional
from pypdf import PdfReader
from .tests import BasePDFTest, BaselineTest, load_tests
from .katex.render import clear_cache_dir
from .utils import calculate_bootstrap_ci, perform_permutation_test
def evaluate_candidate(
candidate_folder: str, all_tests: List[BasePDFTest], pdf_basenames: List[str], force: bool=False
) -> Tuple[float, int, List[str], List[str], Dict[str, List[float]], List[float]]:
"""
For the candidate folder (pipeline tool output), validate that it contains at least one .md file
(i.e. repeated generations like _pg{page}_repeat{repeat}.md) for every PDF in the pdf folder.
Then, run each rule against all corresponding .md files and average the results.
Returns a tuple:
(overall_score, total_tests, candidate_errors, test_failures, test_type_breakdown, all_test_scores)
- overall_score: Average fraction of tests passed (averaged over repeats and tests).
- total_tests: Total number of tests evaluated.
- candidate_errors: List of candidate errors (e.g. missing files).
- test_failures: List of failure messages for tests not passing on all repeats.
- test_type_breakdown: Dictionary mapping test type to list of average pass ratios for tests of that type.
- all_test_scores: List of all individual test scores (used for bootstrapping).
"""
candidate_errors = []
test_failures = []
test_type_breakdown = {} # key: test type, value: list of average pass ratios
all_test_scores = [] # Store all individual test scores for bootstrapping
candidate_name = os.path.basename(candidate_folder)
# Map each PDF to its corresponding MD repeats (e.g., doc1_pg1_repeat1.md, doc1_pg2_repeat2.md, etc.)
pdf_to_md_files = {}
for pdf_name in pdf_basenames:
md_base = os.path.splitext(pdf_name)[0]
# Updated regex for new format: {pdf_name}_pg<page>_repeat<repeat>.md
md_regex = re.compile(rf"^{re.escape(md_base)}_pg\d+_repeat\d+\.md$")
# List all files in the candidate folder and filter using regex
all_files = os.listdir(candidate_folder)
md_files = [os.path.join(candidate_folder, f) for f in all_files if md_regex.match(f)]
if not md_files and not force:
candidate_errors.append(
f"Candidate '{candidate_name}' is missing MD repeats for {pdf_name} "
f"(expected files matching {md_base}_pg{{page}}_repeat*.md)."
)
else:
pdf_to_md_files[pdf_name] = md_files
if candidate_errors:
return (0.0, len(all_tests), candidate_errors, test_failures, test_type_breakdown, all_test_scores)
total_test_score = 0.0
# Evaluate each test. Each test references a PDF (e.g., "doc1.pdf") and a specific page.
for test in all_tests:
test_type = test.type
if test_type not in test_type_breakdown:
test_type_breakdown[test_type] = []
pdf_name = test.pdf
md_base = os.path.splitext(pdf_name)[0]
md_files = pdf_to_md_files.get(pdf_name, [])
# Filter MD files for the specific page corresponding to the test
page_md_files = [f for f in md_files if re.search(rf"_pg{test.page}_", os.path.basename(f))]
if not page_md_files:
candidate_errors.append(
f"Candidate '{candidate_name}' is missing MD repeats for {pdf_name} page {test.page} "
f"(expected files matching {md_base}_pg{test.page}_repeat*.md)."
)
continue
repeat_passes = 0
num_repeats = 0
explanations = []
for md_path in page_md_files:
num_repeats += 1
try:
with open(md_path, "r", encoding="utf-8") as f:
md_content = f.read()
except Exception as e:
candidate_errors.append(f"Error reading {md_path}: {e}")
continue
try:
# Use the test's run method to evaluate the content
passed, explanation = test.run(md_content)
if passed:
repeat_passes += 1
else:
explanations.append(explanation)
except Exception as e:
candidate_errors.append(f"Error running test {test.id} on {md_path}: {e}")
explanations.append(str(e))
test_avg = repeat_passes / num_repeats if num_repeats > 0 else 0.0
all_test_scores.append(test_avg) # Add to list for bootstrapping
total_test_score += test_avg
if test_avg < 1.0:
test_failures.append(
f"Test {test.id} on {md_base} page {test.page} average pass ratio: {test_avg:.3f} "
f"({repeat_passes}/{num_repeats} repeats passed). Ex: {explanations[0] if explanations else 'No explanation'}"
)
test_type_breakdown[test_type].append(test_avg)
overall_score = total_test_score / len(all_tests) if all_tests else 0.0
return (overall_score, len(all_tests), candidate_errors, test_failures, test_type_breakdown, all_test_scores)
def main():
parser = argparse.ArgumentParser(description="Run OLMOCR Bench.")
parser.add_argument(
"--input_folder",
default=os.path.join(os.path.dirname(__file__), "sample_data"),
help="Path to the folder containing .jsonl files, /pdfs folder, and pipeline tool subfolders.",
)
parser.add_argument(
"--force",
action="store_true",
help="Run benchmark even if some files are missing",
)
parser.add_argument(
"--candidate",
type=str,
default=None,
help="Run test only for a single candidate"
)
parser.add_argument(
"--bootstrap_samples",
type=int,
default=1000,
help="Number of bootstrap samples for confidence interval calculation (default: 1000).",
)
parser.add_argument(
"--confidence_level",
type=float,
default=0.95,
help="Confidence level for interval calculation (default: 0.95 for 95% CI).",
)
parser.add_argument(
"--permutation_tests",
action="store_true",
help="Run permutation testing",
)
args = parser.parse_args()
input_folder = args.input_folder
n_bootstrap = args.bootstrap_samples
ci_level = args.confidence_level
pdf_folder = os.path.join(input_folder, "pdfs")
# Check that the pdfs folder exists
if not os.path.exists(pdf_folder):
print("Error: /pdfs folder must exist in your data directory.", file=sys.stderr)
sys.exit(1)
# Find all pdf files in the pdf folder
all_pdf_files = list(glob.glob(os.path.join(pdf_folder, "*.pdf")))
if not all_pdf_files:
print(f"Error: No PDF files found in {pdf_folder}", file=sys.stderr)
sys.exit(1)
# Get PDF basenames (e.g. "doc1.pdf")
pdf_basenames = [os.path.basename(p) for p in all_pdf_files]
# Find and validate .jsonl files in the input folder
jsonl_files = glob.glob(os.path.join(input_folder, "*.jsonl"))
if not jsonl_files:
print(f"Error: No .jsonl files found in {input_folder}.", file=sys.stderr)
sys.exit(1)
# Load and concatenate all test rules from JSONL files
all_tests = []
for jsonl_path in jsonl_files:
tests = load_tests(jsonl_path)
all_tests.extend(tests)
if not all_tests:
print("No valid tests found. Exiting.", file=sys.stderr)
sys.exit(1)
# Add in a default repeat test for every PDF that doesn't already have one
for pdf in pdf_basenames:
if not any(t.type == "baseline" for t in all_tests if t.pdf == pdf):
all_tests.append(BaselineTest(id=f"{pdf}_baseline", pdf=pdf, page=1, type="baseline"))
# Make sure that each PDF and page has at least one test in it
for pdf in pdf_basenames:
pdf_doc = PdfReader(os.path.join(pdf_folder, pdf))
for page in range(1, len(pdf_doc.pages) + 1):
if not any(test for test in all_tests if test.pdf == pdf and test.page == page) and not args.force:
print(f"No dataset entry found for pdf {pdf} page {page}")
sys.exit(1)
# Identify candidate pipeline folders (subdirectories of input_folder excluding /pdfs)
candidate_folders = []
for entry in os.listdir(input_folder):
full_path = os.path.join(input_folder, entry)
if args.candidate is not None:
if entry == args.candidate:
candidate_folders.append(full_path)
else:
if os.path.isdir(full_path) and entry != "pdfs":
candidate_folders.append(full_path)
if not candidate_folders:
print("Error: No candidate pipeline folders found (subdirectories besides 'pdfs').", file=sys.stderr)
sys.exit(1)
candidate_folders.sort()
# Evaluate each candidate
summary = []
print("\nRunning tests for each candidate:")
for candidate in candidate_folders:
candidate_name = os.path.basename(candidate)
overall_score, total_tests, candidate_errors, test_failures, test_type_breakdown, all_test_scores = evaluate_candidate(
candidate, all_tests, pdf_basenames, args.force,
)
# Calculate confidence interval
if all_test_scores:
ci = calculate_bootstrap_ci(all_test_scores, n_bootstrap=n_bootstrap, ci_level=ci_level)
else:
ci = (0.0, 0.0)
summary.append((candidate_name, overall_score, total_tests, candidate_errors, test_failures, test_type_breakdown, ci, all_test_scores))
print(f"\nCandidate: {candidate_name}")
if candidate_errors:
for err in candidate_errors:
print(f" [ERROR] {err}")
else:
if test_failures:
for fail in test_failures:
print(f" [FAIL] {fail}")
print(f" Average Score: {overall_score * 100:.1f}% (95% CI: [{ci[0] * 100:.1f}%, {ci[1] * 100:.1f}%]) over {total_tests} tests.")
# Print final summary with breakdown by test type
print("\n" + "=" * 60)
print("Final Summary with 95% Confidence Intervals:")
for candidate_name, overall_score, total_tests, candidate_errors, _, test_type_breakdown, ci, _ in summary:
if candidate_errors:
status = "FAILED (errors)"
ci_str = "N/A"
ciw_str = ""
else:
status = f"{overall_score * 100:0.1f}%"
half_width = ((ci[1] - ci[0]) / 2) * 100
ciw_str = f{half_width:0.1f}%"
ci_str = f"[{ci[0] * 100:0.1f}%, {ci[1] * 100:0.1f}%]"
print(f"{candidate_name:20s} : Average Score: {status} {ciw_str}")
for ttype, scores in test_type_breakdown.items():
if scores:
avg = sum(scores) / len(scores) * 100
else:
avg = 0.0
print(f" {ttype:8s}: {avg:0.1f}% average pass rate over {len(scores)} tests")
print("")
# Perform pairwise permutation tests
if args.permutation_tests:
print("\n" + "=" * 60)
print("Pairwise Permutation Tests:")
valid_candidates = [c for c in summary if not c[3]] # Filter out candidates with errors
olmocr_candidates = sorted([c for c in valid_candidates if "olmocr" in c[0].lower()], key=lambda x: x[1], reverse=True)
non_olmocr_candidates = sorted([c for c in valid_candidates if "olmocr" not in c[0].lower()], key=lambda x: x[1], reverse=True)
top_olmocr = olmocr_candidates[0] if olmocr_candidates else None
top_non_olmocr = non_olmocr_candidates[0] if non_olmocr_candidates else None
top_two_olmocr = olmocr_candidates[:2]
# Test 1: Top olmocr vs Top non-olmocr
if top_olmocr and top_non_olmocr:
olmocr_name, olmocr_score = top_olmocr[0], top_olmocr[1]
non_olmocr_name, non_olmocr_score = top_non_olmocr[0], top_non_olmocr[1]
olmocr_scores = top_olmocr[7] # all_test_scores
non_olmocr_scores = top_non_olmocr[7] # all_test_scores
diff, p_value = perform_permutation_test(
olmocr_scores, non_olmocr_scores
)
print(f"\nComparison 1: Top olmocr vs Top non-olmocr candidate")
print(f" {olmocr_name} ({olmocr_score*100:.1f}%) vs {non_olmocr_name} ({non_olmocr_score*100:.1f}%)")
print(f" Difference: {diff*100:.2f}% (positive means {olmocr_name} is better)")
print(f" p-value: {p_value:.4f}")
if p_value < 0.05:
print(f" Result: Statistically significant difference (p < 0.05)")
else:
print(f" Result: No statistically significant difference (p ≥ 0.05)")
else:
print("\nCannot perform olmocr vs non-olmocr comparison: Missing candidates")
# Test 2: Top two olmocr candidates (if there are at least two)
if len(top_two_olmocr) >= 2:
olmocr1_name, olmocr1_score = top_two_olmocr[0][0], top_two_olmocr[0][1]
olmocr2_name, olmocr2_score = top_two_olmocr[1][0], top_two_olmocr[1][1]
olmocr1_scores = top_two_olmocr[0][7] # all_test_scores
olmocr2_scores = top_two_olmocr[1][7] # all_test_scores
diff, p_value = perform_permutation_test(
olmocr1_scores, olmocr2_scores
)
print(f"\nComparison 2: Top two olmocr candidates")
print(f" {olmocr1_name} ({olmocr1_score*100:.1f}%) vs {olmocr2_name} ({olmocr2_score*100:.1f}%)")
print(f" Difference: {diff*100:.2f}% (positive means {olmocr1_name} is better)")
print(f" p-value: {p_value:.4f}")
if p_value < 0.05:
print(f" Result: Statistically significant difference (p < 0.05)")
else:
print(f" Result: No statistically significant difference (p ≥ 0.05)")
else:
print("\nCannot perform top two olmocr comparison: Not enough olmocr candidates")
print("=" * 60)
if __name__ == "__main__":
main()
import argparse
import asyncio
import glob
import importlib
import os
from functools import partial
from itertools import product
from tqdm import tqdm
from pypdf import PdfReader
def parse_method_arg(method_arg):
"""
Parse a method configuration string of the form:
method_name[:key=value[:key2=value2...]]
Returns:
(method_name, kwargs_dict, folder_name)
"""
parts = method_arg.split(":")
name = parts[0]
kwargs = {}
folder_name = name # Default folder name is the method name
for extra in parts[1:]:
if "=" in extra:
key, value = extra.split("=", 1)
if key == "name":
folder_name = value
continue
try:
converted = int(value)
except ValueError:
try:
converted = float(value)
except ValueError:
converted = value
kwargs[key] = converted
else:
raise ValueError(f"Extra argument '{extra}' is not in key=value format")
return name, kwargs, folder_name
# Wrapper to run synchronous functions in the event loop
async def run_sync_in_executor(func, *args, **kwargs):
"""Run a synchronous function in the default executor"""
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, partial(func, *args, **kwargs))
async def process_pdf(pdf_path, page_num, method, kwargs, output_path, is_async):
"""Process a single PDF and save the result to output_path"""
try:
if is_async:
# Run async function directly
markdown = await method(pdf_path, page_num=page_num, **kwargs)
else:
# Run synchronous function in the executor
markdown = await run_sync_in_executor(method, pdf_path, page_num=page_num, **kwargs)
if markdown is None:
print(f"Warning, did not get output for {os.path.basename(output_path)}")
# Write blank to this file, so that it's marked as an error and not just skipped in evals
with open(output_path, "w") as out_f:
out_f.write("")
return False
# Write the markdown to the output file
with open(output_path, "w") as out_f:
out_f.write(markdown)
return True
except Exception as ex:
print(f"Exception {str(ex)} occurred while processing {os.path.basename(output_path)}")
# Write blank to this file, so that it's marked as an error and not just skipped in evals
with open(output_path, "w") as out_f:
out_f.write("")
return False
async def process_pdfs(config, pdf_directory, data_directory, repeats, force, max_parallel=None):
"""
Process PDFs using asyncio for both sync and async methods,
limiting the number of concurrent tasks to max_parallel.
"""
for candidate in config.keys():
print(f"Starting conversion using {candidate} with kwargs: {config[candidate]['kwargs']}")
folder_name = config[candidate]["folder_name"]
candidate_output_dir = os.path.join(data_directory, folder_name)
os.makedirs(candidate_output_dir, exist_ok=True)
method = config[candidate]["method"]
kwargs = config[candidate]["kwargs"]
is_async = asyncio.iscoroutinefunction(method)
all_pdfs = glob.glob(os.path.join(pdf_directory, "*.pdf"))
all_pdfs.sort()
# Prepare all tasks
tasks = []
task_descriptions = {}
for pdf_path in all_pdfs:
pdf = PdfReader(pdf_path)
num_pages = len(pdf.pages)
base_name = os.path.basename(pdf_path).replace(".pdf", "")
for repeat in range(1, repeats + 1):
for page_num in range(1, num_pages + 1):
output_filename = f"{base_name}_pg{page_num}_repeat{repeat}.md"
output_path = os.path.join(candidate_output_dir, output_filename)
if os.path.exists(output_path) and not force:
print(f"Skipping {base_name}_pg{page_num}_repeat{repeat} for {candidate}, file already exists")
print("Rerun with --force flag to force regeneration")
continue
task = process_pdf(pdf_path, page_num, method, kwargs, output_path, is_async)
tasks.append(task)
task_descriptions[id(task)] = f"{base_name}_pg{page_num}_repeat{repeat} ({candidate})"
# Process tasks with semaphore to limit concurrency
semaphore = asyncio.Semaphore(max_parallel or 1) # Default to 1 if not specified
async def process_with_semaphore(task):
async with semaphore:
return await task
# Wrap each task with the semaphore
limited_tasks = [process_with_semaphore(task) for task in tasks]
# Process tasks with progress bar
if limited_tasks:
completed = 0
with tqdm(total=len(limited_tasks), desc=f"Processing {candidate}") as pbar:
for task in asyncio.as_completed(limited_tasks):
try:
result = await task
if result:
completed += 1
except Exception as e:
print(f"Task failed: {e}")
finally:
pbar.update(1)
print(f"Completed {completed} out of {len(limited_tasks)} tasks for {candidate}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run PDF conversion using specified OCR methods and extra parameters.")
parser.add_argument(
"methods",
nargs="+",
help="Methods to run in the format method[:key=value ...]. "
"Example: gotocr mineru:temperature=2 marker:u=3. "
"Use 'name=folder_name' to specify a custom output folder name.",
)
parser.add_argument("--repeats", type=int, default=1, help="Number of times to repeat the conversion for each PDF.")
parser.add_argument("--dir", type=str, default=os.path.join(os.path.dirname(__file__), "sample_data"), help="Path to the data folder in which to save outputs, pdfs should be in /pdfs folder within it.")
parser.add_argument("--force", action="store_true", default=False, help="Force regenerating of output files, even if they already exist")
parser.add_argument("--parallel", type=int, default=1, help="Maximum number of concurrent tasks")
args = parser.parse_args()
# Mapping of method names to a tuple: (module path, function name)
available_methods = {
"olmocr_pipeline": ("olmocr.bench.runners.run_olmocr_pipeline", "run_olmocr_pipeline"),
"gotocr": ("olmocr.bench.runners.run_gotocr", "run_gotocr"),
"marker": ("olmocr.bench.runners.run_marker", "run_marker"),
"mineru": ("olmocr.bench.runners.run_mineru", "run_mineru"),
"chatgpt": ("olmocr.bench.runners.run_chatgpt", "run_chatgpt"),
"gemini": ("olmocr.bench.runners.run_gemini", "run_gemini"),
"mistral": ("olmocr.bench.runners.run_mistral", "run_mistral"),
"server": ("olmocr.bench.runners.run_server", "run_server"),
}
# Build config by importing only requested methods.
config = {}
for method_arg in args.methods:
method_name, extra_kwargs, folder_name = parse_method_arg(method_arg)
if method_name not in available_methods:
parser.error(f"Unknown method: {method_name}. " f"Available methods: {', '.join(available_methods.keys())}")
module_path, function_name = available_methods[method_name]
# Dynamically import the module and get the function.
module = importlib.import_module(module_path)
function = getattr(module, function_name)
config[method_name] = {"method": function, "kwargs": extra_kwargs, "folder_name": folder_name}
data_directory = args.dir
pdf_directory = os.path.join(data_directory, "pdfs")
# Run the async process function with the parallel argument
asyncio.run(process_pdfs(config, pdf_directory, data_directory, args.repeats, args.force, args.parallel))
\ No newline at end of file
This diff is collapsed.
This source diff could not be displayed because it is too large. You can view the blob instead.
#!/usr/bin/env python3
"""
Extract inner-most spans and their bounding boxes, and the mathML output,
from rendered LaTeX equations using Playwright and KaTeX.
Caching is maintained via a SHA1-based hash stored as a JSON file.
Requirements:
pip install playwright
python -m playwright install chromium
Place katex.min.css and katex.min.js in the same directory as this script
"""
import os
import re
import html
import hashlib
import pathlib
import json
import re
import shutil
from dataclasses import dataclass
from typing import List
import unittest
import html.entities
from lxml import etree
from playwright.sync_api import sync_playwright, Error as PlaywrightError
@dataclass
class BoundingBox:
x: float
y: float
width: float
height: float
@dataclass
class SpanInfo:
text: str
bounding_box: BoundingBox
@dataclass
class RenderedEquation:
mathml: str
spans: List[SpanInfo]
def get_equation_hash(equation, bg_color="white", text_color="black", font_size=24):
"""
Calculate SHA1 hash of the equation string and rendering parameters.
"""
params_str = f"{equation}|{bg_color}|{text_color}|{font_size}"
return hashlib.sha1(params_str.encode('utf-8')).hexdigest()
def get_cache_dir():
"""
Get the cache directory for equations, creating it if it doesn't exist.
"""
cache_dir = pathlib.Path.home() / '.cache' / 'olmocr' / 'bench' / 'equations'
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir
def clear_cache_dir():
"""
Clear all files and subdirectories in the cache directory.
"""
cache_dir = get_cache_dir()
if cache_dir.exists() and cache_dir.is_dir():
shutil.rmtree(cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True) # Recreate the empty directory
def render_equation(
equation,
bg_color="white",
text_color="black",
font_size=24,
use_cache=True,
debug_dom=False,
):
"""
Render a LaTeX equation using Playwright and KaTeX, extract the inner-most span elements
(those without child elements that contain non-whitespace text) along with their bounding boxes,
and also extract the MathML output generated by KaTeX.
Returns:
RenderedEquation: A dataclass containing the mathml string and a list of SpanInfo dataclasses.
"""
# Calculate hash for caching
eq_hash = get_equation_hash(equation, bg_color, text_color, font_size)
cache_dir = get_cache_dir()
cache_file = cache_dir / f"{eq_hash}.json"
cache_error_file = cache_dir / f"{eq_hash}_error"
if use_cache:
if cache_error_file.exists():
return None
if cache_file.exists():
with open(cache_file, 'r') as f:
data = json.load(f)
spans = [
SpanInfo(
text=s["text"],
bounding_box=BoundingBox(
x=s["boundingBox"]["x"],
y=s["boundingBox"]["y"],
width=s["boundingBox"]["width"],
height=s["boundingBox"]["height"],
)
)
for s in data["spans"]
]
return RenderedEquation(mathml=data["mathml"], spans=spans)
# Escape backslashes for JavaScript string
escaped_equation = json.dumps(equation)
# Get local paths for KaTeX files
script_dir = os.path.dirname(os.path.abspath(__file__))
katex_css_path = os.path.join(script_dir, "katex.min.css")
katex_js_path = os.path.join(script_dir, "katex.min.js")
if not os.path.exists(katex_css_path) or not os.path.exists(katex_js_path):
raise FileNotFoundError(f"KaTeX files not found. Please ensure katex.min.css and katex.min.js are in {script_dir}")
with sync_playwright() as p:
browser = p.chromium.launch()
page = browser.new_page(viewport={"width": 800, "height": 400})
# Basic HTML structure
page_html = f"""
<!DOCTYPE html>
<html>
<head>
<style>
body {{
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
margin: 0;
background-color: {bg_color};
color: {text_color};
}}
#equation-container {{
padding: 0;
font-size: {font_size}px;
}}
</style>
</head>
<body>
<div id="equation-container"></div>
</body>
</html>
"""
page.set_content(page_html)
page.add_style_tag(path=katex_css_path)
page.add_script_tag(path=katex_js_path)
page.wait_for_load_state("networkidle")
katex_loaded = page.evaluate("typeof katex !== 'undefined'")
if not katex_loaded:
raise RuntimeError("KaTeX library failed to load. Check your katex.min.js file.")
try:
error_message = page.evaluate(f"""
() => {{
try {{
katex.render({escaped_equation}, document.getElementById("equation-container"), {{
displayMode: true,
throwOnError: true
}});
return null;
}} catch (error) {{
console.error("KaTeX error:", error.message);
return error.message;
}}
}}
""")
except PlaywrightError as ex:
print(escaped_equation)
error_message = str(ex)
raise
if error_message:
print(f"Error rendering equation: '{equation}'")
print(error_message)
cache_error_file.touch()
browser.close()
return None
page.wait_for_selector(".katex", state="attached")
if debug_dom:
katex_dom_html = page.evaluate("""
() => {
return document.getElementById("equation-container").innerHTML;
}
""")
print("\n===== KaTeX DOM HTML =====")
print(katex_dom_html)
# Extract inner-most spans with non-whitespace text
spans_info = page.evaluate("""
() => {
const spans = Array.from(document.querySelectorAll('span'));
const list = [];
spans.forEach(span => {
// Check if this span has no child elements and contains non-whitespace text
if (span.children.length === 0 && /\S/.test(span.textContent)) {
const rect = span.getBoundingClientRect();
list.push({
text: span.textContent.trim(),
boundingBox: {
x: rect.x,
y: rect.y,
width: rect.width,
height: rect.height
}
});
}
});
return list;
}
""")
if debug_dom:
print("\n===== Extracted Span Information =====")
print(spans_info)
# Extract mathML output (if available) from the KaTeX output.
# We try to get the <math> element within an element with class "katex-mathml".
mathml = page.evaluate("""
() => {
const mathElem = document.querySelector('.katex-mathml math');
return mathElem ? mathElem.outerHTML : "";
}
""")
browser.close()
# Build the result as a RenderedEquation dataclass
rendered_eq = RenderedEquation(
mathml=mathml,
spans=[
SpanInfo(
text=s["text"],
bounding_box=BoundingBox(
x=s["boundingBox"]["x"],
y=s["boundingBox"]["y"],
width=s["boundingBox"]["width"],
height=s["boundingBox"]["height"]
)
)
for s in spans_info
]
)
# Save to cache (convert dataclasses to a JSON-serializable dict)
cache_data = {
"mathml": rendered_eq.mathml,
"spans": [
{
"text": span.text,
"boundingBox": {
"x": span.bounding_box.x,
"y": span.bounding_box.y,
"width": span.bounding_box.width,
"height": span.bounding_box.height
}
}
for span in rendered_eq.spans
]
}
with open(cache_file, 'w') as f:
json.dump(cache_data, f)
return rendered_eq
def compare_rendered_equations(reference: RenderedEquation, hypothesis: RenderedEquation) -> bool:
"""
First, try to determine whether the normalized MathML of the hypothesis is contained
as a substring of the normalized MathML of the reference.
If that fails, then perform a neighbor‐based matching:
Each span in the hypothesis must be matched to a span in the reference (with identical text),
and if a hypothesis span has an immediate neighbor (up, down, left, or right), then its candidate
reference span must have a neighbor in the same direction that (if already matched) is the candidate
for the hypothesis neighbor – otherwise, the candidate must have the same text as the hypothesis neighbor.
The algorithm uses backtracking to explore all possible assignments.
"""
from bs4 import BeautifulSoup
def extract_inner(mathml: str) -> str:
try:
# Use the "xml" parser so that BeautifulSoup parses MathML correctly,
# handling HTML entities along the way.
soup = BeautifulSoup(mathml, "xml")
semantics = soup.find("semantics")
if semantics:
# Concatenate the string representation of all children except <annotation>
inner_parts = [
str(child)
for child in semantics.contents
if getattr(child, "name", None) != "annotation"
]
return ''.join(inner_parts)
else:
return str(soup)
except Exception as e:
print("Error parsing MathML with BeautifulSoup:", e)
print(mathml)
return mathml
def normalize(s: str) -> str:
return re.sub(r'\s+', '', s)
# First, try a fast mathML substring check.
reference_inner = normalize(extract_inner(reference.mathml))
hypothesis_inner = normalize(extract_inner(hypothesis.mathml))
if reference_inner in hypothesis_inner:
return True
# Fallback: neighbor-based matching using the spans.
# First, print out the original span lists.
# print("Hypothesis spans:")
# for s in hypothesis.spans:
# print(s)
# print("---")
# print("Reference spans:")
# for s in reference.spans:
# print(s)
# print("---")
# We swap H and R so that we are effectively checking if the reference is contained in the hypothesis.
H, R = reference.spans, hypothesis.spans
H = [span for span in H if span.text != "\u200b"]
R = [span for span in R if span.text != "\u200b"]
# Build candidate map: for each span in H (reference), record indices in R with matching text.
candidate_map = {}
for i, hspan in enumerate(H):
candidate_map[i] = [j for j, rsp in enumerate(R) if rsp.text == hspan.text]
if not candidate_map[i]:
return False # no candidate for a given span, so we fail immediately
# print("Candidate Map:")
# print(candidate_map)
# Function to compute neighbor mappings for a list of spans.
def compute_neighbors(spans, tol=5):
neighbors = {}
for i, span in enumerate(spans):
cx = span.bounding_box.x + span.bounding_box.width / 2
cy = span.bounding_box.y + span.bounding_box.height / 2
up = down = left = right = None
up_dist = down_dist = left_dist = right_dist = None
for j, other in enumerate(spans):
if i == j:
continue
ocx = other.bounding_box.x + other.bounding_box.width / 2
ocy = other.bounding_box.y + other.bounding_box.height / 2
# Up: candidate must be above (ocy < cy) and nearly aligned horizontally.
if ocy < cy and abs(ocx - cx) <= tol:
dist = cy - ocy
if up is None or dist < up_dist:
up = j
up_dist = dist
# Down: candidate below.
if ocy > cy and abs(ocx - cx) <= tol:
dist = ocy - cy
if down is None or dist < down_dist:
down = j
down_dist = dist
# Left: candidate left.
if ocx < cx and abs(ocy - cy) <= tol:
dist = cx - ocx
if left is None or dist < left_dist:
left = j
left_dist = dist
# Right: candidate right.
if ocx > cx and abs(ocy - cy) <= tol:
dist = ocx - cx
if right is None or dist < right_dist:
right = j
right_dist = dist
neighbors[i] = {"up": up, "down": down, "left": left, "right": right}
return neighbors
hyp_neighbors = compute_neighbors(H)
ref_neighbors = compute_neighbors(R)
# print("Neighbor Map for Reference spans (H):")
# for i, nb in hyp_neighbors.items():
# print(f"Span {i}: {nb}")
# print("Neighbor Map for Hypothesis spans (R):")
# for i, nb in ref_neighbors.items():
# print(f"Span {i}: {nb}")
# Backtracking search for an injection f: {0,...,n-1} -> {indices in R} that preserves neighbor relations.
n = len(H)
used = [False] * len(R)
assignment = {}
def backtrack(i):
if i == n:
return True
for cand in candidate_map[i]:
if used[cand]:
continue
# Tentatively assign hypothesis span i (from H) to reference span cand (in R).
assignment[i] = cand
used[cand] = True
valid = True
# Check neighbor constraints for all directions.
for direction in ["up", "down", "left", "right"]:
hyp_nb = hyp_neighbors[i].get(direction)
ref_nb = ref_neighbors[cand].get(direction)
if hyp_nb is not None:
expected_text = H[hyp_nb].text
# The candidate in R must have a neighbor in that direction.
if ref_nb is None:
valid = False
break
# If the neighbor in H is already assigned, then the candidate neighbor must match.
if hyp_nb in assignment:
if assignment[hyp_nb] != ref_nb:
valid = False
break
else:
# If not yet assigned, the neighbor text in R must match the expected text.
if R[ref_nb].text != expected_text:
valid = False
break
if valid:
if backtrack(i + 1):
return True
# Backtrack this candidate assignment.
used[cand] = False
del assignment[i]
return False
result = backtrack(0)
return result
class TestRenderedEquationComparison(unittest.TestCase):
def test_exact_match(self):
# Both calls with identical LaTeX should produce matching MathML output.
eq1 = render_equation("a+b", use_cache=False)
eq2 = render_equation("a+b", use_cache=False)
self.assertTrue(compare_rendered_equations(eq1, eq2))
def test_whitespace_difference(self):
# Differences in whitespace in the LaTeX input should not affect the MathML output.
eq1 = render_equation("a+b", use_cache=False)
eq2 = render_equation("a + b", use_cache=False)
self.assertTrue(compare_rendered_equations(eq1, eq2))
def test_not_found(self):
# Completely different equations should not match.
eq1 = render_equation("c-d", use_cache=False)
eq2 = render_equation("a+b", use_cache=False)
self.assertFalse(compare_rendered_equations(eq1, eq2))
def test_align_block_contains_needle(self):
# The MathML output of the plain equation should be found within the align block output.
eq_plain = render_equation("a+b", use_cache=False)
eq_align = render_equation("\\begin{align*}a+b\\end{align*}", use_cache=False)
self.assertTrue(compare_rendered_equations(eq_plain, eq_align))
def test_align_block_needle_not_in(self):
# An align block rendering a different equation should not contain the MathML of an unrelated equation.
eq_align = render_equation("\\begin{align*}a+b\\end{align*}", use_cache=False)
eq_diff = render_equation("c-d", use_cache=False)
self.assertFalse(compare_rendered_equations(eq_diff, eq_align))
def test_big(self):
ref_rendered = render_equation("\\nabla \\cdot \\mathbf{E} = \\frac{\\rho}{\\varepsilon_0}", use_cache=False, debug_dom=False)
align_rendered = render_equation("""\\begin{align*}\\nabla \\cdot \\mathbf{E} = \\frac{\\rho}{\\varepsilon_0}\\end{align*}""", use_cache=False, debug_dom=False)
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
def test_dot_end1(self):
ref_rendered = render_equation("\\lambda_{g}=\\sum_{s \\in S} \\zeta_{n}^{\\psi(g s)}=\\sum_{i=1}^{k}\\left[\\sum_{s, R s=\\mathcal{I}_{i}} \\zeta_{n}^{\\varphi(g s)}\\right]")
align_rendered = render_equation("\\lambda_{g}=\\sum_{s \\in S} \\zeta_{n}^{\\psi(g s)}=\\sum_{i=1}^{k}\\left[\\sum_{s, R s=\\mathcal{I}_{i}} \\zeta_{n}^{\\varphi(g s)}\\right].")
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
def test_dot_end2(self):
ref_rendered = render_equation("\\lambda_{g}=\\sum_{s \\in S} \\zeta_{n}^{\\psi(g s)}=\\sum_{i=1}^{k}\\left[\\sum_{s, R s=\\mathcal{I}_{i}} \\zeta_{n}^{\\psi(g s)}\\right]")
align_rendered = render_equation("\\lambda_g = \\sum_{s \\in S} \\zeta_n^{\\psi(gs)} = \\sum_{i=1}^{k} \\left[ \\sum_{s, Rs = \\mathcal{I}_i} \\zeta_n^{\\psi(gs)} \\right]")
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
def test_lambda(self):
ref_rendered = render_equation("\\lambda_g = \\lambda_{g'}")
align_rendered = render_equation("\\lambda_{g}=\\lambda_{g^{\\prime}}")
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
def test_gemini(self):
ref_rendered = render_equation("u \\in (R/\\operatorname{Ann}_R(x_i))^{\\times}")
align_rendered = render_equation("u \\in\\left(R / \\operatorname{Ann}_{R}\\left(x_{i}\\right)\\right)^{\\times}")
self.assertTrue(compare_rendered_equations(ref_rendered, align_rendered))
if __name__ == "__main__":
unittest.main()
# This script goes to
# https://arxiv.org/list/math/recent?skip=0&show=2000
# and downloads all the source PDFs, as well as latex equivalents, and puts them together into
# Searching for:
# <a href="/pdf/2503.08675" title="Download PDF" id="pdf-2503.08675" aria-labelledby="pdf-2503.08675">pdf</a>
# a math_data folder
#!/usr/bin/env python3
import argparse
import os
import re
import time
import io
import tarfile
import requests
from tqdm import tqdm
def download_and_extract_source(paper_id, data_dir):
source_url = f"https://export.arxiv.org/src/{paper_id}"
print(f"Downloading source for {paper_id} from {source_url}...")
response = requests.get(source_url)
if response.status_code != 200:
print(f"Error downloading source for {paper_id}: HTTP {response.status_code}")
return False
# Try to open as a tar archive.
try:
file_obj = io.BytesIO(response.content)
with tarfile.open(fileobj=file_obj, mode='r:*') as tar:
# Filter for regular .tex files.
members = [m for m in tar.getmembers() if m.isfile() and m.name.endswith('.tex')]
print("Found TeX files:", [m.name for m in members])
if len(members) == 1:
member = members[0]
extracted = tar.extractfile(member)
if extracted is None:
print(f"Error extracting {paper_id}: Could not read the file from the archive.")
return False
content = extracted.read()
out_path = os.path.join(data_dir, f"{paper_id}.tex")
with open(out_path, "wb") as f:
f.write(content)
print(f"Saved tex source for {paper_id} as {out_path}")
return True
else:
print(f"Error: {paper_id} contains multiple .tex files or none. Skipping extraction.")
return False
except tarfile.ReadError:
# Not a tar archive; assume it's a single file.
out_path = os.path.join(data_dir, f"{paper_id}.tex")
with open(out_path, "wb") as f:
f.write(response.content)
print(f"Saved non-archive tex source for {paper_id} as {out_path}")
return True
def download_pdf(paper_id, data_dir):
pdf_url = f"https://arxiv.org/pdf/{paper_id}.pdf"
print(f"Downloading PDF for {paper_id} from {pdf_url}...")
response = requests.get(pdf_url)
if response.status_code != 200:
print(f"Error downloading PDF for {paper_id}: HTTP {response.status_code}")
return False
out_path = os.path.join(data_dir, f"{paper_id}.pdf")
with open(out_path, "wb") as f:
f.write(response.content)
print(f"Saved PDF for {paper_id} as {out_path}")
return True
def main():
parser = argparse.ArgumentParser(
description="Download and extract arXiv LaTeX source files and PDFs only if both succeed."
)
parser.add_argument(
"--url",
type=str,
default="https://arxiv.org/list/math/recent?skip=0&show=2000",
help="URL of the arXiv list page to scrape (default: %(default)s)"
)
parser.add_argument(
"--data_dir",
type=str,
default="math_data/pdfs",
help="Directory to save downloaded files (default: %(default)s)"
)
args = parser.parse_args()
if not os.path.exists(args.data_dir):
os.makedirs(args.data_dir)
print(f"Downloading list page from {args.url}...")
response = requests.get(args.url)
if response.status_code != 200:
print(f"Error downloading list page: HTTP {response.status_code}")
return
# Find all pdf links in the form: <a href="/pdf/2503.08675" ...>pdf</a>
pattern = re.compile(r'href="/pdf/(\d+\.\d+)"')
paper_ids = pattern.findall(response.text)
print(f"Found {len(paper_ids)} papers.")
# For each paper, only keep the files if both the tex extraction and pdf download succeed.
for paper_id in tqdm(paper_ids):
tex_success = download_and_extract_source(paper_id, args.data_dir)
if not tex_success:
print(f"Skipping PDF download for {paper_id} because tex extraction failed.")
continue
pdf_success = download_pdf(paper_id, args.data_dir)
if not pdf_success:
# Remove the tex file if the PDF download fails.
tex_path = os.path.join(args.data_dir, f"{paper_id}.tex")
if os.path.exists(tex_path):
os.remove(tex_path)
print(f"Removed tex file for {paper_id} because PDF download failed.")
time.sleep(1)
if __name__ == "__main__":
main()
import argparse
import base64
import os
import re
import time
from collections import Counter
from difflib import SequenceMatcher
import syntok.segmenter as segmenter
from google import genai
from google.genai import types
from olmocr.bench.tests import TextPresenceTest, save_tests
from olmocr.data.renderpdf import render_pdf_to_base64png
LABEL_WIDTH = 8 # fixed width for printing labels
# Uses a gemini prompt to get the most likely clean sentence from a pdf page
last_gemini_call = time.perf_counter()
def clean_base_sentence(pdf_path: str, page_num: int, base_sentence: str) -> str:
client = genai.Client(
api_key=os.environ.get("GEMINI_API_KEY"),
)
image_base64 = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=2048)
image_part = types.Part(inline_data=types.Blob(mime_type="image/png", data=base64.b64decode(image_base64)))
model = "gemini-2.0-flash-thinking-exp-01-21" # Consider using a more stable model for production
# model="gemini-2.0-flash-001"
contents = [
types.Content(
role="user",
parts=[
image_part,
types.Part.from_text(
text=f"""Base: {base_sentence}
Consider the sentence labeled "Base" above in the document image attached. What is the correct reading of this document within the image of the page? I need it to be exact down to the individual character and that's very important to get right. It needs to match the picture, not the provided text. Please just output the correct full sentence exactly how it appears in the document image and nothing else. You can merge hyphenated words back together, and don't output any new lines."""
),
],
),
]
generate_content_config = types.GenerateContentConfig(
temperature=0.7,
top_p=0.95,
top_k=64,
max_output_tokens=500,
response_mime_type="text/plain",
)
response = client.models.generate_content(
model=model,
contents=contents,
config=generate_content_config,
)
# Basic rate limitting
global last_gemini_call
if time.perf_counter() - last_gemini_call < 6:
time.sleep(6 - (time.perf_counter() - last_gemini_call))
last_gemini_call = time.perf_counter()
# Return response
if response is not None and response.candidates is not None and len(response.candidates) > 0:
return response.candidates[0].content.parts[0].text
else:
return None
def parse_sentences(text: str) -> list[str]:
"""
Splits a text into a list of sentence strings using syntok.
Preserves original spacing and punctuation.
"""
sentences = []
for paragraph in segmenter.process(text):
for sentence in paragraph:
# Reconstruct the sentence with original spacing
sentence_str = ""
for token in sentence:
sentence_str += token.spacing + token.value
# Trim any leading whitespace
sentence_str = sentence_str.lstrip()
sentences.append(sentence_str)
return sentences
def compare_votes_for_file(base_pdf_file: str, base_pdf_page: int, base_text: str, candidate_texts: list[str], max_diffs: int) -> None:
"""
For each sentence in the base text, finds the best matching sentence from
each candidate text (using a similarity threshold). If any candidate sentences
differ from the base sentence, collects that diff (base sentence plus variant
votes) for later printing. At the end, prints only the top N diffs (by total vote count)
for the file.
Comparison is case-insensitive, but output preserves original capitalization.
"""
base_sentences = parse_sentences(base_text)
# Parse all candidate texts into lists of sentences
candidate_sentences_list = [parse_sentences(ct) for ct in candidate_texts]
diffs = [] # list to hold diff entries
for b_sentence in base_sentences:
b_sentence = b_sentence.replace("\n", " ").strip()
votes = []
for c_sentences in candidate_sentences_list:
best_ratio = 0.0
best_candidate = None
# Find the candidate sentence with the highest similarity to b_sentence
# using case-insensitive comparison
for c_sentence in c_sentences:
ratio = SequenceMatcher(None, b_sentence.lower(), c_sentence.lower()).ratio()
if ratio > best_ratio:
best_ratio = ratio
best_candidate = c_sentence # Keep original capitalization for output
# Append the candidate if it passes the similarity threshold (e.g., 0.7)
if best_ratio > 0.5 and best_candidate is not None:
votes.append(best_candidate.strip())
# Only consider variants that differ when compared case-insensitively
variant_votes = [vote for vote in votes if vote.lower() != b_sentence.lower()]
if variant_votes:
diff_entry = {
"base": b_sentence,
"variants": Counter(variant_votes),
"vote_count": len(variant_votes),
}
diffs.append(diff_entry)
# Sort diffs by vote_count descending and take only the top max_diffs
diffs.sort(key=lambda d: d["vote_count"], reverse=True)
top_diffs = diffs[:max_diffs]
tests = []
for index, diff in enumerate(top_diffs):
base_sentence = diff["base"]
variant_counter = diff["variants"]
# Print base sentence using fixed-width label formatting
print(f"{'Base:':<{LABEL_WIDTH}} {base_sentence}")
print(f"{'Variants:':<{LABEL_WIDTH}}")
for variant, count in variant_counter.items():
label = f"{count}x:"
print(f"{label:<{LABEL_WIDTH}} {variant}")
# Get the clean version of the sentence
cleaned = clean_base_sentence(base_pdf_file, base_pdf_page, base_sentence)
print(f"{'Clean:':<{LABEL_WIDTH}} {cleaned}")
print("-" * 40)
if cleaned is None:
cleaned = base_sentence
tests.append(
TextPresenceTest(
pdf=os.path.basename(base_pdf_file),
page=base_pdf_page,
id=f"{os.path.basename(base_pdf_file).replace('.pdf', '')}_minediff_{index:02d}",
type="present",
threshold=1.0,
text=cleaned,
)
)
return tests
def get_pdf_from_md(md_path: str) -> str:
base = os.path.basename(md_path)
base = re.sub(r"_\d+\.md$", ".pdf", base)
return os.path.join(os.path.dirname(md_path), "..", "pdfs", base)
def main():
parser = argparse.ArgumentParser(description="Compares sentences from base and candidate texts, printing differences.")
parser.add_argument("--base", default=os.path.join(os.path.dirname(__file__), "chatgpt"), help="Path to the folder containing base .md files.")
parser.add_argument("--compare", default=os.path.join(os.path.dirname(__file__), "olmocr"), help="Path to the folder containing candidate .md files.")
parser.add_argument("--max-diffs", type=int, default=5, help="Maximum number of diffs to display per file.")
parser.add_argument(
"--output", default="mine_diffs_candidates.jsonl", type=str, help="Output of potential candidate test proposals, to be verified or added to dataset"
)
args = parser.parse_args()
base_path = args.base
compare_path = args.compare
max_diffs = args.max_diffs
# Collect all .md files from the base and compare folders
base_files = [f for f in os.listdir(base_path) if f.endswith(".md")]
all_tests = []
# Process each base file and print out the vote differences
for bf in base_files:
base_file_path = os.path.join(base_path, bf)
with open(base_file_path, "r", encoding="utf-8") as f:
base_text = f.read()
compare_files = [f for f in os.listdir(compare_path) if f.endswith(".md") and re.sub(r"_\d+\.md$", "", f) == re.sub(r"_\d+\.md$", "", bf)]
if not compare_files:
print(f"skipping {bf} nothing to compare against")
# Read all candidate texts at once
candidate_texts = []
for cf in compare_files:
with open(os.path.join(compare_path, cf), "r", encoding="utf-8") as f:
candidate_texts.append(f.read())
base_pdf_file = get_pdf_from_md(base_file_path)
base_pdf_page = 1
print(f"Results for base file: {bf}")
tests = compare_votes_for_file(base_pdf_file, base_pdf_page, base_text, candidate_texts, max_diffs)
all_tests.extend(tests)
print("")
# Output test candidates for review after each file, in case there are errors
save_tests(all_tests, args.output)
if __name__ == "__main__":
main()
#!/usr/bin/env python3
"""
mine_math.py - Extract and validate math equations from candidate files and TeX bases.
This upgraded version:
• Uses the Python logging module for cleaner logging.
• Uses tqdm to display a progress bar.
• Uses ProcessPoolExecutor to process TeX file groups in parallel.
• For each TeX file, shuffles its pages randomly and processes them one-by-one.
Once three pages return at least one equation each, further pages are skipped.
• Adds an argparse argument for the similarity threshold for matches.
• Saves JSONL outputs incrementally as each TeX file group is processed.
Usage:
python mine_math.py --math_data /path/to/math_data --candidate candidate_folder --output_file math_tests.jsonl
[--max_pages 3] [--parallel 8] [--sim_threshold 0.7]
"""
import argparse
import glob
import os
import re
import random
import json
import logging
from typing import List, Optional, Tuple, Dict
from concurrent.futures import ProcessPoolExecutor, as_completed
from fuzzysearch import find_near_matches
from rapidfuzz import fuzz
from tqdm import tqdm
from olmocr.bench.tests import MathTest # Assumes MathTest is JSON serializable or has __dict__
from olmocr.bench.tests import save_tests # Original saving function (not used for incremental save)
from olmocr.bench.katex.render import render_equation
import numpy as np
import numba
# --- Logging Setup ---
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# --- Utility Functions ---
def normalize_text(text: str) -> str:
"""Normalize text for better matching."""
text = re.sub(r'\s+', " ", text)
replacements = {
"'": "'",
"‚": "'",
'"': '"',
"„": '"',
"_": "_",
"–": "-", "—": "-", "‑": "-", "‒": "-"
}
for fancy_char, ascii_char in replacements.items():
text = text.replace(fancy_char, ascii_char)
return text
def extract_tex_content(tex_file: str) -> str:
"""Extract the content from a TeX file."""
try:
with open(tex_file, 'r', encoding='utf-8') as f:
return f.read()
except UnicodeDecodeError:
try:
with open(tex_file, 'r', encoding='latin-1') as f:
return f.read()
except Exception as e:
logging.error("Error reading %s: %s", tex_file, e)
return ""
def extract_candidate_content(candidate_file: str) -> str:
"""Extract the content from a candidate .md file."""
try:
with open(candidate_file, 'r', encoding='utf-8') as f:
return f.read()
except Exception as e:
logging.error("Error reading %s: %s", candidate_file, e)
return ""
def extract_math_from_tex(tex_content: str) -> List[Tuple[str, str]]:
"""
Extract math equations from TeX content.
Returns list of tuples (equation_type, equation_content)
"""
math_equations = []
# Patterns for display math
display_patterns = [
(r'\$\$(.*?)\$\$', '$$'),
(r'\\begin\{equation\}(.*?)\\end\{equation\}', 'equation'),
(r'\\begin\{equation\*\}(.*?)\\end\{equation\*\}', 'equation*'),
(r'\\begin\{align\}(.*?)\\end\{align\}', 'align'),
(r'\\begin\{align\*\}(.*?)\\end\{align\*\}', 'align*'),
(r'\\begin\{displaymath\}(.*?)\\end\{displaymath\}', 'displaymath'),
(r'\\\[(.*?)\\\]', 'displaymath')
]
# Patterns for inline math
inline_patterns = [
(r'\$(.*?)\$', 'inline'),
(r'\\\((.*?)\\\)', 'inline')
]
for pattern_list in [display_patterns, inline_patterns]:
for pattern, eq_type in pattern_list:
matches = re.finditer(pattern, tex_content, re.DOTALL)
for match in matches:
equation = match.group(1).strip()
if equation and not equation.isspace():
math_equations.append((eq_type, equation))
return math_equations
@numba.njit
def compute_dp(candidate_arr, text_arr):
m = candidate_arr.shape[0]
n = text_arr.shape[0]
dp = np.empty((m + 1, n + 1), dtype=np.int32)
# For empty candidate, cost is 0 (can match anywhere in text)
for j in range(n + 1):
dp[0, j] = 0
# When text is empty, need to delete all candidate characters.
for i in range(1, m + 1):
dp[i, 0] = i
for i in range(1, m + 1):
for j in range(1, n + 1):
cost = 0 if candidate_arr[i - 1] == text_arr[j - 1] else 1
dp[i, j] = min(dp[i - 1, j - 1] + cost, # substitution or match
dp[i - 1, j] + 1, # deletion (from candidate)
dp[i, j - 1] + 1) # insertion (in candidate)
return dp
@numba.njit
def find_best_end(dp, m, n):
best_distance = 1 << 30 # a large number
best_end = 0
for j in range(n + 1):
if dp[m, j] < best_distance:
best_distance = dp[m, j]
best_end = j
return best_end, best_distance
@numba.njit
def backtrack(dp, candidate_arr, text_arr, m, best_end):
i = m
j = best_end
while i > 0:
# Check for a diagonal move (match or substitution)
if j > 0 and dp[i, j] == dp[i - 1, j - 1] + (0 if candidate_arr[i - 1] == text_arr[j - 1] else 1):
i -= 1
j -= 1
elif dp[i, j] == dp[i - 1, j] + 1:
i -= 1
else:
j -= 1
return j # start index in text
def find_matching_content(candidate_text: str, tex_content: str, sim_threshold: float) -> Optional[str]:
"""
Find the substring of tex_content that most closely matches candidate_text using
dynamic programming accelerated by numba. Returns the matching substring if its
normalized similarity (1 - (edit_distance / len(candidate_text))) is above sim_threshold,
otherwise returns None.
"""
candidate_norm = normalize_text(candidate_text)
tex_norm = normalize_text(tex_content)
m = len(candidate_norm)
n = len(tex_norm)
if m == 0 or n == 0:
return None
# Convert strings to numpy arrays of integer character codes.
candidate_arr = np.empty(m, dtype=np.int32)
for i, c in enumerate(candidate_norm):
candidate_arr[i] = ord(c)
text_arr = np.empty(n, dtype=np.int32)
for j, c in enumerate(tex_norm):
text_arr[j] = ord(c)
dp = compute_dp(candidate_arr, text_arr)
best_end, min_distance = find_best_end(dp, m, n)
similarity = (m - min_distance) / m
logging.info("Similarity: %.3f", similarity)
if similarity < sim_threshold:
return None
start_index = backtrack(dp, candidate_arr, text_arr, m, best_end)
return tex_norm[start_index:best_end]
def parse_candidate_filename(filename: str) -> Optional[Tuple[str, int]]:
"""
Parse candidate filename in the format: [tex file basename]_pg[pagenum]_repeat1.md
Returns tuple (tex_basename, page_num) or None if the format doesn't match.
"""
basename = os.path.basename(filename)
match = re.match(r"(.+)_pg(\d+)_repeat\d+\.md$", basename)
if match:
tex_basename = match.group(1)
page_num = int(match.group(2))
return tex_basename, page_num
return None
def validate_equation(equation: str) -> bool:
"""
Validate that an equation renders correctly with KaTeX.
Returns True if the equation is valid, False otherwise.
"""
rendered = render_equation(equation)
return rendered is not None
def process_candidate_file(candidate_file: str, pdfs_folder: str, sim_threshold: float) -> List[MathTest]:
"""
Process a single candidate file.
Returns a list of MathTest objects extracted from the corresponding TeX file.
"""
logging.info("Processing %s", candidate_file)
tests = []
parse_result = parse_candidate_filename(candidate_file)
if not parse_result:
logging.error("Filename %s does not match expected format.", candidate_file)
return tests
tex_basename, page_num = parse_result
tex_file_path = os.path.join(pdfs_folder, f"{tex_basename}.tex")
if not os.path.exists(tex_file_path):
logging.error("TeX file %s not found for candidate %s.", tex_file_path, candidate_file)
return tests
candidate_text = extract_candidate_content(candidate_file)
tex_content = extract_tex_content(tex_file_path)
if not tex_content:
logging.error("No content extracted from %s", tex_file_path)
return tests
matching_tex = find_matching_content(candidate_text, tex_content, sim_threshold)
if not matching_tex:
logging.warning("No matching TeX content found in %s for candidate %s", tex_file_path, candidate_file)
return tests
logging.debug("Matching TeX content: %s", matching_tex)
math_equations = extract_math_from_tex(matching_tex)
if not math_equations:
logging.warning("No math equations found in matching content for candidate %s", candidate_file)
return tests
# Filter out equations that are too short, remove duplicates, and shuffle
math_equations = [(eq_type, eq.strip()) for (eq_type, eq) in math_equations if len(eq.strip()) > 20]
math_equations = list(set(math_equations))
random.shuffle(math_equations)
for i, (eq_type, equation) in enumerate(math_equations):
if validate_equation(equation):
test_id = f"{tex_basename}_pg{page_num}_math_{i:03d}"
math_test = MathTest(
id=test_id,
pdf=f"{tex_basename}.pdf",
page=page_num,
type="math",
math=equation,
)
tests.append(math_test)
if len(tests) >= 10:
break
return tests
def process_tex_file_group(tex_basename: str, candidate_files: List[str], pdfs_folder: str,
sim_threshold: float, max_pages: int) -> List[MathTest]:
"""
For a given TeX file, group candidate files by page, randomly shuffle the pages,
and process them one-by-one. Stop once max_pages (pages with valid equations) have
been processed.
"""
tests = []
valid_pages = set()
# Group candidate files by page number.
page_dict: Dict[int, List[str]] = {}
for candidate_file in candidate_files:
parse_result = parse_candidate_filename(candidate_file)
if not parse_result:
continue
_, page_num = parse_result
page_dict.setdefault(page_num, []).append(candidate_file)
# For each page, randomly choose one candidate file.
distinct_candidate_files = []
for page_num, files in page_dict.items():
chosen_file = random.choice(files)
distinct_candidate_files.append(chosen_file)
# Shuffle the pages randomly.
random.shuffle(distinct_candidate_files)
# Process pages sequentially until max_pages with valid equations have been found.
for candidate_file in distinct_candidate_files:
result = process_candidate_file(candidate_file, pdfs_folder, sim_threshold)
if result:
tests.extend(result)
# Mark this page as valid.
page_num = parse_candidate_filename(candidate_file)[1]
valid_pages.add(page_num)
if len(valid_pages) >= max_pages:
break
return tests
def main():
parser = argparse.ArgumentParser(
description="Extract math equations from candidate files and corresponding TeX bases."
)
parser.add_argument("--math_data", required=True, help="Path to math_data folder")
parser.add_argument("--candidate", required=True, help="Candidate folder name inside math_data")
parser.add_argument("--max_pages", type=int, default=3, help="Maximum distinct pages with equations to process per TeX document")
parser.add_argument("--parallel", type=int, default=8, help="Maximum process pool workers")
parser.add_argument("--sim_threshold", type=float, default=0.7, help="Similarity threshold for matching candidate text")
args = parser.parse_args()
candidate_folder = os.path.join(args.math_data, args.candidate)
pdfs_folder = os.path.join(args.math_data, "pdfs")
candidate_files = glob.glob(os.path.join(candidate_folder, "*.md"))
logging.info("Found %d candidate files.", len(candidate_files))
# Group candidate files by TeX basename.
tex_groups: Dict[str, List[str]] = {}
for candidate_file in candidate_files:
parse_result = parse_candidate_filename(candidate_file)
if not parse_result:
continue
tex_basename, _ = parse_result
tex_groups.setdefault(tex_basename, []).append(candidate_file)
logging.info("Found %d TeX groups.", len(tex_groups))
# Remove output file if it exists to start fresh
output_file = os.path.join(args.math_data, "math_tests.jsonl")
if os.path.exists(output_file):
os.remove(output_file)
all_math_tests = []
# Process each TeX group in parallel using ProcessPoolExecutor.
with ProcessPoolExecutor(max_workers=args.parallel) as executor:
future_to_tex = {
executor.submit(process_tex_file_group, tex_basename, candidate_list, pdfs_folder,
args.sim_threshold, args.max_pages): tex_basename
for tex_basename, candidate_list in tex_groups.items()
}
for future in tqdm(as_completed(future_to_tex), total=len(future_to_tex), desc="Processing TeX files"):
tex_basename = future_to_tex[future]
try:
tests = future.result()
all_math_tests.extend(tests)
# Incrementally save tests as each TeX group finishes processing.
save_tests(all_math_tests, output_file)
except Exception as e:
logging.error("Error processing TeX group %s: %s", tex_basename, e)
logging.info("Found %d valid math equations from %d TeX groups.", len(all_math_tests), len(tex_groups))
logging.info("Results incrementally saved to %s", output_file)
if __name__ == "__main__":
main()
import re
from dataclasses import dataclass
from typing import Optional
def claude_response_format_schema() -> dict:
return (
{
"name": "page_response",
"description": "Extracts text from pdf's.",
"input_schema": {
"type": "object",
"properties": {
"primary_language": {
"type": ["string", "null"],
"description": "The primary language of the text using two-letter codes or null if there is no text at all that you think you should read.",
},
"is_rotation_valid": {
"type": "boolean",
"description": "Is this page oriented correctly for reading? Answer only considering the textual content, do not factor in the rotation of any charts, tables, drawings, or figures.",
},
"rotation_correction": {
"type": "integer",
"description": "Indicates the degree of clockwise rotation needed if the page is not oriented correctly.",
"enum": [0, 90, 180, 270],
"default": 0,
},
"is_table": {
"type": "boolean",
"description": "Indicates if the majority of the page content is in tabular format.",
},
"is_diagram": {
"type": "boolean",
"description": "Indicates if the majority of the page content is a visual diagram.",
},
"natural_text": {
"type": ["string", "null"],
"description": "The natural text content extracted from the page.",
},
},
"required": [
"primary_language",
"is_rotation_valid",
"rotation_correction",
"is_table",
"is_diagram",
"natural_text",
],
},
},
)
def gemini_response_format_schema() -> dict:
return (
{
"type": "OBJECT",
"properties": {
"primary_language": {
"type": "STRING",
"description": "The primary language of the text using two-letter codes or null if there is no text at all that you think you should read.",
},
"is_rotation_valid": {
"type": "BOOL",
"description": "Is this page oriented correctly for reading? Answer only considering the textual content, do not factor in the rotation of any charts, tables, drawings, or figures.",
},
"rotation_correction": {
"type": "INTEGER",
"enum": [0, 90, 180, 270],
"description": "Indicates the degree of clockwise rotation needed if the page is not oriented correctly.",
},
"is_table": {"type": "BOOL", "description": "Indicates if the majority of the page content is in tabular format."},
"is_diagram": {"type": "BOOL", "description": "Indicates if the majority of the page content is a visual diagram."},
"natural_text": {"type": "STRING", "description": "The natural text content extracted from the page."},
},
"required": ["primary_language", "is_rotation_valid", "rotation_correction", "is_table", "is_diagram", "natural_text"],
"propertyOrdering": ["primary_language", "is_rotation_valid", "rotation_correction", "is_table", "is_diagram", "natural_text"],
},
)
def build_find_difference_prompt(base_text: str) -> str:
return (
f"Below is an image of a document page, along with raw textual content previously extracted using different models."
f"Your goal is to carefully identify the differences between the extracted texts from both models and determine which one is more accurate by comparing them with the image."
f"Only return the differences and specify which model extracted the text with higher accuracy.\n"
f"Do not hallucinate.\n"
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
)
import json
import os
from openai import OpenAI
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts.anchor import get_anchor_text
from olmocr.prompts.prompts import (
PageResponse,
build_openai_silver_data_prompt,
openai_response_format_schema,
)
def run_chatgpt(pdf_path: str, page_num: int = 1, model: str = "gpt-4o-2024-08-06", temperature: float = 0.1) -> str:
"""
Convert page of a PDF file to markdown using the commercial openAI APIs.
See run_server.py for running against an openai compatible server
Args:
pdf_path (str): The local path to the PDF file.
Returns:
str: The OCR result in markdown format.
"""
# Convert the first page of the PDF to a base64-encoded PNG image.
image_base64 = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=2048)
anchor_text = get_anchor_text(pdf_path, page_num, pdf_engine="pdfreport")
if not os.getenv("OPENAI_API_KEY"):
raise SystemExit("You must specify an OPENAI_API_KEY")
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
response = client.chat.completions.create(
model=model,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": build_openai_silver_data_prompt(anchor_text)},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
],
}
],
temperature=temperature,
max_tokens=3000,
response_format=openai_response_format_schema(),
)
raw_response = response.choices[0].message.content
assert len(response.choices) > 0
assert response.choices[0].message.refusal is None
assert response.choices[0].finish_reason == "stop"
data = json.loads(raw_response)
data = PageResponse(**data)
return data.natural_text
import json
import os
from anthropic import Anthropic
from prompts import (
build_openai_silver_data_prompt,
claude_response_format_schema,
)
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts.anchor import get_anchor_text
def run_claude(pdf_path: str, page_num: int = 1, model: str = "claude-3-7-sonnet-20250219", temperature: float = 0.1) -> str:
"""
Convert page of a PDF file to markdown using Claude OCR.
This function renders the specified page of the PDF to an image, runs OCR on that image,
and returns the OCR result as a markdown-formatted string.
Args:
pdf_path (str): The local path to the PDF file.
page_num (int): The page number to process (starting from 1).
model (str): The Claude model to use.
temperature (float): The temperature parameter for generation.
Returns:
str: The OCR result in markdown format.
"""
if not os.getenv("ANTHROPIC_API_KEY"):
raise SystemExit("You must specify an ANTHROPIC_API_KEY")
image_base64 = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=2048)
anchor_text = get_anchor_text(pdf_path, page_num, pdf_engine="pdfreport")
client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY"))
response = client.messages.create(
model=model,
max_tokens=3000,
temperature=temperature,
# system=system_prompt,
tools=claude_response_format_schema(),
messages=[
{
"role": "user",
"content": [
{"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": image_base64}},
{
"type": "text",
"text": f"{build_openai_silver_data_prompt(anchor_text)}. Use the page_response tool to respond. If the propeties are true, then extract the text from them and respond in natural_text.",
},
],
}
],
)
json_sentiment = None
for content in response.content:
if content.type == "tool_use" and content.name == "page_response":
json_sentiment = content.input
break
if json_sentiment:
response = json.dumps(json_sentiment, indent=2)
return response
import base64
import os
from google.ai import generativelanguage as glm
from google.api_core import client_options
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts.anchor import get_anchor_text
from olmocr.prompts.prompts import build_openai_silver_data_prompt
def run_gemini(pdf_path: str, page_num: int = 1, model: str = "gemini-2.0-flash", temperature: float = 0.1) -> str:
"""
Convert page of a PDF file to markdown using Gemini's vision capabilities.
This function renders the specified page of the PDF to an image, runs OCR on that image,
and returns the OCR result as a markdown-formatted string.
Args:
pdf_path (str): The local path to the PDF file.
page_num (int): The page number to process (starting from 1).
model (str): The Gemini model to use.
temperature (float): The temperature parameter for generation.
Returns:
str: The OCR result in markdown format.
"""
if not os.getenv("GEMINI_API_KEY"):
raise SystemExit("You must specify an GEMINI_API_KEY")
image_base64 = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=2048)
anchor_text = get_anchor_text(pdf_path, page_num, pdf_engine="pdfreport")
api_key = os.getenv("GEMINI_API_KEY")
client = glm.GenerativeServiceClient(
client_options=client_options.ClientOptions(
api_key=api_key,
),
)
image_part = glm.Part(inline_data=glm.Blob(mime_type="image/png", data=base64.b64decode(image_base64)))
text_part = glm.Part(text=f"""{build_openai_silver_data_prompt(anchor_text)}""")
generation_config = glm.GenerationConfig(
temperature=temperature,
top_p=1.0,
top_k=32,
max_output_tokens=4096,
)
# response_schema = gemini_response_format_schema()
request = glm.GenerateContentRequest(
model=f"models/{model}",
contents=[glm.Content(parts=[image_part, text_part])],
generation_config=generation_config,
)
# request = glm.GenerateContentRequest(
# model=f"models/{model}",
# contents=[glm.Content(parts=[image_part, text_part])],
# generation_config=generation_config,
# tools=[
# glm.Tool(
# function_declarations=[
# glm.FunctionDeclaration(
# name="page_response",
# parameters=response_schema
# )
# ]
# )
# ],
# tool_config=glm.ToolConfig(
# function_calling_config=glm.FunctionCallingConfig(
# mode="any",
# allowed_function_names=["page_response"]
# )
# )
# )
response = client.generate_content(request)
assert len(response.candidates) > 0, "No candidates found"
assert response.candidates[0].finish_reason == glm.Candidate.FinishReason.STOP, "Finish reason was not STOP, likely a processing error or repetition failure"
result = response.candidates[0].content.parts[0].text
return result
import base64
import os
import tempfile
import torch
from transformers import AutoModel, AutoTokenizer
from olmocr.data.renderpdf import render_pdf_to_base64png
# Global cache for the model and tokenizer.
_device = "cuda" if torch.cuda.is_available() else "cpu"
_model = None
_tokenizer = None
def load_model():
"""
Load the GOT-OCR model and tokenizer if they haven't been loaded already.
Returns:
model: The GOT-OCR model loaded on the appropriate device.
tokenizer: The corresponding tokenizer.
"""
global _model, _tokenizer
if _model is None or _tokenizer is None:
_tokenizer = AutoTokenizer.from_pretrained("ucaslcl/GOT-OCR2_0", trust_remote_code=True)
_model = AutoModel.from_pretrained(
"ucaslcl/GOT-OCR2_0",
trust_remote_code=True,
use_safetensors=True,
revision="979938bf89ccdc949c0131ddd3841e24578a4742",
pad_token_id=_tokenizer.eos_token_id,
)
_model = _model.eval().to(_device)
return _model, _tokenizer
def run_gotocr(pdf_path: str, page_num: int = 1, ocr_type: str = "ocr") -> str:
"""
Convert page of a PDF file to markdown using GOT-OCR.
This function renders the first page of the PDF to an image, runs OCR on that image,
and returns the OCR result as a markdown-formatted string.
Args:
pdf_path (str): The local path to the PDF file.
Returns:
str: The OCR result in markdown format.
"""
# Ensure the model is loaded (cached across calls)
model, tokenizer = load_model()
# Convert the first page of the PDF to a base64-encoded PNG image.
base64image = render_pdf_to_base64png(pdf_path, page_num=page_num, target_longest_image_dim=1024)
# Write the image to a temporary file.
with tempfile.NamedTemporaryFile("wb", suffix=".png", delete=False) as tmp:
tmp.write(base64.b64decode(base64image))
tmp_filename = tmp.name
# Run GOT-OCR on the saved image.
result = model.chat(tokenizer, tmp_filename, ocr_type=ocr_type)
# Clean up the temporary file.
os.remove(tmp_filename)
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment