utils.py 3.39 KB
Newer Older
Chen Chen's avatar
Chen Chen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common utils for tasks."""
17
18
from typing import Any, Callable

Chen Chen's avatar
Chen Chen committed
19
from absl import logging
20
import orbit
Chen Chen's avatar
Chen Chen committed
21
22
23
24
import tensorflow as tf
import tensorflow_hub as hub


Chen Chen's avatar
Chen Chen committed
25
def get_encoder_from_hub(hub_model_path: str) -> tf.keras.Model:
Chen Chen's avatar
Chen Chen committed
26
27
28
  """Gets an encoder from hub.

  Args:
Chen Chen's avatar
Chen Chen committed
29
    hub_model_path: The path to the tfhub model.
Chen Chen's avatar
Chen Chen committed
30
31
32
33

  Returns:
    A tf.keras.Model.
  """
Chen Chen's avatar
Chen Chen committed
34
35
36
37
38
39
  input_word_ids = tf.keras.layers.Input(
      shape=(None,), dtype=tf.int32, name='input_word_ids')
  input_mask = tf.keras.layers.Input(
      shape=(None,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
      shape=(None,), dtype=tf.int32, name='input_type_ids')
Chen Chen's avatar
Chen Chen committed
40
  hub_layer = hub.KerasLayer(hub_model_path, trainable=True)
Chen Chen's avatar
Chen Chen committed
41
42
43
44
45
46
47
48
49
50
51
  output_dict = {}
  dict_input = dict(
      input_word_ids=input_word_ids,
      input_mask=input_mask,
      input_type_ids=input_type_ids)

  # The legacy hub model takes a list as input and returns a Tuple of
  # `pooled_output` and `sequence_output`, while the new hub model takes dict
  # as input and returns a dict.
  # TODO(chendouble): Remove the support of legacy hub model when the new ones
  # are released.
Chen Chen's avatar
Chen Chen committed
52
53
  hub_output_signature = hub_layer.resolved_object.signatures[
      'serving_default'].outputs
Chen Chen's avatar
Chen Chen committed
54
55
56
57
58
59
60
61
62
63
64
  if len(hub_output_signature) == 2:
    logging.info('Use the legacy hub module with list as input/output.')
    pooled_output, sequence_output = hub_layer(
        [input_word_ids, input_mask, input_type_ids])
    output_dict['pooled_output'] = pooled_output
    output_dict['sequence_output'] = sequence_output
  else:
    logging.info('Use the new hub module with dict as input/output.')
    output_dict = hub_layer(dict_input)

  return tf.keras.Model(inputs=dict_input, outputs=output_dict)
65
66
67


def predict(predict_step_fn: Callable[[Any], Any],
Hongkun Yu's avatar
Hongkun Yu committed
68
            aggregate_fn: Callable[[Any, Any], Any], dataset: tf.data.Dataset):
69
70
71
72
73
74
  """Runs prediction.

  Args:
    predict_step_fn: A callable such as `def predict_step(inputs)`, where
      `inputs` are input tensors.
    aggregate_fn: A callable such as `def aggregate_fn(state, value)`, where
Hongkun Yu's avatar
Hongkun Yu committed
75
      `value` is the outputs from `predict_step_fn`.
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    dataset: A `tf.data.Dataset` object.

  Returns:
    The aggregated predictions.
  """

  @tf.function
  def predict_step(iterator):
    """Predicts on distributed devices."""
    outputs = tf.distribute.get_strategy().run(
        predict_step_fn, args=(next(iterator),))
    return tf.nest.map_structure(
        tf.distribute.get_strategy().experimental_local_results, outputs)

  loop_fn = orbit.utils.create_loop_fn(predict_step)
  # Set `num_steps` to -1 to exhaust the dataset.
  outputs = loop_fn(
      iter(dataset), num_steps=-1, state=None, reduce_fn=aggregate_fn)  # pytype: disable=wrong-arg-types
  return outputs