model_utils.py 2.59 KB
Newer Older
1
2
3
import time
import torch
from loguru import logger
4
import numpy as np
5
6
7
from magic_pdf.libs.clean_memory import clean_memory


8
9
def crop_img(input_res, input_np_img, crop_paste_x=0, crop_paste_y=0):

10
11
    crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
    crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
12
13

    # Calculate new dimensions
14
15
16
    crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
    crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2

17
18
19
20
21
22
23
24
25
26
27
28
    # Create a white background array
    return_image = np.ones((crop_new_height, crop_new_width, 3), dtype=np.uint8) * 255

    # Crop the original image using numpy slicing
    cropped_img = input_np_img[crop_ymin:crop_ymax, crop_xmin:crop_xmax]

    # Paste the cropped image onto the white background
    return_image[crop_paste_y:crop_paste_y + (crop_ymax - crop_ymin),
    crop_paste_x:crop_paste_x + (crop_xmax - crop_xmin)] = cropped_img

    return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width,
                   crop_new_height]
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    return return_image, return_list


# Select regions for OCR / formula regions / table regions
def get_res_list_from_layout_res(layout_res):
    ocr_res_list = []
    table_res_list = []
    single_page_mfdetrec_res = []
    for res in layout_res:
        if int(res['category_id']) in [13, 14]:
            single_page_mfdetrec_res.append({
                "bbox": [int(res['poly'][0]), int(res['poly'][1]),
                         int(res['poly'][4]), int(res['poly'][5])],
            })
        elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
            ocr_res_list.append(res)
        elif int(res['category_id']) in [5]:
            table_res_list.append(res)
    return ocr_res_list, table_res_list, single_page_mfdetrec_res


def clean_vram(device, vram_threshold=8):
51
    total_memory = get_vram(device)
52
    if total_memory and total_memory <= vram_threshold:
53
        gc_start = time.time()
54
        clean_memory(device)
55
56
57
58
59
        gc_time = round(time.time() - gc_start, 2)
        logger.info(f"gc time: {gc_time}")


def get_vram(device):
60
    if torch.cuda.is_available() and str(device).startswith("cuda"):
61
        total_memory = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)  # 将字节转换为 GB
62
        return total_memory
63
64
    elif str(device).startswith("npu"):
        import torch_npu
65
66
        if torch_npu.npu.is_available():
            total_memory = torch_npu.npu.get_device_properties(device).total_memory / (1024 ** 3)  # 转为 GB
67
68
69
            return total_memory
    else:
        return None