base_model.py 3.32 KB
Newer Older
MissPenguin's avatar
refine  
MissPenguin committed
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
WenmuZhou's avatar
WenmuZhou committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from paddle import nn
19
from ppocr.modeling.transforms import build_transform
WenmuZhou's avatar
WenmuZhou committed
20
21
22
23
from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head

dyning's avatar
dyning committed
24
__all__ = ['BaseModel']
WenmuZhou's avatar
WenmuZhou committed
25

WenmuZhou's avatar
WenmuZhou committed
26

dyning's avatar
dyning committed
27
class BaseModel(nn.Layer):
WenmuZhou's avatar
WenmuZhou committed
28
29
    def __init__(self, config):
        """
dyning's avatar
dyning committed
30
        the module for OCR.
WenmuZhou's avatar
WenmuZhou committed
31
32
33
        args:
            config (dict): the super parameters for module.
        """
dyning's avatar
dyning committed
34
        super(BaseModel, self).__init__()
WenmuZhou's avatar
WenmuZhou committed
35
        in_channels = config.get('in_channels', 3)
dyning's avatar
dyning committed
36
        model_type = config['model_type']
WenmuZhou's avatar
WenmuZhou committed
37
38
39
        # build transfrom,
        # for rec, transfrom can be TPS,None
        # for det and cls, transfrom shoule to be None,
dyning's avatar
dyning committed
40
        # if you make model differently, you can use transfrom in det and cls
WenmuZhou's avatar
WenmuZhou committed
41
42
43
44
45
46
47
48
49
50
        if 'Transform' not in config or config['Transform'] is None:
            self.use_transform = False
        else:
            self.use_transform = True
            config['Transform']['in_channels'] = in_channels
            self.transform = build_transform(config['Transform'])
            in_channels = self.transform.out_channels

        # build backbone, backbone is need for del, rec and cls
        config["Backbone"]['in_channels'] = in_channels
dyning's avatar
dyning committed
51
        self.backbone = build_backbone(config["Backbone"], model_type)
WenmuZhou's avatar
WenmuZhou committed
52
        in_channels = self.backbone.out_channels
WenmuZhou's avatar
WenmuZhou committed
53

WenmuZhou's avatar
WenmuZhou committed
54
55
56
57
58
59
60
61
62
63
64
        # build neck
        # for rec, neck can be cnn,rnn or reshape(None)
        # for det, neck can be FPN, BIFPN and so on.
        # for cls, neck should be none
        if 'Neck' not in config or config['Neck'] is None:
            self.use_neck = False
        else:
            self.use_neck = True
            config['Neck']['in_channels'] = in_channels
            self.neck = build_neck(config['Neck'])
            in_channels = self.neck.out_channels
WenmuZhou's avatar
WenmuZhou committed
65

WenmuZhou's avatar
WenmuZhou committed
66
        # # build head, head is need for det, rec and cls
WenmuZhou's avatar
WenmuZhou committed
67
68
69
        config["Head"]['in_channels'] = in_channels
        self.head = build_head(config["Head"])

70
71
        self.return_all_feats = config.get("return_all_feats", False)

MissPenguin's avatar
MissPenguin committed
72
    def forward(self, x, data=None, mode='Train'):
73
        y = dict()
WenmuZhou's avatar
WenmuZhou committed
74
75
76
        if self.use_transform:
            x = self.transform(x)
        x = self.backbone(x)
77
        y["backbone_out"] = x
WenmuZhou's avatar
WenmuZhou committed
78
79
        if self.use_neck:
            x = self.neck(x)
80
        y["neck_out"] = x
tink2123's avatar
tink2123 committed
81
82
83
        if data is None:
            x = self.head(x)
        else:
MissPenguin's avatar
MissPenguin committed
84
85
86
87
            if mode == 'Eval' or mode == 'Test':
                x = self.head(x, targets=data, mode=mode)
            else:
                x = self.head(x, targets=data)
88
89
90
91
92
        y["head_out"] = x
        if self.return_all_feats:
            return y
        else:
            return x