model_saving_utils.py 3.05 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities to save models."""

from __future__ import absolute_import
from __future__ import division
19
# from __future__ import google_type_annotations
20
21
22
23
24
25
from __future__ import print_function

import os

from absl import logging
import tensorflow as tf
26
import typing
27
28


29
30
31
32
def export_bert_model(model_export_path: typing.Text,
                      model: tf.keras.Model,
                      checkpoint_dir: typing.Optional[typing.Text] = None,
                      restore_model_using_load_weights: bool = False) -> None:
33
  """Export BERT model for serving which does not include the optimizer.
34
35
36

  Arguments:
      model_export_path: Path to which exported model will be saved.
37
38
39
      model: Keras model object to export.
      checkpoint_dir: Path from which model weights will be loaded, if
        specified.
40
41
42
43
44
45
46
47
      restore_model_using_load_weights: Whether to use checkpoint.restore() API
        for custom checkpoint or to use model.load_weights() API.
        There are 2 different ways to save checkpoints. One is using
        tf.train.Checkpoint and another is using Keras model.save_weights().
        Custom training loop implementation uses tf.train.Checkpoint API
        and Keras ModelCheckpoint callback internally uses model.save_weights()
        API. Since these two API's cannot be used toghether, model loading logic
        must be take into account how model checkpoint was saved.
48
49
50
51
52
53
54
55
56
57

  Raises:
    ValueError when either model_export_path or model is not specified.
  """
  if not model_export_path:
    raise ValueError('model_export_path must be specified.')
  if not isinstance(model, tf.keras.Model):
    raise ValueError('model must be a tf.keras.Model object.')

  if checkpoint_dir:
58
59
60
61
62
63
64
65
66
67
68
69
70
71
    if restore_model_using_load_weights:
      model_weight_path = os.path.join(checkpoint_dir, 'checkpoint')
      assert tf.io.gfile.exists(model_weight_path)
      model.load_weights(model_weight_path)
    else:
      checkpoint = tf.train.Checkpoint(model=model)

      # Restores the model from latest checkpoint.
      latest_checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
      assert latest_checkpoint_file
      logging.info('Checkpoint file %s found and restoring from '
                   'checkpoint', latest_checkpoint_file)
      checkpoint.restore(
          latest_checkpoint_file).assert_existing_objects_matched()
72
73

  model.save(model_export_path, include_optimizer=False, save_format='tf')