"src/turbomind/models/llama/BlockManager.h" did not exist on "06125966d7054a53458086f342734ea01dc2faf4"
demo.py 2.88 KB
Newer Older
mashun1's avatar
mashun1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from util.utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_yolo_model
import torch
from ultralytics import YOLO
from PIL import Image

if __name__ == "__main__":
    
    from argparse import ArgumentParser
    
    parser = ArgumentParser()
    
    parser.add_argument("--img_path", type=str, required=True)
    
    args = parser.parse_args()
    
    image_path = args.img_path
    
    device = 'cuda'
    model_path='weights/OmniParser-v2/icon_detect/model.pt'

    som_model = get_yolo_model(model_path)

    som_model.to(device)
    print('model to {}'.format(device))

    # two choices for caption model: fine-tuned blip2 or florence2
    import importlib
    # import util.utils
    # importlib.reload(utils)
    from util.utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_yolo_model
    caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/OmniParser-v2/icon_caption", device=device)

    # reload utils
    import importlib
    import utils
    importlib.reload(utils)
    # from utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_yolo_model

    # image_path = 'imgs/google_page.png'
    # image_path = 'imgs/windows_home.png'
    # image_path = 'imgs/windows_multitab.png'
    # image_path = 'imgs/omni3.jpg'
    # image_path = 'imgs/ios.png'
    # image_path = 'imgs/word.png'
    # image_path = 'imgs/excel2.png'

    image = Image.open(image_path)
    image_rgb = image.convert('RGB')
    print('image size:', image.size)

    box_overlay_ratio = max(image.size) / 3200
    draw_bbox_config = {
        'text_scale': 0.8 * box_overlay_ratio,
        'text_thickness': max(int(2 * box_overlay_ratio), 1),
        'text_padding': max(int(3 * box_overlay_ratio), 1),
        'thickness': max(int(3 * box_overlay_ratio), 1),
    }
    BOX_TRESHOLD = 0.05

    import time
    start = time.time()
    ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9}, use_paddleocr=True)
    text, ocr_bbox = ocr_bbox_rslt
    cur_time_ocr = time.time() 

    dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128)
    cur_time_caption = time.time() 

    # plot dino_labled_img it is in base64
    import base64
    import matplotlib.pyplot as plt
    import io
    plt.figure(figsize=(15,15))

    image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img)))
    plt.axis('off')

    plt.imshow(image)
    plt.savefig("demo.png")
    # print(len(parsed_content_list))