Commit 89e60e48 authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
Pipeline #2484 canceled with stages
//import React from 'react';
const PermitGuidelinesTemplate = () => {
// Sample data - you can replace these with your own
const guidelineItems = [
{
number: 'iii.',
content: 'Not rely on personal preference or opinion, or regional interpretation of statute, regulation or guidance that is inconsistent with the Department\'s statewide interpretation. Staff should confer with the appropriate Bureau Director as necessary.'
},
{
number: 'iv.',
content: 'Process technically adequate and scientifically sound applications for final approval to minimize elapsed time in accordance with the Permit Decision Guarantee.'
},
{
number: 'v.',
content: 'Where the Application Manager determines that the technical information submitted with the application does not meet technical guidance or standards published by the Department, the application must provide the scientific or engineering basis to support the application. Note that deviations from technical guidance can generally be approved, by the appropriate section chief and manager, when warranted, provided acceptable justification has been submitted. Minor deficiencies that can be easily corrected should be addressed through a telephone call with the applicant and consultant, and may negate the need for a deficiency letter. The Program Manager or District Manager will be responsible for making that decision.'
},
{
number: 'vi.',
content: 'If an application fails to provide the technical information necessary to document that applicable regulatory and statutory requirements will be achieved, it is technically deficient and the Application Manager will prepare a technical deficiency letter. Again, all deficiencies noted must cite the statutory or regulatory obligation that the application has failed to meet and the Section Chief and the Program Manager will routinely review these letters. For District Oil and Gas Offices and District Mining Offices the Permits Chief and the Manager will review the letters.'
},
{
number: 'vii.',
content: 'Applicant responses that do not make the application technically adequate within the established response timeframe will be subject to the Elevated Review Process below. Applications that are made technically adequate within the established response timeframe will proceed to processing for final action.'
}
];
// Footnote data
const footnote = {
number: '2',
content: 'More technically complex projects and applications may receive additional deficiency letters as appropriate prior to a decision point. This exception will not void inclusion in the Permit Decision Guarantee and will follow program specific guidance that is developed. The more technically complex projects and applications are noted with an asterisk ("*") in Appendix A.'
};
// Document info
const documentInfo = "021-2100-001 / November 2, 2012 / Page 11";
// Special note about technical deficiency letter
const technicalDeficiencyNote = {
prefix: 'One',
superscript: '2',
content: ' technical deficiency letter will be sent. Each deficiency cited must note the statute, regulation or technical guidance provision. Technical guidance provides a means to compliance, but may not be used or cited when issuing a permit denial. The letter will state, as necessary, that the Permit Decision Guarantee is no longer applicable and offer the applicant an opportunity to meet and discuss the deficiencies. The letter will include a deadline for submission of the deficient information.'
};
return (
<div className="bg-white p-8 max-w-4xl mx-auto font-serif text-black">
<div className="mb-8">
{guidelineItems.map((item, index) => (
<div key={index} className="mb-6 flex">
<div className="w-12 flex-shrink-0 font-bold">{item.number}</div>
<div className="flex-grow">{item.content}</div>
</div>
))}
{/* Technical deficiency letter note */}
<div className="mb-6 ml-12">
<p>
{technicalDeficiencyNote.prefix}
<sup>{technicalDeficiencyNote.superscript}</sup>
{technicalDeficiencyNote.content}
</p>
</div>
</div>
{/* Horizontal line */}
<div className="border-t border-gray-400 my-6"></div>
{/* Footnote section */}
<div className="text-sm">
<p>
<sup>{footnote.number}</sup> {footnote.content}
</p>
</div>
{/* Document info */}
<div className="text-center mt-6 text-sm">
{documentInfo}
</div>
</div>
);
};
//export default PermitGuidelinesTemplate;
window.BookPageTemplate = PermitGuidelinesTemplate;
\ No newline at end of file
import json
import re
import numpy as np
from bs4 import BeautifulSoup
from dataclasses import asdict, dataclass
from enum import Enum
from typing import List, Optional, Tuple, Dict, Any
from fuzzysearch import find_near_matches
from rapidfuzz import fuzz
from tqdm import tqdm
from olmocr.repeatdetect import RepeatDetector
from .katex.render import render_equation, compare_rendered_equations
class TestType(str, Enum):
BASELINE = "baseline"
PRESENT = "present"
ABSENT = "absent"
ORDER = "order"
TABLE = "table"
MATH = "math"
class TestChecked(str, Enum):
VERIFIED = "verified"
REJECTED = "rejected"
class ValidationError(Exception):
"""Exception raised for validation errors."""
pass
def normalize_text(md_content: str) -> str:
# Normalize whitespace in the md_content
md_content = re.sub(r'\s+', " ", md_content)
# Dictionary of characters to replace: keys are fancy characters, values are ASCII equivalents
replacements = {
"‘": "'", "’": "'", "‚": "'",
"“": "\"", "”": "\"", "„": "\"",
"_": "_",
"–": "-", "—": "-", "‑": "-", "‒": "-"
}
# Apply all replacements from the dictionary
for fancy_char, ascii_char in replacements.items():
md_content = md_content.replace(fancy_char, ascii_char)
return md_content
@dataclass(kw_only=True)
class BasePDFTest:
"""
Base class for all PDF test types.
Attributes:
pdf: The PDF filename.
page: The page number for the test.
id: Unique identifier for the test.
type: The type of test.
threshold: A float between 0 and 1 representing the threshold for fuzzy matching.
"""
pdf: str
page: int
id: str
type: str
max_diffs: int = 0
checked: Optional[TestChecked] = None
def __post_init__(self):
if not self.pdf:
raise ValidationError("PDF filename cannot be empty")
if not self.id:
raise ValidationError("Test ID cannot be empty")
if not isinstance(self.max_diffs, int) or self.max_diffs < 0:
raise ValidationError(f"Max diffs must be positive number or 0")
if self.type not in {t.value for t in TestType}:
raise ValidationError(f"Invalid test type: {self.type}")
def run(self, md_content: str) -> Tuple[bool, str]:
"""
Run the test on the provided markdown content.
Args:
md_content: The content of the .md file.
Returns:
A tuple (passed, explanation) where 'passed' is True if the test passes,
and 'explanation' provides details when the test fails.
"""
raise NotImplementedError("Subclasses must implement the run method")
@dataclass
class TextPresenceTest(BasePDFTest):
"""
Test to verify the presence or absence of specific text in a PDF.
Attributes:
text: The text string to search for.
"""
text: str
case_sensitive: bool=True
def __post_init__(self):
super().__post_init__()
if self.type not in {TestType.PRESENT.value, TestType.ABSENT.value}:
raise ValidationError(f"Invalid type for TextPresenceTest: {self.type}")
if not self.text.strip():
raise ValidationError("Text field cannot be empty")
def run(self, md_content: str) -> Tuple[bool, str]:
reference_query = self.text
# Normalize whitespace in the md_content
md_content = normalize_text(md_content)
if not self.case_sensitive:
reference_query = reference_query.lower()
md_content = md_content.lower()
# Threshold for fuzzy matching derived from max_diffs
threshold = 1.0 - (self.max_diffs / (len(reference_query) if len(reference_query) > 0 else 1))
best_ratio = fuzz.partial_ratio(reference_query, md_content) / 100.0
if self.type == TestType.PRESENT.value:
if best_ratio >= threshold:
return True, ""
else:
msg = f"Expected '{reference_query[:40]}...' with threshold {threshold} " f"but best match ratio was {best_ratio:.3f}"
return False, msg
else: # ABSENT
if best_ratio < threshold:
return True, ""
else:
msg = f"Expected absence of '{reference_query[:40]}...' with threshold {threshold} " f"but best match ratio was {best_ratio:.3f}"
return False, msg
@dataclass
class TextOrderTest(BasePDFTest):
"""
Test to verify that one text appears before another in a PDF.
Attributes:
before: The text expected to appear first.
after: The text expected to appear after the 'before' text.
"""
before: str
after: str
def __post_init__(self):
super().__post_init__()
if self.type != TestType.ORDER.value:
raise ValidationError(f"Invalid type for TextOrderTest: {self.type}")
if not self.before.strip():
raise ValidationError("Before field cannot be empty")
if not self.after.strip():
raise ValidationError("After field cannot be empty")
def run(self, md_content: str) -> Tuple[bool, str]:
md_content = normalize_text(md_content)
before_matches = find_near_matches(self.before, md_content, max_l_dist=self.max_diffs)
after_matches = find_near_matches(self.after, md_content, max_l_dist=self.max_diffs)
if not before_matches:
return False, f"'before' text '{self.before[:40]}...' not found with max_l_dist {self.max_diffs}"
if not after_matches:
return False, f"'after' text '{self.after[:40]}...' not found with max_l_dist {self.max_diffs}"
for before_match in before_matches:
for after_match in after_matches:
if before_match.start < after_match.start:
return True, ""
return False, (f"Could not find a location where '{self.before[:40]}...' appears before " f"'{self.after[:40]}...'.")
@dataclass
class TableTest(BasePDFTest):
"""
Test to verify certain properties of a table are held, namely that some cells appear relative to other cells correctly
"""
# This is the target cell, which must exist in at least one place in the table
cell: str
# These properties say that the cell immediately up/down/left/right of the target cell has the string specified
up: str = ""
down: str = ""
left: str = ""
right: str = ""
# These properties say that the cell all the way up, or all the way left of the target cell (ex. headings) has the string value specified
top_heading: str = ""
left_heading: str = ""
def __post_init__(self):
super().__post_init__()
if self.type != TestType.TABLE.value:
raise ValidationError(f"Invalid type for TableTest: {self.type}")
def parse_markdown_tables(self, md_content: str) -> List[np.ndarray]:
"""
Extract and parse all markdown tables from the provided content.
Args:
md_content: The markdown content containing tables
Returns:
A list of numpy arrays, each representing a parsed table
"""
import re
import numpy as np
# Updated regex to allow optional leading and trailing pipes
table_pattern = (
r'(\|?(?:[^|\n]*\|)+[^|\n]*\|?)\s*\n'
r'\|?(?:[ :-]+\|)+[ :-]+\|?\s*\n'
r'((?:\|?(?:[^|\n]*\|)+[^|\n]*\|?\s*\n)+)'
)
table_matches = re.finditer(table_pattern, md_content)
parsed_tables = []
for table_match in table_matches:
# Extract header and body from the table match
header_row = table_match.group(1).strip()
body_rows = table_match.group(2).strip().split('\n')
# Process header and rows to remove leading/trailing pipes
header_cells = [cell.strip() for cell in header_row.split('|')]
if header_cells and header_cells[0] == '':
header_cells = header_cells[1:]
if header_cells and header_cells[-1] == '':
header_cells = header_cells[:-1]
# Process table body rows
table_data = []
for row in [header_row] + body_rows:
if '|' not in row: # Skip separator row
continue
cells = [cell.strip() for cell in row.split('|')]
if cells and cells[0] == '':
cells = cells[1:]
if cells and cells[-1] == '':
cells = cells[:-1]
table_data.append(cells)
# Skip separator row (second row with dashes)
if len(table_data) > 1 and all('-' in cell for cell in table_data[1]):
table_data = [table_data[0]] + table_data[2:]
# Convert to numpy array for easier manipulation
# Ensure all rows have the same number of columns by padding if necessary
max_cols = max(len(row) for row in table_data)
padded_data = [row + [''] * (max_cols - len(row)) for row in table_data]
table_array = np.array(padded_data)
parsed_tables.append(table_array)
return parsed_tables
def parse_html_tables(self, html_content: str) -> List[np.ndarray]:
"""
Extract and parse all HTML tables from the provided content.
Args:
html_content: The HTML content containing tables
Returns:
A list of numpy arrays, each representing a parsed table
"""
soup = BeautifulSoup(html_content, 'html.parser')
tables = soup.find_all('table')
parsed_tables = []
for table in tables:
rows = table.find_all(['tr'])
table_data = []
for row in rows:
cells = row.find_all(['th', 'td'])
row_data = [cell.get_text().strip() for cell in cells]
table_data.append(row_data)
# Ensure all rows have the same number of columns
if table_data:
max_cols = max(len(row) for row in table_data)
padded_data = [row + [''] * (max_cols - len(row)) for row in table_data]
table_array = np.array(padded_data)
parsed_tables.append(table_array)
return parsed_tables
def run(self, content: str) -> Tuple[bool, str]:
"""
Run the table test on provided content.
Finds all tables (markdown and/or HTML based on content_type) and checks if any cell
matches the target cell and satisfies the specified relationships.
Args:
content: The content containing tables (markdown or HTML)
Returns:
A tuple (passed, explanation) where 'passed' is True if the test passes,
and 'explanation' provides details when the test fails.
"""
# Initialize variables to track tables and results
tables_to_check = []
failed_reasons = []
# Threshold for fuzzy matching derived from max_diffs
threshold = 1.0 - (self.max_diffs / (len(self.cell) if len(self.cell) > 0 else 1))
# Parse tables based on content_type
md_tables = self.parse_markdown_tables(content)
tables_to_check.extend(md_tables)
html_tables = self.parse_html_tables(content)
tables_to_check.extend(html_tables)
# If no tables found, return failure
if not tables_to_check:
return False, f"No tables found in the content"
# Check each table
for table_array in tables_to_check:
# Find all cells that match the target cell using fuzzy matching
matches = []
for i in range(table_array.shape[0]):
for j in range(table_array.shape[1]):
cell_content = table_array[i, j]
similarity = fuzz.ratio(self.cell, cell_content) / 100.0
if similarity >= threshold:
matches.append((i, j))
# If no matches found in this table, continue to the next table
if not matches:
continue
# Check the relationships for each matching cell
for row_idx, col_idx in matches:
all_relationships_satisfied = True
current_failed_reasons = []
# Check up relationship
if self.up and row_idx > 0:
up_cell = table_array[row_idx - 1, col_idx]
up_similarity = fuzz.ratio(self.up, up_cell) / 100.0
if up_similarity < threshold:
all_relationships_satisfied = False
current_failed_reasons.append(f"Cell above '{up_cell}' doesn't match expected '{self.up}' (similarity: {up_similarity:.2f})")
# Check down relationship
if self.down and row_idx < table_array.shape[0] - 1:
down_cell = table_array[row_idx + 1, col_idx]
down_similarity = fuzz.ratio(self.down, down_cell) / 100.0
if down_similarity < threshold:
all_relationships_satisfied = False
current_failed_reasons.append(f"Cell below '{down_cell}' doesn't match expected '{self.down}' (similarity: {down_similarity:.2f})")
# Check left relationship
if self.left and col_idx > 0:
left_cell = table_array[row_idx, col_idx - 1]
left_similarity = fuzz.ratio(self.left, left_cell) / 100.0
if left_similarity < threshold:
all_relationships_satisfied = False
current_failed_reasons.append(f"Cell to the left '{left_cell}' doesn't match expected '{self.left}' (similarity: {left_similarity:.2f})")
# Check right relationship
if self.right and col_idx < table_array.shape[1] - 1:
right_cell = table_array[row_idx, col_idx + 1]
right_similarity = fuzz.ratio(self.right, right_cell) / 100.0
if right_similarity < threshold:
all_relationships_satisfied = False
current_failed_reasons.append(f"Cell to the right '{right_cell}' doesn't match expected '{self.right}' (similarity: {right_similarity:.2f})")
# Check top heading relationship
if self.top_heading and row_idx > 0:
# Find the first non-empty cell in the same column (starting from the top)
top_heading_cell = ""
for i in range(row_idx):
if table_array[i, col_idx].strip():
top_heading_cell = table_array[i, col_idx]
break
if not top_heading_cell:
all_relationships_satisfied = False
current_failed_reasons.append(f"No non-empty top heading found in column {col_idx}")
else:
top_similarity = fuzz.ratio(self.top_heading, top_heading_cell) / 100.0
if top_similarity < threshold:
all_relationships_satisfied = False
current_failed_reasons.append(f"Top heading '{top_heading_cell}' doesn't match expected '{self.top_heading}' (similarity: {top_similarity:.2f})")
# Check left heading relationship
if self.left_heading and col_idx > 0:
# Find the first non-empty cell in the same row (starting from the left)
left_heading_cell = ""
for j in range(col_idx):
if table_array[row_idx, j].strip():
left_heading_cell = table_array[row_idx, j]
break
if not left_heading_cell:
all_relationships_satisfied = False
current_failed_reasons.append(f"No non-empty left heading found in row {row_idx}")
else:
left_heading_similarity = fuzz.ratio(self.left_heading, left_heading_cell) / 100.0
if left_heading_similarity < threshold:
all_relationships_satisfied = False
current_failed_reasons.append(f"Left heading '{left_heading_cell}' doesn't match expected '{self.left_heading}' (similarity: {left_heading_similarity:.2f})")
# If all relationships are satisfied for this cell, the test passes
if all_relationships_satisfied:
return True, ""
else:
failed_reasons.extend(current_failed_reasons)
# If we've gone through all tables and all matching cells and none satisfied all relationships
if not failed_reasons:
return False, f"No cell matching '{self.cell}' found in any table with threshold {threshold}"
else:
return False, f"Found cells matching '{self.cell}' but relationships were not satisfied: {'; '.join(failed_reasons)}"
@dataclass
class BaselineTest(BasePDFTest):
"""
This test makes sure that several baseline quality checks pass for the output generation.
Namely, the output is not blank, not endlessly repeating, and contains characters of the proper
character sets.
"""
max_repeats: int=30
def run(self, content: str) -> Tuple[bool, str]:
if len("".join(c for c in content if c.isalnum()).strip()) == 0:
return False, "The text contains no alpha numeric characters"
# Makes sure that the content has no egregious repeated ngrams at the end, which indicate a degradation of quality
# Honestly, this test doesn't seem to catch anything at the moment, maybe it can be refactored to a "text-quality"
# test or something, that measures repetition, non-blanks, charsets, etc
d = RepeatDetector(max_ngram_size=5)
d.add_letters(content)
repeats = d.ngram_repeats()
for index, count in enumerate(repeats):
if count > self.max_repeats:
return False, f"Text ends with {count} repeating {index+1}-grams, invalid"
pattern = re.compile(
r'['
r'\u4e00-\u9FFF' # CJK Unified Ideographs (Chinese characters)
r'\u3040-\u309F' # Hiragana (Japanese)
r'\u30A0-\u30FF' # Katakana (Japanese)
r'\U0001F600-\U0001F64F' # Emoticons (Emoji)
r'\U0001F300-\U0001F5FF' # Miscellaneous Symbols and Pictographs (Emoji)
r'\U0001F680-\U0001F6FF' # Transport and Map Symbols (Emoji)
r'\U0001F1E0-\U0001F1FF' # Regional Indicator Symbols (flags, Emoji)
r']',
flags=re.UNICODE)
matches = pattern.findall(content)
if matches:
return False, f"Text contains disallowed characters {matches}"
return True, ""
@dataclass
class MathTest(BasePDFTest):
math: str
def __post_init__(self):
super().__post_init__()
if self.type != TestType.MATH.value:
raise ValidationError(f"Invalid type for MathTest: {self.type}")
if len(self.math.strip()) == 0:
raise ValidationError(f"Math test must have non-empty math expression")
self.reference_render = render_equation(self.math)
if self.reference_render is None:
raise ValidationError(f"Math equation {self.math} was not able to render")
def run(self, content: str) -> Tuple[bool, str]:
# Store both the search pattern and the full pattern to replace
patterns = [
(r'\$\$(.+?)\$\$', r'\$\$(.+?)\$\$'), # $$...$$
(r'\\\((.+?)\\\)', r'\\\((.+?)\\\)'), # \(...\)
(r'\\\[(.+?)\\\]', r'\\\[(.+?)\\\]'), # \[...\]
(r'\$(.+?)\$', r'\$(.+?)\$') # $...$
]
equations = []
modified_content = content
for search_pattern, replace_pattern in patterns:
# Find all matches for the current pattern
matches = re.findall(search_pattern, modified_content, re.DOTALL)
equations.extend([e.strip() for e in matches])
# Replace all instances of this pattern with empty strings
modified_content = re.sub(replace_pattern, '', modified_content, flags=re.DOTALL)
# If an equation in the markdown exactly matches our math string, then that's good enough
# we don't have to do a more expensive comparison
if any(hyp == self.math for hyp in equations):
return True, ""
# If not, then let's render the math equation itself and now compare to each hypothesis
best_match_score = 0.0
best_match_render = None
for hypothesis in equations:
hypothesis_render = render_equation(hypothesis)
if not hypothesis_render:
continue
if compare_rendered_equations(self.reference_render, hypothesis_render):
return True, ""
# self.reference_render.save(f"maths/{self.id}_ref.png", format="PNG")
# best_match_render.save(f"maths/{self.id}_hyp.png", format="PNG")
return False, f"No match found for {self.math} anywhere in content, best match threshold was {best_match_score:.3f}"
def load_tests(jsonl_file: str) -> List[BasePDFTest]:
"""
Load tests from a JSONL file.
Args:
jsonl_file: Path to the JSONL file containing test definitions.
Returns:
A list of test objects.
"""
tests: List[BasePDFTest] = []
unique_ids = set()
with open(jsonl_file, "r") as file:
for line_number, line in tqdm(enumerate(file, start=1), desc="Loading tests"):
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
test_type = data.get("type")
if test_type in {TestType.PRESENT.value, TestType.ABSENT.value}:
test = TextPresenceTest(**data)
elif test_type == TestType.ORDER.value:
test = TextOrderTest(**data)
elif test_type == TestType.TABLE.value:
test = TableTest(**data)
elif test_type == TestType.MATH.value:
test = MathTest(**data)
else:
raise ValidationError(f"Unknown test type: {test_type}")
if test.id in unique_ids:
raise ValidationError(f"Test with duplicate id {test.id} found, error loading tests.")
else:
unique_ids.add(test.id)
tests.append(test)
except json.JSONDecodeError as e:
print(f"Error parsing JSON on line {line_number}: {e}")
raise
except (ValidationError, KeyError) as e:
print(f"Error on line {line_number}: {e}")
raise
except Exception as e:
print(f"Unexpected error on line {line_number}: {e}")
raise
return tests
def save_tests(tests: List[BasePDFTest], jsonl_file: str) -> None:
"""
Save tests to a JSONL file using asdict for conversion.
Args:
tests: A list of test objects.
jsonl_file: Path to the output JSONL file.
"""
with open(jsonl_file, "w") as file:
for test in tests:
file.write(json.dumps(asdict(test)) + "\n")
\ No newline at end of file
import numpy as np
from typing import Dict, List, Tuple, Optional
def calculate_bootstrap_ci(
test_scores: List[float],
n_bootstrap: int = 1000,
ci_level: float = 0.95
) -> Tuple[float, float]:
"""
Calculate bootstrap confidence interval for test scores.
Args:
test_scores: List of test scores (0.0 to 1.0 for each test)
n_bootstrap: Number of bootstrap samples to generate
ci_level: Confidence interval level (default: 0.95 for 95% CI)
Returns:
Tuple of (lower_bound, upper_bound) representing the confidence interval
"""
if not test_scores:
return (0.0, 0.0)
# Convert to numpy array for efficiency
scores = np.array(test_scores)
# Generate bootstrap samples
bootstrap_means = []
for _ in range(n_bootstrap):
# Sample with replacement
sample = np.random.choice(scores, size=len(scores), replace=True)
bootstrap_means.append(np.mean(sample))
# Calculate confidence interval
alpha = (1 - ci_level) / 2
lower_bound = np.percentile(bootstrap_means, alpha * 100)
upper_bound = np.percentile(bootstrap_means, (1 - alpha) * 100)
return (lower_bound, upper_bound)
def perform_permutation_test(
scores_a: List[float],
scores_b: List[float],
n_permutations: int = 10000
) -> Tuple[float, float]:
"""
Perform a permutation test to determine if there's a significant difference
between two sets of test scores.
Args:
scores_a: List of test scores for candidate A
scores_b: List of test scores for candidate B
n_permutations: Number of permutations to perform
Returns:
Tuple of (observed_difference, p_value)
"""
if not scores_a or not scores_b:
return (0.0, 1.0)
# Calculate observed difference in means
observed_diff = np.mean(scores_a) - np.mean(scores_b)
# Combine all scores
combined = np.concatenate([scores_a, scores_b])
n_a = len(scores_a)
n_combined = len(combined)
# Perform permutation test
count_greater_or_equal = 0
for _ in range(n_permutations):
# Shuffle the combined array
np.random.shuffle(combined)
# Split into two groups of original sizes
perm_a = combined[:n_a]
perm_b = combined[n_a:]
# Calculate difference in means
perm_diff = np.mean(perm_a) - np.mean(perm_b)
# Count how many permuted differences are >= to observed difference in absolute value
if abs(perm_diff) >= abs(observed_diff):
count_greater_or_equal += 1
# Calculate p-value
p_value = count_greater_or_equal / n_permutations
return (observed_diff, p_value)
\ No newline at end of file
#!/usr/bin/env python3
import argparse
import json
import os
import re
import sys
from collections import defaultdict
from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit
import requests
from olmocr.data.renderpdf import render_pdf_to_base64png
def parse_rules_file(file_path):
"""Parse the rules file and organize rules by PDF."""
pdf_rules = defaultdict(list)
with open(file_path, "r") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
rule = json.loads(line)
# Add checked field if it doesn't exist
if "checked" not in rule:
rule["checked"] = None
if "pdf" in rule:
pdf_rules[rule["pdf"]].append(rule)
except json.JSONDecodeError:
print(f"Warning: Could not parse line as JSON: {line}")
return pdf_rules
def get_rule_html(rule, rule_index):
"""Generate HTML representation for a rule with interactive elements."""
rule_type = rule.get("type", "unknown")
rule_id = f"rule-{rule_index}"
# Determine status button class based on 'checked' value
checked_status = rule.get("checked")
thumbs_up_class = "active" if checked_status == "verified" else ""
thumbs_down_class = "active" if checked_status == "rejected" else ""
# Create thumbs up/down buttons
status_button = f"""
<div class="status-control">
<button class="status-button thumbs-up {thumbs_up_class}"
data-rule-id="{rule_id}"
data-action="verified"
onclick="toggleStatus(this)"></button>
<button class="status-button thumbs-down {thumbs_down_class}"
data-rule-id="{rule_id}"
data-action="rejected"
onclick="toggleStatus(this)"></button>
</div>
"""
# Create HTML based on rule type
if rule_type == "present":
return f"""
<tr class="rule-row present-rule" data-rule-id="{rule_id}" data-rule-index="{rule_index}">
<td>{status_button}</td>
<td><span class="rule-type present">PRESENT</span></td>
<td>
<div class="editable-text"
contenteditable="true"
data-rule-id="{rule_id}"
data-field="text"
onblur="updateRuleText(this)">{rule.get('text', '')}</div>
</td>
<td>Threshold: {rule.get('threshold', 'N/A')}</td>
</tr>
"""
elif rule_type == "absent":
return f"""
<tr class="rule-row absent-rule" data-rule-id="{rule_id}" data-rule-index="{rule_index}">
<td>{status_button}</td>
<td><span class="rule-type absent">ABSENT</span></td>
<td>
<div class="editable-text"
contenteditable="true"
data-rule-id="{rule_id}"
data-field="text"
onblur="updateRuleText(this)">{rule.get('text', '')}</div>
</td>
<td>Threshold: {rule.get('threshold', 'N/A')}</td>
</tr>
"""
elif rule_type == "order":
return f"""
<tr class="rule-row order-rule" data-rule-id="{rule_id}" data-rule-index="{rule_index}">
<td>{status_button}</td>
<td><span class="rule-type order">ORDER</span></td>
<td>
<p><strong>Before:</strong>
<span class="editable-text"
contenteditable="true"
data-rule-id="{rule_id}"
data-field="before"
onblur="updateRuleText(this)">{rule.get('before', '')}</span>
</p>
<p><strong>After:</strong>
<span class="editable-text"
contenteditable="true"
data-rule-id="{rule_id}"
data-field="after"
onblur="updateRuleText(this)">{rule.get('after', '')}</span>
</p>
</td>
<td>Threshold: {rule.get('threshold', 'N/A')}</td>
</tr>
"""
else:
return f"""
<tr class="rule-row unknown-rule" data-rule-id="{rule_id}" data-rule-index="{rule_index}">
<td>{status_button}</td>
<td><span class="rule-type unknown">UNKNOWN</span></td>
<td>Unknown rule type: {rule_type}</td>
<td></td>
</tr>
"""
def generate_html(pdf_rules, rules_file_path):
"""Generate the HTML page with PDF renderings and interactive rules."""
# Limit to 10 unique PDFs
pdf_names = list(pdf_rules.keys())[:10]
# Prepare rules data for JavaScript
all_rules = []
for pdf_name in pdf_names:
all_rules.extend(pdf_rules[pdf_name])
rules_json = json.dumps(all_rules)
html = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Interactive PDF Rules Visualizer</title>
<style>
body {
font-family: Arial, sans-serif;
margin: 0;
padding: 20px;
background-color: #f5f5f5;
}
.container {
max-width: 1920px;
margin: 0 auto;
}
h1 {
color: #333;
text-align: center;
margin-bottom: 30px;
}
.pdf-container {
background-color: white;
border-radius: 8px;
box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1);
margin-bottom: 30px;
overflow: hidden;
}
.pdf-header {
background-color: #4a6fa5;
color: white;
padding: 15px;
font-size: 18px;
font-weight: bold;
}
.pdf-content {
display: flex;
flex-direction: row;
padding: 20px;
}
@media (max-width: 1200px) {
.pdf-content {
flex-direction: column;
}
}
.pdf-image {
flex: 0 0 50%;
max-width: 800px;
text-align: center;
padding-right: 20px;
}
.pdf-image img {
max-width: 100%;
height: auto;
border: 1px solid #ddd;
}
.rules-container {
flex: 1;
overflow: auto;
}
.rules-table {
width: 100%;
border-collapse: collapse;
}
.rules-table th {
background-color: #4a6fa5;
color: white;
padding: 10px;
text-align: left;
}
.rules-table td {
padding: 10px;
border-bottom: 1px solid #ddd;
vertical-align: top;
}
.rule-type {
display: inline-block;
padding: 5px 10px;
border-radius: 4px;
color: white;
font-weight: bold;
}
.present {
background-color: #28a745;
}
.absent {
background-color: #dc3545;
}
.order {
background-color: #fd7e14;
}
.unknown {
background-color: #6c757d;
}
.rule-row:hover {
background-color: #f8f9fa;
}
/* New styles for interactive elements */
.editable-text {
min-height: 20px;
padding: 5px;
border-radius: 4px;
border: 1px solid transparent;
transition: border-color 0.2s;
}
.editable-text:hover {
border-color: #ccc;
background-color: #f8f9fa;
}
.editable-text:focus {
outline: none;
border-color: #4a6fa5;
background-color: #fff;
}
.status-control {
display: flex;
justify-content: center;
align-items: center;
gap: 8px;
}
.status-button {
width: 36px;
height: 36px;
border-radius: 4px;
border: 1px solid #ccc;
background-color: #f8f9fa;
cursor: pointer;
transition: all 0.2s;
display: flex;
justify-content: center;
align-items: center;
}
.status-button:hover {
border-color: #999;
background-color: #e9ecef;
}
.thumbs-up:before {
content: "👍";
font-size: 18px;
opacity: 0.5;
}
.thumbs-down:before {
content: "👎";
font-size: 18px;
opacity: 0.5;
}
.thumbs-up.active {
background-color: #28a745;
border-color: #28a745;
}
.thumbs-up.active:before {
opacity: 1;
color: white;
}
.thumbs-down.active {
background-color: #dc3545;
border-color: #dc3545;
}
.thumbs-down.active:before {
opacity: 1;
color: white;
}
</style>
</head>
<body>
<div class="container">
<h1>Interactive PDF Rules Visualizer</h1>
"""
# Global rule index for unique IDs
rule_index = 0
for pdf_name in pdf_names:
rules = pdf_rules[pdf_name]
# Render the PDF (first page only) from the /pdfs folder
try:
pdf_path = os.path.join(os.path.dirname(rules_file_path), "pdfs", pdf_name)
base64_img = render_pdf_to_base64png(pdf_path, 0)
img_html = f'<img src="data:image/png;base64,{base64_img}" alt="{pdf_name}">'
except Exception as e:
img_html = f'<div class="error">Error rendering PDF: {str(e)}</div>'
html += f"""
<div class="pdf-container">
<div class="pdf-header">{pdf_name}</div>
<div class="pdf-content">
<div class="pdf-image">
{img_html}
</div>
<div class="rules-container">
<table class="rules-table">
<thead>
<tr>
<th>Status</th>
<th>Type</th>
<th>Content</th>
<th>Parameters</th>
</tr>
</thead>
<tbody>
"""
for rule in rules:
html += get_rule_html(rule, rule_index)
rule_index += 1
html += """
</tbody>
</table>
</div>
</div>
</div>
"""
# Add JavaScript to manage interactivity and datastore integration
html += f"""
</div>
<script>
// Store all rules data (initially injected from the JSON file)
let rulesData = {rules_json};
// Function to toggle status button
function toggleStatus(button) {{
const ruleRow = button.closest('.rule-row');
const ruleIndex = parseInt(ruleRow.dataset.ruleIndex);
const action = button.dataset.action;
const currentState = rulesData[ruleIndex].checked;
const newState = (currentState === action) ? null : action;
rulesData[ruleIndex].checked = newState;
// Update UI for status buttons
const buttons = ruleRow.querySelectorAll('.status-button');
buttons.forEach(btn => {{
if (btn.dataset.action === newState) {{
btn.classList.add('active');
}} else {{
btn.classList.remove('active');
}}
}});
// Upload updated data to datastore
uploadRulesData();
outputJSON();
}}
// Function to update rule text
function updateRuleText(element) {{
const ruleRow = element.closest('.rule-row');
const ruleIndex = parseInt(ruleRow.dataset.ruleIndex);
const field = element.dataset.field;
const newText = element.innerText.trim();
// Update the rules data
rulesData[ruleIndex][field] = newText;
// Upload updated data to datastore
uploadRulesData();
outputJSON();
}}
// Function to output JSONL to console
function outputJSON() {{
console.clear();
console.log("Updated JSONL:");
rulesData.forEach(rule => {{
console.log(JSON.stringify(rule));
}});
}}
// Function to upload rulesData to datastore using putDatastore
async function uploadRulesData() {{
try {{
await putDatastore(rulesData);
console.log("Datastore updated successfully");
}} catch (error) {{
console.error("Failed to update datastore", error);
}}
}}
// Function to update UI from rulesData (used after fetching datastore state)
function updateUIFromRulesData() {{
document.querySelectorAll('.rule-row').forEach(ruleRow => {{
const ruleIndex = parseInt(ruleRow.dataset.ruleIndex);
const rule = rulesData[ruleIndex];
// Update status buttons
const buttons = ruleRow.querySelectorAll('.status-button');
buttons.forEach(btn => {{
if (btn.dataset.action === rule.checked) {{
btn.classList.add('active');
}} else {{
btn.classList.remove('active');
}}
}});
// Update editable text fields
ruleRow.querySelectorAll('.editable-text').forEach(div => {{
const field = div.dataset.field;
if (rule[field] !== undefined) {{
div.innerText = rule[field];
}}
}});
}});
}}
// On page load, fetch data from the datastore and update UI accordingly
document.addEventListener('DOMContentLoaded', async function() {{
try {{
const datastoreState = await fetchDatastore();
if (datastoreState.length) {{
rulesData = datastoreState;
updateUIFromRulesData();
outputJSON();
}}
}} catch (error) {{
console.error("Error fetching datastore", error);
}}
}});
</script>
</body>
</html>
"""
return html
def get_page_datastore(html: str):
"""
Fetch the JSON datastore from the presigned URL.
Returns a dict. If any error or no content, returns {}.
"""
match = re.search(r"const presignedGetUrl = \"(.*?)\";", html)
if not match:
return None
presigned_url = match.group(1)
try:
# Clean up the presigned URL (sometimes the signature may need re-encoding)
url_parts = urlsplit(presigned_url)
query_params = parse_qs(url_parts.query)
encoded_query = urlencode(query_params, doseq=True)
cleaned_url = urlunsplit((url_parts.scheme, url_parts.netloc, url_parts.path, encoded_query, url_parts.fragment))
resp = requests.get(cleaned_url)
resp.raise_for_status()
return resp.json()
except Exception as e:
print(f"Error fetching datastore from {presigned_url}: {e}")
return None
def main():
parser = argparse.ArgumentParser(description="Generate an interactive HTML visualization of PDF rules.")
parser.add_argument("rules_file", help="Path to the rules file (JSON lines format)")
parser.add_argument("-o", "--output", help="Output HTML file path", default="interactive_pdf_rules.html")
args = parser.parse_args()
if not os.path.exists(args.rules_file):
print(f"Error: Rules file not found: {args.rules_file}")
sys.exit(1)
if os.path.exists(args.output):
print(f"Output file {args.output} already exists, attempting to reload it's datastore")
with open(args.output, "r") as df:
datastore = get_page_datastore(df.read())
if datastore is None:
print(f"Datastore for {args.output} is empty, please run tinyhost and verify your rules and then rerun the script")
sys.exit(1)
print(f"Loaded {len(datastore)} entries from datastore, updating {args.rules_file}")
with open(args.rules_file, "w") as of:
for rule in datastore:
of.write(json.dumps(rule) + "\n")
return
pdf_rules = parse_rules_file(args.rules_file)
html = generate_html(pdf_rules, args.rules_file)
with open(args.output, "w") as f:
f.write(html)
print(f"Interactive HTML visualization created: {args.output}")
if __name__ == "__main__":
main()
import importlib.util
import logging
import subprocess
import sys
logger = logging.getLogger(__name__)
def check_poppler_version():
try:
result = subprocess.run(["pdftoppm", "-h"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
if result.returncode == 0 and result.stderr.startswith("pdftoppm"):
logger.info("pdftoppm is installed and working.")
else:
logger.error("pdftoppm is installed but returned an error.")
sys.exit(1)
except FileNotFoundError:
logger.error("pdftoppm is not installed.")
logger.error("Check the README in the https://github.com/allenai/olmocr/blob/main/README.md for installation instructions")
sys.exit(1)
def check_sglang_version():
if importlib.util.find_spec("sglang") is None:
logger.error("Please make sure sglang is installed according to the latest instructions here: https://docs.sglang.ai/start/install.html")
logger.error("Sglang needs to be installed with a separate command in order to find all dependencies properly.")
sys.exit(1)
def check_torch_gpu_available(min_gpu_memory: int = 20 * 1024**3):
try:
import torch
except:
logger.error("Pytorch must be installed, visit https://pytorch.org/ for installation instructions")
raise
try:
gpu_memory = torch.cuda.get_device_properties(0).total_memory
assert gpu_memory >= min_gpu_memory
except:
logger.error(f"Torch was not able to find a GPU with at least {min_gpu_memory // (1024 ** 3)} GB of RAM.")
raise
if __name__ == "__main__":
check_poppler_version()
check_sglang_version()
import argparse
import glob
import json
import os
import random
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Generator
from urllib.parse import urlparse
import boto3
from pypdf import PdfReader
from tqdm import tqdm
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.filter import PdfFilter
from olmocr.prompts import (
build_openai_silver_data_prompt,
openai_response_format_schema,
)
from olmocr.prompts.anchor import get_anchor_text
TARGET_IMAGE_DIM = 2048
pdf_filter = PdfFilter()
def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int) -> dict:
image_base64 = render_pdf_to_base64png(local_pdf_path, page, TARGET_IMAGE_DIM)
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
# DEBUG crappy temporary code here that does the actual api call live so I can debug it a bit
# from openai import OpenAI
# client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
# response = client.chat.completions.create(
# model="gpt-4o-2024-08-06",
# messages= [
# {
# "role": "user",
# "content": [
# {"type": "text", "text": build_openai_silver_data_prompt(anchor_text)},
# {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
# ],
# }
# ],
# temperature=0.1,
# max_tokens=3000,
# logprobs=True,
# top_logprobs=5,
# response_format=openai_response_format_schema()
# )
# print(response)
# Construct OpenAI Batch API request format#
# There are a few tricks to know when doing data processing with OpenAI's apis
# First off, use the batch query system, it's 1/2 the price and exactly the same performance
# Second off, use structured outputs. If your application is not an actual chatbot, use structured outputs!
# Even if the last 10 queries you ran with the regular chat api returned exactly what you wanted without extra "LLM fluff text", that doesn't mean this will hold across 1000's of queries
# Also, structured outputs let you cheat, because the order in which fields are in your schema, is the order in which the model will answer them, so you can have it answer some "preperatory" or "chain of thought" style questions first before going into the meat of your response, which is going to give better answers
# Check your prompt for typos, it makes a performance difference!
# Ask for logprobs, it's not any more expensive and you can use them later to help identify problematic responses
return {
"custom_id": f"{pretty_pdf_path}-{page}",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "gpt-4o-2024-08-06",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": build_openai_silver_data_prompt(anchor_text)},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
],
}
],
"temperature": 0.1,
"max_tokens": 6000,
"logprobs": True,
"top_logprobs": 5,
"response_format": openai_response_format_schema(),
},
}
def sample_pdf_pages(num_pages: int, first_n_pages: int, max_sample_pages: int) -> list:
if num_pages <= first_n_pages:
return list(range(1, num_pages + 1)) # Return all pages if fewer than first_n_pages
sample_pages = list(range(1, first_n_pages + 1)) # Always get the first_n_pages
remaining_pages = list(range(first_n_pages + 1, num_pages + 1))
if remaining_pages:
sample_pages += random.sample(remaining_pages, min(max_sample_pages - first_n_pages, len(remaining_pages)))
return sample_pages
def fetch_s3_file(s3_url: str, local_path: str) -> str:
parsed = urlparse(s3_url)
bucket_name = parsed.netloc
key = parsed.path.lstrip("/")
s3 = boto3.client("s3")
s3.download_file(bucket_name, key, local_path)
return local_path
def process_pdf(pdf_path: str, first_n_pages: int, max_sample_pages: int, no_filter: bool) -> Generator[dict, None, None]:
if pdf_path.startswith("s3://"):
local_pdf_path = os.path.join("/tmp", os.path.basename(pdf_path))
fetch_s3_file(pdf_path, local_pdf_path)
else:
local_pdf_path = pdf_path
if (not no_filter) and pdf_filter.filter_out_pdf(local_pdf_path):
print(f"Skipping {local_pdf_path} due to common filter")
return []
pretty_pdf_path = pdf_path
pdf = PdfReader(local_pdf_path)
num_pages = len(pdf.pages)
sample_pages = sample_pdf_pages(num_pages, first_n_pages, max_sample_pages)
result = []
for page in sample_pages:
try:
query = build_page_query(local_pdf_path, pretty_pdf_path, page)
result.append(query)
except Exception as e:
print(f"Error processing page {page} of {pdf_path}: {e}")
return result
def main():
parser = argparse.ArgumentParser(description="Sample PDFs and create requests for GPT-4o.")
parser.add_argument("--glob_path", type=str, help="Local or S3 path glob (e.g., *.pdf or s3://bucket/pdfs/*.pdf).")
parser.add_argument("--path_list", type=str, help="Path to a file containing paths to PDFs, one per line.")
parser.add_argument("--no_filter", action="store_true", help="Disables the basic spam/language filtering so that ALL pdfs listed are used")
parser.add_argument("--num_sample_docs", type=int, default=5000, help="Number of PDF documents to sample.")
parser.add_argument("--first_n_pages", type=int, default=0, help="Always sample the first N pages of each PDF.")
parser.add_argument("--max_sample_pages", type=int, default=15, help="Max number of pages to sample per PDF.")
parser.add_argument("--output", type=str, default="openai_batch_data", help="Output destination")
parser.add_argument("--reservoir_size", type=int, default=None, help="Size of the reservoir for sampling paths. Defaults to 10x num_sample_docs.")
args = parser.parse_args()
# Set default reservoir_size if not provided
if args.reservoir_size is None:
args.reservoir_size = 10 * args.num_sample_docs
# Initialize reservoir sampling variables
pdf_paths = []
n = 0 # Total number of items seen
# Load PDF paths from glob or path_list using reservoir sampling
if args.glob_path:
if args.glob_path.startswith("s3://"):
# Handle S3 globbing using boto3 with pagination
parsed = urlparse(args.glob_path)
s3 = boto3.client("s3")
bucket_name = parsed.netloc
prefix = os.path.dirname(parsed.path.lstrip("/")) + "/"
paginator = s3.get_paginator("list_objects_v2")
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
for page in page_iterator:
for obj in page.get("Contents", []):
if obj["Key"].endswith(".pdf"):
n += 1
path = f"s3://{bucket_name}/{obj['Key']}"
if len(pdf_paths) < args.reservoir_size:
pdf_paths.append(path)
else:
s = random.randint(1, n)
if s <= args.reservoir_size:
pdf_paths[s - 1] = path
else:
# Handle local globbing using glob.iglob()
for path in glob.iglob(args.glob_path, recursive=True):
n += 1
if len(pdf_paths) < args.reservoir_size:
pdf_paths.append(path)
else:
s = random.randint(1, n)
if s <= args.reservoir_size:
pdf_paths[s - 1] = path
elif args.path_list:
with open(args.path_list, "r") as f:
for line in f:
n += 1
path = line.strip()
if len(pdf_paths) < args.reservoir_size:
pdf_paths.append(path)
else:
s = random.randint(1, n)
if s <= args.reservoir_size:
pdf_paths[s - 1] = path
# Shuffle the reservoir
random.shuffle(pdf_paths)
print(f"Loaded and shuffled {len(pdf_paths)} paths to use.")
# Rest of the code remains the same
cur_file_num = 0
output_dir = args.output
max_file_size = 99 * 1024 * 1024 # 99MB in bytes
cur_file_size = 0
cur_file_path = os.path.join(output_dir, f"output_{cur_file_num}.jsonl")
# Ensure output directory exists
os.makedirs(output_dir, exist_ok=True)
# Open the first file for writing
cur_file = open(cur_file_path, "w")
# Counter to track PDFs that produce at least one output
pdfs_with_output = 0
# Using ThreadPoolExecutor to process files concurrently
with ProcessPoolExecutor() as executor:
futures = []
with tqdm(desc="Processing PDFs", leave=False, total=args.num_sample_docs) as pb:
for pdf_path in pdf_paths:
futures.append(executor.submit(process_pdf, pdf_path, args.first_n_pages, args.max_sample_pages, args.no_filter))
for future in as_completed(futures):
has_output = False # Track if the current PDF produces at least one request
try:
request_results = future.result() # Get the result from the thread
for request_obj in request_results:
request_json = json.dumps(request_obj)
request_size = len(request_json.encode("utf-8")) # Calculate size in bytes
# Check if the current request can fit in the current file
if cur_file_size + request_size > max_file_size:
# Close the current file and create a new one
cur_file.close()
cur_file_num += 1
cur_file_path = os.path.join(output_dir, f"output_{cur_file_num}.jsonl")
cur_file = open(cur_file_path, "w")
cur_file_size = 0 # Reset file size
# Write the JSON entry to the file
cur_file.write(request_json)
cur_file.write("\n")
cur_file_size += request_size
has_output = True # At least one request object was generated
if has_output:
pdfs_with_output += 1
pb.update(1)
if pdfs_with_output >= args.num_sample_docs:
executor.shutdown(cancel_futures=True)
break
except Exception as e:
print(f"Error processing {pdf_path}: {str(e)}")
# Close the last open file
cur_file.close()
# Print or log the number of PDFs that resulted in at least one output
print(f"Number of sampled PDFs that produced at least one output: {pdfs_with_output}")
if __name__ == "__main__":
main()
import argparse
import csv
import json
import os
import random
import re
import sqlite3
from collections import Counter
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Optional
from urllib.parse import urlparse
from tqdm import tqdm
def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf-\d+"
match = re.match(pattern, pretty_pdf_path)
if match:
return match.group(1) + match.group(2)
return None
def cache_athena_csv_to_db(athena_csv_path: str) -> str:
db_path = athena_csv_path + ".db"
if not os.path.exists(db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("PRAGMA synchronous = OFF;")
cursor.execute("PRAGMA journal_mode = MEMORY;")
cursor.execute(
"""
CREATE TABLE pdf_mapping (
pdf_hash TEXT PRIMARY KEY,
uri TEXT
)
"""
)
with open(athena_csv_path, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
batch = []
for row in tqdm(reader):
batch.append((row["distinct_pdf_hash"], row["uri"]))
if len(batch) == 1000:
cursor.executemany("INSERT INTO pdf_mapping (pdf_hash, uri) VALUES (?, ?)", batch)
conn.commit()
batch = []
if batch:
cursor.executemany("INSERT INTO pdf_mapping (pdf_hash, uri) VALUES (?, ?)", batch)
conn.commit()
conn.close()
return db_path
def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,))
result = cursor.fetchone()
conn.close()
return result[0] if result else None
def process_file(filepath, db_path):
results = []
with open(filepath, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError:
continue
custom_id = data.get("custom_id")
if not custom_id:
continue
pdf_hash = parse_pdf_hash(custom_id)
if not pdf_hash:
continue
uri = get_uri_from_db(db_path, pdf_hash)
domain = None
if uri:
parsed = urlparse(uri)
domain = parsed.netloc
results.append((custom_id, uri, domain))
return results
def main():
parser = argparse.ArgumentParser(
description="Review silver dataset and provide summary statistics based on source URL and also provide a few data samples for review."
)
parser.add_argument(
"--input",
type=str,
default="openai_batch_data",
help="Input folder, which is the output of the buildsilver.py script",
)
parser.add_argument(
"--output",
type=str,
default="openai_batch_data_summary",
help="Output destination (folder)",
)
parser.add_argument(
"--athena-csv",
type=str,
default="/home/ubuntu/s2pdf_url_data/c974870d-3b06-4793-9a62-d46d38e2c8b2.csv",
help="CSV file that maps pdf_hash to uri",
)
parser.add_argument(
"--sample-size",
type=int,
default=20,
help="How many sample rows to include in the sample CSV",
)
args = parser.parse_args()
db_path = cache_athena_csv_to_db(args.athena_csv)
all_rows = []
filepaths = [os.path.join(args.input, filename) for filename in os.listdir(args.input) if filename.endswith(".jsonl")]
with ProcessPoolExecutor() as executor:
future_to_file = {executor.submit(process_file, filepath, db_path): filepath for filepath in filepaths}
for future in tqdm(as_completed(future_to_file), total=len(filepaths)):
try:
results = future.result()
all_rows.extend(results)
except Exception as e:
print(f"Error processing file: {future_to_file[future]}\n{e}")
os.makedirs(args.output, exist_ok=True)
output_csv_path = os.path.join(args.output, "custom_id_to_url.csv")
with open(output_csv_path, "w", encoding="utf-8", newline="") as f:
writer = csv.writer(f)
writer.writerow(["custom_id", "uri", "domain"])
for cid, uri, domain in all_rows:
writer.writerow([cid, uri if uri else "", domain if domain else ""])
domain_counter: Counter[str] = Counter()
for _, _, domain in all_rows:
if domain:
domain_counter[domain] += 1
most_common_domains = domain_counter.most_common(1000)
domain_csv_path = os.path.join(args.output, "top_1000_domains.csv")
with open(domain_csv_path, "w", encoding="utf-8", newline="") as f:
writer = csv.writer(f)
writer.writerow(["domain", "count"])
for domain, count in most_common_domains:
writer.writerow([domain, count])
sample_size = min(args.sample_size, len(all_rows))
sample_rows = random.sample(all_rows, sample_size) if all_rows else []
sample_csv_path = os.path.join(args.output, "data_samples.csv")
with open(sample_csv_path, "w", encoding="utf-8", newline="") as f:
writer = csv.writer(f)
writer.writerow(["custom_id", "uri", "domain"])
for cid, uri, domain in sample_rows:
writer.writerow([cid, uri if uri else "", domain if domain else ""])
print(f"Summary files written to: {args.output}")
print(f" - Full mapping: {output_csv_path}")
print(f" - Top domains: {domain_csv_path}")
print(f" - Samples: {sample_csv_path}")
if __name__ == "__main__":
main()
import argparse
import base64
import glob
import os
import random
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import List
from urllib.parse import urlparse
import boto3
from pypdf import PdfReader, PdfWriter
from tqdm import tqdm
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.filter import PdfFilter
pdf_filter = PdfFilter()
def sample_pdf_pages(num_pages: int, first_n_pages: int, max_sample_pages: int) -> List[int]:
"""
Returns a list of sampled page indices (1-based).
- Always include the first_n_pages (or all pages if num_pages < first_n_pages).
- Randomly sample the remaining pages up to a total of max_sample_pages.
"""
if num_pages <= first_n_pages:
return list(range(1, num_pages + 1))
sample_pages = list(range(1, first_n_pages + 1))
remaining_pages = list(range(first_n_pages + 1, num_pages + 1))
if remaining_pages:
# How many random pages to pick beyond the first_n_pages
random_pick = min(max_sample_pages - first_n_pages, len(remaining_pages))
sample_pages += random.sample(remaining_pages, random_pick)
return sample_pages
def fetch_s3_file(s3_url: str, local_path: str) -> str:
"""
Download a file from an S3 URI (s3://bucket/key) to local_path.
"""
parsed = urlparse(s3_url)
bucket_name = parsed.netloc
key = parsed.path.lstrip("/")
s3 = boto3.client("s3")
s3.download_file(bucket_name, key, local_path)
return local_path
def extract_single_page_pdf(input_pdf_path: str, page_number: int, output_pdf_path: str) -> None:
"""
Extracts exactly one page (page_number, 1-based) from input_pdf_path
and writes to output_pdf_path.
"""
reader = PdfReader(input_pdf_path)
writer = PdfWriter()
# Page numbers in PdfReader are 0-based
writer.add_page(reader.pages[page_number - 1])
with open(output_pdf_path, "wb") as f:
writer.write(f)
def process_pdf(pdf_path: str, first_n_pages: int, max_sample_pages: int, no_filter: bool, output_dir: str):
"""
- Download the PDF locally if it's in S3.
- Optionally filter the PDF (if no_filter=False).
- Sample the pages.
- For each sampled page, extract a one-page PDF and also render it to PNG.
"""
if pdf_path.startswith("s3://"):
local_pdf_path = os.path.join("/tmp", os.path.basename(pdf_path))
fetch_s3_file(pdf_path, local_pdf_path)
else:
local_pdf_path = pdf_path
if (not no_filter) and pdf_filter.filter_out_pdf(local_pdf_path):
print(f"Skipping {local_pdf_path} due to filter.")
return False
# Make sure we have an absolute path for the PDF name
base_pdf_name = os.path.splitext(os.path.basename(pdf_path))[0]
reader = PdfReader(local_pdf_path)
num_pages = len(reader.pages)
sampled_pages = sample_pdf_pages(num_pages, first_n_pages, max_sample_pages)
# For each sampled page, produce a single-page PDF and a PNG
for page_num in sampled_pages:
single_pdf_name = f"{base_pdf_name}_page{page_num}.pdf"
single_png_name = f"{base_pdf_name}_page{page_num}.png"
single_pdf_path = os.path.join(output_dir, single_pdf_name)
single_png_path = os.path.join(output_dir, single_png_name)
try:
# 1) Extract single-page PDF
extract_single_page_pdf(local_pdf_path, page_num, single_pdf_path)
# 2) Render that single-page PDF to a PNG
b64png = render_pdf_to_base64png(single_pdf_path, page_num=0, target_longest_image_dim=1024)
with open(single_png_path, "wb") as pngf:
pngf.write(base64.b64decode(b64png))
except Exception as e:
print(f"Error while processing {pdf_path}, page {page_num}: {e}")
return True
def main():
parser = argparse.ArgumentParser(description="Sample PDFs, extract single-page PDFs, and render them as PNG.")
parser.add_argument("--glob_path", type=str, help="Local or S3 path glob (e.g., *.pdf or s3://bucket/pdfs/*.pdf).")
parser.add_argument("--path_list", type=str, help="Path to a file containing paths to PDFs, one per line.")
parser.add_argument("--no_filter", action="store_true", help="Disables filtering so that ALL PDFs are processed.")
parser.add_argument("--num_sample_docs", type=int, default=2000, help="Number of PDF documents to sample.")
parser.add_argument("--first_n_pages", type=int, default=0, help="Always sample the first N pages of each PDF.")
parser.add_argument("--max_sample_pages", type=int, default=1, help="Max number of pages to sample per PDF.")
parser.add_argument("--output_dir", type=str, default="sampled_pages_output", help="Output directory for the extracted PDFs and PNGs.")
parser.add_argument("--reservoir_size", type=int, default=None, help="Size of the reservoir for sampling paths. Defaults to 10x num_sample_docs.")
args = parser.parse_args()
# Set default reservoir_size if not provided
if args.reservoir_size is None:
args.reservoir_size = 10 * args.num_sample_docs
os.makedirs(args.output_dir, exist_ok=True)
# Reservoir sample for PDF paths
pdf_paths = []
n = 0 # total number of items seen
# Either load from glob or from path_list
if args.glob_path:
if args.glob_path.startswith("s3://"):
# Handle S3 globbing
parsed = urlparse(args.glob_path)
s3 = boto3.client("s3")
bucket_name = parsed.netloc
prefix = os.path.dirname(parsed.path.lstrip("/")) + "/"
paginator = s3.get_paginator("list_objects_v2")
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
for page in page_iterator:
for obj in page.get("Contents", []):
if obj["Key"].endswith(".pdf"):
n += 1
path = f"s3://{bucket_name}/{obj['Key']}"
if len(pdf_paths) < args.reservoir_size:
pdf_paths.append(path)
else:
s = random.randint(1, n)
if s <= args.reservoir_size:
pdf_paths[s - 1] = path
else:
# Handle local globbing
for path in glob.iglob(args.glob_path, recursive=True):
n += 1
if len(pdf_paths) < args.reservoir_size:
pdf_paths.append(path)
else:
s = random.randint(1, n)
if s <= args.reservoir_size:
pdf_paths[s - 1] = path
elif args.path_list:
with open(args.path_list, "r") as f:
for line in f:
path = line.strip()
if not path:
continue
n += 1
if len(pdf_paths) < args.reservoir_size:
pdf_paths.append(path)
else:
s = random.randint(1, n)
if s <= args.reservoir_size:
pdf_paths[s - 1] = path
# Shuffle the reservoir so we don't always pick from the front
random.shuffle(pdf_paths)
print(f"Loaded and shuffled {len(pdf_paths)} PDF paths. Will process up to {args.num_sample_docs} of them.")
pdfs_with_output = 0
# Use a ProcessPoolExecutor to parallelize PDF processing
# You may reduce max_workers if you have memory/CPU constraints
with ProcessPoolExecutor() as executor:
futures = {}
# Submit tasks
for pdf_path in pdf_paths:
future = executor.submit(process_pdf, pdf_path, args.first_n_pages, args.max_sample_pages, args.no_filter, args.output_dir)
futures[future] = pdf_path
# Track completion
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing PDFs"):
if future.result():
pdfs_with_output += 1
if pdfs_with_output >= args.num_sample_docs:
# Cancel remaining tasks
executor.shutdown(cancel_futures=True)
break
print(f"Done. Processed or attempted to process {pdfs_with_output} PDFs. Output is in: {args.output_dir}")
if __name__ == "__main__":
main()
import argparse
import json
import logging
import os
import re
import sys
import tempfile
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
import boto3
# Import Plotly for plotting
import plotly.express as px
import smart_open
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts import build_finetuning_prompt
from olmocr.prompts.anchor import get_anchor_text
def setup_logging():
"""Configure logging for the script."""
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s", handlers=[logging.StreamHandler(sys.stdout)])
def is_s3_path(path):
"""Check if the given path is an S3 path."""
return str(path).startswith("s3://")
def download_pdf_from_s3(s3_path: str, pdf_profile: str) -> str:
"""
Downloads a PDF file from S3 to a temporary local file and returns the local file path.
Args:
s3_path (str): S3 path in the format s3://bucket/key
pdf_profile (str): The name of the boto3 profile to use.
Returns:
str: Path to the downloaded PDF file in the local filesystem.
"""
# Parse the bucket and key from the s3_path
# s3_path format: s3://bucket_name/some/folder/file.pdf
path_without_scheme = s3_path.split("s3://", 1)[1]
bucket_name, key = path_without_scheme.split("/", 1)
# Create a session with the specified profile or default
session = boto3.Session(profile_name=pdf_profile) if pdf_profile else boto3.Session()
s3_client = session.client("s3")
# Create a temporary local file
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
tmp_file.close() # We only want the path and not keep it locked
local_path = tmp_file.name
logging.info(f"Downloading PDF from {s3_path} to {local_path} using profile {pdf_profile}")
s3_client.download_file(bucket_name, key, local_path)
return local_path
def transform_json_object(obj):
"""
Transform a single JSON object by extracting and renaming specific fields.
Args:
obj (dict): Original JSON object.
Returns:
dict or None: Transformed JSON object, or None if there's an error.
"""
try:
transformed = {
"custom_id": obj["custom_id"],
"chat_messages": obj["body"]["messages"],
"temperature": obj["body"]["temperature"],
"max_tokens": obj["body"]["max_tokens"],
}
return transformed
except KeyError as e:
logging.error(f"Missing key {e} in object: {obj.get('custom_id', 'unknown')}")
return None
def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool, pdf_profile: str):
"""
Process a single JSONL file: read, transform, and write to output.
Args:
input_file (str): Path or URL to the input JSONL file.
output_file (str): Path or URL to the output JSONL file.
rewrite_prompt_str (bool): Flag to rewrite the prompt string.
pdf_profile (str): Boto3 profile to use when fetching PDFs from S3.
"""
processed_count = 0
error_count = 0
prompt_lengths = []
try:
with smart_open.open(input_file, "r", encoding="utf-8") as infile, smart_open.open(output_file, "w", encoding="utf-8") as outfile:
for line_number, line in enumerate(infile, 1):
line = line.strip()
if not line:
continue # Skip empty lines
try:
obj = json.loads(line)
except json.JSONDecodeError as e:
logging.error(f"JSON decode error in file {input_file} at line {line_number}: {e}")
error_count += 1
continue
transformed = transform_json_object(obj)
if transformed is not None and rewrite_prompt_str:
# We look for RAW_TEXT_START ... RAW_TEXT_END in the existing content
pattern = r"RAW_TEXT_START\s*\n(.*?)\nRAW_TEXT_END"
match = re.search(pattern, transformed["chat_messages"][0]["content"][0]["text"], re.DOTALL)
if match:
# We found raw page text, but we'll attempt to regenerate it
goldkey = obj["custom_id"]
# goldkey might look like: "s3://bucket/path/to/file.pdf-23"
# s3_path = everything up to the last dash
# page = everything after the dash
try:
s3_path = goldkey[: goldkey.rindex("-")]
page = int(goldkey[goldkey.rindex("-") + 1 :])
except (ValueError, IndexError) as e:
logging.error(f"Could not parse the page number from custom_id {goldkey}: {e}")
error_count += 1
continue
# If the path is an S3 path, download to a local temp file; else assume local
if is_s3_path(s3_path):
local_pdf_path = download_pdf_from_s3(s3_path, pdf_profile)
else:
local_pdf_path = s3_path
# Recalculate the anchor text
raw_page_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport", target_length=6000)
image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024)
transformed["chat_messages"][0]["content"][0]["text"] = build_finetuning_prompt(raw_page_text)
transformed["chat_messages"][0]["content"][1]["image_url"]["url"] = f"data:image/png;base64,{image_base64}"
# Clean up the temp PDF file if it was downloaded
if is_s3_path(s3_path):
try:
os.remove(local_pdf_path)
except OSError as remove_err:
logging.error(f"Failed to remove temporary PDF file {local_pdf_path}: {remove_err}")
if transformed is not None:
prompt_text = transformed["chat_messages"][0]["content"][0]["text"]
prompt_length = len(prompt_text)
if prompt_length > 6000:
print(transformed["custom_id"], "length ", prompt_length)
prompt_lengths.append(prompt_length)
outfile.write(json.dumps(transformed) + "\n")
processed_count += 1
else:
error_count += 1
logging.info(f"Processed '{input_file}': {processed_count} records transformed, {error_count} errors.")
return prompt_lengths
except Exception as e:
logging.exception(e)
logging.error(f"Failed to process file {input_file}: {e}")
return []
def construct_output_file_path(input_file_path, input_dir, output_dir):
"""
Given an input file path, input directory, and output directory,
construct the corresponding output file path.
Args:
input_file_path (str): Path to the input file.
input_dir (str): Path to the input directory.
output_dir (str): Path to the output directory.
Returns:
str: Path to the output file.
"""
input_file = Path(input_file_path)
if is_s3_path(input_dir):
# For S3 paths, manually construct the relative path based on the input S3 path
input_prefix = input_dir.split("s3://")[1]
input_prefix = input_prefix.rstrip("*") # Remove any glob patterns like *.jsonl
# Remove the 's3://' part from input_file_path and extract the relative part
input_file_key = input_file_path.split("s3://")[1]
relative_path = input_file_key[len(input_prefix) :].lstrip("/")
# Construct the output S3 path by appending the relative part to the output S3 directory
output_file_path = output_dir.rstrip("/") + "/" + relative_path
else:
# For local paths, use the existing relative path logic
input_dir_path = Path(input_dir)
relative_path = input_file.relative_to(input_dir_path)
output_file_path = str(Path(output_dir) / relative_path)
return output_file_path
def list_input_files(input_dir):
"""
List all JSONL files in the input directory. If input_dir is an S3 path, handle
globbing manually by listing objects and filtering based on patterns.
Args:
input_dir (str): Path to the input directory or S3 URL.
Returns:
list: List of input file paths.
"""
if is_s3_path(input_dir):
import fnmatch
# Parse bucket and prefix
bucket_name = input_dir.split("s3://")[1].split("/")[0]
path_and_pattern = "/".join(input_dir.split("s3://")[1].split("/")[1:])
# Separate the prefix and pattern
if "/" in path_and_pattern:
prefix = path_and_pattern.rsplit("/", 1)[0] + "/"
pattern = path_and_pattern.rsplit("/", 1)[1]
else:
prefix = ""
pattern = path_and_pattern
# Use a Boto3 session (no specific PDF profile needed here if only listing)
session = boto3.Session()
s3 = session.resource("s3")
bucket = s3.Bucket(bucket_name)
files = []
for obj in bucket.objects.filter(Prefix=prefix):
if fnmatch.fnmatch(obj.key, f"{prefix}{pattern}"):
files.append(f"s3://{bucket_name}/{obj.key}")
return files
else:
input_dir_path = Path(input_dir)
return [str(p) for p in input_dir_path.glob("*.jsonl")]
def main():
setup_logging()
parser = argparse.ArgumentParser(description="Transform JSONL files by extracting and renaming specific fields.")
parser.add_argument(
"--rewrite_finetuning_prompt",
action="store_true",
default=True,
help="Rewrite the input prompt from a standard OPENAI instruction format into a finetuned format.",
)
parser.add_argument("input_dir", type=str, help="Path to the input directory containing JSONL files. Can be a local path or S3 URL.")
parser.add_argument("output_dir", type=str, help="Path to the output directory where transformed JSONL files will be saved. Can be a local path or S3 URL.")
parser.add_argument("--jobs", "-j", type=int, default=20, help="Number of parallel jobs to run (default: 20).")
parser.add_argument("--pdf_profile", type=str, default=None, help="Boto3 profile to use for downloading PDFs from S3. Defaults to the default session.")
args = parser.parse_args()
input_dir = args.input_dir.rstrip("/")
output_dir = args.output_dir.rstrip("/")
max_jobs = args.jobs
# List input files
input_files = list_input_files(input_dir)
if not input_files:
logging.warning(f"No JSONL files found in '{input_dir}'. Exiting.")
sys.exit(0)
logging.info(f"Found {len(input_files)} JSONL files to process.")
# Prepare tasks for parallel processing
tasks = []
for input_file in input_files:
output_file = construct_output_file_path(input_file, input_dir, output_dir)
tasks.append((input_file, output_file))
# Process files in parallel
all_prompt_lengths = []
with ProcessPoolExecutor(max_workers=max_jobs) as executor:
future_to_file = {
executor.submit(process_file, input_file, output_file, args.rewrite_finetuning_prompt, args.pdf_profile): input_file
for input_file, output_file in tasks
}
for future in as_completed(future_to_file):
input_file = future_to_file[future]
try:
prompt_lengths = future.result()
all_prompt_lengths.extend(prompt_lengths)
except Exception as exc:
logging.error(f"File {input_file} generated an exception: {exc}")
logging.info("All files have been processed.")
# Plot histogram of prompt lengths
if all_prompt_lengths:
fig = px.histogram(all_prompt_lengths, nbins=50, title="Histogram of Prompt Lengths")
fig.update_xaxes(title="Prompt Length")
fig.update_yaxes(title="Frequency")
try:
fig.write_image("prompt_lengths_histogram.png")
logging.info("Histogram of prompt lengths has been saved to 'prompt_lengths_histogram.png'.")
except Exception as e:
logging.error(f"Failed to save the histogram image: {e}")
logging.error("Please make sure that the 'kaleido' package is installed (pip install -U kaleido).")
fig.write_html("prompt_lengths_histogram.html")
logging.info("Histogram of prompt lengths has been saved to 'prompt_lengths_histogram.html'.")
else:
logging.warning("No prompt lengths were collected; histogram will not be generated.")
if __name__ == "__main__":
main()
import argparse
import json
import logging
import os
import re
import sys
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
import smart_open
from cached_path import cached_path
def setup_logging():
"""Configure logging for the script."""
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s", handlers=[logging.StreamHandler(sys.stdout)])
def is_s3_path(path):
"""Check if the given path is an S3 path."""
return str(path).startswith("s3://")
def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool):
"""
Process a single JSONL file: read, transform, and write to output.
Args:
input_file (str): Path or URL to the input JSONL file.
output_file (str): Path or URL to the output JSONL file.
"""
processed_count = 0
error_count = 0
try:
with smart_open.open(input_file, "r", encoding="utf-8") as infile, smart_open.open(output_file, "w", encoding="utf-8") as outfile:
for line_number, line in enumerate(infile, 1):
line = line.strip()
if not line:
continue # Skip empty lines
try:
obj = json.loads(line)
except json.JSONDecodeError as e:
logging.error(f"JSON decode error in file {input_file} at line {line_number}: {e}")
error_count += 1
continue
if obj is not None and rewrite_prompt_str:
pattern = r"RAW_TEXT_START\s*\n(.*?)\nRAW_TEXT_END"
# Use re.DOTALL to ensure that the dot matches newline characters
match = re.search(pattern, obj["body"]["messages"][0]["content"][0]["text"], re.DOTALL)
if match:
# Ok, now we want to try to see if it's better if we recalculate the anchor text
goldkey = obj["custom_id"]
s3_path = goldkey[: goldkey.rindex("-")]
page = int(goldkey[goldkey.rindex("-") + 1 :])
# Save the pdf to a temporary cache folder
local_pdf_path = cached_path(s3_path, quiet=True)
from olmocr.data.buildsilver import build_page_query
obj = build_page_query(local_pdf_path, s3_path, page)
# raw_page_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
# from olmocr.prompts import build_openai_silver_data_prompt
# obj["body"]["messages"][0]["content"][0]["text"] = build_openai_silver_data_prompt(raw_page_text)
if obj is not None:
outfile.write(json.dumps(obj) + "\n")
processed_count += 1
else:
error_count += 1
logging.info(f"Processed '{input_file}': {processed_count} records transformed, {error_count} errors.")
except Exception as e:
logging.exception(e)
logging.error(f"Failed to process file {input_file}: {e}")
def construct_output_file_path(input_file_path, input_dir, output_dir):
"""
Given an input file path, input directory, and output directory,
construct the corresponding output file path.
Args:
input_file_path (str): Path to the input file.
input_dir (str): Path to the input directory.
output_dir (str): Path to the output directory.
Returns:
str: Path to the output file.
"""
input_file = Path(input_file_path)
if is_s3_path(input_dir):
# For S3 paths, manually construct the relative path based on the input S3 path
input_prefix = input_dir.split("s3://")[1]
input_prefix = input_prefix.rstrip("*") # Remove any glob patterns like *.jsonl
# Remove the 's3://' part from input_file_path and extract the relative part
input_file_key = input_file_path.split("s3://")[1]
relative_path = input_file_key[len(input_prefix) :].lstrip("/")
# Construct the output S3 path by appending the relative part to the output S3 directory
output_file_path = output_dir.rstrip("/") + "/" + relative_path
else:
# For local paths, use the existing relative path logic
input_dir_path = Path(input_dir)
relative_path = input_file.relative_to(input_dir_path)
output_file_path = str(Path(output_dir) / relative_path)
return output_file_path
def list_input_files(input_dir):
"""
List all JSONL files in the input directory. If input_dir is an S3 path, handle
globbing manually by listing objects and filtering based on patterns.
Args:
input_dir (str): Path to the input directory or S3 URL.
Returns:
list: List of input file paths.
"""
if is_s3_path(input_dir):
# Use smart_open's s3 functionality to list files
import fnmatch
import boto3
# Parse bucket and prefix
bucket_name = input_dir.split("s3://")[1].split("/")[0]
path_and_pattern = "/".join(input_dir.split("s3://")[1].split("/")[1:])
# Separate the prefix and pattern
if "/" in path_and_pattern:
prefix = path_and_pattern.rsplit("/", 1)[0] + "/"
pattern = path_and_pattern.rsplit("/", 1)[1]
else:
prefix = ""
pattern = path_and_pattern
# Set up S3 resource and bucket
s3 = boto3.resource("s3")
bucket = s3.Bucket(bucket_name)
# Get all objects and filter them manually based on the pattern
files = []
for obj in bucket.objects.filter(Prefix=prefix):
if fnmatch.fnmatch(obj.key, f"{prefix}{pattern}"):
files.append(f"s3://{bucket_name}/{obj.key}")
return files
else:
# Local path handling (with glob pattern)
input_dir_path = Path(input_dir)
return [str(p) for p in input_dir_path.glob("*.jsonl")]
def main():
setup_logging()
parser = argparse.ArgumentParser(description="Transform JSONL files by extracting and renaming specific fields.")
parser.add_argument("--rewrite_prompt", action="store_true", default=False, help="Rewrites the input prompt by reloading the pdf from source")
parser.add_argument("input_dir", type=str, help="Path to the input directory containing JSONL files. Can be a local path or S3 URL.")
parser.add_argument("output_dir", type=str, help="Path to the output directory where transformed JSONL files will be saved. Can be a local path or S3 URL.")
parser.add_argument("--jobs", "-j", type=int, default=20, help="Number of parallel jobs to run (default: 20).")
args = parser.parse_args()
input_dir = args.input_dir.rstrip("/")
output_dir = args.output_dir.rstrip("/")
max_jobs = args.jobs
if not output_dir.startswith("s3:"):
os.makedirs(output_dir, exist_ok=True)
# List input files
input_files = list_input_files(input_dir)
if not input_files:
logging.warning(f"No JSONL files found in '{input_dir}'. Exiting.")
sys.exit(0)
logging.info(f"Found {len(input_files)} JSONL files to process.")
# Prepare tasks for parallel processing
tasks = []
for input_file in input_files:
output_file = construct_output_file_path(input_file, input_dir, output_dir)
tasks.append((input_file, output_file))
# Process files in parallel
with ProcessPoolExecutor(max_workers=max_jobs) as executor:
future_to_file = {executor.submit(process_file, input_file, output_file, args.rewrite_prompt): input_file for input_file, output_file in tasks}
for future in as_completed(future_to_file):
input_file = future_to_file[future]
try:
future.result()
except Exception as exc:
logging.error(f"File {input_file} generated an exception: {exc}")
logging.info("All files have been processed.")
if __name__ == "__main__":
main()
import base64
import io
import subprocess
from typing import List
from PIL import Image
def get_pdf_media_box_width_height(local_pdf_path: str, page_num: int) -> tuple[float, float]:
"""
Get the MediaBox dimensions for a specific page in a PDF file using the pdfinfo command.
:param pdf_file: Path to the PDF file
:param page_num: The page number for which to extract MediaBox dimensions
:return: A dictionary containing MediaBox dimensions or None if not found
"""
# Construct the pdfinfo command to extract info for the specific page
command = ["pdfinfo", "-f", str(page_num), "-l", str(page_num), "-box", "-enc", "UTF-8", local_pdf_path]
# Run the command using subprocess
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
# Check if there is any error in executing the command
if result.returncode != 0:
raise ValueError(f"Error running pdfinfo: {result.stderr}")
# Parse the output to find MediaBox
output = result.stdout
for line in output.splitlines():
if "MediaBox" in line:
media_box_str: List[str] = line.split(":")[1].strip().split()
media_box: List[float] = [float(x) for x in media_box_str]
return abs(media_box[0] - media_box[2]), abs(media_box[3] - media_box[1])
raise ValueError("MediaBox not found in the PDF info.")
def render_pdf_to_base64png(local_pdf_path: str, page_num: int, target_longest_image_dim: int = 2048) -> str:
longest_dim = max(get_pdf_media_box_width_height(local_pdf_path, page_num))
# Convert PDF page to PNG using pdftoppm
pdftoppm_result = subprocess.run(
[
"pdftoppm",
"-png",
"-f",
str(page_num),
"-l",
str(page_num),
"-r",
str(target_longest_image_dim * 72 / longest_dim), # 72 pixels per point is the conversion factor
local_pdf_path,
],
timeout=120,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
assert pdftoppm_result.returncode == 0, pdftoppm_result.stderr
return base64.b64encode(pdftoppm_result.stdout).decode("utf-8")
def render_pdf_to_base64webp(local_pdf_path: str, page: int, target_longest_image_dim: int = 1024):
base64_png = render_pdf_to_base64png(local_pdf_path, page, target_longest_image_dim)
png_image = Image.open(io.BytesIO(base64.b64decode(base64_png)))
webp_output = io.BytesIO()
png_image.save(webp_output, format="WEBP")
return base64.b64encode(webp_output.getvalue()).decode("utf-8")
def get_png_dimensions_from_base64(base64_data) -> tuple[int, int]:
"""
Returns the (width, height) of a PNG image given its base64-encoded data,
without base64-decoding the entire data or loading the PNG itself
Should be really fast to support filtering
Parameters:
- base64_data (str): Base64-encoded PNG image data.
Returns:
- tuple: (width, height) of the image.
Raises:
- ValueError: If the data is not a valid PNG image or the required bytes are not found.
"""
# PNG signature is 8 bytes
png_signature_base64 = base64.b64encode(b"\x89PNG\r\n\x1a\n").decode("ascii")
if not base64_data.startswith(png_signature_base64[:8]):
raise ValueError("Not a valid PNG file")
# Positions in the binary data where width and height are stored
width_start = 16 # Byte position where width starts (0-based indexing)
_width_end = 20 # Byte position where width ends (exclusive)
_height_start = 20
height_end = 24
# Compute the byte range needed (from width_start to height_end)
start_byte = width_start
end_byte = height_end
# Calculate base64 character positions
# Each group of 3 bytes corresponds to 4 base64 characters
base64_start = (start_byte // 3) * 4
base64_end = ((end_byte + 2) // 3) * 4 # Add 2 to ensure we cover partial groups
# Extract the necessary base64 substring
base64_substring = base64_data[base64_start:base64_end]
# Decode only the necessary bytes
decoded_bytes = base64.b64decode(base64_substring)
# Compute the offset within the decoded bytes
offset = start_byte % 3
# Extract width and height bytes
width_bytes = decoded_bytes[offset : offset + 4]
height_bytes = decoded_bytes[offset + 4 : offset + 8]
if len(width_bytes) < 4 or len(height_bytes) < 4:
raise ValueError("Insufficient data to extract dimensions")
# Convert bytes to integers
width = int.from_bytes(width_bytes, "big")
height = int.from_bytes(height_bytes, "big")
return width, height
# Sends list of batch files to OpenAI for processing
# However, it also waits and gets the files when they are done, saves its state, and
# allows you to submit more than the 100GB of file request limits that the openaiAPI has
import argparse
import datetime
import json
import os
import time
from openai import OpenAI
from tqdm import tqdm
# Set up OpenAI client (API key should be set in the environment)
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
MAX_OPENAI_DISK_SPACE = 100 * 1024 * 1024 * 1024 # Max is 100GB on openAI
UPLOAD_STATE_FILENAME = "SENDSILVER_DATA"
# Function to upload a file to OpenAI and start batch processing
def upload_and_start_batch(file_path):
# Upload the file to OpenAI
with open(file_path, "rb") as file:
print(f"Uploading {file_path} to OpenAI Batch API...")
upload_response = client.files.create(file=file, purpose="batch")
file_id = upload_response.id
print(f"File uploaded successfully: {file_id}")
# Create a batch job
print(f"Creating batch job for {file_path}...")
batch_response = client.batches.create(
input_file_id=file_id, endpoint="/v1/chat/completions", completion_window="24h", metadata={"description": "pdf gold/silver data"}
)
batch_id = batch_response.id
print(f"Batch created successfully: {batch_id}")
return batch_id
def download_batch_result(batch_id, output_folder):
# Retrieve the batch result from OpenAI API
batch_data = client.batches.retrieve(batch_id)
if batch_data.status != "completed":
print(f"WARNING: {batch_id} is not completed, status: {batch_data.status}")
return batch_id, False
if batch_data.output_file_id is None:
print(f"WARNING: {batch_id} is completed, but no output file was generated")
return batch_id, False
print(f"Downloading batch data for {batch_id}")
file_response = client.files.content(batch_data.output_file_id)
# Define output file path
output_file = os.path.join(output_folder, f"{batch_id}.json")
# Save the result to a file
with open(output_file, "w") as f:
f.write(str(file_response.text))
return batch_id, True
ALL_STATES = ["init", "processing", "completed", "errored_out", "could_not_upload"]
FINISHED_STATES = ["completed", "errored_out"]
def _json_datetime_decoder(obj):
if "last_checked" in obj:
try:
obj["last_checked"] = datetime.datetime.fromisoformat(obj["last_checked"])
except (TypeError, ValueError):
pass # If it's not a valid ISO format, leave it as is
return obj
def _json_datetime_encoder(obj):
if isinstance(obj, datetime.datetime):
return obj.isoformat() # Convert datetime to ISO format string
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
def get_state(folder_path: str) -> dict:
state_file = os.path.join(folder_path, UPLOAD_STATE_FILENAME)
try:
with open(state_file, "r") as f:
return json.load(f, object_hook=_json_datetime_decoder)
except (json.decoder.JSONDecodeError, FileNotFoundError):
# List all .jsonl files in the specified folder
jsonl_files = [f for f in os.listdir(folder_path) if f.endswith(".jsonl")]
if not jsonl_files:
raise Exception("No JSONL files found to process")
state = {
f: {
"filename": f,
"batch_id": None,
"state": "init",
"size": os.path.getsize(os.path.join(folder_path, f)),
"last_checked": datetime.datetime.now(),
}
for f in jsonl_files
}
with open(state_file, "w") as f:
json.dump(state, f, default=_json_datetime_encoder)
return state
def update_state(folder_path: str, filename: str, **kwargs):
all_state = get_state(folder_path)
for kwarg_name, kwarg_value in kwargs.items():
all_state[filename][kwarg_name] = kwarg_value
all_state[filename]["last_checked"] = datetime.datetime.now()
state_file = os.path.join(folder_path, UPLOAD_STATE_FILENAME)
temp_file = state_file + ".tmp"
# Write to temporary file first
with open(temp_file, "w") as f:
json.dump(all_state, f, default=_json_datetime_encoder)
f.flush()
os.fsync(f.fileno())
# Atomic rename of temporary file to target file
os.replace(temp_file, state_file)
return all_state
def get_total_space_usage():
return sum(file.bytes for file in client.files.list())
def get_estimated_space_usage(folder_path):
all_states = get_state(folder_path)
return sum(s["size"] for s in all_states.values() if s["state"] == "processing")
def get_next_work_item(folder_path):
all_states = list(get_state(folder_path).values())
all_states = [s for s in all_states if s["state"] not in FINISHED_STATES]
all_states.sort(key=lambda s: s["last_checked"])
return all_states[0] if len(all_states) > 0 else None
def get_done_total(folder_path):
processing, done, total = 0, 0, 0
for state in get_state(folder_path).values():
if state["state"] in FINISHED_STATES:
done += 1
if state["state"] == "processing":
processing += 1
total += 1
return processing, done, total
# Main function to process all .jsonl files in a folder
def process_folder(folder_path: str, max_gb: int):
output_folder = f"{folder_path.rstrip('/')}_done"
os.makedirs(output_folder, exist_ok=True)
last_loop_time = datetime.datetime.now()
starting_free_space = MAX_OPENAI_DISK_SPACE - get_total_space_usage()
if starting_free_space < (max_gb * 1024**3) * 2:
raise ValueError(
f"Insufficient free space in OpenAI's file storage: Only {starting_free_space} GB left, but 2x{max_gb} GB are required (1x for your uploads, 1x for your results)."
)
while not all(state["state"] in FINISHED_STATES for state in get_state(folder_path).values()):
processing, done, total = get_done_total(folder_path)
print(f"Total items {total}, processing {processing}, done {done}, {done/total*100:.1f}%")
work_item = get_next_work_item(folder_path)
print(f"Processing {os.path.basename(work_item['filename'])}, cur status = {work_item['state']}")
# If all work items have been checked on, then you need to sleep a bit
if last_loop_time > datetime.datetime.now() - datetime.timedelta(seconds=1):
time.sleep(0.2)
if work_item["state"] == "init":
if get_estimated_space_usage(folder_path) < (max_gb * 1024**3):
try:
batch_id = upload_and_start_batch(os.path.join(folder_path, work_item["filename"]))
update_state(folder_path, work_item["filename"], state="processing", batch_id=batch_id)
except Exception as ex:
print(ex)
update_state(folder_path, work_item["filename"], state="init")
else:
print("waiting for something to finish processing before uploading more")
# Update the time you checked so you can move onto the next time
update_state(folder_path, work_item["filename"])
elif work_item["state"] == "processing":
batch_data = client.batches.retrieve(work_item["batch_id"])
if batch_data.status == "completed":
batch_id, success = download_batch_result(work_item["batch_id"], output_folder)
if success:
update_state(folder_path, work_item["filename"], state="completed")
else:
update_state(folder_path, work_item["filename"], state="errored_out")
try:
client.files.delete(batch_data.input_file_id)
except Exception as ex:
print(ex)
print("Could not delete old input data")
try:
client.files.delete(batch_data.output_file_id)
except Exception as ex:
print(ex)
print("Could not delete old output data")
elif batch_data.status in ["failed", "expired", "cancelled"]:
update_state(folder_path, work_item["filename"], state="errored_out")
try:
client.files.delete(batch_data.input_file_id)
except:
print("Could not delete old file data")
else:
# Update the time you checked so you can move onto the next time
update_state(folder_path, work_item["filename"])
last_loop_time = datetime.datetime.now()
print("All work has been completed")
if __name__ == "__main__":
# Set up argument parsing for folder input
parser = argparse.ArgumentParser(description="Upload .jsonl files and process batches in OpenAI API.")
parser.add_argument("--max_gb", type=int, default=25, help="Max number of GB of batch processing files to upload at one time")
parser.add_argument("--clear_all_files", action="store_true", help="Helper to delete ALL files stored in your openai account")
parser.add_argument("folder", type=str, help="Path to the folder containing .jsonl files")
args = parser.parse_args()
if args.clear_all_files:
all_files = list(client.files.list())
if input(f"Are you sure you want to delete {len(all_files)} files from your OpenAI account? [y/N]").lower() == "y":
for file in tqdm(all_files):
client.files.delete(file.id)
quit()
# Process the folder and start batches
process_folder(args.folder, args.max_gb)
import datetime
import hashlib
import json
from dataclasses import dataclass
@dataclass(frozen=True)
class PdfOutput:
path: str
text: str
total_pdf_pages: int
processed_pdf_pages: int
def mk_dolma_doc(self, **kwargs) -> str:
metadata = {
"Source-File": self.path,
"pdf-pages": self.processed_pdf_pages,
"pdf-total-pages": self.total_pdf_pages,
# Kwargs are added as extra metadata
**kwargs,
}
id_ = hashlib.sha1(self.text.encode()).hexdigest()
dolma_doc = {
"id": id_,
"text": self.text,
"source": "s2pdf",
"added": datetime.datetime.now().strftime("%Y-%m-%d"),
"created": datetime.datetime.now().strftime("%Y-%m-%d"),
"metadata": metadata,
}
return json.dumps(dolma_doc)
import argparse
import dataclasses
import functools
import random
import re
from concurrent.futures import ProcessPoolExecutor, as_completed
from itertools import combinations
import boto3
from dolma_refine.evaluate.aligners import HirschbergAligner
from dolma_refine.evaluate.metrics import DocumentEditSimilarity
from dolma_refine.evaluate.segmenters import SpacySegmenter
from tqdm import tqdm
from olmocr.eval.evalhtml import create_review_html
from olmocr.s3_utils import expand_s3_glob, get_s3_bytes
@dataclasses.dataclass
class Comparison:
pdf_path: str
comparison_a_path: str
comparison_b_path: str
comparison_a_str: str
comparison_b_str: str
alignment: float
@property
def comparison_a_method(self):
match = re.search(r"page[0-9]+_(\w+)\.md$", self.comparison_a_path)
if match:
return match.group(1)
raise ValueError(f"No match found in path: {self.comparison_a_path}")
@property
def comparison_b_method(self):
match = re.search(r"page[0-9]+_(\w+)\.md$", self.comparison_b_path)
if match:
return match.group(1)
raise ValueError(f"No match found in path: {self.comparison_b_path}")
def process_single_pdf(pdf_path, all_mds, comparisons, segmenter_name="spacy"):
"""Process a single PDF and return its comparisons."""
# Create resources inside the worker process
s3_client = boto3.client("s3")
segmenter = SpacySegmenter(segmenter_name)
aligner = HirschbergAligner(match_score=1, mismatch_score=-1, indel_score=-1)
comparer = DocumentEditSimilarity(segmenter=segmenter, aligner=aligner)
pdf_comps = []
result_comps = []
# Get all comparison files for this PDF
for comp in comparisons:
comp_path = pdf_path.replace(".pdf", f"_{comp}.md")
if comp_path in all_mds:
pdf_comps.append(comp_path)
# Generate all possible combinations
for compa, compb in combinations(pdf_comps, 2):
if random.choice([True, False]):
compa, compb = compb, compa
# Get the text content
text_a = get_s3_bytes(s3_client, compa).decode("utf-8")
text_b = get_s3_bytes(s3_client, compb).decode("utf-8")
result_comps.append(
Comparison(
pdf_path=pdf_path,
comparison_a_path=compa,
comparison_b_path=compb,
comparison_a_str=text_a,
comparison_b_str=text_b,
alignment=comparer.compute(text_a, text_b),
)
)
return result_comps
def build_review_page(args, comparisons, index=0):
page_data = []
for comp in comparisons:
page_data.append(
{
"s3_path": comp.pdf_path,
"page": 1,
"entry_key": comp.pdf_path + "-" + comp.comparison_a_method + "-" + comp.comparison_b_method,
"gold_text": comp.comparison_a_str,
"gold_metadata": comp.comparison_a_method,
"eval_text": comp.comparison_b_str,
"eval_metadata": comp.comparison_b_method,
"alignment": comp.alignment,
}
)
report_name = f"{args.name}{f'_{index}' if args.num_copies > 1 else ''}.html"
create_review_html(page_data, report_name)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generates comparison voting pages between different pairs of parses for a PDF.")
parser.add_argument("--name", default="review_page", help="What name to give to this evaluation/comparison")
parser.add_argument(
"--review_size",
default=50,
type=int,
help="Number of entries to show on the generated review page",
)
parser.add_argument(
"--max_workers",
type=int,
default=None,
help="Maximum number of worker processes to use for parallel processing",
)
parser.add_argument("--comparisons", default=["pdelf", "marker", "gotocr_format", "mineru"], help="Different variants to compare against")
parser.add_argument(
"--num_copies",
default=1,
type=int,
help="Number of reports to generate, labeled _0, _1, etc. if greater than 1",
)
parser.add_argument(
"s3_path", type=str, help="Path to the folder where you keep your data files, expecting to see *.md files in there along with *.png and *.pdf"
)
args = parser.parse_args()
# Create S3 client only for initial file listing
s3_client = boto3.client("s3")
# Get all PDFs and MD files
all_pdfs = set(expand_s3_glob(s3_client, args.s3_path + "/*.pdf"))
all_mds = set(expand_s3_glob(s3_client, args.s3_path + "/*.md"))
all_comps = []
# Create a partial function with all the common arguments
process_pdf = functools.partial(process_single_pdf, all_mds=all_mds, comparisons=args.comparisons)
# Use ProcessPoolExecutor for parallel processing
with ProcessPoolExecutor(max_workers=args.max_workers) as executor:
# Submit all PDF processing tasks
future_to_pdf = {executor.submit(process_pdf, pdf_path): pdf_path for pdf_path in all_pdfs}
# Process results as they complete using tqdm for progress
for future in tqdm(as_completed(future_to_pdf), total=len(all_pdfs)):
pdf_path = future_to_pdf[future]
try:
pdf_results = future.result()
all_comps.extend(pdf_results)
except Exception as e:
print(f"Error processing {pdf_path}: {str(e)}")
# Remove all results where the alignment is > 0.96 as these are just too similar to be useful
all_comps = [c for c in all_comps if c.alignment < 0.96]
# Shuffle the results
random.shuffle(all_comps)
# Generate the specified number of copies of the report
for i in range(args.num_copies):
start_index = i * args.review_size
end_index = start_index + args.review_size
# Check if there is enough data for the next report
if start_index >= len(all_comps):
print(f"Not enough data to generate report {i}. Stopping early.")
break
build_review_page(args, all_comps[start_index:end_index], index=i)
from typing import Type
from sequence_align.pairwise import hirschberg, needleman_wunsch
from .registry import BaseRegistry
class AlignerRegistry(BaseRegistry[Type["BaseAligner"]]):
"""A registry for aligners."""
class BaseAligner:
def __init__(self, *args, **kwargs):
super().__init__()
def align(self, gold: list[str], pred: list[str]) -> tuple[list[str], list[str]]:
raise NotImplementedError()
@AlignerRegistry.add("hirschberg")
class HirschbergAligner(BaseAligner):
def __init__(
self,
match_score: float = 1.0,
mismatch_score: float = -1.0,
indel_score: float = -1.0,
gap_token: str = "▓",
):
self.match_score = match_score
self.mismatch_score = mismatch_score
self.indel_score = indel_score
self.gap_token = gap_token
super().__init__()
def align(self, gold: list[str], pred: list[str]) -> tuple[list[str], list[str]]:
return hirschberg(
gold,
pred,
match_score=self.match_score,
mismatch_score=self.mismatch_score,
indel_score=self.indel_score,
gap=self.gap_token,
)
@AlignerRegistry.add("needleman-wunsch")
class NeedlemanWunschAligner(BaseAligner):
def __init__(
self,
match_score: float = 1.0,
mismatch_score: float = -1.0,
indel_score: float = -1.0,
gap_token: str = "▓",
):
self.match_score = match_score
self.mismatch_score = mismatch_score
self.indel_score = indel_score
self.gap_token = gap_token
super().__init__()
def align(self, gold: list[str], pred: list[str]) -> tuple[list[str], list[str]]:
return needleman_wunsch(
gold,
pred,
match_score=self.match_score,
mismatch_score=self.mismatch_score,
indel_score=self.indel_score,
gap=self.gap_token,
)
import bisect
from typing import Type
import regex as re
from tqdm import tqdm
from .aligners import AlignerRegistry, BaseAligner
from .registry import BaseRegistry
from .segmenters import BaseSegmenter, SegmenterRegistry
class TextMetricRegistry(BaseRegistry[Type["BaseTextMetric"]]):
"""A registry for text metrics."""
class BaseTextMetric:
def __init__(self, *args, **kwargs):
super().__init__()
def compute(self, gold: str, pred: str) -> float:
raise NotImplementedError()
def batch_compute(self, golds: list[str], preds: list[str]) -> list[float]:
it = tqdm(
zip(golds, preds),
total=min(len(golds), len(preds)),
desc=type(self).__name__,
unit="samples",
unit_scale=True,
)
return [self.compute(gold, pred) for gold, pred in it]
class BaseTextAlignMetric(BaseTextMetric):
def __init__(
self,
segmenter: str | BaseSegmenter,
aligner: str | BaseAligner = "hirschberg",
aligner_kwargs: dict = {},
segmenter_kwargs: dict = {},
gap_token: str = "▓",
*args,
**kwargs,
):
if isinstance(segmenter, str):
self.segmenter = SegmenterRegistry.get(segmenter)(segmenter, **segmenter_kwargs)
else:
self.segmenter = segmenter
if isinstance(aligner, str):
self.aligner = AlignerRegistry.get(aligner)(aligner, **aligner_kwargs)
else:
self.aligner = aligner
self.gap_token = gap_token
def segment(self, seq_a_tokens: list[str], seq_b_tokens: list[str]) -> list[tuple[list[str], list[str]]]:
return [(seq_a_tokens, seq_b_tokens)]
def align(self, seq_a_tokens: list[str], seq_b_tokens: list[str]) -> tuple[list[str], list[str]]:
return self.aligner.align(seq_a_tokens, seq_b_tokens)
def tokenize(self, text: str) -> list[str]:
return [w for w in re.split(r"(\p{P}+|\s+)", text) if w]
def compute(self, gold: str, pred: str) -> float:
raise NotImplementedError()
@TextMetricRegistry.add("document_edit_similarity")
class DocumentEditSimilarity(BaseTextAlignMetric):
def _score_aligned(self, aligned_gold_tokens: list[str], aligned_pred_tokens: list[str]) -> float:
insertions = deletions = matches = substitutions = 0.0
for gold_symbol, pred_symbol in zip(aligned_gold_tokens, aligned_pred_tokens):
if gold_symbol == self.gap_token:
insertions += 1
elif pred_symbol == self.gap_token:
deletions += 1
elif gold_symbol == pred_symbol:
matches += 1
else:
substitutions += 1
if total := insertions + deletions + matches + substitutions:
return matches / total
return 0.0
def compute(self, gold: str, pred: str) -> float:
gold_tokens = self.tokenize(gold)
pred_tokens = self.tokenize(pred)
aligned_gold_tokens, aligned_pred_tokens = self.align(gold_tokens, pred_tokens)
return self._score_aligned(aligned_gold_tokens, aligned_pred_tokens)
def find_align_gaps(aligned_text: list[str], gap_token: str = "▓", gap_threshold: int = 3) -> list[int]:
consecutive_gaps_counter = 0
above_threshold_locs: list[int] = []
for aligned_pos, symbol in enumerate(aligned_text):
if symbol == gap_token:
consecutive_gaps_counter += 1
else:
consecutive_gaps_counter = 0
if consecutive_gaps_counter >= gap_threshold:
above_threshold_locs.append(aligned_pos)
consecutive_gaps_counter = 0
return above_threshold_locs
def make_unaligned_text(tokens: list[str], gap_token: str = "▓") -> str:
return "".join(symbol for symbol in tokens if symbol != gap_token)
def find_sentences(
tokens: list[str],
sentences: list[str],
gap_token: str = "▓",
):
matches: list[tuple[int, int]] = []
original_text = ""
original: list[int] = []
original_to_aligned: list[int] = []
for i, token in enumerate(tokens):
if token != gap_token:
original_text += token
original.append(len(original_text))
original_to_aligned.append(i)
matches = []
for sentence in sentences:
start_pos = original_text.find(sentence)
if start_pos < 0:
continue
end_pos = start_pos + len(sentence)
start_token = original_to_aligned[bisect.bisect_left(original, start_pos)]
end_token = original_to_aligned[min(bisect.bisect_right(original, end_pos), len(original) - 1)]
matches.append((start_token, end_token))
return matches
def merge_spans(spans: list[tuple[int, int]]) -> list[tuple[int, int]]:
if not spans:
return []
# Sort spans based on start position
sorted_spans = sorted(spans, key=lambda x: x[0])
merged = [sorted_spans[0]]
for current in sorted_spans[1:]:
last = merged[-1]
# If current span overlaps with last merged span, update the end of last span
if current[0] <= last[1]:
merged[-1] = (last[0], max(last[1], current[1]))
else:
merged.append(current)
return merged
def make_sentences_around_gaps(sent_locs: list[tuple[int, int]], gaps_locs: list[int], window: int):
sent_start_only = [start for start, _ in sent_locs]
sentences_with_gaps = []
# collect all sentences that are around the gaps
for gap in gaps_locs:
start_idx = bisect.bisect_left(sent_start_only, gap)
fwd_window = max(0, start_idx - window)
bwd_window = min(len(sent_locs) - 1, start_idx + window)
sentences_with_gaps.append((sent_locs[fwd_window][0], sent_locs[bwd_window][-1]))
# merge overlapping sentences
sentences_with_gaps = merge_spans(sentences_with_gaps)
return sentences_with_gaps
@TextMetricRegistry.add("paragraph_edit_similarity")
class ParagraphEditSimilarity(DocumentEditSimilarity):
def __init__(
self,
segmenter: str | BaseSegmenter,
aligner: str | BaseAligner = "hirschberg",
aligner_kwargs: dict = {},
segmenter_kwargs: dict = {},
gap_token: str = "▓",
gap_threshold: int = 3,
sent_window: int = 1,
*args,
**kwargs,
):
super().__init__(
segmenter=segmenter,
aligner=aligner,
aligner_kwargs=aligner_kwargs,
segmenter_kwargs=segmenter_kwargs,
gap_token=gap_token,
)
self.gap_threshold = gap_threshold
self.sent_window = sent_window
def segment(self, seq_a_tokens: list[str], seq_b_tokens: list[str]) -> list[tuple[list[str], list[str]]]:
all_spans = []
for seq_tokens in (seq_a_tokens, seq_b_tokens):
text = make_unaligned_text(tokens=seq_tokens, gap_token=self.gap_token)
sentences = self.segmenter.segment(text)
sent_locs = find_sentences(tokens=seq_tokens, sentences=sentences, gap_token=self.gap_token)
gaps_locs = find_align_gaps(aligned_text=seq_tokens, gap_token=self.gap_token, gap_threshold=3)
sentences_with_gaps = make_sentences_around_gaps(sent_locs=sent_locs, gaps_locs=gaps_locs, window=self.sent_window)
all_spans.extend(sentences_with_gaps)
return [(seq_a_tokens[start:end], seq_b_tokens[start:end]) for start, end in merge_spans(all_spans)]
def compute(self, gold: str, pred: str) -> float:
gold_tokens = self.tokenize(gold)
pred_tokens = self.tokenize(pred)
aligned_gold_tokens, aligned_pred_tokens = self.align(gold_tokens, pred_tokens)
scores = []
for gold_segment, pred_segment in self.segment(aligned_gold_tokens, aligned_pred_tokens):
score = self._score_aligned(gold_segment, pred_segment)
scores.append(score)
return sum(scores) / len(scores) if scores else 1.0
import re
from typing import (
Callable,
Dict,
Generator,
Generic,
Literal,
Optional,
Tuple,
Type,
TypeVar,
overload,
)
T = TypeVar("T")
R = TypeVar("R")
class BaseRegistry(Generic[T]):
"""A registry for objects."""
_registry_of_registries: Dict[str, Type["BaseRegistry"]] = {}
_registry_storage: Dict[str, Tuple[T, Optional[str]]]
@classmethod
def _add_to_registry_of_registries(cls) -> None:
name = cls.__name__
if name not in cls._registry_of_registries:
cls._registry_of_registries[name] = cls
@classmethod
def registries(cls) -> Generator[Tuple[str, Type["BaseRegistry"]], None, None]:
"""Yield all registries in the registry of registries."""
yield from sorted(cls._registry_of_registries.items())
@classmethod
def _get_storage(cls) -> Dict[str, Tuple[T, Optional[str]]]:
if not hasattr(cls, "_registry_storage"):
cls._registry_storage = {}
return cls._registry_storage # pyright: ignore
@classmethod
def items(cls) -> Generator[Tuple[str, T], None, None]:
"""Yield all items in the registry."""
yield from sorted((n, t) for (n, (t, _)) in cls._get_storage().items())
@classmethod
def items_with_description(cls) -> Generator[Tuple[str, T, Optional[str]], None, None]:
"""Yield all items in the registry with their descriptions."""
yield from sorted((n, t, d) for (n, (t, d)) in cls._get_storage().items())
@classmethod
def add(cls, name: str, desc: Optional[str] = None) -> Callable[[R], R]:
"""Add a class to the registry."""
# Add the registry to the registry of registries
cls._add_to_registry_of_registries()
def _add(
inner_self: T,
inner_name: str = name,
inner_desc: Optional[str] = desc,
inner_cls: Type[BaseRegistry] = cls,
) -> T:
"""Add a tagger to the registry using tagger_name as the name."""
existing = inner_cls.get(inner_name, raise_on_missing=False)
if existing and existing != inner_self:
if inner_self.__module__ == "__main__":
return inner_self
raise ValueError(f"Tagger {inner_name} already exists")
inner_cls._get_storage()[inner_name] = (inner_self, inner_desc)
return inner_self
return _add # type: ignore
@classmethod
def remove(cls, name: str) -> bool:
"""Remove a tagger from the registry."""
if name in cls._get_storage():
cls._get_storage().pop(name)
return True
return False
@classmethod
def has(cls, name: str) -> bool:
"""Check if a tagger exists in the registry."""
return name in cls._get_storage()
@overload
@classmethod
def get(cls, name: str) -> T: ...
@overload
@classmethod
def get(cls, name: str, raise_on_missing: Literal[True]) -> T: ...
@overload
@classmethod
def get(cls, name: str, raise_on_missing: Literal[False]) -> Optional[T]: ...
@classmethod
def get(cls, name: str, raise_on_missing: bool = True) -> Optional[T]:
"""Get a tagger from the registry; raise ValueError if it doesn't exist."""
matches = [registered for registered in cls._get_storage() if re.match(registered, name)]
if len(matches) > 1:
raise ValueError(f"Multiple taggers match {name}: {', '.join(matches)}")
elif len(matches) == 0:
if raise_on_missing:
tagger_names = ", ".join([tn for tn, _ in cls.items()])
raise ValueError(f"Unknown tagger {name}; available taggers: {tagger_names}")
return None
else:
name = matches[0]
t, _ = cls._get_storage()[name]
return t
from typing import Type
from spacy.lang.en import English
from .registry import BaseRegistry
class SegmenterRegistry(BaseRegistry[Type["BaseSegmenter"]]):
"""A registry for segmenters."""
class BaseSegmenter:
def __init__(self, segmenter_name_or_path: str, *args, **kwargs):
super().__init__()
def segment(self, text: str) -> list[str]:
raise NotImplementedError()
@SegmenterRegistry.add("spacy")
class SpacySegmenter(BaseSegmenter):
def __init__(self, segmenter_name_or_path: str, *args, **kwargs):
assert segmenter_name_or_path == "spacy", "Only 'spacy' segmenter is supported"
self.nlp = English()
self.nlp.add_pipe("sentencizer")
def segment(self, text: str) -> list[str]:
return [sent.text_with_ws for sent in self.nlp(text).sents]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment