save_load.py 6.1 KB
Newer Older
LDOUBLEV's avatar
LDOUBLEV committed
1
2
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
WenmuZhou's avatar
WenmuZhou committed
3
4
5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
LDOUBLEV's avatar
LDOUBLEV committed
6
7
8
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
WenmuZhou's avatar
WenmuZhou committed
9
10
11
12
13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
LDOUBLEV's avatar
LDOUBLEV committed
14
15
16
17
18
19
20

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import errno
import os
WenmuZhou's avatar
WenmuZhou committed
21
22
import pickle
import six
LDOUBLEV's avatar
LDOUBLEV committed
23

WenmuZhou's avatar
WenmuZhou committed
24
import paddle
LDOUBLEV's avatar
LDOUBLEV committed
25

littletomatodonkey's avatar
littletomatodonkey committed
26
27
from ppocr.utils.logging import get_logger

Double_V's avatar
Double_V committed
28
__all__ = ['init_model', 'save_model', 'load_dygraph_params']
LDOUBLEV's avatar
LDOUBLEV committed
29
30


WenmuZhou's avatar
WenmuZhou committed
31
def _mkdir_if_not_exist(path, logger):
LDOUBLEV's avatar
LDOUBLEV committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
    """
    mkdir if not exists, ignore the exception when multiprocess mkdir together
    """
    if not os.path.exists(path):
        try:
            os.makedirs(path)
        except OSError as e:
            if e.errno == errno.EEXIST and os.path.isdir(path):
                logger.warning(
                    'be happy if some process has already created {}'.format(
                        path))
            else:
                raise OSError('Failed to mkdir {}'.format(path))


littletomatodonkey's avatar
littletomatodonkey committed
47
def init_model(config, model, optimizer=None, lr_scheduler=None):
LDOUBLEV's avatar
LDOUBLEV committed
48
49
50
    """
    load model from checkpoint or pretrained_model
    """
littletomatodonkey's avatar
littletomatodonkey committed
51
    logger = get_logger()
YukSing's avatar
YukSing committed
52
53
54
    global_config = config['Global']
    checkpoints = global_config.get('checkpoints')
    pretrained_model = global_config.get('pretrained_model')
WenmuZhou's avatar
WenmuZhou committed
55
    best_model_dict = {}
LDOUBLEV's avatar
LDOUBLEV committed
56
    if checkpoints:
WenmuZhou's avatar
WenmuZhou committed
57
58
59
60
        assert os.path.exists(checkpoints + ".pdparams"), \
            "Given dir {}.pdparams not exist.".format(checkpoints)
        assert os.path.exists(checkpoints + ".pdopt"), \
            "Given dir {}.pdopt not exist.".format(checkpoints)
WenmuZhou's avatar
WenmuZhou committed
61
62
        para_dict = paddle.load(checkpoints + '.pdparams')
        opti_dict = paddle.load(checkpoints + '.pdopt')
WenmuZhou's avatar
WenmuZhou committed
63
        model.set_state_dict(para_dict)
WenmuZhou's avatar
WenmuZhou committed
64
65
66
67
68
69
70
71
72
73
74
75
        if optimizer is not None:
            optimizer.set_state_dict(opti_dict)

        if os.path.exists(checkpoints + '.states'):
            with open(checkpoints + '.states', 'rb') as f:
                states_dict = pickle.load(f) if six.PY2 else pickle.load(
                    f, encoding='latin1')
            best_model_dict = states_dict.get('best_model_dict', {})
            if 'epoch' in states_dict:
                best_model_dict['start_epoch'] = states_dict['epoch'] + 1
        logger.info("resume from {}".format(checkpoints))
    elif pretrained_model:
dyning's avatar
dyning committed
76
77
        if not isinstance(pretrained_model, list):
            pretrained_model = [pretrained_model]
78
        for pretrained in pretrained_model:
littletomatodonkey's avatar
littletomatodonkey committed
79
80
81
82
83
84
            if not (os.path.isdir(pretrained) or
                    os.path.exists(pretrained + '.pdparams')):
                raise ValueError("Model pretrain path {} does not "
                                 "exists.".format(pretrained))
            param_state_dict = paddle.load(pretrained + '.pdparams')
            model.set_state_dict(param_state_dict)
dyning's avatar
dyning committed
85
86
            logger.info("load pretrained model from {}".format(
                pretrained_model))
87
    else:
WenmuZhou's avatar
WenmuZhou committed
88
89
        logger.info('train from scratch')
    return best_model_dict
LDOUBLEV's avatar
LDOUBLEV committed
90
91


Double_V's avatar
Double_V committed
92
93
def load_dygraph_params(config, model, logger, optimizer):
    ckp = config['Global']['checkpoints']
bingooo's avatar
fix bug  
bingooo committed
94
    if ckp and os.path.exists(ckp + ".pdparams"):
Double_V's avatar
Double_V committed
95
96
97
98
99
100
        pre_best_model_dict = init_model(config, model, optimizer)
        return pre_best_model_dict
    else:
        pm = config['Global']['pretrained_model']
        if pm is None:
            return {}
bingooo's avatar
fix bug  
bingooo committed
101
        if not os.path.exists(pm) and not os.path.exists(pm + ".pdparams"):
Double_V's avatar
Double_V committed
102
103
104
105
106
107
108
109
110
            logger.info(f"The pretrained_model {pm} does not exists!")
            return {}
        pm = pm if pm.endswith('.pdparams') else pm + '.pdparams'
        params = paddle.load(pm)
        state_dict = model.state_dict()
        new_state_dict = {}
        for k1, k2 in zip(state_dict.keys(), params.keys()):
            if list(state_dict[k1].shape) == list(params[k2].shape):
                new_state_dict[k1] = params[k2]
111
112
113
114
            else:
                logger.info(
                    f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
                )
Double_V's avatar
Double_V committed
115
116
117
118
        model.set_state_dict(new_state_dict)
        logger.info(f"loaded pretrained_model successful from {pm}")
        return {}

119

LDOUBLEV's avatar
fix bug  
LDOUBLEV committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def load_pretrained_params(model, path):
    if path is None:
        return False
    if not os.path.exists(path) and not os.path.exists(path + ".pdparams"):
        print(f"The pretrained_model {path} does not exists!")
        return False

    path = path if path.endswith('.pdparams') else path + '.pdparams'
    params = paddle.load(path)
    state_dict = model.state_dict()
    new_state_dict = {}
    for k1, k2 in zip(state_dict.keys(), params.keys()):
        if list(state_dict[k1].shape) == list(params[k2].shape):
            new_state_dict[k1] = params[k2]
LDOUBLEV's avatar
LDOUBLEV committed
134
135
136
137
        else:
            print(
                f"The shape of model params {k1} {state_dict[k1].shape} not matched with loaded params {k2} {params[k2].shape} !"
            )
LDOUBLEV's avatar
fix bug  
LDOUBLEV committed
138
    model.set_state_dict(new_state_dict)
LDOUBLEV's avatar
LDOUBLEV committed
139
    print(f"load pretrain successful from {path}")
LDOUBLEV's avatar
LDOUBLEV committed
140
    return model
Double_V's avatar
Double_V committed
141

142

143
def save_model(model,
WenmuZhou's avatar
WenmuZhou committed
144
145
146
147
148
149
               optimizer,
               model_path,
               logger,
               is_best=False,
               prefix='ppocr',
               **kwargs):
LDOUBLEV's avatar
LDOUBLEV committed
150
151
152
    """
    save model to the target path
    """
WenmuZhou's avatar
WenmuZhou committed
153
154
    _mkdir_if_not_exist(model_path, logger)
    model_prefix = os.path.join(model_path, prefix)
155
    paddle.save(model.state_dict(), model_prefix + '.pdparams')
WenmuZhou's avatar
WenmuZhou committed
156
    paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
WenmuZhou's avatar
WenmuZhou committed
157
158
159
160
161
162
163
164

    # save metric and config
    with open(model_prefix + '.states', 'wb') as f:
        pickle.dump(kwargs, f, protocol=2)
    if is_best:
        logger.info('save best model is to {}'.format(model_prefix))
    else:
        logger.info("save model in {}".format(model_prefix))