run_olmocr_pipeline.py 3.03 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import asyncio
import logging
from dataclasses import dataclass
from typing import Optional

# Import necessary components from olmocr
from olmocr.pipeline import (
    MetricsKeeper,
    PageResult,
    WorkerTracker,
    process_page,
    sglang_server_host,
    sglang_server_ready
)

# Setup basic logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger("olmocr_runner")


# Basic configuration
@dataclass
class Args:
    model: str = "allenai/olmOCR-7B-0225-preview"
    model_chat_template: str = "qwen2-vl"
    model_max_context: int = 8192
    target_longest_image_dim: int = 1024
    target_anchor_text_len: int = 6000
    max_page_retries: int = 8
    max_page_error_rate: float = 0.004


async def run_olmocr_pipeline(pdf_path: str, page_num: int = 1) -> Optional[str]:
    """
    Process a single page of a PDF using the official olmocr pipeline's process_page function
    
    Args:
        pdf_path: Path to the PDF file
        page_num: Page number to process (1-indexed)
        
    Returns:
        The extracted text from the page or None if processing failed
    """
    # Ensure global variables are initialized
    global metrics, tracker
    if "metrics" not in globals() or metrics is None:
        metrics = MetricsKeeper(window=60 * 5)
    if "tracker" not in globals() or tracker is None:
        tracker = WorkerTracker()

    args = Args()
    semaphore = asyncio.Semaphore(1)
    worker_id = 0  # Using 0 as default worker ID

    # Ensure server is running
    _server_task = None
    try:
        await asyncio.wait_for(sglang_server_ready(), timeout=5)
        logger.info("Using existing sglang server")
    except Exception:
        logger.info("Starting new sglang server")
        _server_task = asyncio.create_task(sglang_server_host(args, semaphore))
        await sglang_server_ready()

    try:
        # Process the page using the pipeline's process_page function
        # Note: process_page expects both original path and local path
        # In our case, we're using the same path for both
        page_result: PageResult = await process_page(
            args=args,
            worker_id=worker_id,
            pdf_orig_path=pdf_path,
            pdf_local_path=pdf_path,
            page_num=page_num
        )
        
        # Return the natural text from the response
        if page_result and page_result.response:
            return page_result.response.natural_text
        return None

    except Exception as e:
        logger.error(f"Error processing page: {type(e).__name__} - {str(e)}")
        return None

    finally:
        # We leave the server running for potential reuse
        pass


async def main():
    # Example usage
    pdf_path = "your_pdf_path.pdf"
    page_num = 1
    
    result = await run_olmocr_pipeline(pdf_path, page_num)
    if result:
        print(f"Extracted text: {result[:200]}...")  # Print first 200 chars
    else:
        print("Failed to extract text from the page")


if __name__ == "__main__":
    asyncio.run(main())