base_model.py 3.64 KB
Newer Older
MissPenguin's avatar
refine  
MissPenguin committed
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
WenmuZhou's avatar
WenmuZhou committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from paddle import nn
18
from ppocr.modeling.transforms import build_transform
WenmuZhou's avatar
WenmuZhou committed
19
20
21
22
from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head

dyning's avatar
dyning committed
23
__all__ = ['BaseModel']
WenmuZhou's avatar
WenmuZhou committed
24

WenmuZhou's avatar
WenmuZhou committed
25

dyning's avatar
dyning committed
26
class BaseModel(nn.Layer):
WenmuZhou's avatar
WenmuZhou committed
27
28
    def __init__(self, config):
        """
dyning's avatar
dyning committed
29
        the module for OCR.
WenmuZhou's avatar
WenmuZhou committed
30
31
32
        args:
            config (dict): the super parameters for module.
        """
dyning's avatar
dyning committed
33
        super(BaseModel, self).__init__()
WenmuZhou's avatar
WenmuZhou committed
34
        in_channels = config.get('in_channels', 3)
dyning's avatar
dyning committed
35
        model_type = config['model_type']
WenmuZhou's avatar
WenmuZhou committed
36
37
38
        # build transfrom,
        # for rec, transfrom can be TPS,None
        # for det and cls, transfrom shoule to be None,
dyning's avatar
dyning committed
39
        # if you make model differently, you can use transfrom in det and cls
WenmuZhou's avatar
WenmuZhou committed
40
41
42
43
44
45
46
47
48
49
        if 'Transform' not in config or config['Transform'] is None:
            self.use_transform = False
        else:
            self.use_transform = True
            config['Transform']['in_channels'] = in_channels
            self.transform = build_transform(config['Transform'])
            in_channels = self.transform.out_channels

        # build backbone, backbone is need for del, rec and cls
        config["Backbone"]['in_channels'] = in_channels
dyning's avatar
dyning committed
50
        self.backbone = build_backbone(config["Backbone"], model_type)
WenmuZhou's avatar
WenmuZhou committed
51
        in_channels = self.backbone.out_channels
WenmuZhou's avatar
WenmuZhou committed
52

WenmuZhou's avatar
WenmuZhou committed
53
54
55
56
57
58
59
60
61
62
63
        # build neck
        # for rec, neck can be cnn,rnn or reshape(None)
        # for det, neck can be FPN, BIFPN and so on.
        # for cls, neck should be none
        if 'Neck' not in config or config['Neck'] is None:
            self.use_neck = False
        else:
            self.use_neck = True
            config['Neck']['in_channels'] = in_channels
            self.neck = build_neck(config['Neck'])
            in_channels = self.neck.out_channels
WenmuZhou's avatar
WenmuZhou committed
64

WenmuZhou's avatar
WenmuZhou committed
65
        # # build head, head is need for det, rec and cls
66
67
68
69
70
71
        if 'Head' not in config or config['Head'] is None:
            self.use_head = False
        else:
            self.use_head = True
            config["Head"]['in_channels'] = in_channels
            self.head = build_head(config["Head"])
WenmuZhou's avatar
WenmuZhou committed
72

73
74
        self.return_all_feats = config.get("return_all_feats", False)

tink2123's avatar
tink2123 committed
75
    def forward(self, x, data=None):
76
        y = dict()
WenmuZhou's avatar
WenmuZhou committed
77
78
79
        if self.use_transform:
            x = self.transform(x)
        x = self.backbone(x)
80
        y["backbone_out"] = x
WenmuZhou's avatar
WenmuZhou committed
81
82
        if self.use_neck:
            x = self.neck(x)
83
        y["neck_out"] = x
84
85
        if self.use_head:
            x = self.head(x, targets=data)
andyjpaddle's avatar
andyjpaddle committed
86
87
88
89
90
        # for multi head, save ctc neck out for udml
        if isinstance(x, dict) and 'ctc_neck' in x.keys():
            y["neck_out"] = x["ctc_neck"]
            y["head_out"] = x
        elif isinstance(x, dict):
LDOUBLEV's avatar
LDOUBLEV committed
91
            y.update(x)
tink2123's avatar
tink2123 committed
92
        else:
LDOUBLEV's avatar
LDOUBLEV committed
93
            y["head_out"] = x
94
        if self.return_all_feats:
littletomatodonkey's avatar
littletomatodonkey committed
95
96
97
98
            if self.training:
                return y
            else:
                return {"head_out": y["head_out"]}
99
100
        else:
            return x