som_tom.py 6.23 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import torch
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import matplotlib.pyplot as plt

def som_prompting(image, pos_traces, neg_traces, draw_som_positive=False, draw_som_negative=False):
    """
    draw marks on the image
    """
    image_size = image.size
    draw = ImageDraw.Draw(image)

    def get_text_size(text, image, font):
        im = Image.new('RGB', (image.width, image.height))
        draw = ImageDraw.Draw(im)
        _, _, width, height = draw.textbbox((0, 0), text=text, font=font)
        return width, height
    
    def expand_bbox(bbox):
        x1, y1, x2, y2 = bbox
        return [x1-4, y1-4, x2+4, y2+4]
    
    def draw_marks(draw, points, text_size, id, font_size):
        txt = str(id)
        draw.ellipse(((points[0]-max(text_size)//2-1, points[1]-max(text_size)//2-1, points[0]+max(text_size)//2+1, points[1]+max(text_size)//2+1)), fill='red')
        draw.text((points[0]-text_size[0] // 2, points[1]-text_size[1] // 2-3), txt, fill='white', font=font_size)
        
    fontsize = 1
    font = ImageFont.truetype("data/utils/arial.ttf", fontsize)
    txt = "55"    
    while min(get_text_size(txt, image, font)) < 0.03*image_size[0]:
        # iterate until the text size is just larger than the criteria
        fontsize += 1
        font = ImageFont.truetype("data/utils/arial.ttf", fontsize)

    text_size_2digits = get_text_size('55', image, font)
    text_size_1digit = get_text_size('5', image, font)
    text_size = {
        1: text_size_1digit,
        2: text_size_2digits
    }

    # draw the starting point of positive traces on image
    num_pos = pos_traces.shape[2]
    pos_idx = torch.arange(num_pos)
    pos_traces_to_mark = pos_traces

    # random sample at most 10 negative traces
    num_neg = neg_traces.shape[2]
    neg_idx = torch.arange(num_neg)
    neg_traces_to_mark = neg_traces

    num_traces_total = pos_traces_to_mark.shape[2] + neg_traces_to_mark.shape[2]
    # shuffle the indices
    all_idx = torch.randperm(num_traces_total)

    pos_mark_ids = []; neg_mark_ids = []
    pos_traces_som = {}
    for i in range(pos_traces_to_mark.shape[2]):
        pos = pos_traces_to_mark[:,:,i]
        mark_id = all_idx[i].item()
        text_size = get_text_size(str(mark_id+1), image, font)
        if draw_som_positive:
            draw_marks(draw, pos[0][0], text_size, mark_id+1, font)
        pos_traces_som[mark_id+1] = pos
        pos_mark_ids.append(mark_id+1)
    
    neg_traces_som = {}
    for i in range(neg_traces_to_mark.shape[2]):
        neg = neg_traces_to_mark[:,:,i]
        mark_id = all_idx[pos_traces_to_mark.shape[2]+i].item()
        text_size = get_text_size(str(mark_id+1), image, font)
        if draw_som_negative:
            draw_marks(draw, neg[0][0], text_size, mark_id+1, font)
        neg_traces_som[mark_id+1] = neg
        neg_mark_ids.append(mark_id+1)

    return image, pos_traces_som, neg_traces_som, pos_mark_ids, neg_mark_ids, all_idx

def som_prompting_with_priors(image, pos_traces_som, neg_traces_som, pos_mark_ids, neg_mark_ids, all_idx, step_offset=1, draw_som_positive=False, draw_som_negative=False):
    """
    draw marks on the image
    """
    image_size = image.size
    draw = ImageDraw.Draw(image)

    def get_text_size(text, image, font):
        im = Image.new('RGB', (image.width, image.height))
        draw = ImageDraw.Draw(im)
        _, _, width, height = draw.textbbox((0, 0), text=text, font=font)
        return width, height
    
    def expand_bbox(bbox):
        x1, y1, x2, y2 = bbox
        return [x1-4, y1-4, x2+4, y2+4]
    
    def draw_marks(draw, points, text_size, id, font_size):
        txt = str(id)
        draw.ellipse(((points[0]-max(text_size)//2-1, points[1]-max(text_size)//2-1, points[0]+max(text_size)//2+1, points[1]+max(text_size)//2+1)), fill='red')
        draw.text((points[0]-text_size[0] // 2, points[1]-text_size[1] // 2-3), txt, fill='white', font=font_size)
        
    fontsize = 1
    font = ImageFont.truetype("data/utils/arial.ttf", fontsize)
    txt = "55"    
    while min(get_text_size(txt, image, font)) < 0.02*image_size[0]:
        # iterate until the text size is just larger than the criteria
        fontsize += 1
        font = ImageFont.truetype("data/utils/arial.ttf", fontsize)

    text_size_2digits = get_text_size('55', image, font)
    text_size_1digit = get_text_size('5', image, font)
    text_size = {
        1: text_size_1digit,
        2: text_size_2digits
    }

    for key, val in pos_traces_som.items():
        mark_id = key
        pos = val[:,step_offset if step_offset < val.shape[1] else -1]
        text_size = get_text_size(str(mark_id), image, font)
        if draw_som_positive:
            draw_marks(draw, pos[0], text_size, mark_id, font)
    
    for key, val in neg_traces_som.items():
        mark_id = key
        neg = val[:,step_offset if step_offset < val.shape[1] else -1]
        text_size = get_text_size(str(mark_id), image, font)
        if draw_som_negative:
            draw_marks(draw, neg[0], text_size, mark_id, font)

    return image

def tom_prompting(trace, image, pos_traces, neg_traces, draw_som_positive=False, draw_som_negative=False):
    """
    draw trace-of-marks on the image
    """
    image_size = image.size
    # draw traces for all points
    # get all traces
    tracks = torch.cat([pos_traces, neg_traces], dim=2).cpu().numpy()
    _, T, N, _ = tracks.shape    
    vector_colors = np.zeros((T, N, 3))
    if trace.visualizer.mode == "rainbow":
        y_min, y_max = (
            tracks[0, 0, :, 1].min(),
            tracks[0, 0, :, 1].max(),
        )
        norm = plt.Normalize(y_min, y_max)
        for n in range(N):
            color = trace.visualizer.color_map(norm(tracks[0, 0, n, 1]))
            color = np.array(color[:3])[None] * 255
            vector_colors[:, n] = np.repeat(color, T, axis=0)
    else:
        # color changes with time
        for t in range(T):
            color = np.array(trace.visualizer.color_map(t / T)[:3])[None] * 255
            vector_colors[t] = np.repeat(color, N, axis=0)

    # PIL to numpy
    image = np.array(image).astype(np.uint8)
    # unsqueeze image to 4D
    curr_tracks = tracks[0]
    curr_colors = vector_colors
    image = trace.visualizer._draw_pred_tracks(image, curr_tracks, curr_colors)
    image = Image.fromarray(image)
    return image