model_saving_utils.py 4.28 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities to save models."""

from __future__ import absolute_import
from __future__ import division
19
# from __future__ import google_type_annotations
20
21
22
23
24
25
from __future__ import print_function

import os

from absl import logging
import tensorflow as tf
26
import typing
27
28


29
30
31
32
def export_bert_model(
    model_export_path: typing.Text,
    model: tf.keras.Model,
    checkpoint_dir: typing.Optional[typing.Text] = None) -> None:
33
  """Export BERT model for serving which does not include the optimizer.
34
35
36

  Arguments:
      model_export_path: Path to which exported model will be saved.
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
      model: Keras model object to export.
      checkpoint_dir: Path from which model weights will be loaded, if
        specified.

  Raises:
    ValueError when either model_export_path or model is not specified.
  """
  if not model_export_path:
    raise ValueError('model_export_path must be specified.')
  if not isinstance(model, tf.keras.Model):
    raise ValueError('model must be a tf.keras.Model object.')

  if checkpoint_dir:
    # Restores the model from latest checkpoint.
    checkpoint = tf.train.Checkpoint(model=model)
    latest_checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
    assert latest_checkpoint_file
    logging.info('Checkpoint file %s found and restoring from '
                 'checkpoint', latest_checkpoint_file)
    checkpoint.restore(latest_checkpoint_file).assert_existing_objects_matched()

  model.save(model_export_path, include_optimizer=False, save_format='tf')


def export_pretraining_checkpoint(
    checkpoint_dir: typing.Text,
    model: tf.keras.Model,
    checkpoint_name: typing.Optional[
        typing.Text] = 'pretrained/bert_model.ckpt'):
  """Exports BERT model for as a checkpoint without optimizer.

  Arguments:
      checkpoint_dir: Path to where training mdoel checkpoints are stored.
      model: Keras model object to export.
      checkpoint_name: File name or suffix path to export pretrained checkpoint.

  Raises:
    ValueError when either checkpoint_dir or model is not specified.
75
  """
76
77
78
79
  if not checkpoint_dir:
    raise ValueError('checkpoint_dir must be specified.')
  if not isinstance(model, tf.keras.Model):
    raise ValueError('model must be a tf.keras.Model object.')
80

81
  checkpoint = tf.train.Checkpoint(model=model)
82
83
84
85
86
  latest_checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir)
  assert latest_checkpoint_file
  logging.info('Checkpoint file %s found and restoring from '
               'checkpoint', latest_checkpoint_file)
  checkpoint.restore(latest_checkpoint_file).assert_existing_objects_matched()
87
88
  saved_path = checkpoint.save(os.path.join(checkpoint_dir, checkpoint_name))
  logging.info('Exporting the model as a new TF checkpoint: %s', saved_path)
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112


class BertModelCheckpoint(tf.keras.callbacks.Callback):
  """Keras callback that saves model at the end of every epoch."""

  def __init__(self, checkpoint_dir, checkpoint):
    """Initializes BertModelCheckpoint.

    Arguments:
      checkpoint_dir: Directory of the to be saved checkpoint file.
      checkpoint: tf.train.Checkpoint object.
    """
    super(BertModelCheckpoint, self).__init__()
    self.checkpoint_file_name = os.path.join(
        checkpoint_dir, 'bert_training_checkpoint_step_{global_step}.ckpt')
    assert isinstance(checkpoint, tf.train.Checkpoint)
    self.checkpoint = checkpoint

  def on_epoch_end(self, epoch, logs=None):
    global_step = tf.keras.backend.get_value(self.model.optimizer.iterations)
    formatted_file_name = self.checkpoint_file_name.format(
        global_step=global_step)
    saved_path = self.checkpoint.save(formatted_file_name)
    logging.info('Saving model TF checkpoint to : %s', saved_path)