Commit 89e60e48 authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
Pipeline #2484 canceled with stages
import os
import random
import tempfile
from concurrent.futures import ThreadPoolExecutor
from difflib import SequenceMatcher
from urllib.parse import urlparse
import boto3
from jinja2 import Template
from tqdm import tqdm
from olmocr.data.renderpdf import render_pdf_to_base64png
session = boto3.Session(profile_name="s2")
s3_client = session.client("s3")
def generate_diff_html(a, b):
"""
Generates HTML with differences between strings a and b.
Additions in 'b' are highlighted in green, deletions from 'a' are highlighted in red.
"""
seq_matcher = SequenceMatcher(None, a, b)
output_html = ""
for opcode, a0, a1, b0, b1 in seq_matcher.get_opcodes():
if opcode == "equal":
output_html += a[a0:a1]
elif opcode == "insert":
output_html += f"<span class='added'>{b[b0:b1]}</span>"
elif opcode == "delete":
output_html += f"<span class='removed'>{a[a0:a1]}</span>"
elif opcode == "replace":
output_html += f"<span class='removed'>{a[a0:a1]}</span><span class='added'>{b[b0:b1]}</span>"
return output_html
def process_entry(i, entry):
# Randomly decide whether to display gold on the left or right
if random.choice([True, False]):
left_text, right_text = entry["gold_text"], entry["eval_text"]
left_class, right_class = "gold", "eval"
left_metadata, right_metadata = entry.get("gold_metadata", ""), entry.get("eval_metadata", "")
else:
left_text, right_text = entry["eval_text"], entry["gold_text"]
left_class, right_class = "eval", "gold"
left_metadata, right_metadata = entry.get("eval_metadata", ""), entry.get("gold_metadata", "")
# Generate diff for right_text compared to left_text
diff_html = generate_diff_html(left_text, right_text)
left_text = "<p>" + left_text.replace("\n", "</p><p>") + "</p>"
right_text = "<p>" + right_text.replace("\n", "</p><p>") + "</p>"
diff_html = "<p>" + diff_html.replace("\n", "</p><p>") + "</p>"
parsed_url = urlparse(entry["s3_path"])
bucket = parsed_url.netloc
s3_key = parsed_url.path.lstrip("/")
signed_pdf_link = s3_client.generate_presigned_url("get_object", Params={"Bucket": bucket, "Key": s3_key}, ExpiresIn=604800)
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp_pdf:
pdf_path = tmp_pdf.name
bucket, key = entry["s3_path"].replace("s3://", "").split("/", 1)
s3_client.download_file(bucket, key, pdf_path)
page_image_base64 = render_pdf_to_base64png(tmp_pdf.name, entry["page"], target_longest_image_dim=1024)
return {
"entry_id": i,
"page_image": page_image_base64,
"s3_path": entry["s3_path"],
"page": entry["page"],
"key": entry.get("entry_key", entry["s3_path"] + "_" + str(entry["page"])),
"alignment": entry["alignment"],
"signed_pdf_link": signed_pdf_link,
"left_metadata": left_metadata,
"right_metadata": right_metadata,
"left_text": left_text,
"right_text": right_text,
"diff_text": diff_html,
"left_class": left_class,
"right_class": right_class,
"gold_class": "gold" if left_class == "gold" else "eval",
"eval_class": "eval" if right_class == "eval" else "gold",
}
def create_review_html(data, filename="review_page.html"):
# Load the Jinja2 template from the file
template_path = os.path.join(os.path.dirname(__file__), "evalhtml_template.html")
with open(template_path, "r") as f:
template = Template(f.read())
entries = []
with ThreadPoolExecutor() as executor:
# Submit tasks to the executor
futures = [executor.submit(process_entry, i, entry) for i, entry in enumerate(data)]
# Process the results as they are completed
for future in tqdm(futures):
entries.append(future.result())
# Render the template with the entries
final_html = template.render(entries=entries)
# Write the HTML content to the specified file
with open(filename, "w") as f:
f.write(final_html)
print(f"HTML file '{filename}' created successfully!")
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Text Evaluation Review</title>
<style>
body {
font-family: Arial, sans-serif;
background-color: #f9f9f9;
margin: 0;
padding: 20px;
}
h1 {
text-align: center;
font-size: 2em;
color: #333;
}
.container {
width: 100%;
max-width: 1600px;
margin: 0 auto;
}
.entry {
display: grid;
grid-template-columns: 1fr 1fr 1fr;
grid-gap: 20px;
margin-bottom: 20px;
padding: 20px;
background-color: #fff;
border-radius: 8px;
box-shadow: 0 0 10px rgba(0,0,0,0.1);
transition: background-color 0.3s ease;
}
.text-block {
padding: 10px;
background-color: #f1f1f1;
border-radius: 6px;
min-height: 100px;
display: flex;
flex-direction: column;
justify-content: space-between;
cursor: pointer;
position: relative;
}
.text-block:hover {
background-color: #e0e0e0;
}
.text-block.selected {
background-color: lightgreen;
border: 2px solid black;
}
.alignment {
font-size: 0.9em;
color: #777;
margin-top: 10px;
}
.reveal-box {
position: fixed;
top: 20px;
right: 20px;
padding: 15px;
background-color: white;
border: 1px solid #ccc;
border-radius: 8px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
z-index: 1000;
width: 200px;
}
.reveal-box input {
margin-right: 10px;
}
.reveal-info {
margin-top: 10px;
font-size: 0.9em;
color: #333;
}
.revealed .gold {
background-color: #fff9e6;
}
.revealed .eval {
background-color: #e6f3ff;
}
.revealed .text-block.selected {
border: 2px solid black;
}
.entry > div:first-child img {
width: 100%;
height: auto;
object-fit: cover;
border-radius: 6px;
cursor: pointer;
transition: transform 0.3s ease, box-shadow 0.3s ease;
}
/* Full-screen preview mode */
.full-screen img {
position: fixed;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
width: unset !important;
max-width: 90vw;
height: auto;
max-height: 90vh;
z-index: 1001;
box-shadow: 0 0 20px rgba(0, 0, 0, 0.5);
}
.overlay {
position: fixed;
top: 0;
left: 0;
width: 100vw;
height: 100vh;
background-color: rgba(0, 0, 0, 0.7);
z-index: 1000;
display: none;
}
.overlay.active {
display: block;
}
/* Voting Buttons Styles */
.voting-buttons {
margin-top: 10px;
display: flex;
flex-direction: column;
gap: 5px;
}
.voting-buttons button {
padding: 8px 12px;
border: none;
border-radius: 4px;
cursor: pointer;
background-color: #007BFF;
color: white;
transition: background-color 0.3s ease, border 0.3s ease;
}
.voting-buttons button:hover {
background-color: #0056b3;
}
.voting-buttons button.invalid {
background-color: #dc3545;
}
.voting-buttons button.invalid:hover {
background-color: #a71d2a;
}
.voting-buttons button.both-good {
background-color: #28a745;
}
.voting-buttons button.both-good:hover {
background-color: #1e7e34;
}
.voting-buttons button.both-bad {
background-color: #ffc107;
color: #212529;
}
.voting-buttons button.both-bad:hover {
background-color: #e0a800;
}
/* Selected State for Voting Buttons */
.voting-buttons button.selected {
border: 3px solid #000;
}
/* for diffs */
.added {
background-color: #d4fcdc;
}
.removed {
background-color: #fcd4d4;
text-decoration: line-through;
}
/* Diff Toggle Styles */
body.diffed .right-text {
display: none;
}
body.diffed .diff-text {
display: block;
}
.diff-text {
display: none;
}
</style>
</head>
<body>
<h1>Text Evaluation Review</h1>
<!-- Floating Reveal Box -->
<div class="reveal-box">
<div>
<input type="checkbox" id="diff-toggle" />
<label for="diff-toggle">Toggle diff</label>
</div>
<div>
<input type="checkbox" id="reveal-toggle" />
<label for="reveal-toggle">Reveal Gold/Eval</label>
</div>
<div class="reveal-info" id="vote-info">Votes</div>
</div>
<div class="container">
{% for entry in entries %}
<div class="entry {{ entry.gold_class }} {{ entry.eval_class }}" data-entry-id="{{ entry.key }}" data-left-metadata="{{ entry.left_metadata }}" data-right-metadata="{{ entry.right_metadata }}">
<div class="image-container">
<img src="data:image/png;base64,{{ entry.page_image }}" alt="Render">
<div class="alignment">Alignment: {{ entry.alignment }}</div>
<a href="{{entry.signed_pdf_link}}#page={{ entry.page }}" target="_blank">{{ entry.s3_path }} (Page {{ entry.page }})</a>
<!-- Voting Buttons -->
<div class="voting-buttons">
<button class="both-good" data-vote="both_good">Both Good</button>
<button class="both-bad" data-vote="both_bad">Both Bad</button>
<button class="invalid" data-vote="invalid_pdf">Invalid PDF</button>
</div>
</div>
<div class="text-block {{ entry.left_class }}" data-choice="left">
<div>{{ entry.left_text|safe }}</div>
</div>
<!-- Updated Right Text-Block with separate divs for right_text and diff_text -->
<div class="text-block {{ entry.right_class }}" data-choice="right">
<div class="right-text">{{ entry.right_text|safe }}</div>
<div class="diff-text">{{ entry.diff_text|safe }}</div>
</div>
</div>
{% endfor %}
</div>
<!-- Overlay for full-screen preview -->
<div class="overlay"></div>
<script>
document.addEventListener('DOMContentLoaded', () => {
fetchDataAndUpdatePage();
// Toggle the full-screen image preview
const overlay = document.querySelector('.overlay');
document.querySelectorAll('.image-container img').forEach(img => {
img.addEventListener('click', () => {
const entry = img.closest('.entry');
entry.classList.toggle('full-screen');
overlay.classList.toggle('active');
});
});
overlay.addEventListener('click', () => {
document.querySelectorAll('.full-screen').forEach(entry => {
entry.classList.remove('full-screen');
});
overlay.classList.remove('active');
});
// Handle Reveal Gold/Eval Toggle
document.getElementById('reveal-toggle').addEventListener('change', (e) => {
document.body.classList.toggle('revealed', e.target.checked);
updateReveal();
});
// Handle Diff Toggle
document.getElementById('diff-toggle').addEventListener('change', (e) => {
document.body.classList.toggle('diffed', e.target.checked);
toggleDiff(e.target.checked);
});
// Handle text-block selections
document.querySelectorAll('.text-block').forEach(block => {
block.addEventListener('click', () => selectChoice(block));
});
// Handle voting buttons
document.querySelectorAll('.voting-buttons button').forEach(button => {
button.addEventListener('click', () => handleVote(button));
});
});
// Utility function to sanitize s3_path for use as a key
function sanitizeKey(key) {
return key.replace(/[^a-zA-Z0-9-_]/g, '_');
}
async function fetchDataAndUpdatePage() {
let datastore = await fetchDatastore();
document.querySelectorAll('.entry').forEach(entry => {
const entryKey = sanitizeKey(entry.getAttribute('data-entry-id'));
const leftBlock = entry.querySelector('.text-block[data-choice="left"]');
const rightBlock = entry.querySelector('.text-block[data-choice="right"]');
const voteButtons = entry.querySelectorAll('.voting-buttons button');
if (datastore[entryKey]) {
const choice = datastore[entryKey];
if (choice === 'left' || choice === 'right') {
const selectedBlock = choice === 'left' ? leftBlock : rightBlock;
selectChoice(selectedBlock, false);
} else {
// Handle additional voting choices
handleAdditionalVote(entry, choice, false);
}
}
});
// Ensure diff state is consistent on load
const diffToggle = document.getElementById('diff-toggle');
toggleDiff(diffToggle.checked);
}
async function selectChoice(block, save = true) {
let datastore = await fetchDatastore();
const entry = block.closest('.entry');
const entryKey = sanitizeKey(entry.getAttribute('data-entry-id'));
const blocks = entry.querySelectorAll('.text-block');
blocks.forEach(b => b.classList.remove('selected'));
block.classList.add('selected');
datastore[entryKey] = block.getAttribute('data-choice');
const numVotes = Object.keys(datastore).length;
document.getElementById("vote-info").innerText = `Total Votes: ${numVotes}`;
if (save) {
putDatastore(datastore); // Save entire datastore
}
}
async function handleVote(button, save = true) {
let datastore = await fetchDatastore();
const entry = button.closest('.entry');
const entryKey = sanitizeKey(entry.getAttribute('data-entry-id'));
const choice = button.getAttribute('data-vote');
// Deselect any selected voting buttons within this entry
const voteButtons = entry.querySelectorAll('.voting-buttons button');
voteButtons.forEach(b => b.classList.remove('selected'));
// Select the clicked button
button.classList.add('selected');
// Deselect any selected text-blocks
const textBlocks = entry.querySelectorAll('.text-block');
textBlocks.forEach(b => b.classList.remove('selected'));
datastore[entryKey] = choice;
const numVotes = Object.keys(datastore).length;
document.getElementById("vote-info").innerText = `Total Votes: ${numVotes}`;
if (save) {
putDatastore(datastore); // Save entire datastore
}
}
async function handleAdditionalVote(entry, choice, save = true) {
let datastore = await fetchDatastore();
const entryKey = sanitizeKey(entry.getAttribute('data-entry-id'));
// Select the appropriate voting button based on the choice
const voteButton = entry.querySelector(`.voting-buttons button[data-vote="${choice}"]`);
if (voteButton) {
// Deselect other voting buttons
const voteButtons = entry.querySelectorAll('.voting-buttons button');
voteButtons.forEach(b => b.classList.remove('selected'));
// Select the current button
voteButton.classList.add('selected');
}
datastore[entryKey] = choice;
if (save) {
putDatastore(datastore);
}
}
async function updateReveal() {
let datastore = await fetchDatastore();
let goldVotes = 0;
let evalVotes = 0;
let bothGoodVotes = 0;
let bothBadVotes = 0;
let invalidPdfVotes = 0;
document.querySelectorAll('.entry').forEach(entry => {
const entryKey = sanitizeKey(entry.getAttribute('data-entry-id'));
const leftBlock = entry.querySelector('.text-block[data-choice="left"]');
const rightBlock = entry.querySelector('.text-block[data-choice="right"]');
const vote = datastore[entryKey];
if (vote === 'left') {
if (leftBlock.classList.contains('gold')) {
goldVotes++;
} else {
evalVotes++;
}
} else if (vote === 'right') {
if (rightBlock.classList.contains('gold')) {
goldVotes++;
} else {
evalVotes++;
}
} else if (vote === 'both_good') {
bothGoodVotes++;
} else if (vote === 'both_bad') {
bothBadVotes++;
} else if (vote === 'invalid_pdf') {
invalidPdfVotes++;
}
});
const totalVotes = goldVotes + evalVotes + bothGoodVotes + bothBadVotes + invalidPdfVotes;
const goldPercentage = totalVotes > 0 ? Math.round((goldVotes / totalVotes) * 100) : 0;
const evalPercentage = totalVotes > 0 ? Math.round((evalVotes / totalVotes) * 100) : 0;
const bothGoodPercentage = totalVotes > 0 ? Math.round((bothGoodVotes / totalVotes) * 100) : 0;
const bothBadPercentage = totalVotes > 0 ? Math.round((bothBadVotes / totalVotes) * 100) : 0;
const invalidPdfPercentage = totalVotes > 0 ? Math.round((invalidPdfVotes / totalVotes) * 100) : 0;
document.getElementById("vote-info").innerText = `Gold: ${goldPercentage}% | Eval: ${evalPercentage}% | Both Good: ${bothGoodPercentage}% | Both Bad: ${bothBadPercentage}% | Invalid PDF: ${invalidPdfPercentage}%`;
document.querySelectorAll('.entry').forEach(entry => {
const entryKey = sanitizeKey(entry.getAttribute('data-entry-id'));
const vote = datastore[entryKey];
if (vote === 'left' || vote === 'right') {
const selectedBlock = vote === 'left' ? entry.querySelector('.text-block[data-choice="left"]') : entry.querySelector('.text-block[data-choice="right"]');
selectedBlock.classList.add('selected');
}
// Additional votes already handled in handleAdditionalVote
});
}
// Function to toggle diff text
function toggleDiff(isDiffed) {
if (isDiffed) {
document.body.classList.add('diffed');
} else {
document.body.classList.remove('diffed');
}
}
</script>
</body>
</html>
# This script will build a set of scores for the accuracy of a given pdf conversion tactic against a gold dataset
import argparse
import hashlib
import json
import logging
import os
import random
import sys
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional
import boto3
import zstandard
from smart_open import register_compressor, smart_open
from tqdm import tqdm
from .dolma_refine.aligners import HirschbergAligner
from .dolma_refine.metrics import DocumentEditSimilarity
from .dolma_refine.segmenters import SpacySegmenter
from .evalhtml import create_review_html
logging.getLogger("pypdf").setLevel(logging.ERROR)
CACHE_DIR = os.path.join(Path.home(), ".cache", "pdf_gold_data_cache")
s3_client = boto3.client("s3")
def _handle_zst(file_obj, mode):
return zstandard.open(file_obj, mode)
register_compressor(".zstd", _handle_zst)
register_compressor(".zst", _handle_zst)
# Helper function to download files from S3
def download_from_s3(s3_path: str, local_path: str):
bucket_name, key = s3_path.replace("s3://", "").split("/", 1)
s3_client.download_file(bucket_name, key, local_path)
def is_debugging():
return sys.gettrace() is not None
# Create a hash to store file contents and check for changes
def compute_file_hash(file_path: str) -> str:
hash_md5 = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
# A single method which can take in any format json entry (openai regular, openai structured, birr)
# and normalize it to a common structure for use later in the
@dataclass(frozen=True)
class NormalizedEntry:
s3_path: str
pagenum: int
text: Optional[str]
finish_reason: Optional[str]
error: Optional[str] = None
@staticmethod
def from_goldkey(goldkey: str, **kwargs):
s3_path = goldkey[: goldkey.rindex("-")]
page_num = int(goldkey[goldkey.rindex("-") + 1 :])
return NormalizedEntry(s3_path, page_num, **kwargs)
@property
def goldkey(self):
return f"{self.s3_path}-{self.pagenum}"
def normalize_json_entry(data: dict) -> NormalizedEntry:
if "outputs" in data:
# Birr case
if data["outputs"] is None:
text = None
finish_reason = None
else:
text = data["outputs"][0]["text"]
finish_reason = data["outputs"][0]["finish_reason"]
# Try to parse the structured output if possible
try:
if text is not None:
parsed_content = json.loads(text)
text = parsed_content["natural_text"]
except json.JSONDecodeError:
pass
return NormalizedEntry.from_goldkey(goldkey=data["custom_id"], text=text, finish_reason=finish_reason, error=data.get("completion_error", None))
elif all(field in data for field in ["s3_path", "pagenum", "text", "error", "finish_reason"]):
return NormalizedEntry(**data)
elif "response" in data and "body" in data["response"] and "choices" in data["response"]["body"]:
# OpenAI case
try:
# Attempt to parse the JSON content from OpenAI's response
parsed_content = json.loads(data["response"]["body"]["choices"][0]["message"]["content"])
return NormalizedEntry.from_goldkey(
goldkey=data["custom_id"], text=parsed_content["natural_text"], finish_reason=data["response"]["body"]["choices"][0]["finish_reason"]
)
except json.JSONDecodeError:
# Fallback if content is not valid JSON
return NormalizedEntry.from_goldkey(
goldkey=data["custom_id"],
text=data["response"]["body"]["choices"][0]["message"]["content"],
finish_reason=data["response"]["body"]["choices"][0]["finish_reason"],
)
else:
# SGLang case
try:
# Attempt to parse the JSON content from OpenAI's response
parsed_content = json.loads(data["response"]["choices"][0]["message"]["content"])
return NormalizedEntry.from_goldkey(
goldkey=data["custom_id"], text=parsed_content["natural_text"], finish_reason=data["response"]["choices"][0]["finish_reason"]
)
except json.JSONDecodeError:
# Fallback if content is not valid JSON
return NormalizedEntry.from_goldkey(
goldkey=data["custom_id"],
text=data["response"]["choices"][0]["message"]["content"],
finish_reason=data["response"]["choices"][0]["finish_reason"],
)
# Load every .json file from GOLD_DATA_S3_PATH (and saves it to some temp folder for quick loading next time)
# returns map from "custom_id" ex. "s3://ai2-s2-pdfs/39ce/3db4516cd6e7d7f8e580a494c7a665a6a16a.pdf-4" (where the -4 means page 4)
# to the gold standard text
def load_gold_data(gold_data_path: str, max_workers: int = 8) -> dict:
"""
Load gold data from JSONL files in a multithreaded manner.
Args:
gold_data_path (str): Path to the directory containing JSONL files.
max_workers (int, optional): Maximum number of threads to use. Defaults to 8.
Returns:
dict: A dictionary containing gold data entries.
"""
if not os.path.exists(CACHE_DIR):
os.makedirs(CACHE_DIR)
gold_data: Dict[str, str] = {}
total_errors = 0
total_overruns = 0
gold_jsonl_files: List[str] = list_jsonl_files(gold_data_path)
def process_file(path: str) -> tuple:
"""Process a single JSONL file and return its data and error counts."""
file_gold_data = {}
file_errors = 0
file_overruns = 0
with smart_open(path, "r") as f:
for line in f:
data = json.loads(line)
data = normalize_json_entry(data)
if data.error is not None:
file_errors += 1
elif data.finish_reason != "stop":
file_overruns += 1
else:
file_gold_data[data.goldkey] = data.text
return file_gold_data, file_errors, file_overruns
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all file processing tasks
futures = [executor.submit(process_file, path) for path in gold_jsonl_files]
# Gather results as they complete
for future in as_completed(futures):
try:
file_gold_data, file_errors, file_overruns = future.result()
gold_data.update(file_gold_data)
total_errors += file_errors
total_overruns += file_overruns
except Exception as e:
print(f"Error processing a file: {e}")
print(f"Loaded {len(gold_data):,} gold data entries for comparison")
print(f"Gold processing errors: {total_errors}")
print(f"Gold overrun errors: {total_overruns}")
print("-----------------------------------------------------------")
return gold_data
# Helper function to list all .jsonl files from a directory or an S3 bucket
def list_jsonl_files(path: str) -> list:
valid_endings = [".json", ".jsonl", ".json.zstd", ".jsonl.zstd"]
jsonl_files = []
if path.startswith("s3://"):
bucket_name, prefix = path.replace("s3://", "").split("/", 1)
paginator = s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
for page in pages:
for obj in page.get("Contents", []):
if any(obj["Key"].endswith(ending) for ending in valid_endings):
jsonl_files.append(f"s3://{bucket_name}/{obj['Key']}")
else:
# If it's a local directory, list all .jsonl files
for root, _, files in os.walk(path):
for file in files:
if any(file.endswith(ending) for ending in valid_endings):
jsonl_files.append(os.path.join(root, file))
return jsonl_files
# Takes in a path to a local directory or s3://[bucket]/[prefix path] where your jsonl files are stored
# This is most likely the output location of the refiner
# Expecting each jsonl line to include {s3_path: [path to original pdf], page: [pagenum], text: [proper page text]}
# Returns the average Levenshtein distance match between the data
def process_jsonl_file(jsonl_file, gold_data, comparer):
page_data = {}
total_alignment_score: float = 0.0
char_weighted_alignment_score: float = 0.0
total_pages = 0
total_chars = 0
total_errors = 0
total_overruns = 0
with smart_open(jsonl_file, "r") as f:
for line in f:
data = json.loads(line)
data = normalize_json_entry(data)
if data.goldkey not in gold_data:
continue
gold_text = gold_data[data.goldkey]
eval_text = data.text
gold_text = gold_text or ""
eval_text = eval_text or ""
if data.error is not None:
total_errors += 1
eval_text = f"[Error processing this page: {data.error}]"
if data.error is None and data.finish_reason != "stop":
total_overruns += 1
eval_text += f"\n[Error processing this page: overrun {data.finish_reason}]"
if len(gold_text.strip()) < 3 and len(eval_text.strip()) < 3:
alignment = 1.0
else:
alignment = comparer.compute(gold_text, eval_text)
page_data[data.goldkey] = {"s3_path": data.s3_path, "page": data.pagenum, "gold_text": gold_text, "eval_text": eval_text, "alignment": alignment}
total_alignment_score += alignment
char_weighted_alignment_score += alignment * len(gold_text)
total_chars += len(gold_text)
total_pages += 1
return total_alignment_score, char_weighted_alignment_score, total_chars, total_pages, total_errors, total_overruns, page_data
def do_eval(gold_data_path: str, eval_data_path: str, review_page_name: str, review_page_size: int) -> tuple[float, list[dict]]:
gold_data = load_gold_data(gold_data_path)
total_alignment_score = 0
total_char_alignment_score = 0
total_weight = 0
total_pages = 0
total_errors = 0
total_overruns = 0
total_pages_compared = set()
page_eval_data = []
segmenter = SpacySegmenter("spacy")
aligner = HirschbergAligner(match_score=1, mismatch_score=-1, indel_score=-1)
comparer = DocumentEditSimilarity(segmenter=segmenter, aligner=aligner)
# List all .jsonl files in the directory or S3 bucket
jsonl_files = list_jsonl_files(eval_data_path)
if not jsonl_files:
raise ValueError("No .jsonl files found in the specified path.")
print(f"Found {len(jsonl_files):,} files to evaluate")
with ProcessPoolExecutor() if not is_debugging() else ThreadPoolExecutor() as executor:
# Prepare the future tasks
futures = [executor.submit(process_jsonl_file, jsonl_file, gold_data, comparer) for jsonl_file in jsonl_files]
# Process each future as it completes
for future in tqdm(as_completed(futures), total=len(jsonl_files)):
alignment_score, char_weighted_score, chars, pages, errors, overruns, page_data = future.result() # Get the result of the completed task
# Aggregate statistics
total_alignment_score += alignment_score
total_char_alignment_score += char_weighted_score
total_weight += chars
total_pages += pages
total_errors += errors
total_overruns += overruns
total_pages_compared |= page_data.keys()
# Generate the eval data
for pd_key, pd in page_data.items():
# if pd["alignment"] > 0.97:
# continue
# if len(pd["gold_text"]) < 200 and len(pd["eval_text"]) < 200:
# continue
# if "[Error processing this page: overrun" not in pd["eval_text"]:
# continue
page_eval_data.append(pd)
print(f"Compared {len(total_pages_compared):,} pages")
print(f"Found {total_errors} errors in the eval set, and {total_overruns} cases of length overruns")
print(f"Mean page-weighted alignment: {total_alignment_score / total_pages:.3f}")
print(f"Mean char-weighted alignment: {total_char_alignment_score / total_weight:.3f}")
print("")
print("...creating review page")
# TODO Temporary filter to see other stuff
# page_eval_data = [x for x in page_eval_data if "NO ENGLISH TEXT" not in x["gold_text"]]
# Select the top 20 lowest alignments
page_eval_data.sort(key=lambda x: x["alignment"])
create_review_html(page_eval_data[:review_page_size], filename=review_page_name + "_worst.html")
# Select random entries to return in the page_eval_data
page_eval_data = random.sample(page_eval_data, review_page_size)
create_review_html(page_eval_data, filename=review_page_name + "_sample.html")
return total_alignment_score / total_weight, page_eval_data
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Transform JSONL files by extracting and renaming specific fields.")
parser.add_argument("--name", default="review_page", help="What name to give to this evaluation/comparison")
parser.add_argument(
"--review_size",
default=20,
type=int,
help="Number of entries to show on the generated review page",
)
parser.add_argument(
"gold_data_path",
type=str,
help='Path to the gold data directory containing JSONL files. Can be a local path or S3 URL. Can be openai "done" data, or birr "done" data',
)
parser.add_argument(
"eval_data_path",
type=str,
help='Path to the eval data directory containing JSONL files. Can be a local path or S3 URL. Can be openai "done" data, or birr "done" data',
)
args = parser.parse_args()
result = do_eval(gold_data_path=args.gold_data_path, eval_data_path=args.eval_data_path, review_page_name=args.name, review_page_size=args.review_size)
import csv
import re
from collections import defaultdict
from typing import Any, DefaultDict
from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit
import requests # type: ignore
def fetch_review_page_html(url):
"""
Fetch the HTML from the Tinyhost URL.
"""
resp = requests.get(url)
resp.raise_for_status()
return resp.text
def extract_presigned_url(html):
"""
Given the HTML of the page, extract the `presignedGetUrl`.
Returns None if not found.
"""
match = re.search(r"const presignedGetUrl = \"(.*?)\";", html)
if not match:
return None
return match.group(1)
def fetch_presigned_datastore(presigned_url):
"""
Fetch the JSON datastore from the presigned URL.
Returns a dict. If any error or no content, returns {}.
"""
try:
# Clean up the presigned URL (sometimes the signature may need re-encoding)
url_parts = urlsplit(presigned_url)
query_params = parse_qs(url_parts.query)
encoded_query = urlencode(query_params, doseq=True)
cleaned_url = urlunsplit((url_parts.scheme, url_parts.netloc, url_parts.path, encoded_query, url_parts.fragment))
resp = requests.get(cleaned_url, headers={"Host": url_parts.netloc, "User-Agent": "Mozilla/5.0"})
resp.raise_for_status()
return resp.json()
except Exception as e:
print(f"Error fetching datastore from {presigned_url}: {e}")
return {}
def sanitize_key(key):
return re.sub(r"[^a-zA-Z0-9-_]", "_", key)
def parse_entry_metadata(html):
"""
Parse each .entry block from the HTML to figure out:
- data-entry-id
- left-metadata
- right-metadata
- classes that might indicate 'gold', 'eval', etc.
Returns a dict:
{
entry_id_1: {
'left_metadata': str,
'right_metadata': str,
'class_str': str,
},
...
}
Note: This uses a regex that looks for something like:
<div class="entry gold eval" data-entry-id="..."
data-left-metadata="..." data-right-metadata="...">
"""
pattern = re.compile(
r'<div\s+class="entry([^"]*)"\s+data-entry-id="([^"]+)"\s+data-left-metadata="([^"]+)"\s+data-right-metadata="([^"]+)"', re.DOTALL | re.MULTILINE
)
entries = {}
for m in pattern.finditer(html):
class_str = m.group(1).strip() # e.g. " gold eval"
entry_id = m.group(2).strip()
left_md = m.group(3).strip()
right_md = m.group(4).strip()
# Transform the HTML's data-entry-id to match the JS datastore keys:
entry_id = sanitize_key(entry_id)
entries[entry_id] = {
"class_str": class_str,
"left_metadata": left_md,
"right_metadata": right_md,
}
return entries
def build_comparison_report(entries_dict, datastore):
"""
Build a comparison report showing how often each type of method
beats each other type of method, based on user votes in the datastore.
We assume:
- If user vote is 'left', then left_metadata's method "wins".
- If user vote is 'right', then right_metadata's method "wins".
- If user vote is 'both_good', 'both_bad', or 'invalid_pdf',
we do not count it as a direct matchup.
Returns a structure:
comparisons[(A, B)] = [A_wins, B_wins],
where A < B lexicographically in that tuple.
"""
comparisons: DefaultDict[Any, list[int]] = defaultdict(lambda: [0, 0])
for entry_id, vote in datastore.items():
if entry_id not in entries_dict:
# No matching <div> found for this key in the HTML
continue
left_method = entries_dict[entry_id]["left_metadata"]
right_method = entries_dict[entry_id]["right_metadata"]
if left_method == right_method:
# Same "method" on both sides => skip
continue
if vote == "left":
# left_method beats right_method
pair = tuple(sorted([left_method, right_method]))
if pair[0] == left_method:
comparisons[pair][0] += 1
else:
comparisons[pair][1] += 1
elif vote == "right":
# right_method beats left_method
pair = tuple(sorted([left_method, right_method]))
if pair[0] == right_method:
comparisons[pair][0] += 1
else:
comparisons[pair][1] += 1
else:
# "both_good", "both_bad", "invalid_pdf", etc. -> not counted
pass
return comparisons
def elo_update(ratingA, ratingB, scoreA, scoreB, k=32):
"""
Perform a single ELO update for a match between A and B.
- ratingA, ratingB are current ELO ratings of A and B.
- scoreA, scoreB in {0 or 1} (1 if the player is the winner, 0 if loser).
- Returns (new_ratingA, new_ratingB).
"""
# Expected scores for each player
expectedA = 1 / (1 + 10 ** ((ratingB - ratingA) / 400))
expectedB = 1 / (1 + 10 ** ((ratingA - ratingB) / 400))
new_ratingA = ratingA + k * (scoreA - expectedA)
new_ratingB = ratingB + k * (scoreB - expectedB)
return new_ratingA, new_ratingB
def compute_elo_arena(comparisons, k=32, initial_rating=1500):
"""
Given the aggregated comparisons dict:
comparisons[(A, B)] = [A_wins, B_wins]
1) Collect all unique methods.
2) Initialize them to initial_rating (1500).
3) For each pair (A, B), apply A_wins times the scenario
"A beats B" -> ELO update
B_wins times the scenario "B beats A" -> ELO update
Because we don't have a strict order of matches, we just
apply them in some consistent but arbitrary order.
Returns a dict { method_name: final_elo_rating }
"""
# 1) Collect all unique methods
methods = set()
for A, B in comparisons.keys():
methods.add(A)
methods.add(B)
# 2) Initialize ratings
ratings = {m: float(initial_rating) for m in methods}
# 3) Walk through each pair
for (A, B), (A_wins, B_wins) in comparisons.items():
for _ in range(A_wins):
# A beats B
oldA = ratings[A]
oldB = ratings[B]
newA, newB = elo_update(oldA, oldB, 1, 0, k=k)
ratings[A] = newA
ratings[B] = newB
for _ in range(B_wins):
# B beats A
oldA = ratings[A]
oldB = ratings[B]
newA, newB = elo_update(oldA, oldB, 0, 1, k=k)
ratings[A] = newA
ratings[B] = newB
return ratings
def make_report(urls):
"""
Main function that:
- Fetches each HTML page
- Extracts presignedGetUrl
- Fetches the JSON datastore
- Parses .entry blocks for metadata
- Produces an overall "win rate" report for each method vs method
- Produces an ELO arena result for each method
"""
# Aggregate all entries from all URLs into a single dict
# so each entry_id is unique across all pages (they usually are).
master_entries_dict = {}
master_datastore = {}
for url in urls:
try:
html = fetch_review_page_html(url)
except Exception as e:
print(f"Error fetching HTML from {url}: {e}")
continue
# Extract the presignedGetUrl
presigned_url = extract_presigned_url(html)
if not presigned_url:
print(f"Warning: Could not find presignedGetUrl in {url}")
continue
# Fetch the datastore
datastore = fetch_presigned_datastore(presigned_url)
# Parse the HTML for entry metadata
entries_dict = parse_entry_metadata(html)
# Merge into master
for k, v in entries_dict.items():
master_entries_dict[k] = v
for k, v in datastore.items():
master_datastore[k] = v
# Now build the comparison report
comparisons = build_comparison_report(master_entries_dict, master_datastore)
print("=== Pairwise Win/Loss Report ===")
if not comparisons:
print("No head-to-head comparisons found (did not find left/right votes).")
return
# Print out each matchup
for (A, B), (A_wins, B_wins) in comparisons.items():
total = A_wins + B_wins
A_rate = A_wins / total * 100 if total else 0
B_rate = B_wins / total * 100 if total else 0
print(f"{A} vs {B}: " f"{A} wins={A_wins} ({A_rate:.1f}%), " f"{B} wins={B_wins} ({B_rate:.1f}%)")
# -- ADDED: Write the same data to scoreelo.csv
with open("scoreelo.csv", "w", newline="", encoding="utf-8") as csvfile:
writer = csv.writer(csvfile)
writer.writerow(["MethodA", "MethodB", "A_wins", "B_wins", "A_rate(%)", "B_rate(%)"])
for (A, B), (A_wins, B_wins) in comparisons.items():
total = A_wins + B_wins
A_rate = A_wins / total * 100 if total else 0
B_rate = B_wins / total * 100 if total else 0
writer.writerow([A, B, A_wins, B_wins, f"{A_rate:.1f}", f"{B_rate:.1f}"])
# ==== ELO Arena ====
elo_ratings = compute_elo_arena(comparisons, k=32, initial_rating=1500)
# Sort methods by final rating descending
sorted_ratings = sorted(elo_ratings.items(), key=lambda x: x[1], reverse=True)
print("\n=== ELO Arena Results ===")
for method, rating in sorted_ratings:
print(f"{method}: {rating:.2f}")
if __name__ == "__main__":
# Example usage
urls = [
"https://jakep-tinyhost.s3.amazonaws.com/review_page_0-ff70abb8f517.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=NarEyyCfvusCh%2FHdB47VfHOnnBs%3D&Expires=1738359221",
"https://jakep-tinyhost.s3.amazonaws.com/review_page_1-0800f9af46cf.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=ncTWAu5rSndBJJsU26HRYDaK6i8%3D&Expires=1738359222",
"https://jakep-tinyhost.s3.amazonaws.com/review_page_10-f7081f6ca6f9.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=gYX8yjGyYshRqXGgdsX17%2Fdi9Ig%3D&Expires=1738359223",
"https://jakep-tinyhost.s3.amazonaws.com/review_page_11-355dc69335bc.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=7%2Bc5qoa8Tbk06z0VcvJiIIVAz9M%3D&Expires=1738359224",
"https://jakep-tinyhost.s3.amazonaws.com/review_page_12-95fce9bf0c18.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=fw4PBo0LnxikmLZ8xH%2BGD%2F%2BhXMU%3D&Expires=1738359225",
"https://jakep-tinyhost.s3.amazonaws.com/review_page_13-f88f7d7482bf.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=yXkQp9oFDtroKgiO50EwpYdGLcA%3D&Expires=1738359226",
"https://jakep-tinyhost.s3.amazonaws.com/review_page_14-8ac0b974bfd5.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=EgZTpj1%2FdzMBUgd%2BX4pVZ1Sp%2FrA%3D&Expires=1738359226",
"https://jakep-tinyhost.s3.amazonaws.com/review_page_15-e3136188de5c.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=YKhAv4unNIlRcerQAaHN4kjc4qI%3D&Expires=1738359227",
"https://jakep-tinyhost.s3.amazonaws.com/review_page_16-2c5abde50d49.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=Mj8%2BK5ISKzAYQFeYvmzTgCPcRwA%3D&Expires=1738359228",
"https://jakep-tinyhost.s3.amazonaws.com/review_page_17-f13132a4cdcc.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=%2FHuzw2cjJ4oFm91UXojPnGzYi8Q%3D&Expires=1738359229",
"https://jakep-tinyhost.s3.amazonaws.com/review_page_18-25070f2aa05e.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=ctd%2BUIM%2FxryJm%2FcwA%2BRZ%2FbRzBp8%3D&Expires=1738359230",
"https://jakep-tinyhost.s3.amazonaws.com/review_page_19-d436ee434162.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=jVdFKobIoHlbTQ7zziG%2BXiIQ0Fo%3D&Expires=1738359230",
"https://jakep-tinyhost.s3.amazonaws.com/review_page_2-a5ece743fd31.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=K8hIrjWtvo4SLVQrOB8TiXLgNJk%3D&Expires=1738359231",
"https://jakep-tinyhost.s3.amazonaws.com/review_page_3-9ce03af05f51.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=T0fLGSH%2Bv%2F19veqbxnLxoSf7gVA%3D&Expires=1738359232",
"https://jakep-tinyhost.s3.amazonaws.com/review_page_4-94eec18f8027.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=u2R1LundKpfnAUCcD%2BdGHA6uIR0%3D&Expires=1738359233",
"https://jakep-tinyhost.s3.amazonaws.com/review_page_5-377d0a7d8f5a.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=5R38ZQAR9ew5x%2BRmMVQbTqbfVh0%3D&Expires=1738359234",
"https://jakep-tinyhost.s3.amazonaws.com/review_page_6-537b22646a26.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=PLOELum1qzOXW8Cm5rfZphlFeMw%3D&Expires=1738359235",
"https://jakep-tinyhost.s3.amazonaws.com/review_page_7-a4a7dcb08f20.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=DxPHukGXEpPrEPL6TF9QBKPE1Xg%3D&Expires=1738359236",
"https://jakep-tinyhost.s3.amazonaws.com/review_page_8-48a71c829863.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=TjEINKj69HdmXsKY59k4f3PieeM%3D&Expires=1738359237",
"https://jakep-tinyhost.s3.amazonaws.com/review_page_9-8557438928c3.html?AWSAccessKeyId=AKIASHLPW4FEVZOPGK46&Signature=F7sQxw5A%2FDOcOaa%2FQSeqepH0PQc%3D&Expires=1738359238",
]
# import tinyhost
# print(tinyhost.tinyhost(urls))
make_report(urls)
from .filter import PdfFilter
from functools import lru_cache
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
@lru_cache()
def load_coherency_model(model_name: str = "HuggingFaceTB/SmolLM-135M"):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
model.eval() # Set the model to evaluation mode
return tokenizer, model
def get_document_coherency(text: str) -> float:
"""
Calculates the coherency of a document based on the log likelihood of its tokens.
Handles texts longer than the model's maximum token limit by splitting them into chunks.
Args:
text (str): The input text to evaluate.
Returns:
float: The average log likelihood per token as a measure of coherency.
"""
tokenizer, model = load_coherency_model()
# Determine the model's maximum number of tokens
max_length = tokenizer.model_max_length - 1
# Some tokenizers have a default value indicating no limit; use model config if so
if max_length > 1_000_000:
max_length = model.config.max_position_embeddings
# Tokenize the entire text
tokens = tokenizer.encode(text, return_tensors="pt").squeeze(0)
total_log_likelihood = 0.0
total_tokens = 0
# Split tokens into chunks that fit within the model's max length
for i in range(0, len(tokens), max_length):
chunk = tokens[i : i + max_length]
inputs = chunk.unsqueeze(0) # Add batch dimension
# Move inputs to CPU (ensure compatibility)
inputs = {k: v.cpu() for k, v in {"input_ids": inputs}.items()}
with torch.no_grad():
outputs = model(**inputs, labels=inputs["input_ids"])
# Compute log likelihood for the chunk
log_likelihood = -outputs.loss.item() * chunk.size(0)
total_log_likelihood += log_likelihood
total_tokens += chunk.size(0)
# Calculate the average log likelihood per token
avg_log_likelihood = total_log_likelihood / total_tokens if total_tokens > 0 else 0.0
return avg_log_likelihood
import logging
import re
import subprocess
from collections import Counter
from typing import List
from lingua import Language, LanguageDetectorBuilder
from pypdf import PdfReader
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class PdfFilter:
def __init__(
self,
languages_to_keep=None,
apply_form_check=True,
apply_download_spam_check=True,
download_spam_threshold=0.004,
):
super().__init__()
self.language_detector = LanguageDetectorBuilder.from_all_languages().with_preloaded_language_models().build()
self.languages_to_keep = languages_to_keep if languages_to_keep is not None else [Language.ENGLISH]
self.apply_form_check = apply_form_check
self.apply_download_spam_check = apply_download_spam_check
self.download_spam_threshold = download_spam_threshold
def _is_form(self, pdf_reader) -> bool:
# Check if the PDF is a form
if pdf_reader.get_form_text_fields():
return True
return False # Not a form
def _is_download_spam(self, base_text: str) -> bool:
seo_words = {
"download",
"pdf",
"epub",
"mobi",
"free",
"ebook",
"file",
"save",
"casino",
"viagra",
"cialis",
"ciprofloxacin",
}
base_text = base_text.strip().lower()
clean_text = re.sub(r"\W+", " ", base_text)
word_counts = Counter(clean_text.split())
total_words = len(clean_text.split())
if total_words == 0:
return False
seo_score = sum(word_counts[word] for word in seo_words if word in word_counts)
return (seo_score / total_words) > self.download_spam_threshold
# Returns True if there is something wrong with this PDF
def filter_out_pdf(self, local_pdf_path: str) -> bool:
try:
# Attempt to read the PDF at the beginning
pdf_reader = PdfReader(local_pdf_path)
# Form check
if self.apply_form_check and self._is_form(pdf_reader):
logger.info(f"Filtering out {local_pdf_path} because it's a form")
return True # Filter out
except Exception as e:
logger.warning(f"Error reading PDF {local_pdf_path}: {e}")
return True # Filter out the PDF if an exception occurs
# Read the first five pages of text for language calculation
pdftotext_result = subprocess.run(
["pdftotext", "-f", "1", "-l", "5", local_pdf_path, "-"],
timeout=60,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
if pdftotext_result.returncode != 0:
logger.warning(f"pdftotext returned {pdftotext_result.returncode} on {local_pdf_path}")
return True # Filter out
base_text = pdftotext_result.stdout.decode("utf-8")
alpha_count = sum(c.isalpha() for c in base_text)
if len(base_text) < 200:
logger.info(f"Keeping {local_pdf_path} on the safe side because not enough text exists in it to analyze")
return False # keep the pdf
if alpha_count / len(base_text) < 0.50:
logger.info(f"Keeping {local_pdf_path} on the safe side because it's text does not contain many letters so it might be OCRed badly")
return False # keep the pdf
# Language check
language = self.language_detector.detect_language_of(base_text)
if language not in self.languages_to_keep:
logger.info(f"Filtering out {local_pdf_path} because language was {language}")
return True # Filter out
# Download spam check
if self.apply_download_spam_check and self._is_download_spam(base_text):
logger.info(f"Filtering out {local_pdf_path} because of SEO/download spam")
return True # Filter out
return False # Keep the PDF
if __name__ == "__main__":
import tempfile
from concurrent.futures import FIRST_COMPLETED, ProcessPoolExecutor, wait
import boto3
from tqdm import tqdm
from olmocr.s3_utils import parse_s3_path
# Quiet logs from pypdf
logging.getLogger("pypdf").setLevel(logging.ERROR)
def process_pdf(s3_path):
"""
Process a single PDF file to determine if it should be kept or removed.
"""
s3_bucket, s3_key = parse_s3_path(s3_path)
pdf_s3 = boto3.client("s3")
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=True) as tmp_file:
pdf_s3.download_fileobj(s3_bucket, s3_key, tmp_file)
tmp_file.flush()
# Perform filtering logic
if filter.filter_out_pdf(tmp_file.name):
return s3_path, "remove"
else:
return s3_path, "keep"
# Load the list of S3 paths with a progress bar
with open("/home/ubuntu/s2pdf_paths_1M.txt", "r") as f:
s3_work_paths: List[str] = list(filter(None, (line.strip() for line in tqdm(f, desc="Loading paths"))))
# Initialize the PDF filter
filter = PdfFilter(
languages_to_keep={Language.ENGLISH, None}, # none means could not detect language, that's okay keep it, might be an OCR
apply_download_spam_check=True,
apply_form_check=True,
)
# Output files
keep_path = "/home/ubuntu/s2pdf_paths_filter_keep.txt"
remove_path = "/home/ubuntu/s2pdf_paths_filter_remove.txt"
max_pending = 20 # Limit on the number of concurrent futures
total_pdfs = len(s3_work_paths)
pdf_iter = iter(s3_work_paths) # Iterator for PDFs
# Process the PDFs with limited concurrent futures
with open(keep_path, "w") as fkeep, open(remove_path, "w") as fremove:
with ProcessPoolExecutor(max_workers=max_pending) as executor:
pending_futures = {}
with tqdm(total=total_pdfs, desc="Processing PDFs") as pbar:
# Submit initial batch of futures
for _ in range(min(max_pending, total_pdfs)):
s3_path = next(pdf_iter)
future = executor.submit(process_pdf, s3_path)
pending_futures[future] = s3_path
while pending_futures:
# Wait for the next future to complete
done, _ = wait( # type: ignore
pending_futures.keys(),
timeout=0.1,
return_when=FIRST_COMPLETED,
)
for future in done:
s3_path = pending_futures.pop(future)
try:
s3_path, result = future.result()
if result == "keep":
fkeep.write(s3_path + "\n")
elif result == "remove":
fremove.write(s3_path + "\n")
except Exception as e:
print(f"Error processing {s3_path}: {e}")
# Update the progress bar
pbar.update(1)
# Submit a new future if there are more PDFs
try:
s3_path = next(pdf_iter)
future = executor.submit(process_pdf, s3_path)
pending_futures[future] = s3_path
except StopIteration:
pass # No more PDFs to process
import asyncio
import time
from collections import defaultdict, deque
from typing import Any, Deque, Dict, List, Set
class MetricsKeeper:
def __init__(self, window=60 * 5):
"""
Initializes the MetricsKeeper.
Args:
window (int): Time window in seconds for recent metrics. Defaults to 5 minutes.
"""
self.window = window # Time window in seconds
self.start_time = time.time() # Timestamp when MetricsKeeper was created
self.total_metrics = defaultdict(int) # Cumulative metrics since start
self.window_metrics: Deque[Any] = deque() # Deque to store (timestamp, metrics_dict)
self.window_sum = defaultdict(int) # Sum of metrics within the window
def add_metrics(self, **kwargs):
"""
Adds metrics to the keeper.
Args:
**kwargs: Arbitrary keyword arguments representing metric names and their values.
"""
current_time = time.time()
# Update cumulative metrics
for key, value in kwargs.items():
self.total_metrics[key] += value
# Append current metrics with timestamp to the deque
self.window_metrics.append((current_time, kwargs))
# Update window sums
for key, value in kwargs.items():
self.window_sum[key] += value
# Remove metrics that are outside the time window
while self.window_metrics and self.window_metrics[0][0] < current_time - self.window:
old_time, old_metrics = self.window_metrics.popleft()
for key, value in old_metrics.items():
self.window_sum[key] -= value
if self.window_sum[key] <= 0:
del self.window_sum[key] # Clean up to prevent negative counts
def __str__(self):
"""
Returns a formatted string of metrics showing tokens/sec since start and within the window.
Returns:
str: Formatted metrics string as a table.
"""
current_time = time.time()
elapsed_time = current_time - self.start_time
window_time = min(self.window, elapsed_time) if elapsed_time > 0 else 1 # Prevent division by zero
# Header
header = f"{'Metric Name':<30} {'Lifetime (tokens/sec)':>25} {'Recently (tokens/sec)':>25}"
separator = "-" * len(header)
lines = [header, separator]
# Sort metrics alphabetically for consistency
for key in sorted(self.total_metrics.keys()):
total = self.total_metrics[key]
window = self.window_sum.get(key, 0)
total_rate = total / elapsed_time if elapsed_time > 0 else 0
window_rate = window / window_time if window_time > 0 else 0
line = f"{key:<20} {total_rate:>25.2f} {window_rate:>25.2f}"
lines.append(line)
return "\n".join(lines)
class WorkerTracker:
def __init__(self):
"""
Initializes the WorkerTracker with a default dictionary.
Each worker ID maps to another dictionary that holds counts for each state.
"""
# Mapping from worker_id to a dictionary of state counts
self.worker_status: Dict[int, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
self.lock = asyncio.Lock()
async def clear_work(self, worker_id: int):
async with self.lock:
self.worker_status[worker_id].clear()
async def track_work(self, worker_id: int, work_item_id: str, state: str):
"""
Update the state count for a specific worker.
Args:
worker_id (int): The ID of the worker.
work_item_id (str): The unique identifier of the work item (unused in this implementation).
state (str): The state to increment for the work item.
"""
async with self.lock:
self.worker_status[worker_id][state] += 1
async def get_status_table(self) -> str:
"""
Generate a formatted table of the current status of all workers.
Returns:
str: A string representation of the workers' statuses.
"""
async with self.lock:
# Determine all unique states across all workers
all_states: Set[str] = set()
for states in self.worker_status.values():
all_states.update(states.keys())
sorted_states: List[str] = sorted(all_states)
headers = ["Worker ID"] + sorted_states # type: ignore
rows = []
for worker_id, states in sorted(self.worker_status.items()):
row = [str(worker_id)]
for state in sorted_states:
count = states.get(state, 0)
row.append(str(count))
rows.append(row)
# Calculate column widths
col_widths = [len(header) for header in headers]
for row in rows:
for idx, cell in enumerate(row):
col_widths[idx] = max(col_widths[idx], len(cell))
# Create the table header
header_line = " | ".join(header.ljust(col_widths[idx]) for idx, header in enumerate(headers))
separator = "-+-".join("-" * col_widths[idx] for idx in range(len(headers)))
# Create the table rows
row_lines = [" | ".join(cell.ljust(col_widths[idx]) for idx, cell in enumerate(row)) for row in rows]
# Combine all parts
table = "\n".join([header_line, separator] + row_lines)
return table
def __str__(self):
"""
String representation is not directly supported.
Use 'await get_status_table()' to retrieve the status table.
"""
raise NotImplementedError("Use 'await get_status_table()' to get the status table.")
import argparse
import asyncio
import atexit
import base64
import datetime
import hashlib
import json
import logging
import multiprocessing
import os
import random
import re
import shutil
import sys
import tempfile
import time
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from concurrent.futures.process import BrokenProcessPool
from dataclasses import dataclass
from functools import cache, partial
from io import BytesIO
from urllib.parse import urlparse
import boto3
import httpx
import torch
from botocore.exceptions import ClientError
from huggingface_hub import snapshot_download
from PIL import Image
from pypdf import PdfReader
from tqdm import tqdm
from olmocr.check import (
check_poppler_version,
check_sglang_version,
check_torch_gpu_available,
)
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.filter.filter import Language, PdfFilter
from olmocr.metrics import MetricsKeeper, WorkerTracker
from olmocr.prompts import PageResponse, build_finetuning_prompt
from olmocr.prompts.anchor import get_anchor_text
from olmocr.s3_utils import (
download_zstd_csv,
expand_s3_glob,
get_s3_bytes,
get_s3_bytes_with_backoff,
parse_s3_path,
)
from olmocr.version import VERSION
from olmocr.work_queue import LocalWorkQueue, S3WorkQueue, WorkQueue
# Initialize logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.propagate = False
sglang_logger = logging.getLogger("sglang")
sglang_logger.propagate = False
file_handler = logging.FileHandler("olmocr-pipeline-debug.log", mode="a")
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
# Add handlers to the logger
logger.addHandler(file_handler)
logger.addHandler(console_handler)
sglang_logger.addHandler(file_handler)
# Quiet logs from pypdf
logging.getLogger("pypdf").setLevel(logging.ERROR)
# Global s3 clients fo the whole script, we have two separate ones in case your workspace and your pdfs are in different accounts
workspace_s3 = boto3.client("s3")
pdf_s3 = boto3.client("s3")
# Global variables for token statistics
metrics = MetricsKeeper(window=60 * 5)
tracker = WorkerTracker()
# Process pool for offloading cpu bound work, like calculating anchor texts, max 32 workers, otherwise it can spawn way too many workers on a big machine
process_pool = ProcessPoolExecutor(max_workers=min(multiprocessing.cpu_count() // 2 + 1, 32), mp_context=multiprocessing.get_context("spawn"))
# Filter object, cached so it will only get loaded when/if you need it
get_pdf_filter = cache(lambda: PdfFilter(languages_to_keep={Language.ENGLISH, None}, apply_download_spam_check=True, apply_form_check=True))
SGLANG_SERVER_PORT = 30024
@dataclass(frozen=True)
class PageResult:
s3_path: str
page_num: int
response: PageResponse
input_tokens: int
output_tokens: int
is_fallback: bool
async def build_page_query(local_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int = 0) -> dict:
MAX_TOKENS = 3000
assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query"
# Allow the page rendering to process in the background while we get the anchor text (which blocks the main thread)
image_base64 = asyncio.to_thread(render_pdf_to_base64png, local_pdf_path, page, target_longest_image_dim=target_longest_image_dim)
# GET ANCHOR TEXT IS NOT THREAD SAFE!! Ahhhh..... don't try to do it
# and it's also CPU bound, so it needs to run in a process pool
loop = asyncio.get_running_loop()
anchor_text = loop.run_in_executor(
process_pool, partial(get_anchor_text, pdf_engine="pdfreport", target_length=target_anchor_text_len), local_pdf_path, page
)
image_base64, anchor_text = await asyncio.gather(image_base64, anchor_text) # type: ignore
if image_rotation != 0:
image_bytes = base64.b64decode(image_base64)
with Image.open(BytesIO(image_bytes)) as img:
rotated_img = img.rotate(-image_rotation, expand=True)
# Save the rotated image to a bytes buffer
buffered = BytesIO()
rotated_img.save(buffered, format="PNG")
# Encode the rotated image back to base64
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return {
"model": "Qwen/Qwen2-VL-7B-Instruct",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": build_finetuning_prompt(anchor_text)},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
],
}
],
"max_tokens": MAX_TOKENS,
"temperature": 0.8,
}
# Manual simple implementation of HTTP Post
# It feels strange perhaps, but httpx and aiohttp are very complex beasts
# Ex. the sessionpool in httpcore has 4 different locks in it, and I've noticed
# that at the scale of 100M+ requests, that they deadlock in different strange ways
async def apost(url, json_data):
parsed_url = urlparse(url)
host = parsed_url.hostname
port = parsed_url.port or 80
path = parsed_url.path or "/"
writer = None
try:
reader, writer = await asyncio.open_connection(host, port)
json_payload = json.dumps(json_data)
request = (
f"POST {path} HTTP/1.1\r\n"
f"Host: {host}\r\n"
f"Content-Type: application/json\r\n"
f"Content-Length: {len(json_payload)}\r\n"
f"Connection: close\r\n\r\n"
f"{json_payload}"
)
writer.write(request.encode())
await writer.drain()
# Read status line
status_line = await reader.readline()
if not status_line:
raise ConnectionError("No response from server")
status_parts = status_line.decode().strip().split(" ", 2)
if len(status_parts) < 2:
raise ValueError(f"Malformed status line: {status_line.decode().strip()}")
status_code = int(status_parts[1])
# Read headers
headers = {}
while True:
line = await reader.readline()
if line in (b"\r\n", b"\n", b""):
break
key, _, value = line.decode().partition(":")
headers[key.strip().lower()] = value.strip()
# Read response body
if "content-length" in headers:
body_length = int(headers["content-length"])
response_body = await reader.readexactly(body_length)
else:
raise ConnectionError("Anything other than fixed content length responses are not implemented yet")
return status_code, response_body
except Exception as e:
# Pass through errors
raise e
finally:
# But just make sure to close the socket on your way out
if writer is not None:
try:
writer.close()
await writer.wait_closed()
except:
pass
async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: str, page_num: int) -> PageResult:
COMPLETION_URL = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
MAX_RETRIES = args.max_page_retries
exponential_backoffs = 0
local_anchor_text_len = args.target_anchor_text_len
local_image_rotation = 0
attempt = 0
await tracker.track_work(worker_id, f"{pdf_orig_path}-{page_num}", "started")
while attempt < MAX_RETRIES:
query = await build_page_query(pdf_local_path, page_num, args.target_longest_image_dim, local_anchor_text_len, image_rotation=local_image_rotation)
logger.info(f"Built page query for {pdf_orig_path}-{page_num}")
try:
status_code, response_body = await apost(COMPLETION_URL, json_data=query)
if status_code == 400:
raise ValueError(f"Got BadRequestError from server: {response_body}, skipping this response")
elif status_code == 500:
raise ValueError(f"Got InternalServerError from server: {response_body}, skipping this response")
elif status_code != 200:
raise ValueError(f"Error http status {status_code}")
base_response_data = json.loads(response_body)
if base_response_data["usage"]["total_tokens"] > args.model_max_context:
local_anchor_text_len = max(1, local_anchor_text_len // 2)
logger.info(f"Reducing anchor text len to {local_anchor_text_len} for {pdf_orig_path}-{page_num}")
raise ValueError("Response exceeded model_max_context, cannot use this response")
metrics.add_metrics(
sglang_input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
sglang_output_tokens=base_response_data["usage"].get("completion_tokens", 0),
)
model_response_json = json.loads(base_response_data["choices"][0]["message"]["content"])
page_response = PageResponse(**model_response_json)
if not page_response.is_rotation_valid and attempt < MAX_RETRIES - 1:
logger.info(
f"Got invalid_page rotation for {pdf_orig_path}-{page_num} attempt {attempt}, retrying with {page_response.rotation_correction} rotation"
)
local_image_rotation = page_response.rotation_correction
raise ValueError(f"invalid_page rotation for {pdf_orig_path}-{page_num}")
await tracker.track_work(worker_id, f"{pdf_orig_path}-{page_num}", "finished")
return PageResult(
pdf_orig_path,
page_num,
page_response,
input_tokens=base_response_data["usage"].get("prompt_tokens", 0),
output_tokens=base_response_data["usage"].get("completion_tokens", 0),
is_fallback=False,
)
except (ConnectionError, OSError, asyncio.TimeoutError) as e:
logger.warning(f"Client error on attempt {attempt} for {pdf_orig_path}-{page_num}: {type(e)} {e}")
# Now we want to do exponential backoff, and not count this as an actual page retry
# Page retrys are supposed to be for fixing bad results from the model, but actual requests to sglang
# are supposed to work. Probably this means that the server is just restarting
sleep_delay = 10 * (2**exponential_backoffs)
exponential_backoffs += 1
logger.info(f"Sleeping for {sleep_delay} seconds on {pdf_orig_path}-{page_num} to allow server restart")
await asyncio.sleep(sleep_delay)
except asyncio.CancelledError:
logger.info(f"Process page {pdf_orig_path}-{page_num} cancelled")
await tracker.track_work(worker_id, f"{pdf_orig_path}-{page_num}", "cancelled")
raise
except json.JSONDecodeError as e:
logger.warning(f"JSON decode error on attempt {attempt} for {pdf_orig_path}-{page_num}: {e}")
attempt += 1
except ValueError as e:
logger.warning(f"ValueError on attempt {attempt} for {pdf_orig_path}-{page_num}: {type(e)} - {e}")
attempt += 1
except Exception as e:
logger.exception(f"Unexpected error on attempt {attempt} for {pdf_orig_path}-{page_num}: {type(e)} - {e}")
attempt += 1
logger.error(f"Failed to process {pdf_orig_path}-{page_num} after {MAX_RETRIES} attempts.")
await tracker.track_work(worker_id, f"{pdf_orig_path}-{page_num}", "errored")
return PageResult(
pdf_orig_path,
page_num,
PageResponse(
natural_text=get_anchor_text(pdf_local_path, page_num, pdf_engine="pdftotext"),
primary_language=None,
is_rotation_valid=True,
rotation_correction=0,
is_table=False,
is_diagram=False,
),
input_tokens=0,
output_tokens=0,
is_fallback=True,
)
async def process_pdf(args, worker_id: int, pdf_orig_path: str):
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
try:
data = await asyncio.to_thread(lambda: get_s3_bytes_with_backoff(pdf_s3, pdf_orig_path))
tf.write(data)
tf.flush()
except ClientError as ex:
if ex.response["Error"]["Code"] == "NoSuchKey":
logger.info(f"S3 File Not found, skipping it completely {pdf_orig_path}")
return None
else:
raise
try:
reader = PdfReader(tf.name)
num_pages = reader.get_num_pages()
except:
logger.exception(f"Could not count number of pages for {pdf_orig_path}, aborting document")
return None
logger.info(f"Got {num_pages} pages to do for {pdf_orig_path} in worker {worker_id}")
if args.apply_filter and get_pdf_filter().filter_out_pdf(tf.name):
logger.info(f"Filtering out pdf {pdf_orig_path}")
return None
# List to hold the tasks for processing each page
page_tasks = []
page_results = []
try:
async with asyncio.TaskGroup() as tg:
for page_num in range(1, num_pages + 1):
task = tg.create_task(process_page(args, worker_id, pdf_orig_path, tf.name, page_num))
page_tasks.append(task)
# Collect the results from the entire task group, assuming no exceptions
page_results = [task.result() for task in page_tasks]
num_fallback_pages = sum(page_result.is_fallback for page_result in page_results)
if num_fallback_pages / num_pages > args.max_page_error_rate:
logger.error(
f"Document {pdf_orig_path} has {num_fallback_pages} fallback pages out of {num_pages} exceeding max_page_error_rate of {args.max_page_error_rate}, discarding document."
)
return None
elif num_fallback_pages > 0:
logger.warning(
f"Document {pdf_orig_path} processed with {num_fallback_pages} fallback pages out of {num_pages}, proceeding to build Dolma document."
)
return build_dolma_document(pdf_orig_path, page_results)
except Exception as e:
# Check for ExceptionGroup with BrokenProcessPool
if isinstance(e, ExceptionGroup):
broken_pool, other = e.split(BrokenProcessPool)
if broken_pool is not None: # Found at least one BrokenProcessPool
logger.critical("Encountered BrokenProcessPool, exiting process.")
sys.exit(1)
logger.exception(f"Exception in process_pdf for {pdf_orig_path}: {e}")
# You can't build a dolma doc with even 1 failed page, so just get out of here
# However, you don't want to propagate an exception higher up and cancel the entire work_group
return None
def build_dolma_document(pdf_orig_path, page_results):
# Build the document text and page spans
document_text = ""
pdf_page_spans = []
current_char_pos = 0
for index, page_result in enumerate(page_results):
if page_result.response.natural_text is not None:
content = page_result.response.natural_text + ("\n" if index < len(page_results) - 1 else "")
else:
content = ""
start_pos = current_char_pos
document_text += content
current_char_pos = len(document_text)
pdf_page_spans.append([start_pos, current_char_pos, page_result.page_num])
if not document_text:
logger.info(f"No document text for {pdf_orig_path}")
return None # Return None if the document text is empty
# Build the Dolma document
metadata = {
"Source-File": pdf_orig_path,
"olmocr-version": VERSION,
"pdf-total-pages": len(page_results),
"total-input-tokens": sum(page.input_tokens for page in page_results),
"total-output-tokens": sum(page.output_tokens for page in page_results),
"total-fallback-pages": sum(page.is_fallback for page in page_results),
}
id_ = hashlib.sha1(document_text.encode()).hexdigest()
dolma_doc = {
"id": id_,
"text": document_text,
"source": "olmocr",
"added": datetime.datetime.now().strftime("%Y-%m-%d"),
"created": datetime.datetime.now().strftime("%Y-%m-%d"),
"metadata": metadata,
"attributes": {"pdf_page_numbers": pdf_page_spans},
}
return dolma_doc
async def worker(args, work_queue: WorkQueue, semaphore, worker_id):
while True:
# Wait until allowed to proceed
await semaphore.acquire()
work_item = await work_queue.get_work()
if work_item is None:
logger.info(f"Worker {worker_id} exiting due to empty queue")
semaphore.release()
break
logger.info(f"Worker {worker_id} processing work item {work_item.hash}")
await tracker.clear_work(worker_id)
try:
async with asyncio.TaskGroup() as tg:
dolma_tasks = [tg.create_task(process_pdf(args, worker_id, pdf)) for pdf in work_item.work_paths]
logger.info(f"Created all tasks for {work_item.hash}")
logger.info(f"Finished TaskGroup for worker on {work_item.hash}")
dolma_docs = []
for task in dolma_tasks:
try:
result = task.result()
except:
# some dolma doc creations may have failed
pass
if result is not None:
dolma_docs.append(result)
logger.info(f"Got {len(dolma_docs)} docs for {work_item.hash}")
# Write the Dolma documents to a local temporary file in JSONL format
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as tf:
for doc in dolma_docs:
tf.write(json.dumps(doc))
tf.write("\n")
tf.flush()
# Define the output S3 path using the work_hash
output_final_path = os.path.join(args.workspace, "results", f"output_{work_item.hash}.jsonl")
if output_final_path.startswith("s3://"):
bucket, key = parse_s3_path(output_final_path)
workspace_s3.upload_file(tf.name, bucket, key)
else:
shutil.copyfile(tf.name, output_final_path)
# Update finished token counts from successful documents
metrics.add_metrics(
finished_input_tokens=sum(doc["metadata"]["total-input-tokens"] for doc in dolma_docs),
finished_output_tokens=sum(doc["metadata"]["total-output-tokens"] for doc in dolma_docs),
)
await work_queue.mark_done(work_item)
except Exception as e:
logger.exception(f"Exception occurred while processing work_hash {work_item.hash}: {e}")
finally:
semaphore.release()
async def sglang_server_task(args, semaphore):
model_name_or_path = args.model
# if "://" in model_name_or_path:
# # TODO, Fix this code so that we support the multiple s3/weka paths, or else remove it
# model_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'olmocr', 'model')
# download_directory(model_name_or_path, model_cache_dir)
# # Check the rope config and make sure it's got the proper key
# with open(os.path.join(model_cache_dir, "config.json"), "r") as cfin:
# config_data = json.load(cfin)
# if "rope_type" in config_data["rope_scaling"]:
# del config_data["rope_scaling"]["rope_type"]
# config_data["rope_scaling"]["type"] = "mrope"
# with open(os.path.join(model_cache_dir, "config.json"), "w") as cfout:
# json.dump(config_data, cfout)
# Check GPU memory, lower mem devices need a bit less KV cache space because the VLM takes additional memory
gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) # Convert to GB
mem_fraction_arg = ["--mem-fraction-static", "0.80"] if gpu_memory < 60 else []
cmd = [
"python3",
"-m",
"sglang.launch_server",
"--model-path",
model_name_or_path,
"--chat-template",
args.model_chat_template,
# "--context-length", str(args.model_max_context), # Commented out due to crashes
"--port",
str(SGLANG_SERVER_PORT),
"--log-level-http",
"warning",
]
cmd.extend(mem_fraction_arg)
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
# Ensure the subprocess is terminated on exit
def _kill_proc():
proc.terminate()
atexit.register(_kill_proc)
# Shared variables between tasks
last_running_req, last_queue_req = 0, 0
server_printed_ready_message = False
last_semaphore_release = time.time()
async def process_line(line):
nonlocal last_running_req, last_queue_req, last_semaphore_release, server_printed_ready_message
sglang_logger.info(line)
# if the server hasn't initialized yet, log all the lines to the main logger also, so that the user
# can see any warnings/errors more easily
if not server_printed_ready_message:
logger.info(line)
if "Detected errors during sampling" in line:
logger.error("Cannot continue, sampling errors detected, model is probably corrupt")
sys.exit(1)
# TODO, need to trace down this issue in sglang itself, but it will otherwise cause the server to lock up
if "IndexError: list index out of range" in line:
logger.error("IndexError in model, restarting server")
proc.terminate()
if not server_printed_ready_message and "The server is fired up and ready to roll!" in line:
server_printed_ready_message = True
last_semaphore_release = time.time()
match = re.search(r"#running-req: (\d+)", line)
if match:
last_running_req = int(match.group(1))
match = re.search(r"#queue-req: (\d+)", line)
if match:
last_queue_req = int(match.group(1))
logger.info(f"sglang running req: {last_running_req} queue req: {last_queue_req}")
async def read_stream(stream):
while True:
line = await stream.readline()
if not line:
break
try:
line = line.decode("utf-8").rstrip()
await process_line(line)
except Exception as ex:
logger.warning(f"Got {ex} when reading log line from inference server, skipping")
async def timeout_task():
nonlocal last_running_req, last_queue_req, last_semaphore_release
try:
while True:
await asyncio.sleep(1)
if server_printed_ready_message and last_queue_req == 0 and time.time() - last_semaphore_release > 30 and semaphore.locked():
semaphore.release()
last_semaphore_release = time.time()
logger.info("Semaphore released, allowing a worker to proceed.")
except asyncio.CancelledError:
pass # Clean up if the task is cancelled
# Start tasks to read stdout, stderr, and handle timeout logic
stdout_task = asyncio.create_task(read_stream(proc.stdout))
stderr_task = asyncio.create_task(read_stream(proc.stderr))
timeout_task = asyncio.create_task(timeout_task())
try:
await proc.wait()
except asyncio.CancelledError:
logger.info("Got cancellation request for SGLang server")
proc.terminate()
raise
timeout_task.cancel()
await asyncio.gather(stdout_task, stderr_task, timeout_task, return_exceptions=True)
async def sglang_server_host(args, semaphore):
MAX_RETRIES = 5
retry = 0
while retry < MAX_RETRIES:
await sglang_server_task(args, semaphore)
logger.warning("SGLang server task ended")
retry += 1
if retry >= MAX_RETRIES:
logger.error(f"Ended up starting the sglang server more than {retry} times, cancelling pipeline")
logger.error("")
logger.error("Please make sure sglang is installed according to the latest instructions here: https://docs.sglang.ai/start/install.html")
sys.exit(1)
async def sglang_server_ready():
max_attempts = 300
delay_sec = 1
url = f"http://localhost:{SGLANG_SERVER_PORT}/v1/models"
for attempt in range(1, max_attempts + 1):
try:
async with httpx.AsyncClient() as session:
response = await session.get(url)
if response.status_code == 200:
logger.info("sglang server is ready.")
return
else:
logger.info(f"Attempt {attempt}: Unexpected status code {response.status_code}")
except Exception as e:
logger.warning(f"Attempt {attempt}: Please wait for sglang server to become ready...")
await asyncio.sleep(delay_sec)
raise Exception("sglang server did not become ready after waiting.")
async def download_model(model_name_or_path: str):
logger.info(f"Downloading model '{model_name_or_path}'")
snapshot_download(repo_id=model_name_or_path)
logger.info(f"Model download complete '{model_name_or_path}'")
async def metrics_reporter(work_queue):
while True:
# Leading newlines preserve table formatting in logs
logger.info(f"Queue remaining: {work_queue.size}")
logger.info("\n" + str(metrics))
logger.info("\n" + str(await tracker.get_status_table()))
await asyncio.sleep(10)
def submit_beaker_job(args):
from beaker import ( # type: ignore
Beaker,
Constraints,
EnvVar,
ExperimentSpec,
ImageSource,
Priority,
ResultSpec,
SecretNotFound,
TaskContext,
TaskResources,
TaskSpec,
)
b = Beaker.from_env(default_workspace=args.beaker_workspace)
account = b.account.whoami()
owner = account.name
beaker_image = f"jakep/olmocr-inference-{VERSION}"
task_name = f"olmocr-{os.path.basename(args.workspace.rstrip('/'))}"
# Take out --beaker flag so the workers will just run things
args_list = [arg for arg in sys.argv[1:] if arg != "--beaker"]
# Take out the --pdfs [arg] or --pdfs=[arg], since the queue is populated locally
args_list = [arg for i, arg in enumerate(args_list) if not (arg.startswith("--pdfs") or (i > 0 and args_list[i - 1] == "--pdfs"))]
try:
b.secret.get(f"{owner}-WEKA_ACCESS_KEY_ID", args.beaker_workspace)
b.secret.get(f"{owner}-WEKA_SECRET_ACCESS_KEY", args.beaker_workspace)
b.secret.get(f"{owner}-AWS_CREDENTIALS_FILE", args.beaker_workspace)
except SecretNotFound:
print(
f"Expected beaker secrets for accessing Weka and S3 are not found. Are you okay to write those to your beaker workspace {args.beaker_workspace}? [y/n]"
)
if input().strip().lower() != "y":
print("Exiting...")
sys.exit(1)
b.secret.write(f"{owner}-WEKA_ACCESS_KEY_ID", os.environ.get("WEKA_ACCESS_KEY_ID", ""), args.beaker_workspace)
b.secret.write(f"{owner}-WEKA_SECRET_ACCESS_KEY", os.environ.get("WEKA_SECRET_ACCESS_KEY", ""), args.beaker_workspace)
b.secret.write(
f"{owner}-AWS_CREDENTIALS_FILE",
open(os.path.join(os.path.expanduser("~"), ".aws", "credentials")).read(),
args.beaker_workspace,
)
env_var_secrets = [
EnvVar(name="WEKA_ACCESS_KEY_ID", secret=f"{owner}-WEKA_ACCESS_KEY_ID"),
EnvVar(name="WEKA_SECRET_ACCESS_KEY", secret=f"{owner}-WEKA_SECRET_ACCESS_KEY"),
EnvVar(name="AWS_CREDENTIALS_FILE", secret=f"{owner}-AWS_CREDENTIALS_FILE"),
]
try:
b.secret.get("OLMOCR_PREVIEW_HF_TOKEN", args.beaker_workspace)
env_var_secrets.append(EnvVar(name="HF_TOKEN", secret="OLMOCR_PREVIEW_HF_TOKEN"))
except SecretNotFound:
pass
try:
b.secret.get("OE_DATA_GCS_SA_KEY", args.beaker_workspace)
env_var_secrets.append(EnvVar(name="GOOGLE_APPLICATION_CREDENTIALS_FILE", secret="OE_DATA_GCS_SA_KEY"))
except SecretNotFound:
print("Input the olmo-gcs SA key if you would like to load weights from gcs (end with a double newline):")
lines = []
prev_empty = False
for line in iter(input, None):
if not line and prev_empty:
break
prev_empty = not line
lines.append(line)
gcs_sa_key = "\n".join(lines[:-1]).strip() # Remove the last empty line
if gcs_sa_key:
b.secret.write("OE_DATA_GCS_SA_KEY", gcs_sa_key, args.beaker_workspace)
env_var_secrets.append(EnvVar(name="GOOGLE_APPLICATION_CREDENTIALS_FILE", secret="OE_DATA_GCS_SA_KEY"))
# Create the experiment spec
experiment_spec = ExperimentSpec(
budget="ai2/oe-data",
description=task_name,
tasks=[
TaskSpec(
name=task_name,
propagate_failure=False,
propagate_preemption=False,
replicas=args.beaker_gpus,
context=TaskContext(
priority=Priority(args.beaker_priority),
preemptible=True,
),
image=ImageSource(beaker=beaker_image),
command=["python", "-m", "olmocr.pipeline"] + args_list,
env_vars=[EnvVar(name="BEAKER_JOB_NAME", value=task_name), EnvVar(name="OWNER", value=owner)] + env_var_secrets,
resources=TaskResources(gpu_count=1),
constraints=Constraints(cluster=args.beaker_cluster if isinstance(args.beaker_cluster, list) else [args.beaker_cluster]),
result=ResultSpec(path="/noop-results"),
)
],
)
experiment_data = b.experiment.create(spec=experiment_spec, workspace=args.beaker_workspace)
print(f"Experiment URL: https://beaker.org/ex/{experiment_data.id}")
def print_stats(args):
LONG_CONTEXT_THRESHOLD = 32768
assert args.workspace.startswith("s3://"), "Printing stats functionality only works with s3 workspaces for now."
# Get total work items and completed items
index_file_s3_path = os.path.join(args.workspace, "work_index_list.csv.zstd")
output_glob = os.path.join(args.workspace, "results", "*.jsonl")
done_work_items = expand_s3_glob(workspace_s3, output_glob)
work_queue = {parts[0]: parts[1:] for line in download_zstd_csv(workspace_s3, index_file_s3_path) if (parts := line.strip().split(",")) and line.strip()}
total_items = len(work_queue)
completed_items = len(done_work_items)
def process_output_file(s3_path):
try:
data = get_s3_bytes(workspace_s3, s3_path)
doc_count = 0
total_input_tokens = 0
total_output_tokens = 0
total_pages = 0
total_fallback_pages = 0
processed_paths = set()
# Counters for long context docs within a single file
long_context_docs = 0
long_context_tokens = 0
for line in data.decode("utf-8").splitlines():
if line.strip():
doc = json.loads(line)
doc_count += 1
doc_input_tokens = doc["metadata"].get("total-input-tokens", 0)
doc_output_tokens = doc["metadata"].get("total-output-tokens", 0)
doc_pages = doc["metadata"].get("pdf-total-pages", 0)
doc_fallback_pages = doc["metadata"].get("total-fallback-pages", 0)
total_input_tokens += doc_input_tokens
total_output_tokens += doc_output_tokens
total_pages += doc_pages
total_fallback_pages += doc_fallback_pages
processed_paths.add(doc["metadata"]["Source-File"])
# Check if this doc exceeds the long context threshold
if doc_output_tokens > LONG_CONTEXT_THRESHOLD:
long_context_docs += 1
long_context_tokens += doc_output_tokens
return (
doc_count,
total_input_tokens,
total_output_tokens,
total_pages,
total_fallback_pages,
processed_paths,
long_context_docs,
long_context_tokens,
)
except Exception as e:
logger.warning(f"Error processing {s3_path}: {e}")
return 0, 0, 0, 0, 0, set(), 0, 0
print("\nProcessing output files...")
docs_total = 0
input_tokens_total = 0
output_tokens_total = 0
pages_total = 0
fallback_pages_total = 0
all_processed_paths = set()
original_paths = set()
# Counters for long context documents across all files
long_context_docs_count = 0
long_context_tokens_total = 0
# First collect all original PDF paths
for done_work_item in done_work_items:
if match := re.search(r"output_(\w+).jsonl", done_work_item):
done_work_hash = match.group(1)
original_paths.update(work_queue[done_work_hash])
with ThreadPoolExecutor() as executor:
futures = {executor.submit(process_output_file, item): item for item in done_work_items}
for future in tqdm(as_completed(futures), total=len(futures)):
(doc_count, input_tokens, output_tokens, pages, fallback_pages, processed_paths, long_context_docs, long_context_tokens) = future.result()
docs_total += doc_count
input_tokens_total += input_tokens
output_tokens_total += output_tokens
pages_total += pages
fallback_pages_total += fallback_pages
all_processed_paths.update(processed_paths)
long_context_docs_count += long_context_docs
long_context_tokens_total += long_context_tokens
skipped_paths = original_paths - all_processed_paths
print("\nWork Items Status:")
print(f"Total work items: {total_items:,}")
print(f"Completed items: {completed_items:,}")
print(f"Remaining items: {total_items - completed_items:,}")
print("\nResults:")
print(f"Total documents processed: {docs_total:,}")
print(f"Total documents skipped: {len(skipped_paths):,}")
print(f"Total pages on fallback: {fallback_pages_total:,}")
print(f"Total pages processed: {pages_total:,}")
print(f"\nTotal output tokens: {output_tokens_total:,}")
print(f"Projected output tokens: {round((output_tokens_total/max(1, completed_items))*total_items):,}")
print(f"\nAverage pages per doc: {pages_total/max(1,docs_total):,.1f}")
print(f"Average output tokens per doc: {output_tokens_total/max(1,docs_total):,.1f}")
print(f"Average output tokens per page: {output_tokens_total/max(1,pages_total):,.1f}")
# Print long context documents stats
print(f"\nLong Context Documents (>{LONG_CONTEXT_THRESHOLD} tokens): {long_context_docs_count:,}")
print(f"Total tokens in long context documents: {long_context_tokens_total:,}")
async def main():
parser = argparse.ArgumentParser(description="Manager for running millions of PDFs through a batch inference pipeline")
parser.add_argument(
"workspace",
help="The filesystem path where work will be stored, can be a local folder, or an s3 path if coordinating work with many workers, s3://bucket/prefix/ ",
)
parser.add_argument(
"--pdfs",
nargs="*",
help="Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths",
default=None,
)
parser.add_argument("--workspace_profile", help="S3 configuration profile for accessing the workspace", default=None)
parser.add_argument("--pdf_profile", help="S3 configuration profile for accessing the raw pdf documents", default=None)
parser.add_argument("--pages_per_group", type=int, default=500, help="Aiming for this many pdf pages per work item group")
parser.add_argument("--max_page_retries", type=int, default=8, help="Max number of times we will retry rendering a page")
parser.add_argument("--max_page_error_rate", type=float, default=0.004, help="Rate of allowable failed pages in a document, 1/250 by default")
parser.add_argument("--workers", type=int, default=8, help="Number of workers to run at a time")
parser.add_argument("--apply_filter", action="store_true", help="Apply basic filtering to English pdfs which are not forms, and not likely seo spam")
parser.add_argument("--stats", action="store_true", help="Instead of running any job, reports some statistics about the current workspace")
# Model parameters
parser.add_argument(
"--model",
help="List of paths where you can find the model to convert this pdf. You can specify several different paths here, and the script will try to use the one which is fastest to access",
default="allenai/olmOCR-7B-0225-preview",
)
parser.add_argument("--model_max_context", type=int, default="8192", help="Maximum context length that the model was fine tuned under")
parser.add_argument("--model_chat_template", type=str, default="qwen2-vl", help="Chat template to pass to sglang server")
parser.add_argument("--target_longest_image_dim", type=int, help="Dimension on longest side to use for rendering the pdf pages", default=1024)
parser.add_argument("--target_anchor_text_len", type=int, help="Maximum amount of anchor text to use (characters)", default=6000)
# Beaker/job running stuff
parser.add_argument("--beaker", action="store_true", help="Submit this job to beaker instead of running locally")
parser.add_argument("--beaker_workspace", help="Beaker workspace to submit to", default="ai2/olmocr")
parser.add_argument(
"--beaker_cluster",
help="Beaker clusters you want to run on",
default=["ai2/jupiter-cirrascale-2", "ai2/ceres-cirrascale", "ai2/neptune-cirrascale", "ai2/saturn-cirrascale", "ai2/augusta-google-1"],
)
parser.add_argument("--beaker_gpus", type=int, default=1, help="Number of gpu replicas to run")
parser.add_argument("--beaker_priority", type=str, default="normal", help="Beaker priority level for the job")
args = parser.parse_args()
global workspace_s3, pdf_s3
# setup the job to work in beaker environment, load secrets, adjust logging, etc.
if "BEAKER_JOB_NAME" in os.environ:
sglang_logger.addHandler(console_handler)
cred_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
os.makedirs(os.path.dirname(cred_path), exist_ok=True)
with open(cred_path, "w") as f:
f.write(os.environ.get("AWS_CREDENTIALS_FILE"))
cred_path = os.path.join(os.path.expanduser("~"), ".gcs", "credentials")
os.makedirs(os.path.dirname(cred_path), exist_ok=True)
with open(cred_path, "w") as f:
f.write(os.environ.get("GOOGLE_APPLICATION_CREDENTIALS_FILE"))
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = cred_path
workspace_s3 = boto3.client("s3")
pdf_s3 = boto3.client("s3")
if args.workspace_profile:
workspace_session = boto3.Session(profile_name=args.workspace_profile)
workspace_s3 = workspace_session.client("s3")
if args.pdf_profile:
pdf_session = boto3.Session(profile_name=args.pdf_profile)
pdf_s3 = pdf_session.client("s3")
# We need poppler to load the initial pdfs, even if we are not processing them here
check_poppler_version()
# Create work queue
if args.workspace.startswith("s3://"):
work_queue = S3WorkQueue(workspace_s3, args.workspace)
else:
work_queue = LocalWorkQueue(args.workspace)
if args.pdfs:
logger.info("Got --pdfs argument, going to add to the work queue")
pdf_work_paths = set()
for pdf_path in args.pdfs:
# Expand s3 paths
if pdf_path.startswith("s3://"):
logger.info(f"Expanding s3 glob at {pdf_path}")
pdf_work_paths |= set(expand_s3_glob(pdf_s3, pdf_path))
elif os.path.exists(pdf_path):
if pdf_path.endswith(".pdf"):
if open(pdf_path, "rb").read(4) == b"%PDF":
logger.info(f"Loading file at {pdf_path} as PDF document")
pdf_work_paths.add(pdf_path)
else:
logger.warning(f"File at {pdf_path} is not a valid PDF")
elif pdf_path.endswith(".txt"):
logger.info(f"Loading file at {pdf_path} as list of paths")
with open(pdf_path, "r") as f:
pdf_work_paths |= set(filter(None, (line.strip() for line in f)))
else:
raise ValueError(f"Unsupported file extension for {pdf_path}")
else:
raise ValueError("pdfs argument needs to be either a local path, an s3 path, or an s3 glob pattern...")
logger.info(f"Found {len(pdf_work_paths):,} total pdf paths to add")
# Estimate average pages per pdf
sample_size = min(100, len(pdf_work_paths))
sampled_pdfs = random.sample(list(pdf_work_paths), sample_size)
page_counts = []
for pdf in tqdm(sampled_pdfs, desc="Sampling PDFs to calculate optimal length"):
try:
# Download the PDF to a temp file
with tempfile.NamedTemporaryFile(suffix=".pdf") as tmp_file:
tmp_file.write(get_s3_bytes(pdf_s3, pdf))
tmp_file.flush()
reader = PdfReader(tmp_file.name)
page_counts.append(len(reader.pages))
except Exception as e:
logger.warning(f"Failed to read {pdf}: {e}")
if page_counts:
avg_pages_per_pdf = sum(page_counts) / len(page_counts)
else:
logger.warning("Could not read any PDFs to estimate average page count.")
avg_pages_per_pdf = 10 # Default to 10 pages per PDF if sampling fails
items_per_group = max(1, int(args.pages_per_group / avg_pages_per_pdf))
logger.info(f"Calculated items_per_group: {items_per_group} based on average pages per PDF: {avg_pages_per_pdf:.2f}")
# Now call populate_queue
await work_queue.populate_queue(pdf_work_paths, items_per_group)
if args.stats:
print_stats(args)
return
if args.beaker:
submit_beaker_job(args)
return
# If you get this far, then you are doing inference and need a GPU
check_sglang_version()
check_torch_gpu_available()
logger.info(f"Starting pipeline with PID {os.getpid()}")
# Download the model before you do anything else
await download_model(args.model)
# Initialize the work queue
await work_queue.initialize_queue()
# Create a semaphore to control worker access
# We only allow one worker to move forward with requests, until the server has no more requests in its queue
# This lets us get full utilization by having many workers, but also to be outputting dolma docs as soon as possible
# As soon as one worker is no longer saturating the gpu, the next one can start sending requests
semaphore = asyncio.Semaphore(1)
sglang_server = asyncio.create_task(sglang_server_host(args, semaphore))
await sglang_server_ready()
metrics_task = asyncio.create_task(metrics_reporter(work_queue))
# Create worker tasks to process the queue concurrently.
worker_tasks = []
for i in range(args.workers):
task = asyncio.create_task(worker(args, work_queue, semaphore, worker_id=i))
worker_tasks.append(task)
# Wait for all worker tasks to finish
await asyncio.gather(*worker_tasks)
# Wait for server to stop
process_pool.shutdown(wait=False)
sglang_server.cancel()
metrics_task.cancel()
logger.info("Work done")
if __name__ == "__main__":
asyncio.run(main())
from .prompts import (
PageResponse,
build_finetuning_prompt,
build_openai_silver_data_prompt,
extract_raw_text,
openai_response_format_schema,
)
# This file generates anchor text in a variety of different ways
# The goal here is to generate a bit of text which can be used to help prompt a VLM
# to better understand a document
import random
import re
import subprocess
from dataclasses import dataclass
from typing import List, Literal
import ftfy
import pypdfium2 as pdfium
from pypdf import PdfReader
from pypdf.generic import RectangleObject
from olmocr.filter.coherency import get_document_coherency
def get_anchor_text(
local_pdf_path: str, page: int, pdf_engine: Literal["pdftotext", "pdfium", "pypdf", "topcoherency", "pdfreport"], target_length: int = 4000
) -> str:
assert page > 0, "Pages are 1-indexed in pdf-land"
if pdf_engine == "pdftotext":
return _get_pdftotext(local_pdf_path, page)
elif pdf_engine == "pdfium":
return _get_pdfium(local_pdf_path, page)
elif pdf_engine == "pypdf":
return _get_pypdf_raw(local_pdf_path, page)
elif pdf_engine == "topcoherency":
options = {
"pdftotext": _get_pdftotext(local_pdf_path, page),
"pdfium": _get_pdfium(local_pdf_path, page),
"pypdf_raw": _get_pypdf_raw(local_pdf_path, page),
}
scores = {label: get_document_coherency(text) for label, text in options.items()}
best_option_label = max(scores, key=scores.get) # type: ignore
best_option = options[best_option_label]
print(f"topcoherency chosen: {best_option_label}")
return best_option
elif pdf_engine == "pdfreport":
return _linearize_pdf_report(_pdf_report(local_pdf_path, page), max_length=target_length)
else:
raise NotImplementedError("Unknown engine")
def _get_pdftotext(local_pdf_path: str, page: int) -> str:
pdftotext_result = subprocess.run(
["pdftotext", "-f", str(page), "-l", str(page), local_pdf_path, "-"],
timeout=60,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
assert pdftotext_result.returncode == 0
return pdftotext_result.stdout.decode("utf-8")
def _get_pypdf_raw(local_pdf_path: str, page: int) -> str:
reader = PdfReader(local_pdf_path)
pypage = reader.pages[page - 1]
return pypage.extract_text()
def _get_pdfium(local_pdf_path: str, page: int) -> str:
pdf = pdfium.PdfDocument(local_pdf_path)
textpage = pdf[page - 1].get_textpage()
return textpage.get_text_bounded()
def _transform_point(x, y, m):
x_new = m[0] * x + m[2] * y + m[4]
y_new = m[1] * x + m[3] * y + m[5]
return x_new, y_new
def _mult(m: List[float], n: List[float]) -> List[float]:
return [
m[0] * n[0] + m[1] * n[2],
m[0] * n[1] + m[1] * n[3],
m[2] * n[0] + m[3] * n[2],
m[2] * n[1] + m[3] * n[3],
m[4] * n[0] + m[5] * n[2] + n[4],
m[4] * n[1] + m[5] * n[3] + n[5],
]
@dataclass(frozen=True)
class Element:
pass
@dataclass(frozen=True)
class BoundingBox:
x0: float
y0: float
x1: float
y1: float
@staticmethod
def from_rectangle(rect: RectangleObject) -> "BoundingBox":
return BoundingBox(rect[0], rect[1], rect[2], rect[3])
@dataclass(frozen=True)
class TextElement(Element):
text: str
x: float
y: float
@dataclass(frozen=True)
class ImageElement(Element):
name: str
bbox: BoundingBox
@dataclass(frozen=True)
class PageReport:
mediabox: BoundingBox
text_elements: List[TextElement]
image_elements: List[ImageElement]
def _pdf_report(local_pdf_path: str, page_num: int) -> PageReport:
reader = PdfReader(local_pdf_path)
page = reader.pages[page_num - 1]
resources = page.get("/Resources", {})
xobjects = resources.get("/XObject", {})
text_elements, image_elements = [], []
def visitor_body(text, cm, tm, font_dict, font_size):
txt2user = _mult(tm, cm)
text_elements.append(TextElement(text, txt2user[4], txt2user[5]))
def visitor_op(op, args, cm, tm):
if op == b"Do":
xobject_name = args[0]
xobject = xobjects.get(xobject_name)
if xobject and xobject["/Subtype"] == "/Image":
# Compute image bbox
# The image is placed according to the CTM
_width = xobject.get("/Width")
_height = xobject.get("/Height")
x0, y0 = _transform_point(0, 0, cm)
x1, y1 = _transform_point(1, 1, cm)
image_elements.append(ImageElement(xobject_name, BoundingBox(min(x0, x1), min(y0, y1), max(x0, x1), max(y0, y1))))
page.extract_text(visitor_text=visitor_body, visitor_operand_before=visitor_op)
return PageReport(
mediabox=BoundingBox.from_rectangle(page.mediabox),
text_elements=text_elements,
image_elements=image_elements,
)
def _merge_image_elements(images: List[ImageElement], tolerance: float = 0.5) -> List[ImageElement]:
n = len(images)
parent = list(range(n)) # Initialize Union-Find parent pointers
def find(i):
# Find with path compression
root = i
while parent[root] != root:
root = parent[root]
while parent[i] != i:
parent_i = parent[i]
parent[i] = root
i = parent_i
return root
def union(i, j):
# Union by attaching root of one tree to another
root_i = find(i)
root_j = find(j)
if root_i != root_j:
parent[root_i] = root_j
def bboxes_overlap(b1: BoundingBox, b2: BoundingBox, tolerance: float) -> bool:
# Compute horizontal and vertical distances between boxes
h_dist = max(0, max(b1.x0, b2.x0) - min(b1.x1, b2.x1))
v_dist = max(0, max(b1.y0, b2.y0) - min(b1.y1, b2.y1))
# Check if distances are within tolerance
return h_dist <= tolerance and v_dist <= tolerance
# Union overlapping images
for i in range(n):
for j in range(i + 1, n):
if bboxes_overlap(images[i].bbox, images[j].bbox, tolerance):
union(i, j)
# Group images by their root parent
groups: dict[int, list[int]] = {}
for i in range(n):
root = find(i)
groups.setdefault(root, []).append(i)
# Merge images in the same group
merged_images = []
for indices in groups.values():
# Initialize merged bounding box
merged_bbox = images[indices[0]].bbox
merged_name = images[indices[0]].name
for idx in indices[1:]:
bbox = images[idx].bbox
# Expand merged_bbox to include the current bbox
merged_bbox = BoundingBox(
x0=min(merged_bbox.x0, bbox.x0),
y0=min(merged_bbox.y0, bbox.y0),
x1=max(merged_bbox.x1, bbox.x1),
y1=max(merged_bbox.y1, bbox.y1),
)
# Optionally, update the name
merged_name += f"+{images[idx].name}"
merged_images.append(ImageElement(name=merged_name, bbox=merged_bbox))
# Return the merged images along with other elements
return merged_images
def _cap_split_string(text: str, max_length: int) -> str:
if len(text) <= max_length:
return text
head_length = max_length // 2 - 3
tail_length = head_length
head = text[:head_length].rsplit(" ", 1)[0] or text[:head_length]
tail = text[-tail_length:].split(" ", 1)[-1] or text[-tail_length:]
return f"{head} ... {tail}"
def _cleanup_element_text(element_text: str) -> str:
MAX_TEXT_ELEMENT_LENGTH = 250
TEXT_REPLACEMENTS = {"[": "\\[", "]": "\\]", "\n": "\\n", "\r": "\\r", "\t": "\\t"}
text_replacement_pattern = re.compile("|".join(re.escape(key) for key in TEXT_REPLACEMENTS.keys()))
element_text = ftfy.fix_text(element_text).strip()
# Replace square brackets with escaped brackets and other escaped chars
element_text = text_replacement_pattern.sub(lambda match: TEXT_REPLACEMENTS[match.group(0)], element_text)
return _cap_split_string(element_text, MAX_TEXT_ELEMENT_LENGTH)
def _linearize_pdf_report(report: PageReport, max_length: int = 4000) -> str:
result = ""
result += f"Page dimensions: {report.mediabox.x1:.1f}x{report.mediabox.y1:.1f}\n"
if max_length < 20:
return result
images = _merge_image_elements(report.image_elements)
# Process image elements
image_strings = []
for element in images:
image_str = f"[Image {element.bbox.x0:.0f}x{element.bbox.y0:.0f} to {element.bbox.x1:.0f}x{element.bbox.y1:.0f}]\n"
# Use element's unique identifier (e.g., id or position) for comparison
image_strings.append((element, image_str))
# Process text elements
text_strings = []
for element in report.text_elements: # type: ignore
if len(element.text.strip()) == 0: # type: ignore
continue
element_text = _cleanup_element_text(element.text) # type: ignore
text_str = f"[{element.x:.0f}x{element.y:.0f}]{element_text}\n" # type: ignore
text_strings.append((element, text_str))
# Combine all elements with their positions for sorting
all_elements: list[tuple[str, ImageElement, str, tuple[float, float]]] = []
for elem, s in image_strings:
position = (elem.bbox.x0, elem.bbox.y0)
all_elements.append(("image", elem, s, position))
for elem, s in text_strings:
position = (elem.x, elem.y) # type: ignore
all_elements.append(("text", elem, s, position))
# Calculate total length
total_length = len(result) + sum(len(s) for _, _, s, _ in all_elements)
if total_length <= max_length:
# Include all elements
for _, _, s, _ in all_elements:
result += s
return result
# Identify elements with min/max coordinates
edge_elements = set()
if images:
min_x0_image = min(images, key=lambda e: e.bbox.x0)
max_x1_image = max(images, key=lambda e: e.bbox.x1)
min_y0_image = min(images, key=lambda e: e.bbox.y0)
max_y1_image = max(images, key=lambda e: e.bbox.y1)
edge_elements.update([min_x0_image, max_x1_image, min_y0_image, max_y1_image])
if report.text_elements:
text_elements = [e for e in report.text_elements if len(e.text.strip()) > 0]
if text_elements:
min_x_text = min(text_elements, key=lambda e: e.x)
max_x_text = max(text_elements, key=lambda e: e.x)
min_y_text = min(text_elements, key=lambda e: e.y)
max_y_text = max(text_elements, key=lambda e: e.y)
edge_elements.update([min_x_text, max_x_text, min_y_text, max_y_text]) # type: ignore
# Keep track of element IDs to prevent duplication
selected_element_ids = set()
selected_elements = []
# Include edge elements first
for elem_type, elem, s, position in all_elements:
if elem in edge_elements and id(elem) not in selected_element_ids:
selected_elements.append((elem_type, elem, s, position))
selected_element_ids.add(id(elem))
# Calculate remaining length
current_length = len(result) + sum(len(s) for _, _, s, _ in selected_elements)
_remaining_length = max_length - current_length
# Exclude edge elements from the pool
remaining_elements = [(elem_type, elem, s, position) for elem_type, elem, s, position in all_elements if id(elem) not in selected_element_ids]
# Sort remaining elements by their positions (e.g., x-coordinate and then y-coordinate)
# remaining_elements.sort(key=lambda x: (x[3][0], x[3][1]))
# Shuffle remaining elements randomly
random.shuffle(remaining_elements)
# Add elements until reaching max_length
for elem_type, elem, s, position in remaining_elements:
if current_length + len(s) > max_length:
break
selected_elements.append((elem_type, elem, s, position))
selected_element_ids.add(id(elem))
current_length += len(s)
# Sort selected elements by their positions to maintain logical order
selected_elements.sort(key=lambda x: (x[3][0], x[3][1]))
# Build the final result
for _, _, s, _ in selected_elements:
result += s
return result
import re
from dataclasses import dataclass
from typing import Optional
# This is the prompt we use for getting chat gpt 4o to convert documents into our silver training data
def build_openai_silver_data_prompt(base_text: str) -> str:
return (
f"Below is the image of one page of a PDF document, as well as some raw textual content that was previously extracted for it that includes position information for each image and block of text (The origin [0x0] of the coordinates is in the lower left corner of the image). "
f"Just return the plain text representation of this document as if you were reading it naturally.\n"
f"Turn equations into a LaTeX representation, and tables into markdown format. Remove the headers and footers, but keep references and footnotes.\n"
f"Read any natural handwriting.\n"
f"This is likely one page out of several in the document, so be sure to preserve any sentences that come from the previous page, or continue onto the next page, exactly as they are.\n"
f"If there is no text at all that you think you should read, you can output null.\n"
f"Do not hallucinate.\n"
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
)
@dataclass(frozen=True)
class PageResponse:
primary_language: Optional[str]
is_rotation_valid: bool
rotation_correction: int
is_table: bool
is_diagram: bool
natural_text: Optional[str]
def __post_init__(self):
# Validate rotation_correction is one of the allowed values
if self.rotation_correction not in {0, 90, 180, 270}:
raise ValueError("rotation_correction must be one of [0, 90, 180, 270].")
# Type checks
if not isinstance(self.primary_language, (str, type(None))):
raise TypeError("primary_language must be of type Optional[str].")
if not isinstance(self.is_rotation_valid, bool):
raise TypeError("is_rotation_valid must be of type bool.")
if not isinstance(self.rotation_correction, int):
raise TypeError("rotation_correction must be of type int.")
if not isinstance(self.is_table, bool):
raise TypeError("is_table must be of type bool.")
if not isinstance(self.is_diagram, bool):
raise TypeError("is_diagram must be of type bool.")
if not isinstance(self.natural_text, (str, type(None))):
raise TypeError("natural_text must be of type Optional[str].")
def openai_response_format_schema() -> dict:
return {
"type": "json_schema",
"json_schema": {
"name": "page_response",
"schema": {
"type": "object",
"properties": {
"primary_language": {
"type": ["string", "null"],
"description": "The primary language of the text using two-letter codes or null if there is no text at all that you think you should read.",
},
"is_rotation_valid": {
"type": "boolean",
"description": "Is this page oriented correctly for reading? Answer only considering the textual content, do not factor in the rotation of any charts, tables, drawings, or figures.",
},
"rotation_correction": {
"type": "integer",
"description": "Indicates the degree of clockwise rotation needed if the page is not oriented correctly.",
"enum": [0, 90, 180, 270],
"default": 0,
},
"is_table": {
"type": "boolean",
"description": "Indicates if the majority of the page content is in tabular format.",
},
"is_diagram": {
"type": "boolean",
"description": "Indicates if the majority of the page content is a visual diagram.",
},
"natural_text": {
"type": ["string", "null"],
"description": "The natural text content extracted from the page.",
},
},
"additionalProperties": False,
"required": [
"primary_language",
"is_rotation_valid",
"rotation_correction",
"is_table",
"is_diagram",
"natural_text",
],
},
"strict": True,
},
}
# This is a base prompt that will be used for training and running the fine tuned model
# It's simplified from the prompt which was used to generate the silver data, and can change from dataset to dataset
def build_finetuning_prompt(base_text: str) -> str:
return (
f"Below is the image of one page of a document, as well as some raw textual content that was previously extracted for it. "
f"Just return the plain text representation of this document as if you were reading it naturally.\n"
f"Do not hallucinate.\n"
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
)
# Extracts the anchor text component from an existing prompt string
def extract_raw_text(prompt: str) -> str:
pattern = r"RAW_TEXT_START\s*\n(.*?)\nRAW_TEXT_END"
# Use re.DOTALL to ensure that the dot matches newline characters
match = re.search(pattern, prompt, re.DOTALL)
if match:
return match.group(1).strip()
else:
raise ValueError("Prompt does not contain raw text")
import random
import string
import time
import unittest
import re
class RepeatDetector:
def __init__(self, max_ngram_size: int = 10):
self.max_ngram_size = max_ngram_size
self.data = ""
def add_letters(self, new_str: str):
self.data += new_str
def ngram_repeats(self) -> list[int]:
result = [0] * self.max_ngram_size
if not self.data:
return result
# Normalize all whitespace to single spaces
text = re.sub(r'\s+', ' ', self.data)
# For each n-gram size
for size in range(1, self.max_ngram_size + 1):
if len(text) < size:
continue
# Get the last n-gram
target = text[-size:]
# Count backwards from the end to find repeats
count = 0
pos = len(text) - size # Start position for previous n-gram
while pos >= 0:
if text[pos : pos + size] == target:
count += 1
pos -= size # Move back by the size of the n-gram
else:
break
result[size - 1] = count
return result
class RepeatDetectorTest(unittest.TestCase):
def test_basicTest1(self):
d = RepeatDetector(max_ngram_size=3)
d.add_letters("a")
self.assertEqual(d.ngram_repeats(), [1, 0, 0])
def test_basicTest2(self):
d = RepeatDetector(max_ngram_size=3)
d.add_letters("abab")
self.assertEqual(d.ngram_repeats(), [1, 2, 1])
def test_longer_sequence(self):
d = RepeatDetector(max_ngram_size=3)
d.add_letters("aabaabaa")
self.assertEqual(d.ngram_repeats(), [2, 1, 2])
def test_no_repeats(self):
d = RepeatDetector(max_ngram_size=3)
d.add_letters("abc")
self.assertEqual(d.ngram_repeats(), [1, 1, 1])
def test_empty_data(self):
d = RepeatDetector(max_ngram_size=3)
self.assertEqual(d.ngram_repeats(), [0, 0, 0])
def test_max_ngram_greater_than_data_length(self):
d = RepeatDetector(max_ngram_size=5)
d.add_letters("abc")
self.assertEqual(d.ngram_repeats(), [1, 1, 1, 0, 0])
def test_large_single_char(self):
d = RepeatDetector(max_ngram_size=5)
d.add_letters("a" * 10000)
self.assertEqual(d.ngram_repeats(), [10000, 5000, 3333, 2500, 2000])
def test_repeating_pattern(self):
d = RepeatDetector(max_ngram_size=5)
d.add_letters("abcabcabcabc")
self.assertEqual(d.ngram_repeats(), [1, 1, 4, 1, 1])
def test_mixed_characters(self):
d = RepeatDetector(max_ngram_size=4)
d.add_letters("abcdabcabcdabc")
self.assertEqual(d.ngram_repeats(), [1, 1, 1, 1])
def test_palindrome(self):
d = RepeatDetector(max_ngram_size=5)
d.add_letters("racecar")
self.assertEqual(d.ngram_repeats(), [1, 1, 1, 1, 1])
def test_repeats_not_at_end(self):
d = RepeatDetector(max_ngram_size=3)
d.add_letters("abcabcxyz")
self.assertEqual(d.ngram_repeats(), [1, 1, 1])
def test_long_repeat_at_end(self):
d = RepeatDetector(max_ngram_size=5)
d.add_letters("abcabcabcabcabcabcabcabcabcabc")
self.assertEqual(d.ngram_repeats(), [1, 1, 10, 1, 1])
def test_large_repeating_pattern(self):
d = RepeatDetector(max_ngram_size=4)
pattern = "abcd"
repeat_count = 1000
d.add_letters(pattern * repeat_count)
self.assertEqual(d.ngram_repeats(), [1, 1, 1, repeat_count])
def test_unicode_characters(self):
d = RepeatDetector(max_ngram_size=3)
d.add_letters("αβγαβγ")
self.assertEqual(d.ngram_repeats(), [1, 1, 2])
def test_random_data(self):
random.seed(42)
d = RepeatDetector(max_ngram_size=5)
data = "".join(random.choices(string.ascii_letters, k=10000))
d.add_letters(data)
counts = d.ngram_repeats()
for count in counts:
self.assertTrue(0 <= count <= len(data))
def test_special_characters(self):
d = RepeatDetector(max_ngram_size=4)
d.add_letters("@@##@@##")
self.assertEqual(d.ngram_repeats(), [2, 1, 1, 2])
def test_incremental_addition(self):
d = RepeatDetector(max_ngram_size=3)
d.add_letters("abc")
self.assertEqual(d.ngram_repeats(), [1, 1, 1])
d.add_letters("abc")
self.assertEqual(d.ngram_repeats(), [1, 1, 2])
d.add_letters("abc")
self.assertEqual(d.ngram_repeats(), [1, 1, 3])
def test_long_non_repeating_sequence(self):
d = RepeatDetector(max_ngram_size=5)
d.add_letters("abcdefghijklmnopqrstuvwxyz")
self.assertEqual(d.ngram_repeats(), [1, 1, 1, 1, 1])
def test_alternating_characters(self):
d = RepeatDetector(max_ngram_size=4)
d.add_letters("ababababab")
self.assertEqual(d.ngram_repeats(), [1, 5, 1, 2])
class BenchmarkRepeatDetect(unittest.TestCase):
def testLargeRandom(self):
all_data = []
for iter in range(1000):
all_data.append("".join(random.choices("a", k=10000)))
start = time.perf_counter()
for data in all_data:
d = RepeatDetector(max_ngram_size=20)
d.add_letters(data)
print(d.ngram_repeats())
end = time.perf_counter()
print(f"testLargeRandom took {end-start:0.0001f} seconds")
if __name__ == "__main__":
unittest.main()
import base64
import concurrent.futures
import glob
import hashlib
import logging
import os
import posixpath
import time
from io import BytesIO, TextIOWrapper
from pathlib import Path
from typing import List, Optional
from urllib.parse import urlparse
import boto3
import requests # type: ignore
import zstandard as zstd
from boto3.s3.transfer import TransferConfig
from botocore.config import Config
from botocore.exceptions import ClientError
from google.cloud import storage
from tqdm import tqdm
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def parse_s3_path(s3_path: str) -> tuple[str, str]:
if not (s3_path.startswith("s3://") or s3_path.startswith("gs://") or s3_path.startswith("weka://")):
raise ValueError("s3_path must start with s3://, gs://, or weka://")
parsed = urlparse(s3_path)
bucket = parsed.netloc
key = parsed.path.lstrip("/")
return bucket, key
def expand_s3_glob(s3_client, s3_glob: str) -> dict[str, str]:
"""
Expand an S3 path that may or may not contain wildcards (e.g., *.pdf).
Returns a dict of {'s3://bucket/key': etag} for each matching object.
Raises a ValueError if nothing is found or if a bare prefix was provided by mistake.
"""
parsed = urlparse(s3_glob)
if not parsed.scheme.startswith("s3"):
raise ValueError("Path must start with s3://")
bucket = parsed.netloc
raw_path = parsed.path.lstrip("/")
prefix = posixpath.dirname(raw_path)
pattern = posixpath.basename(raw_path)
# Case 1: We have a wildcard
if any(wc in pattern for wc in ["*", "?", "[", "]"]):
if prefix and not prefix.endswith("/"):
prefix += "/"
paginator = s3_client.get_paginator("list_objects_v2")
matched = {}
for page in paginator.paginate(Bucket=bucket, Prefix=prefix):
for obj in page.get("Contents", []):
key = obj["Key"]
if glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)): # type: ignore
matched[f"s3://{bucket}/{key}"] = obj["ETag"].strip('"')
return matched
# Case 2: No wildcard → single file or a bare prefix
try:
# Attempt to head a single file
resp = s3_client.head_object(Bucket=bucket, Key=raw_path)
if resp["ContentType"] == "application/x-directory":
raise ValueError(f"'{s3_glob}' appears to be a folder. " f"Use a wildcard (e.g., '{s3_glob.rstrip('/')}/*.pdf') to match files.")
return {f"s3://{bucket}/{raw_path}": resp["ETag"].strip('"')}
except ClientError as e:
if e.response["Error"]["Code"] == "404":
# Check if it's actually a folder with contents
check_prefix = raw_path if raw_path.endswith("/") else raw_path + "/"
paginator = s3_client.get_paginator("list_objects_v2")
for page in paginator.paginate(Bucket=bucket, Prefix=check_prefix):
if page.get("Contents"):
raise ValueError(f"'{s3_glob}' appears to be a folder. " f"Use a wildcard (e.g., '{s3_glob.rstrip('/')}/*.pdf') to match files.")
raise ValueError(f"No object or prefix found at '{s3_glob}'. Check your path or add a wildcard.")
else:
raise
def get_s3_bytes(s3_client, s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes:
# Fall back for local files
if os.path.exists(s3_path):
assert start_index is None and end_index is None, "Range query not supported yet"
with open(s3_path, "rb") as f:
return f.read()
bucket, key = parse_s3_path(s3_path)
# Build the range header if start_index and/or end_index are specified
range_header = None
if start_index is not None and end_index is not None:
# Range: bytes=start_index-end_index
range_value = f"bytes={start_index}-{end_index}"
range_header = {"Range": range_value}
elif start_index is not None and end_index is None:
# Range: bytes=start_index-
range_value = f"bytes={start_index}-"
range_header = {"Range": range_value}
elif start_index is None and end_index is not None:
# Range: bytes=-end_index (last end_index bytes)
range_value = f"bytes=-{end_index}"
range_header = {"Range": range_value}
if range_header:
obj = s3_client.get_object(Bucket=bucket, Key=key, Range=range_header["Range"])
else:
obj = s3_client.get_object(Bucket=bucket, Key=key)
return obj["Body"].read()
def get_s3_bytes_with_backoff(s3_client, pdf_s3_path, max_retries: int = 8, backoff_factor: int = 2):
attempt = 0
while attempt < max_retries:
try:
return get_s3_bytes(s3_client, pdf_s3_path)
except ClientError as e:
# Check for some error kinds AccessDenied error and raise immediately
if e.response["Error"]["Code"] in ("AccessDenied", "NoSuchKey"):
logger.error(f"{e.response['Error']['Code']} error when trying to access {pdf_s3_path}: {e}")
raise
else:
wait_time = backoff_factor**attempt
logger.warning(f"Attempt {attempt+1} failed to get_s3_bytes for {pdf_s3_path}: {e}. Retrying in {wait_time} seconds...")
time.sleep(wait_time)
attempt += 1
except Exception as e:
wait_time = backoff_factor**attempt
logger.warning(f"Attempt {attempt+1} failed to get_s3_bytes for {pdf_s3_path}: {e}. Retrying in {wait_time} seconds...")
time.sleep(wait_time)
attempt += 1
logger.error(f"Failed to get_s3_bytes for {pdf_s3_path} after {max_retries} retries.")
raise Exception("Failed to get_s3_bytes after retries")
def put_s3_bytes(s3_client, s3_path: str, data: bytes):
bucket, key = parse_s3_path(s3_path)
s3_client.put_object(Bucket=bucket, Key=key, Body=data, ContentType="text/plain; charset=utf-8")
def parse_custom_id(custom_id: str) -> tuple[str, int]:
s3_path = custom_id[: custom_id.rindex("-")]
page_num = int(custom_id[custom_id.rindex("-") + 1 :])
return s3_path, page_num
def download_zstd_csv(s3_client, s3_path):
"""Download and decompress a .zstd CSV file from S3."""
try:
compressed_data = get_s3_bytes(s3_client, s3_path)
dctx = zstd.ZstdDecompressor()
decompressed = dctx.decompress(compressed_data)
text_stream = TextIOWrapper(BytesIO(decompressed), encoding="utf-8")
lines = text_stream.readlines()
logger.info(f"Downloaded and decompressed {s3_path}")
return lines
except s3_client.exceptions.NoSuchKey:
logger.info(f"No existing {s3_path} found in s3, starting fresh.")
return []
def upload_zstd_csv(s3_client, s3_path, lines):
"""Compress and upload a list of lines as a .zstd CSV file to S3."""
joined_text = "\n".join(lines)
compressor = zstd.ZstdCompressor()
compressed = compressor.compress(joined_text.encode("utf-8"))
put_s3_bytes(s3_client, s3_path, compressed)
logger.info(f"Uploaded compressed {s3_path}")
def is_running_on_gcp():
"""Check if the script is running on a Google Cloud Platform (GCP) instance."""
try:
# GCP metadata server URL to check instance information
response = requests.get(
"http://metadata.google.internal/computeMetadata/v1/instance/", headers={"Metadata-Flavor": "Google"}, timeout=1 # Set a short timeout
)
return response.status_code == 200
except requests.RequestException:
return False
def download_directory(model_choices: List[str], local_dir: str):
"""
Download the model to a specified local directory.
The function will attempt to download from the first available source in the provided list.
Supports Weka (weka://), Google Cloud Storage (gs://), and Amazon S3 (s3://) links.
Args:
model_choices (List[str]): List of model paths (weka://, gs://, or s3://).
local_dir (str): Local directory path where the model will be downloaded.
Raises:
ValueError: If no valid model path is found in the provided choices.
"""
local_path = Path(os.path.expanduser(local_dir))
local_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Local directory set to: {local_path}")
# Reorder model_choices to prioritize weka:// links
weka_choices = [path for path in model_choices if path.startswith("weka://")]
# This is so hacky, but if you are on beaker/pluto, don't use weka
if os.environ.get("BEAKER_NODE_HOSTNAME", "").lower().startswith("pluto") or os.environ.get("BEAKER_NODE_HOSTNAME", "").lower().startswith("augusta"):
weka_choices = []
other_choices = [path for path in model_choices if not path.startswith("weka://")]
prioritized_choices = weka_choices + other_choices
for model_path in prioritized_choices:
logger.info(f"Attempting to download from: {model_path}")
try:
if model_path.startswith("weka://"):
download_dir_from_storage(model_path, str(local_path), storage_type="weka")
logger.info(f"Successfully downloaded model from Weka: {model_path}")
return
elif model_path.startswith("gs://"):
download_dir_from_storage(model_path, str(local_path), storage_type="gcs")
logger.info(f"Successfully downloaded model from Google Cloud Storage: {model_path}")
return
elif model_path.startswith("s3://"):
download_dir_from_storage(model_path, str(local_path), storage_type="s3")
logger.info(f"Successfully downloaded model from S3: {model_path}")
return
else:
logger.warning(f"Unsupported model path scheme: {model_path}")
except Exception as e:
logger.error(f"Failed to download from {model_path}: {e}")
continue
raise ValueError("Failed to download the model from all provided sources.")
def download_dir_from_storage(storage_path: str, local_dir: str, storage_type: str):
"""
Generalized function to download model files from different storage services
to a local directory, syncing using MD5 hashes where possible.
Args:
storage_path (str): The path to the storage location (weka://, gs://, or s3://).
local_dir (str): The local directory where files will be downloaded.
storage_type (str): Type of storage ('weka', 'gcs', or 's3').
Raises:
ValueError: If the storage type is unsupported or credentials are missing.
"""
bucket_name, prefix = parse_s3_path(storage_path)
total_files = 0
objects = []
if storage_type == "gcs":
client = storage.Client()
bucket = client.bucket(bucket_name)
blobs = list(bucket.list_blobs(prefix=prefix))
total_files = len(blobs)
logger.info(f"Found {total_files} files in GCS bucket '{bucket_name}' with prefix '{prefix}'.")
def should_download(blob, local_file_path):
return compare_hashes_gcs(blob, local_file_path)
def download_blob(blob, local_file_path):
try:
blob.download_to_filename(local_file_path)
logger.info(f"Successfully downloaded {blob.name} to {local_file_path}")
except Exception as e:
logger.error(f"Failed to download {blob.name} to {local_file_path}: {e}")
raise
items = blobs
elif storage_type in ("s3", "weka"):
if storage_type == "weka":
weka_access_key = os.getenv("WEKA_ACCESS_KEY_ID")
weka_secret_key = os.getenv("WEKA_SECRET_ACCESS_KEY")
if not weka_access_key or not weka_secret_key:
raise ValueError("WEKA_ACCESS_KEY_ID and WEKA_SECRET_ACCESS_KEY must be set for Weka access.")
endpoint_url = "https://weka-aus.beaker.org:9000"
boto3_config = Config(max_pool_connections=500, signature_version="s3v4", retries={"max_attempts": 10, "mode": "standard"})
s3_client = boto3.client(
"s3", endpoint_url=endpoint_url, aws_access_key_id=weka_access_key, aws_secret_access_key=weka_secret_key, config=boto3_config
)
else:
s3_client = boto3.client("s3", config=Config(max_pool_connections=500))
paginator = s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
for page in pages:
if "Contents" in page:
objects.extend(page["Contents"])
else:
logger.warning(f"No contents found in page: {page}")
total_files = len(objects)
logger.info(f"Found {total_files} files in {'Weka' if storage_type == 'weka' else 'S3'} bucket '{bucket_name}' with prefix '{prefix}'.")
transfer_config = TransferConfig(
multipart_threshold=8 * 1024 * 1024, multipart_chunksize=8 * 1024 * 1024, max_concurrency=10, use_threads=True # Reduced for WekaFS compatibility
)
def should_download(obj, local_file_path):
return compare_hashes_s3(obj, local_file_path, storage_type)
def download_blob(obj, local_file_path):
logger.info(f"Starting download of {obj['Key']} to {local_file_path}")
try:
with open(local_file_path, "wb") as f:
s3_client.download_fileobj(bucket_name, obj["Key"], f, Config=transfer_config)
logger.info(f"Successfully downloaded {obj['Key']} to {local_file_path}")
except Exception as e:
logger.error(f"Failed to download {obj['Key']} to {local_file_path}: {e}")
raise
items = objects
else:
raise ValueError(f"Unsupported storage type: {storage_type}")
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
for item in items:
if storage_type == "gcs":
relative_path = os.path.relpath(item.name, prefix)
else:
relative_path = os.path.relpath(item["Key"], prefix)
local_file_path = os.path.join(local_dir, relative_path)
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
if should_download(item, local_file_path):
futures.append(executor.submit(download_blob, item, local_file_path))
else:
total_files -= 1 # Decrement total_files as we're skipping this file
if total_files > 0:
for future in tqdm(concurrent.futures.as_completed(futures), total=total_files, desc=f"Downloading from {storage_type.upper()}"):
try:
future.result()
except Exception as e:
logger.error(f"Error occurred during download: {e}")
else:
logger.info("All files are up-to-date. No downloads needed.")
logger.info(f"Downloaded model from {storage_type.upper()} to {local_dir}")
def compare_hashes_gcs(blob, local_file_path: str) -> bool:
"""Compare MD5 hashes for GCS blobs."""
if os.path.exists(local_file_path):
remote_md5_base64 = blob.md5_hash
hash_md5 = hashlib.md5()
with open(local_file_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
hash_md5.update(chunk)
local_md5 = hash_md5.digest()
remote_md5 = base64.b64decode(remote_md5_base64)
if remote_md5 == local_md5:
logger.info(f"File '{local_file_path}' already up-to-date. Skipping download.")
return False
else:
logger.info(f"File '{local_file_path}' differs from GCS. Downloading.")
return True
else:
logger.info(f"File '{local_file_path}' does not exist locally. Downloading.")
return True
def compare_hashes_s3(obj, local_file_path: str, storage_type: str) -> bool:
"""Compare MD5 hashes or sizes for S3 objects (including Weka)."""
if os.path.exists(local_file_path):
if storage_type == "weka":
return True
else:
etag = obj["ETag"].strip('"')
if "-" in etag:
# Multipart upload, compare sizes
remote_size = obj["Size"]
local_size = os.path.getsize(local_file_path)
if remote_size == local_size:
logger.info(f"File '{local_file_path}' size matches remote multipart file. Skipping download.")
return False
else:
logger.info(f"File '{local_file_path}' size differs from remote multipart file. Downloading.")
return True
else:
hash_md5 = hashlib.md5()
with open(local_file_path, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
hash_md5.update(chunk)
local_md5 = hash_md5.hexdigest()
if etag == local_md5:
logger.info(f"File '{local_file_path}' already up-to-date. Skipping download.")
return False
else:
logger.info(f"File '{local_file_path}' differs from remote. Downloading.")
return True
else:
logger.info(f"File '{local_file_path}' does not exist locally. Downloading.")
return True
model:
name_or_path: allenai/Molmo-7B-O-0924
arch: causal
use_flash_attn: true
wandb:
project: pdelfin
entity: ai2-llm
generate:
max_length: 8192
train_data:
seed: 1337
cache_location: /data/jakep/pdfdata/pdelfin_cache
sources:
- name: openai_batch_data_v5_1_train
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
- name: openai_batch_data_v5_1_iabooks_train
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
valid_data:
cache_location: /data/jakep/pdfdata/pdelfin_cache
metric_for_best_model: openai_batch_data_v5_1_eval_loss
sources:
# These tend to be small, so you can load from s3 it's no big deal
- name: openai_batch_data_v5_1_eval
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
- name: openai_batch_data_v5_1_iabooks_eval
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_iabooks_eval/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
hparams:
batch_size: 1
eval_batch_size: 1
gradient_accumulation_steps: 4
gradient_checkpointing: true
find_unused_parameters: true
clip_grad_norm: 1.0
learning_rate: 3e-4
max_steps: 10000
pad_multiple_of: 16
log_every_steps: 10
eval_every_steps: 100
optim: adamw_torch
lr_scheduler: cosine
weight_decay: 0.01
warmup_ratio: 0.03
# From https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py
lora:
rank: 32
alpha: 32
dropout: 0.05
task_type: CAUSAL_LM
target_modules:
# attention layers in main transformer
- att_proj
- ff_proj
- attn_out
- ff_out
# vision transformer attention and FF
- attention.wq
- attention.wk
- attention.wv
- attention.wo
- feed_forward.w1
- feed_forward.w2
# vision image projector
- vision_backbone.image_projector.w1
- vision_backbone.image_projector.w2
- vision_backbone.image_projector.w3
save:
path: s3://ai2-oe-data/jakep/experiments/molmo-o-0924/v1/models/
save_every_steps: 1000
max_workers: 10
\ No newline at end of file
model:
name_or_path: allenai/Molmo-7B-O-0924
arch: causal
use_flash_attn: true
wandb:
project: pdelfin
entity: ai2-llm
generate:
max_length: 4096
train_data:
seed: 1337
cache_location: /data/jakep/pdfdata/pdelfin_cache
sources:
- name: openai_batch_data_v5_1_train
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_train_done/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
- name: openai_batch_data_v5_1_iabooks_train
response_glob_path: /data/jakep/pdfdata/openai_batch_data_v5_1_iabooks_train_done/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
valid_data:
cache_location: /data/jakep/pdfdata/pdelfin_cache
metric_for_best_model: openai_batch_data_v5_1_eval_loss
sources:
# These tend to be small, so you can load from s3 it's no big deal
- name: openai_batch_data_v5_1_eval
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
- name: openai_batch_data_v5_1_eval
response_glob_path: s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json
target_longest_image_dim: [1024]
target_anchor_text_len: [6000]
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
hparams:
batch_size: 1
eval_batch_size: 1
gradient_accumulation_steps: 4
gradient_checkpointing: true
find_unused_parameters: true
clip_grad_norm: 1.0
learning_rate: 1e-4
max_steps: 10000
pad_multiple_of: 16
log_every_steps: 10
eval_every_steps: 100
optim: adamw_torch
lr_scheduler: cosine
weight_decay: 0.01
warmup_ratio: 0.03
# From https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py
lora:
rank: 32
alpha: 32
dropout: 0.05
task_type: CAUSAL_LM
target_modules:
# attention layers in main transformer
- att_proj
- ff_proj
- attn_out
- ff_out
# vision transformer attention and FF
- attention.wq
- attention.wk
- attention.wv
- attention.wo
- feed_forward.w1
- feed_forward.w2
# vision image projector
- vision_backbone.image_projector.w1
- vision_backbone.image_projector.w2
- vision_backbone.image_projector.w3
save:
path: s3://ai2-oe-data/jakep/experiments/molmo-o-0924/v1/models/
save_every_steps: 1000
max_workers: 10
\ No newline at end of file
model:
name_or_path: Qwen/Qwen2-VL-2B-Instruct
arch: causal
use_flash_attn: true
wandb:
project: pdelfin
entity: ai2-llm
# TODO This is not used
format:
instruction_template: "Original:"
response_template: "Rewritten:"
# Template from here: https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py#L30
chat_template: |
{% for message in messages %}
{{'<|im_start|>' + message['role'] + '\n' + message['content']}}
{% if loop.last %}
{{ '<|im_end|>'}}
{% else %}
{{ '<|im_end|>\n' }}
{% endif %}
{% endfor %}
generate:
max_length: 4096
train_data:
seed: 1337
sources:
- name: openai_batch_data_v2
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json
backend:
- openai
size: 100_000
valid_data:
sources:
- name: openai_batch_data_eval_mini
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_eval_mini/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_eval_mini/*.json
backend:
- openai
size: 100_000
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
hparams:
batch_size: 1
eval_batch_size: 1
gradient_accumulation_steps: 4
gradient_checkpointing: false
clip_grad_norm: 1.0
learning_rate: 3e-4
max_steps: 2000
pad_multiple_of: 16
log_every_steps: 50
eval_every_steps: 1000
optim: adamw_torch
lr_scheduler: cosine
weight_decay: 0.01
warmup_ratio: 0.03
# From https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py
lora:
rank: 32
alpha: 32
dropout: 0.05
task_type: causal_lm
target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- gate_proj
- up_proj
- down_proj
- visual.blocks.[0-9]+.attn.qkv
- visual.blocks.[0-9]+.attn.proj
- visual.blocks.[0-9]+.mlp.fc1
- visual.blocks.[0-9]+.mlp.fc2
- visual.merger.mlp.0
- visual.merger.mlp.2
save:
path: s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/
save_every_steps: 1000
max_workers: 10
\ No newline at end of file
model:
name_or_path: Qwen/Qwen2-VL-2B-Instruct
arch: causal
use_flash_attn: true
wandb:
project: pdelfin
entity: ai2-llm
# TODO This is not used
format:
instruction_template: "Original:"
response_template: "Rewritten:"
# Template from here: https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py#L30
chat_template: |
{% for message in messages %}
{{'<|im_start|>' + message['role'] + '\n' + message['content']}}
{% if loop.last %}
{{ '<|im_end|>'}}
{% else %}
{{ '<|im_end|>\n' }}
{% endif %}
{% endfor %}
generate:
max_length: 4096
train_data:
seed: 1337
sources:
- name: openai_batch_data_v2
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_v2/*.json
backend:
- openai
size: 100_000
valid_data:
sources:
- name: openai_batch_data_eval_mini
query_glob_path: s3://ai2-oe-data/jakep/openai_batch_data_eval_mini/*.jsonl
response_glob_path: s3://ai2-oe-data/jakep/openai_batch_done_eval_mini/*.json
backend:
- openai
size: 100_000
# Mostly pulled from https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.sh
hparams:
batch_size: 1
eval_batch_size: 1
gradient_accumulation_steps: 4
gradient_checkpointing: false
clip_grad_norm: 1.0
learning_rate: 3e-4
max_steps: 2000
pad_multiple_of: 16
log_every_steps: 50
eval_every_steps: 1000
optim: adamw_torch
lr_scheduler: cosine
weight_decay: 0.01
warmup_ratio: 0.03
# From https://github.com/QwenLM/Qwen2/blob/main/examples/sft/finetune.py
# Disable LORA for now, because we want the visual network to get trained too
# lora:
# rank: 32
# alpha: 32
# dropout: 0.05
# task_type: causal_lm
# target_modules:
# - q_proj
# - k_proj
# - v_proj
# - o_proj
# - gate_proj
# - up_proj
# - down_proj
save:
path: s3://ai2-oe-data/jakep/experiments/qwen2vl-pdf/v1/models/
save_every_steps: 1000
max_workers: 10
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment