model_saving_utils.py 2.89 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities to save models."""

import os

from absl import logging
import tensorflow as tf
21
import typing
22
23


24
25
26
27
def export_bert_model(model_export_path: typing.Text,
                      model: tf.keras.Model,
                      checkpoint_dir: typing.Optional[typing.Text] = None,
                      restore_model_using_load_weights: bool = False) -> None:
28
  """Export BERT model for serving which does not include the optimizer.
29

30
  Args:
31
      model_export_path: Path to which exported model will be saved.
32
33
34
      model: Keras model object to export.
      checkpoint_dir: Path from which model weights will be loaded, if
        specified.
35
      restore_model_using_load_weights: Whether to use checkpoint.restore() API
Hongkun Yu's avatar
Hongkun Yu committed
36
37
38
39
40
41
42
        for custom checkpoint or to use model.load_weights() API. There are 2
        different ways to save checkpoints. One is using tf.train.Checkpoint and
        another is using Keras model.save_weights(). Custom training loop
        implementation uses tf.train.Checkpoint API and Keras ModelCheckpoint
        callback internally uses model.save_weights() API. Since these two API's
        cannot be used toghether, model loading logic must be take into account
        how model checkpoint was saved.
43
44
45
46
47
48
49
50
51
52

  Raises:
    ValueError when either model_export_path or model is not specified.
  """
  if not model_export_path:
    raise ValueError('model_export_path must be specified.')
  if not isinstance(model, tf.keras.Model):
    raise ValueError('model must be a tf.keras.Model object.')

  if checkpoint_dir:
53
54
55
56
57
58
59
60
61
62
63
64
65
66
    if restore_model_using_load_weights:
      model_weight_path = os.path.join(checkpoint_dir, 'checkpoint')
      assert tf.io.gfile.exists(model_weight_path)
      model.load_weights(model_weight_path)
    else:
      checkpoint = tf.train.Checkpoint(model=model)

      # Restores the model from latest checkpoint.
      latest_checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
      assert latest_checkpoint_file
      logging.info('Checkpoint file %s found and restoring from '
                   'checkpoint', latest_checkpoint_file)
      checkpoint.restore(
          latest_checkpoint_file).assert_existing_objects_matched()
67
68

  model.save(model_export_path, include_optimizer=False, save_format='tf')