rec_sar_head.py 14 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from: 
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/sar_encoder.py
https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/sar_decoder.py
"""

andyjpaddle's avatar
andyjpaddle committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import paddle
from paddle import ParamAttr
import paddle.nn as nn
import paddle.nn.functional as F


class SAREncoder(nn.Layer):
    """
    Args:
        enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
        enc_drop_rnn (float): Dropout probability of RNN layer in encoder.
        enc_gru (bool): If True, use GRU, else LSTM in encoder.
        d_model (int): Dim of channels from backbone.
        d_enc (int): Dim of encoder RNN layer.
        mask (bool): If True, mask padding in RNN sequence.
    """
andyjpaddle's avatar
andyjpaddle committed
41

andyjpaddle's avatar
andyjpaddle committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
    def __init__(self,
                 enc_bi_rnn=False,
                 enc_drop_rnn=0.1,
                 enc_gru=False,
                 d_model=512,
                 d_enc=512,
                 mask=True,
                 **kwargs):
        super().__init__()
        assert isinstance(enc_bi_rnn, bool)
        assert isinstance(enc_drop_rnn, (int, float))
        assert 0 <= enc_drop_rnn < 1.0
        assert isinstance(enc_gru, bool)
        assert isinstance(d_model, int)
        assert isinstance(d_enc, int)
        assert isinstance(mask, bool)

        self.enc_bi_rnn = enc_bi_rnn
        self.enc_drop_rnn = enc_drop_rnn
        self.mask = mask

        # LSTM Encoder
        if enc_bi_rnn:
            direction = 'bidirectional'
        else:
            direction = 'forward'
        kwargs = dict(
            input_size=d_model,
            hidden_size=d_enc,
            num_layers=2,
            time_major=False,
            dropout=enc_drop_rnn,
andyjpaddle's avatar
andyjpaddle committed
74
            direction=direction)
andyjpaddle's avatar
andyjpaddle committed
75
76
77
78
        if enc_gru:
            self.rnn_encoder = nn.GRU(**kwargs)
        else:
            self.rnn_encoder = nn.LSTM(**kwargs)
andyjpaddle's avatar
andyjpaddle committed
79

andyjpaddle's avatar
andyjpaddle committed
80
81
82
        # global feature transformation
        encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
        self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
andyjpaddle's avatar
andyjpaddle committed
83

andyjpaddle's avatar
andyjpaddle committed
84
85
    def forward(self, feat, img_metas=None):
        if img_metas is not None:
86
            assert len(img_metas[0]) == paddle.shape(feat)[0]
andyjpaddle's avatar
andyjpaddle committed
87

andyjpaddle's avatar
andyjpaddle committed
88
89
90
        valid_ratios = None
        if img_metas is not None and self.mask:
            valid_ratios = img_metas[-1]
andyjpaddle's avatar
andyjpaddle committed
91
92

        h_feat = feat.shape[2]  # bsz c h w
andyjpaddle's avatar
andyjpaddle committed
93
        feat_v = F.max_pool2d(
andyjpaddle's avatar
andyjpaddle committed
94
95
96
97
98
            feat, kernel_size=(h_feat, 1), stride=1, padding=0)
        feat_v = feat_v.squeeze(2)  # bsz * C * W
        feat_v = paddle.transpose(feat_v, perm=[0, 2, 1])  # bsz * W * C
        holistic_feat = self.rnn_encoder(feat_v)[0]  # bsz * T * C

andyjpaddle's avatar
andyjpaddle committed
99
100
        if valid_ratios is not None:
            valid_hf = []
101
102
103
104
            T = paddle.shape(holistic_feat)[1]
            for i in range(paddle.shape(valid_ratios)[0]):
                valid_step = paddle.minimum(
                    T, paddle.ceil(valid_ratios[i] * T).astype('int32')) - 1
andyjpaddle's avatar
andyjpaddle committed
105
106
107
                valid_hf.append(holistic_feat[i, valid_step, :])
            valid_hf = paddle.stack(valid_hf, axis=0)
        else:
andyjpaddle's avatar
andyjpaddle committed
108
109
110
            valid_hf = holistic_feat[:, -1, :]  # bsz * C
        holistic_feat = self.linear(valid_hf)  # bsz * C

andyjpaddle's avatar
andyjpaddle committed
111
        return holistic_feat
andyjpaddle's avatar
andyjpaddle committed
112

andyjpaddle's avatar
andyjpaddle committed
113
114
115
116
117
118
119
120
121
122
123

class BaseDecoder(nn.Layer):
    def __init__(self, **kwargs):
        super().__init__()

    def forward_train(self, feat, out_enc, targets, img_metas):
        raise NotImplementedError

    def forward_test(self, feat, out_enc, img_metas):
        raise NotImplementedError

andyjpaddle's avatar
andyjpaddle committed
124
    def forward(self,
andyjpaddle's avatar
andyjpaddle committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
                feat,
                out_enc,
                label=None,
                img_metas=None,
                train_mode=True):
        self.train_mode = train_mode

        if train_mode:
            return self.forward_train(feat, out_enc, label, img_metas)
        return self.forward_test(feat, out_enc, img_metas)


class ParallelSARDecoder(BaseDecoder):
    """
    Args:
andyjpaddle's avatar
andyjpaddle committed
140
        out_channels (int): Output class number.
andyjpaddle's avatar
andyjpaddle committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
        enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
        dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
        dec_drop_rnn (float): Dropout of RNN layer in decoder.
        dec_gru (bool): If True, use GRU, else LSTM in decoder.
        d_model (int): Dim of channels from backbone.
        d_enc (int): Dim of encoder RNN layer.
        d_k (int): Dim of channels of attention module.
        pred_dropout (float): Dropout probability of prediction layer.
        max_seq_len (int): Maximum sequence length for decoding.
        mask (bool): If True, mask padding in feature map.
        start_idx (int): Index of start token.
        padding_idx (int): Index of padding token.
        pred_concat (bool): If True, concat glimpse feature from
            attention with holistic feature and hidden state.
    """

andyjpaddle's avatar
andyjpaddle committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    def __init__(
            self,
            out_channels,  # 90 + unknown + start + padding
            enc_bi_rnn=False,
            dec_bi_rnn=False,
            dec_drop_rnn=0.0,
            dec_gru=False,
            d_model=512,
            d_enc=512,
            d_k=64,
            pred_dropout=0.1,
            max_text_length=30,
            mask=True,
            pred_concat=True,
            **kwargs):
andyjpaddle's avatar
andyjpaddle committed
172
173
        super().__init__()

andyjpaddle's avatar
andyjpaddle committed
174
        self.num_classes = out_channels
andyjpaddle's avatar
andyjpaddle committed
175
176
        self.enc_bi_rnn = enc_bi_rnn
        self.d_k = d_k
andyjpaddle's avatar
andyjpaddle committed
177
        self.start_idx = out_channels - 2
andyjpaddle's avatar
andyjpaddle committed
178
        self.padding_idx = out_channels - 1
andyjpaddle's avatar
andyjpaddle committed
179
180
181
182
183
184
185
186
187
        self.max_seq_len = max_text_length
        self.mask = mask
        self.pred_concat = pred_concat

        encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
        decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1)

        # 2D attention layer
        self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
andyjpaddle's avatar
andyjpaddle committed
188
189
        self.conv3x3_1 = nn.Conv2D(
            d_model, d_k, kernel_size=3, stride=1, padding=1)
andyjpaddle's avatar
andyjpaddle committed
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        self.conv1x1_2 = nn.Linear(d_k, 1)

        # Decoder RNN layer
        if dec_bi_rnn:
            direction = 'bidirectional'
        else:
            direction = 'forward'

        kwargs = dict(
            input_size=encoder_rnn_out_size,
            hidden_size=encoder_rnn_out_size,
            num_layers=2,
            time_major=False,
            dropout=dec_drop_rnn,
andyjpaddle's avatar
andyjpaddle committed
204
            direction=direction)
andyjpaddle's avatar
andyjpaddle committed
205
206
207
208
209
210
211
        if dec_gru:
            self.rnn_decoder = nn.GRU(**kwargs)
        else:
            self.rnn_decoder = nn.LSTM(**kwargs)

        # Decoder input embedding
        self.embedding = nn.Embedding(
andyjpaddle's avatar
andyjpaddle committed
212
213
214
215
            self.num_classes,
            encoder_rnn_out_size,
            padding_idx=self.padding_idx)

andyjpaddle's avatar
andyjpaddle committed
216
217
        # Prediction layer
        self.pred_dropout = nn.Dropout(pred_dropout)
andyjpaddle's avatar
andyjpaddle committed
218
        pred_num_classes = self.num_classes - 1
andyjpaddle's avatar
andyjpaddle committed
219
        if pred_concat:
andyjpaddle's avatar
andyjpaddle committed
220
            fc_in_channel = decoder_rnn_out_size + d_model + encoder_rnn_out_size
andyjpaddle's avatar
andyjpaddle committed
221
222
223
224
225
226
227
228
229
        else:
            fc_in_channel = d_model
        self.prediction = nn.Linear(fc_in_channel, pred_num_classes)

    def _2d_attention(self,
                      decoder_input,
                      feat,
                      holistic_feat,
                      valid_ratios=None):
andyjpaddle's avatar
andyjpaddle committed
230

andyjpaddle's avatar
andyjpaddle committed
231
232
        y = self.rnn_decoder(decoder_input)[0]
        # y: bsz * (seq_len + 1) * hidden_size
andyjpaddle's avatar
andyjpaddle committed
233
234

        attn_query = self.conv1x1_1(y)  # bsz * (seq_len + 1) * attn_size
andyjpaddle's avatar
andyjpaddle committed
235
236
237
238
239
240
241
242
243
244
        bsz, seq_len, attn_size = attn_query.shape
        attn_query = paddle.unsqueeze(attn_query, axis=[3, 4])
        # (bsz, seq_len + 1, attn_size, 1, 1)

        attn_key = self.conv3x3_1(feat)
        # bsz * attn_size * h * w
        attn_key = attn_key.unsqueeze(1)
        # bsz * 1 * attn_size * h * w

        attn_weight = paddle.tanh(paddle.add(attn_key, attn_query))
andyjpaddle's avatar
andyjpaddle committed
245

andyjpaddle's avatar
andyjpaddle committed
246
247
248
249
250
        # bsz * (seq_len + 1) * attn_size * h * w
        attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2])
        # bsz * (seq_len + 1) * h * w * attn_size
        attn_weight = self.conv1x1_2(attn_weight)
        # bsz * (seq_len + 1) * h * w * 1
251
        bsz, T, h, w, c = paddle.shape(attn_weight)
andyjpaddle's avatar
andyjpaddle committed
252
253
254
255
        assert c == 1

        if valid_ratios is not None:
            # cal mask of attention weight
256
257
258
            for i in range(paddle.shape(valid_ratios)[0]):
                valid_width = paddle.minimum(
                    w, paddle.ceil(valid_ratios[i] * w).astype("int32"))
andyjpaddle's avatar
andyjpaddle committed
259
260
                if valid_width < w:
                    attn_weight[i, :, :, valid_width:, :] = float('-inf')
andyjpaddle's avatar
andyjpaddle committed
261
262
263

        attn_weight = paddle.reshape(attn_weight, [bsz, T, -1])
        attn_weight = F.softmax(attn_weight, axis=-1)
andyjpaddle's avatar
andyjpaddle committed
264

andyjpaddle's avatar
andyjpaddle committed
265
266
267
268
        attn_weight = paddle.reshape(attn_weight, [bsz, T, h, w, c])
        attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3])
        # attn_weight: bsz * T * c * h * w
        # feat: bsz * c * h * w
andyjpaddle's avatar
andyjpaddle committed
269
270
271
        attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight),
                               (3, 4),
                               keepdim=False)
andyjpaddle's avatar
andyjpaddle committed
272
273
274
275
276
        # bsz * (seq_len + 1) * C

        # Linear transformation
        if self.pred_concat:
            hf_c = holistic_feat.shape[-1]
andyjpaddle's avatar
andyjpaddle committed
277
278
            holistic_feat = paddle.expand(
                holistic_feat, shape=[bsz, seq_len, hf_c])
andyjpaddle's avatar
andyjpaddle committed
279
280
281
282
283
284
            y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2))
        else:
            y = self.prediction(attn_feat)
        # bsz * (seq_len + 1) * num_classes
        if self.train_mode:
            y = self.pred_dropout(y)
andyjpaddle's avatar
andyjpaddle committed
285

andyjpaddle's avatar
andyjpaddle committed
286
287
288
289
290
291
292
        return y

    def forward_train(self, feat, out_enc, label, img_metas):
        '''
        img_metas: [label, valid_ratio]
        '''
        if img_metas is not None:
293
            assert paddle.shape(img_metas[0])[0] == paddle.shape(feat)[0]
andyjpaddle's avatar
andyjpaddle committed
294
295
296
297

        valid_ratios = None
        if img_metas is not None and self.mask:
            valid_ratios = img_metas[-1]
andyjpaddle's avatar
andyjpaddle committed
298

andyjpaddle's avatar
andyjpaddle committed
299
300
301
302
303
304
305
        lab_embedding = self.embedding(label)
        # bsz * seq_len * emb_dim
        out_enc = out_enc.unsqueeze(1)
        # bsz * 1 * emb_dim
        in_dec = paddle.concat((out_enc, lab_embedding), axis=1)
        # bsz * (seq_len + 1) * C
        out_dec = self._2d_attention(
andyjpaddle's avatar
andyjpaddle committed
306
307
308
            in_dec, feat, out_enc, valid_ratios=valid_ratios)

        return out_dec[:, 1:, :]  # bsz * seq_len * num_classes
andyjpaddle's avatar
andyjpaddle committed
309
310
311
312
313
314
315

    def forward_test(self, feat, out_enc, img_metas):
        if img_metas is not None:
            assert len(img_metas[0]) == feat.shape[0]

        valid_ratios = None
        if img_metas is not None and self.mask:
andyjpaddle's avatar
andyjpaddle committed
316
317
            valid_ratios = img_metas[-1]

andyjpaddle's avatar
andyjpaddle committed
318
319
        seq_len = self.max_seq_len
        bsz = feat.shape[0]
andyjpaddle's avatar
andyjpaddle committed
320
321
        start_token = paddle.full(
            (bsz, ), fill_value=self.start_idx, dtype='int64')
andyjpaddle's avatar
andyjpaddle committed
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
        # bsz
        start_token = self.embedding(start_token)
        # bsz * emb_dim
        emb_dim = start_token.shape[1]
        start_token = start_token.unsqueeze(1)
        start_token = paddle.expand(start_token, shape=[bsz, seq_len, emb_dim])
        # bsz * seq_len * emb_dim
        out_enc = out_enc.unsqueeze(1)
        # bsz * 1 * emb_dim
        decoder_input = paddle.concat((out_enc, start_token), axis=1)
        # bsz * (seq_len + 1) * emb_dim

        outputs = []
        for i in range(1, seq_len + 1):
            decoder_output = self._2d_attention(
andyjpaddle's avatar
andyjpaddle committed
337
338
                decoder_input, feat, out_enc, valid_ratios=valid_ratios)
            char_output = decoder_output[:, i, :]  # bsz * num_classes
andyjpaddle's avatar
andyjpaddle committed
339
340
341
            char_output = F.softmax(char_output, -1)
            outputs.append(char_output)
            max_idx = paddle.argmax(char_output, axis=1, keepdim=False)
andyjpaddle's avatar
andyjpaddle committed
342
            char_embedding = self.embedding(max_idx)  # bsz * emb_dim
andyjpaddle's avatar
andyjpaddle committed
343
344
            if i < seq_len:
                decoder_input[:, i + 1, :] = char_embedding
andyjpaddle's avatar
andyjpaddle committed
345
346

        outputs = paddle.stack(outputs, 1)  # bsz * seq_len * num_classes
andyjpaddle's avatar
andyjpaddle committed
347
348
349
350
351

        return outputs


class SARHead(nn.Layer):
andyjpaddle's avatar
andyjpaddle committed
352
    def __init__(self,
andyjpaddle's avatar
andyjpaddle committed
353
                 in_channels,
andyjpaddle's avatar
andyjpaddle committed
354
                 out_channels,
andyjpaddle's avatar
andyjpaddle committed
355
356
                 enc_dim=512,
                 max_text_length=30,
andyjpaddle's avatar
andyjpaddle committed
357
358
359
360
361
362
363
364
365
366
                 enc_bi_rnn=False,
                 enc_drop_rnn=0.1,
                 enc_gru=False,
                 dec_bi_rnn=False,
                 dec_drop_rnn=0.0,
                 dec_gru=False,
                 d_k=512,
                 pred_dropout=0.1,
                 pred_concat=True,
                 **kwargs):
andyjpaddle's avatar
andyjpaddle committed
367
368
369
370
        super(SARHead, self).__init__()

        # encoder module
        self.encoder = SAREncoder(
andyjpaddle's avatar
andyjpaddle committed
371
372
373
374
375
            enc_bi_rnn=enc_bi_rnn,
            enc_drop_rnn=enc_drop_rnn,
            enc_gru=enc_gru,
            d_model=in_channels,
            d_enc=enc_dim)
andyjpaddle's avatar
andyjpaddle committed
376
377
378

        # decoder module
        self.decoder = ParallelSARDecoder(
andyjpaddle's avatar
andyjpaddle committed
379
            out_channels=out_channels,
andyjpaddle's avatar
andyjpaddle committed
380
            enc_bi_rnn=enc_bi_rnn,
andyjpaddle's avatar
andyjpaddle committed
381
382
383
            dec_bi_rnn=dec_bi_rnn,
            dec_drop_rnn=dec_drop_rnn,
            dec_gru=dec_gru,
andyjpaddle's avatar
andyjpaddle committed
384
385
            d_model=in_channels,
            d_enc=enc_dim,
andyjpaddle's avatar
andyjpaddle committed
386
387
388
            d_k=d_k,
            pred_dropout=pred_dropout,
            max_text_length=max_text_length,
andyjpaddle's avatar
andyjpaddle committed
389
390
            pred_concat=pred_concat)

andyjpaddle's avatar
andyjpaddle committed
391
392
393
394
    def forward(self, feat, targets=None):
        '''
        img_metas: [label, valid_ratio]
        '''
andyjpaddle's avatar
andyjpaddle committed
395
396
        holistic_feat = self.encoder(feat, targets)  # bsz c

andyjpaddle's avatar
andyjpaddle committed
397
        if self.training:
andyjpaddle's avatar
andyjpaddle committed
398
399
400
            label = targets[0]  # label
            final_out = self.decoder(
                feat, holistic_feat, label, img_metas=targets)
andyjpaddle's avatar
andyjpaddle committed
401
        else:
andyjpaddle's avatar
andyjpaddle committed
402
403
404
405
406
407
            final_out = self.decoder(
                feat,
                holistic_feat,
                label=None,
                img_metas=targets,
                train_mode=False)
andyjpaddle's avatar
andyjpaddle committed
408
            # (bsz, seq_len, num_classes)
andyjpaddle's avatar
andyjpaddle committed
409

andyjpaddle's avatar
andyjpaddle committed
410
        return final_out