model_helpers.py 3.28 KB
Newer Older
Hongkun Yu's avatar
Hongkun Yu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
16
17
18
"""Miscellaneous functions that can be called by models."""

import numbers

19
from absl import logging
20
import tensorflow as tf
21

22
from tensorflow.python.util import nest
23
# pylint:disable=logging-format-interpolation
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49


def past_stop_threshold(stop_threshold, eval_metric):
  """Return a boolean representing whether a model should be stopped.

  Args:
    stop_threshold: float, the threshold above which a model should stop
      training.
    eval_metric: float, the current value of the relevant metric to check.

  Returns:
    True if training should stop, False otherwise.

  Raises:
    ValueError: if either stop_threshold or eval_metric is not a number
  """
  if stop_threshold is None:
    return False

  if not isinstance(stop_threshold, numbers.Number):
    raise ValueError("Threshold for checking stop conditions must be a number.")
  if not isinstance(eval_metric, numbers.Number):
    raise ValueError("Eval metric being checked against stop conditions "
                     "must be a number.")

  if eval_metric >= stop_threshold:
50
51
    logging.info("Stop threshold of {} was passed with metric value {}.".format(
        stop_threshold, eval_metric))
52
53
54
    return True

  return False
55
56


Hongkun Yu's avatar
Hongkun Yu committed
57
58
59
60
61
62
def generate_synthetic_data(input_shape,
                            input_value=0,
                            input_dtype=None,
                            label_shape=None,
                            label_value=0,
                            label_dtype=None):
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
  """Create a repeating dataset with constant values.

  Args:
    input_shape: a tf.TensorShape object or nested tf.TensorShapes. The shape of
      the input data.
    input_value: Value of each input element.
    input_dtype: Input dtype. If None, will be inferred by the input value.
    label_shape: a tf.TensorShape object or nested tf.TensorShapes. The shape of
      the label data.
    label_value: Value of each input element.
    label_dtype: Input dtype. If None, will be inferred by the target value.

  Returns:
    Dataset of tensors or tuples of tensors (if label_shape is set).
  """
  # TODO(kathywu): Replace with SyntheticDataset once it is in contrib.
  element = input_element = nest.map_structure(
      lambda s: tf.constant(input_value, input_dtype, s), input_shape)

  if label_shape:
    label_element = nest.map_structure(
        lambda s: tf.constant(label_value, label_dtype, s), label_shape)
    element = (input_element, label_element)

  return tf.data.Dataset.from_tensors(element).repeat()
88
89
90


def apply_clean(flags_obj):
91
  if flags_obj.clean and tf.io.gfile.exists(flags_obj.model_dir):
92
    logging.info("--clean flag set. Removing existing model dir:"
93
                 " {}".format(flags_obj.model_dir))
94
    tf.io.gfile.rmtree(flags_obj.model_dir)