"vscode:/vscode.git/clone" did not exist on "e9e9bdb8d904f009e8b1e54af9f77624d481cfb2"
dataset.py 12 KB
Newer Older
1
import os
2
from abc import ABC, abstractmethod
3
from typing import Callable, Iterator
4
5

import fitz
6
from loguru import logger
7
8
9
10

from magic_pdf.config.enums import SupportedPdfParseMethod
from magic_pdf.data.schemas import PageInfo
from magic_pdf.data.utils import fitz_doc_to_image
11
from magic_pdf.filter import classify
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33


class PageableData(ABC):
    @abstractmethod
    def get_image(self) -> dict:
        """Transform data to image."""
        pass

    @abstractmethod
    def get_doc(self) -> fitz.Page:
        """Get the pymudoc page."""
        pass

    @abstractmethod
    def get_page_info(self) -> PageInfo:
        """Get the page info of the page.

        Returns:
            PageInfo: the page info of this page
        """
        pass

34
35
    @abstractmethod
    def draw_rect(self, rect_coords, color, fill, fill_opacity, width, overlay):
icecraft's avatar
icecraft committed
36
37
38
39
        """draw rectangle.

        Args:
            rect_coords (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
xu rui's avatar
xu rui committed
40
            color (list[float] | None): three element tuple which describe the RGB of the board line, None means no board line
icecraft's avatar
icecraft committed
41
42
43
44
45
            fill (list[float] | None): fill the board with RGB, None means will not fill with color
            fill_opacity (float): opacity of the fill, range from [0, 1]
            width (float): the width of board
            overlay (bool): fill the color in foreground or background. True means fill in background.
        """
46
47
48
49
        pass

    @abstractmethod
    def insert_text(self, coord, content, fontsize, color):
icecraft's avatar
icecraft committed
50
51
52
53
54
55
        """insert text.

        Args:
            coord (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
            content (str): the text content
            fontsize (int): font size of the text
xu rui's avatar
xu rui committed
56
            color (list[float] | None):  three element tuple which describe the RGB of the board line, None will use the default font color!
icecraft's avatar
icecraft committed
57
        """
58
59
        pass

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97

class Dataset(ABC):
    @abstractmethod
    def __len__(self) -> int:
        """The length of the dataset."""
        pass

    @abstractmethod
    def __iter__(self) -> Iterator[PageableData]:
        """Yield the page data."""
        pass

    @abstractmethod
    def supported_methods(self) -> list[SupportedPdfParseMethod]:
        """The methods that this dataset support.

        Returns:
            list[SupportedPdfParseMethod]: The supported methods, Valid methods are: OCR, TXT
        """
        pass

    @abstractmethod
    def data_bits(self) -> bytes:
        """The bits used to create this dataset."""
        pass

    @abstractmethod
    def get_page(self, page_id: int) -> PageableData:
        """Get the page indexed by page_id.

        Args:
            page_id (int): the index of the page

        Returns:
            PageableData: the page doc object
        """
        pass

98
99
    @abstractmethod
    def dump_to_file(self, file_path: str):
icecraft's avatar
icecraft committed
100
        """Dump the file.
xu rui's avatar
xu rui committed
101

icecraft's avatar
icecraft committed
102
103
        Args:
            file_path (str): the file path
xu rui's avatar
xu rui committed
104
        """
105
106
107
108
        pass

    @abstractmethod
    def apply(self, proc: Callable, *args, **kwargs):
xu rui's avatar
xu rui committed
109
110
111
112
        """Apply callable method which.

        Args:
            proc (Callable): invoke proc as follows:
xu rui's avatar
xu rui committed
113
                proc(self, *args, **kwargs)
xu rui's avatar
xu rui committed
114
115
116
117

        Returns:
            Any: return the result generated by proc
        """
118
119
120
121
        pass

    @abstractmethod
    def classify(self) -> SupportedPdfParseMethod:
icecraft's avatar
icecraft committed
122
        """classify the dataset.
xu rui's avatar
xu rui committed
123
124
125
126
127
128
129
130

        Returns:
            SupportedPdfParseMethod: _description_
        """
        pass

    @abstractmethod
    def clone(self):
icecraft's avatar
icecraft committed
131
        """clone this dataset."""
132
133
        pass

134
135

class PymuDocDataset(Dataset):
136
    def __init__(self, bits: bytes, lang=None):
137
138
139
140
141
        """Initialize the dataset, which wraps the pymudoc documents.

        Args:
            bits (bytes): the bytes of the pdf
        """
142
143
        self._raw_fitz = fitz.open('pdf', bits)
        self._records = [Doc(v) for v in self._raw_fitz]
144
145
        self._data_bits = bits
        self._raw_data = bits
146
        self._classify_result = None
147

148
149
150
        if lang == '':
            self._lang = None
        elif lang == 'auto':
icecraft's avatar
icecraft committed
151
152
            from magic_pdf.model.sub_modules.language_detection.utils import \
                auto_detect_lang
153
            self._lang = auto_detect_lang(bits)
icecraft's avatar
icecraft committed
154
            logger.info(f'lang: {lang}, detect_lang: {self._lang}')
155
156
        else:
            self._lang = lang
icecraft's avatar
icecraft committed
157
            logger.info(f'lang: {lang}')
icecraft's avatar
icecraft committed
158

159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
    def __len__(self) -> int:
        """The page number of the pdf."""
        return len(self._records)

    def __iter__(self) -> Iterator[PageableData]:
        """Yield the page doc object."""
        return iter(self._records)

    def supported_methods(self) -> list[SupportedPdfParseMethod]:
        """The method supported by this dataset.

        Returns:
            list[SupportedPdfParseMethod]: the supported methods
        """
        return [SupportedPdfParseMethod.OCR, SupportedPdfParseMethod.TXT]

    def data_bits(self) -> bytes:
        """The pdf bits used to create this dataset."""
        return self._data_bits

    def get_page(self, page_id: int) -> PageableData:
        """The page doc object.

        Args:
            page_id (int): the page doc index

        Returns:
            PageableData: the page doc object
        """
        return self._records[page_id]

190
    def dump_to_file(self, file_path: str):
icecraft's avatar
icecraft committed
191
        """Dump the file.
xu rui's avatar
xu rui committed
192

icecraft's avatar
icecraft committed
193
194
        Args:
            file_path (str): the file path
xu rui's avatar
xu rui committed
195
        """
icecraft's avatar
icecraft committed
196

197
198
199
200
201
202
        dir_name = os.path.dirname(file_path)
        if dir_name not in ('', '.', '..'):
            os.makedirs(dir_name, exist_ok=True)
        self._raw_fitz.save(file_path)

    def apply(self, proc: Callable, *args, **kwargs):
xu rui's avatar
xu rui committed
203
204
205
206
207
208
209
210
211
        """Apply callable method which.

        Args:
            proc (Callable): invoke proc as follows:
                proc(dataset, *args, **kwargs)

        Returns:
            Any: return the result generated by proc
        """
212
213
        if 'lang' in kwargs and self._lang is not None:
            kwargs['lang'] = self._lang
xu rui's avatar
xu rui committed
214
        return proc(self, *args, **kwargs)
215
216

    def classify(self) -> SupportedPdfParseMethod:
icecraft's avatar
icecraft committed
217
        """classify the dataset.
xu rui's avatar
xu rui committed
218
219
220
221

        Returns:
            SupportedPdfParseMethod: _description_
        """
222
223
224
        if self._classify_result is None:
            self._classify_result = classify(self._data_bits)
        return self._classify_result
225

xu rui's avatar
xu rui committed
226
    def clone(self):
icecraft's avatar
icecraft committed
227
        """clone this dataset."""
xu rui's avatar
xu rui committed
228
229
        return PymuDocDataset(self._raw_data)

icecraft's avatar
icecraft committed
230
231
232
    def set_images(self, images):
        for i in range(len(self._records)):
            self._records[i].set_image(images[i])
233
234

class ImageDataset(Dataset):
icecraft's avatar
icecraft committed
235
    def __init__(self, bits: bytes, lang=None):
236
237
238
239
240
241
        """Initialize the dataset, which wraps the pymudoc documents.

        Args:
            bits (bytes): the bytes of the photo which will be converted to pdf first. then converted to pymudoc.
        """
        pdf_bytes = fitz.open(stream=bits).convert_to_pdf()
242
243
        self._raw_fitz = fitz.open('pdf', pdf_bytes)
        self._records = [Doc(v) for v in self._raw_fitz]
244
245
246
        self._raw_data = bits
        self._data_bits = pdf_bytes

icecraft's avatar
icecraft committed
247
248
249
250
251
252
253
254
255
256
257
        if lang == '':
            self._lang = None
        elif lang == 'auto':
            from magic_pdf.model.sub_modules.language_detection.utils import \
                auto_detect_lang
            self._lang = auto_detect_lang(bits)
            logger.info(f'lang: {lang}, detect_lang: {self._lang}')
        else:
            self._lang = lang
            logger.info(f'lang: {lang}')

258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
    def __len__(self) -> int:
        """The length of the dataset."""
        return len(self._records)

    def __iter__(self) -> Iterator[PageableData]:
        """Yield the page object."""
        return iter(self._records)

    def supported_methods(self):
        """The method supported by this dataset.

        Returns:
            list[SupportedPdfParseMethod]: the supported methods
        """
        return [SupportedPdfParseMethod.OCR]

    def data_bits(self) -> bytes:
        """The pdf bits used to create this dataset."""
        return self._data_bits

    def get_page(self, page_id: int) -> PageableData:
        """The page doc object.

        Args:
            page_id (int): the page doc index

        Returns:
            PageableData: the page doc object
        """
        return self._records[page_id]

289
    def dump_to_file(self, file_path: str):
icecraft's avatar
icecraft committed
290
        """Dump the file.
xu rui's avatar
xu rui committed
291

icecraft's avatar
icecraft committed
292
293
        Args:
            file_path (str): the file path
xu rui's avatar
xu rui committed
294
        """
295
296
297
298
299
300
        dir_name = os.path.dirname(file_path)
        if dir_name not in ('', '.', '..'):
            os.makedirs(dir_name, exist_ok=True)
        self._raw_fitz.save(file_path)

    def apply(self, proc: Callable, *args, **kwargs):
xu rui's avatar
xu rui committed
301
302
303
304
305
306
307
308
309
        """Apply callable method which.

        Args:
            proc (Callable): invoke proc as follows:
                proc(dataset, *args, **kwargs)

        Returns:
            Any: return the result generated by proc
        """
310
311
312
        return proc(self, *args, **kwargs)

    def classify(self) -> SupportedPdfParseMethod:
icecraft's avatar
icecraft committed
313
        """classify the dataset.
xu rui's avatar
xu rui committed
314
315
316
317

        Returns:
            SupportedPdfParseMethod: _description_
        """
318
319
        return SupportedPdfParseMethod.OCR

xu rui's avatar
xu rui committed
320
    def clone(self):
icecraft's avatar
icecraft committed
321
        """clone this dataset."""
xu rui's avatar
xu rui committed
322
        return ImageDataset(self._raw_data)
icecraft's avatar
icecraft committed
323

icecraft's avatar
icecraft committed
324
325
326
    def set_images(self, images):
        for i in range(len(self._records)):
            self._records[i].set_image(images[i])
327
328
329

class Doc(PageableData):
    """Initialized with pymudoc object."""
330

331
332
    def __init__(self, doc: fitz.Page):
        self._doc = doc
icecraft's avatar
icecraft committed
333
        self._img = None
334
335

    def get_image(self):
xu rui's avatar
xu rui committed
336
        """Return the image info.
337
338
339
340
341
342
343
344

        Returns:
            dict: {
                img: np.ndarray,
                width: int,
                height: int
            }
        """
icecraft's avatar
icecraft committed
345
346
347
348
349
350
351
352
353
354
355
        if self._img is None:
            self._img = fitz_doc_to_image(self._doc)
        return self._img

    def set_image(self, img):
        """
        Args:
            img (np.ndarray): the image
        """
        if self._img is None:
            self._img = img
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377

    def get_doc(self) -> fitz.Page:
        """Get the pymudoc object.

        Returns:
            fitz.Page: the pymudoc object
        """
        return self._doc

    def get_page_info(self) -> PageInfo:
        """Get the page info of the page.

        Returns:
            PageInfo: the page info of this page
        """
        page_w = self._doc.rect.width
        page_h = self._doc.rect.height
        return PageInfo(w=page_w, h=page_h)

    def __getattr__(self, name):
        if hasattr(self._doc, name):
            return getattr(self._doc, name)
378
379

    def draw_rect(self, rect_coords, color, fill, fill_opacity, width, overlay):
icecraft's avatar
icecraft committed
380
381
382
383
        """draw rectangle.

        Args:
            rect_coords (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
xu rui's avatar
xu rui committed
384
            color (list[float] | None): three element tuple which describe the RGB of the board line, None means no board line
icecraft's avatar
icecraft committed
385
386
387
388
389
            fill (list[float] | None): fill the board with RGB, None means will not fill with color
            fill_opacity (float): opacity of the fill, range from [0, 1]
            width (float): the width of board
            overlay (bool): fill the color in foreground or background. True means fill in background.
        """
390
391
392
393
394
395
396
397
398
399
        self._doc.draw_rect(
            rect_coords,
            color=color,
            fill=fill,
            fill_opacity=fill_opacity,
            width=width,
            overlay=overlay,
        )

    def insert_text(self, coord, content, fontsize, color):
icecraft's avatar
icecraft committed
400
401
402
403
404
405
        """insert text.

        Args:
            coord (list[float]): four elements array contain the top-left and bottom-right coordinates, [x0, y0, x1, y1]
            content (str): the text content
            fontsize (int): font size of the text
xu rui's avatar
xu rui committed
406
            color (list[float] | None):  three element tuple which describe the RGB of the board line, None will use the default font color!
icecraft's avatar
icecraft committed
407
        """
408
        self._doc.insert_text(coord, content, fontsize=fontsize, color=color)