batch_build_dataset.py 4.49 KB
Newer Older
icecraft's avatar
icecraft committed
1
2
import concurrent.futures

icecraft's avatar
icecraft committed
3
import fitz
icecraft's avatar
icecraft committed
4

icecraft's avatar
icecraft committed
5
from magic_pdf.data.dataset import PymuDocDataset
icecraft's avatar
icecraft committed
6
7
from magic_pdf.data.utils import fitz_doc_to_image  # PyMuPDF

icecraft's avatar
icecraft committed
8
9

def partition_array_greedy(arr, k):
icecraft's avatar
icecraft committed
10
11
    """Partition an array into k parts using a simple greedy approach.

icecraft's avatar
icecraft committed
12
13
14
15
16
17
    Parameters:
    -----------
    arr : list
        The input array of integers
    k : int
        Number of partitions to create
icecraft's avatar
icecraft committed
18

icecraft's avatar
icecraft committed
19
20
21
22
23
24
25
    Returns:
    --------
    partitions : list of lists
        The k partitions of the array
    """
    # Handle edge cases
    if k <= 0:
icecraft's avatar
icecraft committed
26
        raise ValueError('k must be a positive integer')
icecraft's avatar
icecraft committed
27
28
29
30
31
32
    if k > len(arr):
        k = len(arr)  # Adjust k if it's too large
    if k == 1:
        return [list(range(len(arr)))]
    if k == len(arr):
        return [[i] for i in range(len(arr))]
icecraft's avatar
icecraft committed
33

icecraft's avatar
icecraft committed
34
35
    # Sort the array in descending order
    sorted_indices = sorted(range(len(arr)), key=lambda i: arr[i][1], reverse=True)
icecraft's avatar
icecraft committed
36

icecraft's avatar
icecraft committed
37
38
39
    # Initialize k empty partitions
    partitions = [[] for _ in range(k)]
    partition_sums = [0] * k
icecraft's avatar
icecraft committed
40

icecraft's avatar
icecraft committed
41
42
43
44
    # Assign each element to the partition with the smallest current sum
    for idx in sorted_indices:
        # Find the partition with the smallest sum
        min_sum_idx = partition_sums.index(min(partition_sums))
icecraft's avatar
icecraft committed
45

icecraft's avatar
icecraft committed
46
47
48
        # Add the element to this partition
        partitions[min_sum_idx].append(idx)  # Store the original index
        partition_sums[min_sum_idx] += arr[idx][1]
icecraft's avatar
icecraft committed
49

icecraft's avatar
icecraft committed
50
51
52
53
    return partitions


def process_pdf_batch(pdf_jobs, idx):
icecraft's avatar
icecraft committed
54
55
    """Process a batch of PDF pages using multiple threads.

icecraft's avatar
icecraft committed
56
57
58
59
60
61
62
63
64
65
    Parameters:
    -----------
    pdf_jobs : list of tuples
        List of (pdf_path, page_num) tuples
    output_dir : str or None
        Directory to save images to
    num_threads : int
        Number of threads to use
    **kwargs :
        Additional arguments for process_pdf_page
icecraft's avatar
icecraft committed
66

icecraft's avatar
icecraft committed
67
68
69
70
71
72
    Returns:
    --------
    images : list
        List of processed images
    """
    images = []
icecraft's avatar
icecraft committed
73

icecraft's avatar
icecraft committed
74
75
76
77
78
79
80
81
82
    for pdf_path, _ in pdf_jobs:
        doc = fitz.open(pdf_path)
        tmp = []
        for page_num in range(len(doc)):
            page = doc[page_num]
            tmp.append(fitz_doc_to_image(page))
        images.append(tmp)
    return (idx, images)

icecraft's avatar
icecraft committed
83

icecraft's avatar
icecraft committed
84
def batch_build_dataset(pdf_paths, k, lang=None):
icecraft's avatar
icecraft committed
85
86
87
    """Process multiple PDFs by partitioning them into k balanced parts and
    processing each part in parallel.

icecraft's avatar
icecraft committed
88
89
90
91
92
93
94
95
96
97
98
99
    Parameters:
    -----------
    pdf_paths : list
        List of paths to PDF files
    k : int
        Number of partitions to create
    output_dir : str or None
        Directory to save images to
    threads_per_worker : int
        Number of threads to use per worker
    **kwargs :
        Additional arguments for process_pdf_page
icecraft's avatar
icecraft committed
100

icecraft's avatar
icecraft committed
101
102
103
104
105
106
107
108
    Returns:
    --------
    all_images : list
        List of all processed images
    """
    # Get page counts for each PDF
    pdf_info = []
    total_pages = 0
icecraft's avatar
icecraft committed
109

icecraft's avatar
icecraft committed
110
111
112
113
114
115
116
117
    for pdf_path in pdf_paths:
        try:
            doc = fitz.open(pdf_path)
            num_pages = len(doc)
            pdf_info.append((pdf_path, num_pages))
            total_pages += num_pages
            doc.close()
        except Exception as e:
icecraft's avatar
icecraft committed
118
119
            print(f'Error opening {pdf_path}: {e}')

icecraft's avatar
icecraft committed
120
121
122
123
124
    # Partition the jobs based on page countEach job has 1 page
    partitions = partition_array_greedy(pdf_info, k)

    # Process each partition in parallel
    all_images_h = {}
icecraft's avatar
icecraft committed
125

icecraft's avatar
icecraft committed
126
127
128
129
130
131
    with concurrent.futures.ProcessPoolExecutor(max_workers=k) as executor:
        # Submit one task per partition
        futures = []
        for sn, partition in enumerate(partitions):
            # Get the jobs for this partition
            partition_jobs = [pdf_info[idx] for idx in partition]
icecraft's avatar
icecraft committed
132

icecraft's avatar
icecraft committed
133
134
135
136
137
138
139
140
141
142
143
144
145
            # Submit the task
            future = executor.submit(
                process_pdf_batch,
                partition_jobs,
                sn
            )
            futures.append(future)
        # Process results as they complete
        for i, future in enumerate(concurrent.futures.as_completed(futures)):
            try:
                idx, images = future.result()
                all_images_h[idx] = images
            except Exception as e:
icecraft's avatar
icecraft committed
146
                print(f'Error processing partition: {e}')
icecraft's avatar
icecraft committed
147
148
149
150
    results = [None] * len(pdf_paths)
    for i in range(len(partitions)):
        partition = partitions[i]
        for j in range(len(partition)):
icecraft's avatar
icecraft committed
151
            with open(pdf_info[partition[j]][0], 'rb') as f:
icecraft's avatar
icecraft committed
152
153
154
155
156
                pdf_bytes = f.read()
            dataset = PymuDocDataset(pdf_bytes, lang=lang)
            dataset.set_images(all_images_h[i][j])
            results[partition[j]] = dataset
    return results