extractor.py 3.91 KB
Newer Older
HHL's avatar
v  
HHL committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
from torch import nn
from torch._C import device
from torchvision.ops import roi_align


def convert_to_roi_format(lines_box):
    concat_boxes = torch.cat(lines_box, dim=0)
    device, dtype = concat_boxes.device, concat_boxes.dtype
    ids = torch.cat(
        [
            torch.full((lines_box_pi.shape[0], 1), i, dtype=dtype, device=device)
            for i, lines_box_pi in enumerate(lines_box)
        ],
        dim=0
    )
    rois = torch.cat([ids, concat_boxes], dim=1)
    return rois


class RoIPool(nn.Module):
    def __init__(self, pool_size):
        super().__init__()
        self.pool_size = pool_size

    def gen_rois(self, feats):
        *_, H, W = feats.shape
        pool_W, pool_H = self.pool_size

        Width = W / pool_W
        Height = H / pool_H

        bbox_x = torch.arange(0, pool_W + 1, 1).to(feats) * Width
        bbox_y = torch.arange(0, pool_H + 1, 1).to(feats) * Height

        bboxes = torch.stack(
            [
                bbox_x[:-1].repeat(pool_W, 1),
                bbox_y[:-1].repeat(pool_H, 1).transpose(0, 1),
                bbox_x[1:].repeat(pool_W, 1),
                bbox_y[1:].repeat(pool_H, 1).transpose(0, 1),
            ],
            dim=-1,
        ).view(-1, 4)
        
        rois = list()
        for batch_idx in range(feats.shape[0]):
            ids = torch.full((bboxes.shape[0], 1), batch_idx, dtype=feats.dtype, device=feats.device)
            rois.append(torch.cat([ids, bboxes], dim=-1))
        
        rois = torch.cat(rois, dim=0)
        return rois

    def forward(self, feats):
        rois = self.gen_rois(feats)
        bboxes_feat = roi_align(
            input=feats,
            boxes=rois,
            output_size=(1, 1),
            spatial_scale=1.0,
            sampling_ratio=1
        )
        bs = feats.shape[0]
        len = int(self.pool_size[0] * self.pool_size[1])
        bboxes_feat = bboxes_feat.reshape(bs, len, -1)
        return bboxes_feat


def tensor_convert_to_roi_format(line_bboxes):
    B, L, _ = line_bboxes.shape
    roi_ids = torch.zeros((B, L, 1)).to(line_bboxes).float()
    for id in range(B):
        roi_ids[id] = id
    rois = torch.cat([roi_ids, line_bboxes], dim=-1).reshape(-1, 5)
    return rois


class RoiFeatExtraxtor(nn.Module):
    def __init__(self, scale):
        super().__init__()
        self.scale = scale

    def forward(self, feats, line_bboxes):
        rois = tensor_convert_to_roi_format(line_bboxes)
        lines_feat = roi_align(
            input=feats,
            boxes=rois,
            output_size=(1, 1),
            spatial_scale=self.scale,
            sampling_ratio=1
        )
        
        lines_feat = lines_feat.reshape(lines_feat.shape[0], -1)
        view_shape = line_bboxes.shape[:2]
        lines_feat = lines_feat.view(*view_shape,-1)
        return lines_feat


class RecogFeatExtraxtor(nn.Module):
    def __init__(self, scale):
        super().__init__()
        self.scale = scale

    def forward(self, feats, line_bboxes, output_size=(1,1)):
        rois = tensor_convert_to_roi_format(line_bboxes)
        lines_feat = roi_align(
            input=feats,
            boxes=rois,
            output_size=output_size,
            spatial_scale=self.scale,
            sampling_ratio=2
        )
        return lines_feat


class ImageRegionExtractor(nn.Module):
    def __init__(self, scale, output_size):
        super().__init__()
        self.scale = scale
        self.output_size = output_size
    
    def forward(self, images, line_bboxes):
        rois = tensor_convert_to_roi_format(line_bboxes)
        images_feat = roi_align(
            input=images,
            boxes=rois,
            output_size=self.output_size,
            spatial_scale=self.scale,
            sampling_ratio=1
        )
        images_feat = images_feat.reshape(images_feat.shape[0], -1)
        view_shape = line_bboxes.shape[:2]
        images_feat = images_feat.view(*view_shape, -1)
        return images_feat