base_model.py 2.93 KB
Newer Older
WenmuZhou's avatar
WenmuZhou committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from paddle import nn
19
from ppocr.modeling.transforms import build_transform
WenmuZhou's avatar
WenmuZhou committed
20
21
22
23
from ppocr.modeling.backbones import build_backbone
from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head

dyning's avatar
dyning committed
24
__all__ = ['BaseModel']
WenmuZhou's avatar
WenmuZhou committed
25

WenmuZhou's avatar
WenmuZhou committed
26

dyning's avatar
dyning committed
27
class BaseModel(nn.Layer):
WenmuZhou's avatar
WenmuZhou committed
28
29
    def __init__(self, config):
        """
dyning's avatar
dyning committed
30
        the module for OCR.
WenmuZhou's avatar
WenmuZhou committed
31
32
33
        args:
            config (dict): the super parameters for module.
        """
dyning's avatar
dyning committed
34
        super(BaseModel, self).__init__()
WenmuZhou's avatar
WenmuZhou committed
35

WenmuZhou's avatar
WenmuZhou committed
36
        in_channels = config.get('in_channels', 3)
dyning's avatar
dyning committed
37
        model_type = config['model_type']
WenmuZhou's avatar
WenmuZhou committed
38
39
40
        # build transfrom,
        # for rec, transfrom can be TPS,None
        # for det and cls, transfrom shoule to be None,
dyning's avatar
dyning committed
41
        # if you make model differently, you can use transfrom in det and cls
WenmuZhou's avatar
WenmuZhou committed
42
43
44
45
46
47
48
49
50
51
        if 'Transform' not in config or config['Transform'] is None:
            self.use_transform = False
        else:
            self.use_transform = True
            config['Transform']['in_channels'] = in_channels
            self.transform = build_transform(config['Transform'])
            in_channels = self.transform.out_channels

        # build backbone, backbone is need for del, rec and cls
        config["Backbone"]['in_channels'] = in_channels
dyning's avatar
dyning committed
52
        self.backbone = build_backbone(config["Backbone"], model_type)
WenmuZhou's avatar
WenmuZhou committed
53
        in_channels = self.backbone.out_channels
WenmuZhou's avatar
WenmuZhou committed
54

WenmuZhou's avatar
WenmuZhou committed
55
56
57
58
59
60
61
62
63
64
65
        # build neck
        # for rec, neck can be cnn,rnn or reshape(None)
        # for det, neck can be FPN, BIFPN and so on.
        # for cls, neck should be none
        if 'Neck' not in config or config['Neck'] is None:
            self.use_neck = False
        else:
            self.use_neck = True
            config['Neck']['in_channels'] = in_channels
            self.neck = build_neck(config['Neck'])
            in_channels = self.neck.out_channels
WenmuZhou's avatar
WenmuZhou committed
66

WenmuZhou's avatar
WenmuZhou committed
67
        # # build head, head is need for det, rec and cls
WenmuZhou's avatar
WenmuZhou committed
68
69
70
        config["Head"]['in_channels'] = in_channels
        self.head = build_head(config["Head"])

tink2123's avatar
tink2123 committed
71
    def forward(self, x, data=None):
WenmuZhou's avatar
WenmuZhou committed
72
73
74
75
76
        if self.use_transform:
            x = self.transform(x)
        x = self.backbone(x)
        if self.use_neck:
            x = self.neck(x)
tink2123's avatar
tink2123 committed
77
78
79
80
        if data is None:
            x = self.head(x)
        else:
            x = self.head(x, data)
WenmuZhou's avatar
WenmuZhou committed
81
        return x