utils.py 2.61 KB
Newer Older
Chen Chen's avatar
Chen Chen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common utils for tasks."""
17
18
19
from typing import Any, Callable

import orbit
Chen Chen's avatar
Chen Chen committed
20
21
22
23
import tensorflow as tf
import tensorflow_hub as hub


Chen Chen's avatar
Chen Chen committed
24
def get_encoder_from_hub(hub_model_path: str) -> tf.keras.Model:
Chen Chen's avatar
Chen Chen committed
25
26
27
  """Gets an encoder from hub.

  Args:
Chen Chen's avatar
Chen Chen committed
28
    hub_model_path: The path to the tfhub model.
Chen Chen's avatar
Chen Chen committed
29
30
31
32

  Returns:
    A tf.keras.Model.
  """
Chen Chen's avatar
Chen Chen committed
33
34
35
36
37
38
  input_word_ids = tf.keras.layers.Input(
      shape=(None,), dtype=tf.int32, name='input_word_ids')
  input_mask = tf.keras.layers.Input(
      shape=(None,), dtype=tf.int32, name='input_mask')
  input_type_ids = tf.keras.layers.Input(
      shape=(None,), dtype=tf.int32, name='input_type_ids')
Chen Chen's avatar
Chen Chen committed
39
  hub_layer = hub.KerasLayer(hub_model_path, trainable=True)
Chen Chen's avatar
Chen Chen committed
40
41
42
43
44
  output_dict = {}
  dict_input = dict(
      input_word_ids=input_word_ids,
      input_mask=input_mask,
      input_type_ids=input_type_ids)
45
  output_dict = hub_layer(dict_input)
Chen Chen's avatar
Chen Chen committed
46
47

  return tf.keras.Model(inputs=dict_input, outputs=output_dict)
48
49
50


def predict(predict_step_fn: Callable[[Any], Any],
Hongkun Yu's avatar
Hongkun Yu committed
51
            aggregate_fn: Callable[[Any, Any], Any], dataset: tf.data.Dataset):
52
53
54
55
56
57
  """Runs prediction.

  Args:
    predict_step_fn: A callable such as `def predict_step(inputs)`, where
      `inputs` are input tensors.
    aggregate_fn: A callable such as `def aggregate_fn(state, value)`, where
Hongkun Yu's avatar
Hongkun Yu committed
58
      `value` is the outputs from `predict_step_fn`.
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    dataset: A `tf.data.Dataset` object.

  Returns:
    The aggregated predictions.
  """

  @tf.function
  def predict_step(iterator):
    """Predicts on distributed devices."""
    outputs = tf.distribute.get_strategy().run(
        predict_step_fn, args=(next(iterator),))
    return tf.nest.map_structure(
        tf.distribute.get_strategy().experimental_local_results, outputs)

  loop_fn = orbit.utils.create_loop_fn(predict_step)
  # Set `num_steps` to -1 to exhaust the dataset.
  outputs = loop_fn(
      iter(dataset), num_steps=-1, state=None, reduce_fn=aggregate_fn)  # pytype: disable=wrong-arg-types
  return outputs