"src/routes/vscode:/vscode.git/clone" did not exist on "0fdb346a318595eee8be4214e1034ffde33d0b8f"
Commit 89e60e48 authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
Pipeline #2484 canceled with stages
import glob
import io
import json
import os
import unittest
from pypdf import PdfReader
from olmocr.data.renderpdf import get_pdf_media_box_width_height
from olmocr.prompts.anchor import _linearize_pdf_report, _pdf_report, get_anchor_text
class AnchorTest(unittest.TestCase):
def testExtractText(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "some_ocr1.pdf")
reader = PdfReader(local_pdf_path)
page = reader.pages[0]
def visitor_body(text, cm, tm, font_dict, font_size):
print(repr(text), cm, tm, font_size)
def visitor_op(op, args, cm, tm):
# print(op, args, cm, tm)
pass
page.extract_text(visitor_text=visitor_body, visitor_operand_before=visitor_op)
def testAnchorBase(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "pdftotext_two_column_issue.pdf")
report = _pdf_report(local_pdf_path, 2)
print(report)
print(get_anchor_text(local_pdf_path, 2, pdf_engine="pdfreport"))
def testAnchorImage(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "some_ocr1.pdf")
report = _pdf_report(local_pdf_path, 1)
print(report)
print(get_anchor_text(local_pdf_path, 1, pdf_engine="pdfreport"))
def testSmallPage(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "small_page_size.pdf")
report = _pdf_report(local_pdf_path, 1)
print(report)
print(get_anchor_text(local_pdf_path, 1, pdf_engine="pdfreport"))
def testBadUTFSurrogatePairsGeneration(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "badlines.pdf")
anchor_text = get_anchor_text(local_pdf_path, 4, pdf_engine="pdfreport")
jsondata = json.dumps({"text": anchor_text})
import pyarrow as pa
import pyarrow.compute as pc
import pyarrow.json as paj
buffer = io.BytesIO(jsondata.encode("utf-8"))
paj.read_json(buffer, read_options=paj.ReadOptions(use_threads=False, block_size=len(jsondata)))
def testLargePromptHint1(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "large_prompt_hint1.pdf")
anchor_text = get_anchor_text(local_pdf_path, 4, pdf_engine="pdfreport")
print(anchor_text)
print(len(anchor_text))
self.assertLessEqual(len(anchor_text), 1000)
def testLargePromptHint2(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "large_prompt_hint2.pdf")
anchor_text = get_anchor_text(local_pdf_path, 2, pdf_engine="pdfreport")
print(anchor_text)
print(len(anchor_text))
self.assertLessEqual(len(anchor_text), 4000)
def testLargePromptHint3(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "large_prompt_hint3.pdf")
anchor_text = get_anchor_text(local_pdf_path, 2, pdf_engine="pdfreport")
print(anchor_text)
print(len(anchor_text))
self.assertLessEqual(len(anchor_text), 4000)
def testNewsPaperPromptHint(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "newspaper.pdf")
anchor_text = get_anchor_text(local_pdf_path, 1, pdf_engine="pdfreport")
print(anchor_text)
print(len(anchor_text))
self.assertLessEqual(len(anchor_text), 4000)
def testTobaccoPaperMissingParagraphs(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "tobacco_missed_tokens_pg1.pdf")
anchor_text = get_anchor_text(local_pdf_path, 1, pdf_engine="pdfreport")
print(anchor_text)
print(len(anchor_text))
self.assertLessEqual(len(anchor_text), 4000)
def testAnchorOtherLengths(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "tobacco_missed_tokens_pg1.pdf")
anchor_text = get_anchor_text(local_pdf_path, 1, pdf_engine="pdfreport", target_length=2000)
print(anchor_text)
print(len(anchor_text))
self.assertLessEqual(len(anchor_text), 2000)
anchor_text = get_anchor_text(local_pdf_path, 1, pdf_engine="pdfreport", target_length=6000)
print(anchor_text)
print(len(anchor_text))
self.assertLessEqual(len(anchor_text), 6000)
def testFailingAnchor(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "failing_anchor_pg4.pdf")
anchor_text = get_anchor_text(local_pdf_path, 4, pdf_engine="pdfreport")
print(anchor_text)
print(len(anchor_text))
self.assertLessEqual(len(anchor_text), 4000)
def testEmptyAnchor(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "tobacco_missed_tokens_pg1.pdf")
anchor_text = get_anchor_text(local_pdf_path, 1, pdf_engine="pdfreport", target_length=0)
self.assertEqual(anchor_text.strip(), "Page dimensions: 612.0x792.0")
def testCannotLoad(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "load_v_error.pdf")
reader = PdfReader(local_pdf_path)
page = 5
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport", target_length=6000)
print(anchor_text)
print(len(anchor_text))
self.assertLessEqual(len(anchor_text), 6000)
@unittest.skip("TODO, this unit test still fails, the map text is too large.")
def testExcessiveMapAnchor(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "map1.pdf")
anchor_text = get_anchor_text(local_pdf_path, 1, pdf_engine="pdfreport", target_length=6000)
print(anchor_text)
print(len(anchor_text))
self.assertLessEqual(len(anchor_text), 4000)
def testKyleOnePageAnchors1(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "dolma-page-1.pdf")
anchor_text = get_anchor_text(local_pdf_path, 1, pdf_engine="pdfreport", target_length=6000)
print(anchor_text)
print(len(anchor_text))
self.assertLessEqual(len(anchor_text), 6000)
def testKyleOnePageAnchors2(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "olmo-page-1.pdf")
anchor_text = get_anchor_text(local_pdf_path, 1, pdf_engine="pdfreport", target_length=6000)
print(anchor_text)
print(len(anchor_text))
self.assertLessEqual(len(anchor_text), 6000)
class BuildSilverTest(unittest.TestCase):
def testSmallPage(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "small_page_size.pdf")
from olmocr.data.buildsilver import build_page_query
result = build_page_query(local_pdf_path, "s3://test.pdf", 1)
from olmocr.data.renderpdf import get_png_dimensions_from_base64
base64data = result["body"]["messages"][0]["content"][1]["image_url"]["url"]
if base64data.startswith("data:image/png;base64,"):
base64data = base64data[22:]
width, height = get_png_dimensions_from_base64(base64data)
print(width, height)
assert max(width, height) == 2048
class TestRenderPdf(unittest.TestCase):
def testFastMediaBoxMatchesPyPdf(self):
for file in glob.glob(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "*.pdf")):
reader = PdfReader(file)
print("checking", file)
for page_num in range(1, len(reader.pages) + 1):
w1, h1 = get_pdf_media_box_width_height(file, page_num)
pypdfpage = reader.pages[page_num - 1]
self.assertAlmostEqual(w1, pypdfpage.mediabox.width, places=3)
self.assertAlmostEqual(h1, pypdfpage.mediabox.height, places=3)
class TestOutputSamplePage(unittest.TestCase):
def testTobaccoPaper(self):
local_pdf_path = os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "tobacco_missed_tokens_pg1.pdf")
anchor_text = get_anchor_text(local_pdf_path, 1, "pdfreport", target_length=6000)
print("")
print(anchor_text)
print("")
import unittest
from functools import partial
import pytest
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor
from olmocr.train.dataloader import (
build_finetuning_dataset,
extract_openai_batch_response,
list_dataset_files,
load_jsonl_into_ds,
)
from olmocr.train.dataprep import batch_prepare_data_for_qwen2_training
@pytest.mark.nonci
class TestBatchQueryResponseDataset(unittest.TestCase):
def testLoadS3(self):
ds = load_jsonl_into_ds("s3://ai2-oe-data/jakep/openai_batch_data_v2/*.jsonl", first_n_files=3)
print(f"Loaded {len(ds)} entries")
print(ds)
print(ds["train"])
def testFinetuningDS(self):
ds = build_finetuning_dataset(
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
)
print(ds)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
ds = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor, target_longest_image_dim=1024, target_anchor_text_len=6000))
print(ds[0])
def testPlotSequenceLengthHistogram(self):
import plotly.express as px
ds = build_finetuning_dataset(
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
ds = ds.with_transform(partial(batch_prepare_data_for_qwen2_training, processor=processor, target_longest_image_dim=1024, target_anchor_text_len=6000))
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct")
initial_len = len(ds)
train_dataloader = DataLoader(ds, batch_size=1, num_workers=30, shuffle=False)
max_seen_len = 0
steps = 0
sequence_lengths = [] # List to store sequence lengths
for entry in tqdm(train_dataloader):
num_input_tokens = entry["input_ids"].shape[1]
max_seen_len = max(max_seen_len, num_input_tokens)
sequence_lengths.append(num_input_tokens) # Collecting sequence lengths
if steps % 100 == 0:
print(f"Max input len {max_seen_len}")
steps += 1
# model.forward(**{k: v.to("cuda:0") for (k,v) in entry.items()})
print(f"Max input len {max_seen_len}")
print(f"Total elements before filtering: {initial_len}")
print(f"Total elements after filtering: {steps}")
# Plotting the histogram using Plotly
fig = px.histogram(
sequence_lengths, nbins=100, title="Distribution of Input Sequence Lengths", labels={"value": "Sequence Length", "count": "Frequency"}
)
fig.write_image("sequence_lengths_histogram.png")
import base64
import os
import random
import re
import unittest
from io import BytesIO
from unittest.mock import patch
import numpy as np
import pytest
import requests
import torch
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor
from olmocr.train.core.config import DataConfig, SourceConfig, TrainConfig
from olmocr.train.dataloader import build_finetuning_dataset
from olmocr.train.dataprep import (
batch_prepare_data_for_molmo_training,
build_finetuning_prompt,
prepare_data_for_molmo_training,
prepare_data_for_qwen2_training,
)
from olmocr.train.utils import make_dataset
@pytest.mark.nonci
class TestDataprep(unittest.TestCase):
def testFullDataloader(self):
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
config = TrainConfig(
train_data=DataConfig(
seed=42,
sources=[
SourceConfig(
name="eval_test",
target_longest_image_dim=1024,
target_anchor_text_len=6000,
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
)
],
),
valid_data=DataConfig(
seed=42,
sources=[
SourceConfig(
name="eval_test",
target_longest_image_dim=1024,
target_anchor_text_len=6000,
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
)
],
),
)
train_dataset, valid_dataset = make_dataset(config, processor)
im_end_token_ids = processor.tokenizer("<|im_end|>\n", add_special_tokens=False)["input_ids"]
# train_dataloader = DataLoader(train_dataset, batch_size=1, num_workers=4, shuffle=False)
for entry in train_dataset:
print({x: (y.shape, y.dtype) for (x, y) in entry.items()})
self.assertEqual(entry["input_ids"].dtype, np.int64)
self.assertEqual(entry["attention_mask"].dtype, np.int64)
self.assertEqual(entry["labels"].dtype, np.int64)
self.assertEqual(entry["pixel_values"].dtype, np.float32)
self.assertEqual(entry["image_grid_thw"].dtype, np.int64)
# Extract input_ids and labels
input_ids = entry["input_ids"]
labels = entry["labels"]
# 1. Verify that the last token is the end token
# Ensure input_ids is long enough
self.assertTrue(len(input_ids) >= len(im_end_token_ids), "Input IDs are shorter than the end token sequence.")
# Compare the last tokens of input_ids with im_end_token_ids
self.assertEqual(
input_ids[-len(im_end_token_ids) :].tolist(), im_end_token_ids, "The last tokens of input_ids do not match the end token sequence."
)
# 2. Ensure labels are masked correctly and match input_ids after the mask
# Find where labels start being non-masked (-100 is the mask value)
label_indices = np.where(labels != -100)[0]
# There should be at least one label that is not masked
self.assertTrue(len(label_indices) > 0, "No unmasked labels found in labels array.")
first_label_index = label_indices[0]
# Ensure the masked portion is at least 10 tokens long
self.assertTrue(first_label_index >= 10, "Masked portion of labels is less than 10 tokens.")
# Check that all values before first_label_index are -100
self.assertTrue(np.all(labels[:first_label_index] == -100), "Labels before the first unmasked token are not all -100.")
# Check that the unmasked labels match the corresponding input_ids
self.assertTrue(
np.array_equal(labels[first_label_index:], input_ids[first_label_index:]), "Unmasked labels do not match the corresponding input_ids."
)
# Optionally, verify that the last unmasked tokens in labels match the end token IDs
unmasked_labels = labels[labels != -100]
self.assertEqual(
unmasked_labels[-len(im_end_token_ids) :].tolist(), im_end_token_ids, "The last unmasked tokens in labels do not match the end token sequence."
)
def testListTargetAnchorLength(self):
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
config = TrainConfig(
train_data=DataConfig(
seed=42,
sources=[
SourceConfig(
name="eval_test",
target_longest_image_dim=1024,
target_anchor_text_len=[0, 6000], # Only 0 and 6000
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
)
],
),
valid_data=DataConfig(
seed=42,
sources=[
SourceConfig(
name="eval_test",
target_longest_image_dim=1024,
target_anchor_text_len=[0, 6000], # Only 0 and 6000
response_glob_path="s3://ai2-oe-data/jakep/pdfdata/openai_batch_done_v5_1_eval/*.json",
)
],
),
)
# Set a fixed seed for reproducibility
random.seed(42)
train_dataset, valid_dataset = make_dataset(config, processor)
zero_count = 0
full_count = 0
num_iterations = 100
for i in range(num_iterations):
entry = train_dataset[0] # Get the first entry repeatedly
# Basic type checks
self.assertEqual(entry["input_ids"].dtype, np.int64)
self.assertEqual(entry["attention_mask"].dtype, np.int64)
self.assertEqual(entry["labels"].dtype, np.int64)
self.assertEqual(entry["pixel_values"].dtype, np.float32)
self.assertEqual(entry["image_grid_thw"].dtype, np.int64)
# Get the input text before the response
# Find where labels start being non-masked (-100 is the mask value)
label_indices = np.where(entry["labels"] != -100)[0]
first_label_index = label_indices[0] if len(label_indices) > 0 else len(entry["input_ids"])
# Decode the input portion to check the prompt
input_text = processor.tokenizer.decode(entry["input_ids"][:first_label_index])
pattern = r"RAW_TEXT_START\nPage dimensions: (\d+\.?\d*)x(\d+\.?\d*)\s+RAW_TEXT_END"
match = re.search(pattern, input_text, flags=re.MULTILINE)
if match:
zero_count += 1
else:
full_count += 1
# Verify the distribution: should be roughly 10% zero-length, 90% full-length
zero_ratio = zero_count / num_iterations
full_ratio = full_count / num_iterations
print(zero_count, full_count)
self.assertTrue(0.45 <= zero_ratio <= 0.55, f"Expected zero-length ratio around 0.5, got {zero_ratio:.2f}")
self.assertTrue(0.45 <= full_ratio <= 0.55, f"Expected full-length ratio around 0.5, got {full_ratio:.2f}")
# Verify total adds up to 100%
self.assertEqual(zero_count + full_count, num_iterations, "Total count should equal number of iterations")
@pytest.mark.nonci
class TestMolmoDataPrep(unittest.TestCase):
def testMolmoDefaultSetup(self):
processor = AutoProcessor.from_pretrained("allenai/Molmo-7B-O-0924", trust_remote_code=True, torch_dtype="auto", device_map="auto")
# process the image and text
inputs = processor.process(images=[Image.open(requests.get("https://picsum.photos/id/237/536/354", stream=True).raw)], text="Describe this image.")
print(inputs.keys())
print(inputs["input_ids"])
print(processor.tokenizer.batch_decode(inputs["input_ids"]))
labels = processor.tokenizer("This is a page of the pdf that's the text", return_tensors="np")
print(labels)
print(processor.tokenizer.batch_decode(labels["input_ids"]))
def testMolmoDataPrep(self):
# Initialize the processor for Molmo
processor = AutoProcessor.from_pretrained("allenai/Molmo-7B-O-0924", trust_remote_code=True, torch_dtype="auto", device_map="auto")
# Create a mock example
example = {
"local_pdf_path": os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf"),
"page_num": 1,
"response": "This is the response text.",
}
# Define target dimensions and anchor text lengths
target_longest_image_dim = [1024]
target_anchor_text_len = [0, 6000]
# Set a fixed seed for reproducibility
random.seed(42)
# Mock the functions that require actual PDF files
with (
patch("olmocr.prompts.anchor.get_anchor_text") as mock_get_anchor_text,
patch("olmocr.data.renderpdf.render_pdf_to_base64png") as mock_render_pdf_to_base64png,
):
# Set return values for the mocked functions
mock_get_anchor_text.return_value = "This is the anchor text."
# Create a red square image and encode it in base64
img = Image.new("RGB", (100, 100), color="red")
buffered = BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
mock_render_pdf_to_base64png.return_value = img_str
# Process the example using the prepare_data_for_molmo_training function
processed_example = prepare_data_for_molmo_training(
example, processor, target_longest_image_dim=target_longest_image_dim, target_anchor_text_len=target_anchor_text_len
)
# Basic type checks
self.assertIsInstance(processed_example["input_ids"], torch.Tensor, "input_ids should be a torch.Tensor")
self.assertIsInstance(processed_example["attention_mask"], torch.Tensor, "attention_mask should be a torch.Tensor")
self.assertIsInstance(processed_example["labels"], torch.Tensor, "labels should be a torch.Tensor")
self.assertIsInstance(processed_example["images"], torch.Tensor, "images should be a torch.Tensor")
self.assertIsInstance(processed_example["image_input_idx"], torch.Tensor, "image_input_idx should be a torch.Tensor")
self.assertIsInstance(processed_example["image_masks"], torch.Tensor, "image_masks should be a torch.Tensor")
# Check tensor dimensions
self.assertEqual(len(processed_example["input_ids"].shape), 1, "input_ids should be a 1D tensor")
self.assertEqual(
processed_example["input_ids"].shape, processed_example["attention_mask"].shape, "input_ids and attention_mask should have the same shape"
)
self.assertEqual(processed_example["input_ids"].shape, processed_example["labels"].shape, "input_ids and labels should have the same shape")
# Verify label masking
# Find where labels start being non-masked (-100 is the mask value)
label_indices = torch.where(processed_example["labels"] != -100)[0]
# There should be at least one label that is not masked
self.assertTrue(len(label_indices) > 0, "No unmasked labels found in labels array.")
first_label_index = label_indices[0]
# Ensure the masked portion is reasonable (at least a few tokens long)
self.assertTrue(first_label_index >= 5, "Masked portion of labels is too short")
# Check that all values before first_label_index are -100
self.assertTrue(torch.all(processed_example["labels"][:first_label_index] == -100), "Labels before the first unmasked token are not all -100.")
# Verify attention mask
self.assertTrue(torch.all(processed_example["attention_mask"] == 1), "All attention mask values should be 1")
# Verify image input indices
self.assertTrue(
torch.all(processed_example["image_input_idx"] < len(processed_example["input_ids"])),
"Image input indices should be within the range of input_ids length",
)
# Decode and verify content structure
decoded_input = processor.tokenizer.decode(processed_example["input_ids"])
self.assertIn("This is the anchor text", decoded_input, "Anchor text should be present in the decoded input")
# Verify that unmasked labels decode to the response text
unmasked_labels = processed_example["labels"][processed_example["labels"] != -100]
decoded_labels = processor.tokenizer.decode(unmasked_labels)
self.assertIn("This is the response text", decoded_labels, "Response text should be present in the decoded labels")
def testBatchMolmoDataPrep(self):
"""Test the batch preparation function for Molmo"""
processor = AutoProcessor.from_pretrained("allenai/Molmo-7B-O-0924", trust_remote_code=True, torch_dtype="auto", device_map="auto")
# Create a mock batch
batch = {
"local_pdf_path": [os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "edgar.pdf")],
"page_num": [1],
"response": ["This is the response text."],
}
target_longest_image_dim = [1024]
target_anchor_text_len = [0, 6000]
# Mock the necessary functions
with (
patch("olmocr.prompts.anchor.get_anchor_text") as mock_get_anchor_text,
patch("olmocr.data.renderpdf.render_pdf_to_base64png") as mock_render_pdf_to_base64png,
):
mock_get_anchor_text.return_value = "This is the anchor text."
img = Image.new("RGB", (100, 100), color="red")
buffered = BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
mock_render_pdf_to_base64png.return_value = img_str
# Process the batch
processed_batch = batch_prepare_data_for_molmo_training(
batch, processor, target_longest_image_dim=target_longest_image_dim, target_anchor_text_len=target_anchor_text_len
)
# Verify batch structure
self.assertEqual(len(processed_batch["input_ids"]), 1, "Batch size should be 1")
self.assertEqual(len(processed_batch["attention_mask"]), 1, "Batch size should be 1")
self.assertEqual(len(processed_batch["labels"]), 1, "Batch size should be 1")
self.assertEqual(len(processed_batch["images"]), 1, "Batch size should be 1")
self.assertEqual(len(processed_batch["image_input_idx"]), 1, "Batch size should be 1")
self.assertEqual(len(processed_batch["image_masks"]), 1, "Batch size should be 1")
# Verify the first item in the batch
first_item = {k: v[0] for k, v in processed_batch.items()}
self.assertIsInstance(first_item["input_ids"], torch.Tensor, "Batch item should contain torch.Tensor")
self.assertTrue(torch.all(first_item["attention_mask"] == 1), "All attention mask values should be 1")
import os
import unittest
from pypdf import PdfReader
from olmocr.filter import PdfFilter
class PdfFilterTest(unittest.TestCase):
def testFormLaterPages(self):
self.filter = PdfFilter(apply_form_check=True)
self.assertTrue(self.filter.filter_out_pdf(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "form_on_later_pages.pdf")))
self.filter = PdfFilter(apply_form_check=False)
self.assertFalse(self.filter.filter_out_pdf(os.path.join(os.path.dirname(__file__), "gnarly_pdfs", "form_on_later_pages.pdf")))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment