tpu.py 4.53 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions specific to running TensorFlow on TPUs."""

import tensorflow as tf


# "local" is a magic word in the TPU cluster resolver; it informs the resolver
# to use the local CPU as the compute device. This is useful for testing and
# debugging; the code flow is ostensibly identical, but without the need to
# actually have a TPU on the other end.
LOCAL = "local"


def construct_scalar_host_call(metric_dict, model_dir, prefix=""):
  """Construct a host call to log scalars when training on TPU.

  Args:
    metric_dict: A dict of the tensors to be logged.
    model_dir: The location to write the summary.
    prefix: The prefix (if any) to prepend to the metric names.

  Returns:
    A tuple of (function, args_to_be_passed_to_said_function)
  """
  # type: (dict, str) -> (function, list)
  metric_names = list(metric_dict.keys())

  def host_call_fn(global_step, *args):
    """Training host call. Creates scalar summaries for training metrics.

    This function is executed on the CPU and should not directly reference
    any Tensors in the rest of the `model_fn`. To pass Tensors from the
    model to the `metric_fn`, provide as part of the `host_call`. See
    https://www.tensorflow.org/api_docs/python/tf/contrib/tpu/TPUEstimatorSpec
    for more information.

    Arguments should match the list of `Tensor` objects passed as the second
    element in the tuple passed to `host_call`.

    Args:
      global_step: `Tensor with shape `[batch]` for the global_step
      *args: Remaining tensors to log.

    Returns:
      List of summary ops to run on the CPU host.
    """
    step = global_step[0]
    with tf.contrib.summary.create_file_writer(
        logdir=model_dir, filename_suffix=".host_call").as_default():
      with tf.contrib.summary.always_record_summaries():
        for i, name in enumerate(metric_names):
          tf.contrib.summary.scalar(prefix + name, args[i][0], step=step)

        return tf.contrib.summary.all_summary_ops()

  # To log the current learning rate, and gradient norm for Tensorboard, the
  # summary op needs to be run on the host CPU via host_call. host_call
  # expects [batch_size, ...] Tensors, thus reshape to introduce a batch
  # dimension. These Tensors are implicitly concatenated to
  # [params['batch_size']].
74
  global_step_tensor = tf.reshape(tf.compat.v1.train.get_or_create_global_step(), [1])
75
76
77
78
79
  other_tensors = [tf.reshape(metric_dict[key], [1]) for key in metric_names]

  return host_call_fn, [global_step_tensor] + other_tensors


80
def embedding_matmul(embedding_table, values, mask, name="embedding_matmul"):
81
82
83
84
85
  """Performs embedding lookup via a matmul.

  The matrix to be multiplied by the embedding table Tensor is constructed
  via an implementation of scatter based on broadcasting embedding indices
  and performing an equality comparison against a broadcasted
86
87
  range(num_embedding_table_rows). All masked positions will produce an
  embedding vector of zeros.
88
89
90
91
92
93
94
95
96
97
98
99

  Args:
    embedding_table: Tensor of embedding table.
      Rank 2 (table_size x embedding dim)
    values: Tensor of embedding indices. Rank 2 (batch x n_indices)
    mask: Tensor of mask / weights. Rank 2 (batch x n_indices)
    name: Optional name scope for created ops

  Returns:
    Rank 3 tensor of embedding vectors.
  """

100
101
  with tf.name_scope(name):
    n_embeddings = embedding_table.get_shape().as_list()[0]
102
103
    batch_size, padded_size = values.shape.as_list()

104
105
106
107
108
109
    emb_idcs = tf.tile(
        tf.reshape(values, (batch_size, padded_size, 1)), (1, 1, n_embeddings))
    emb_weights = tf.tile(
        tf.reshape(mask, (batch_size, padded_size, 1)), (1, 1, n_embeddings))
    col_idcs = tf.tile(
        tf.reshape(tf.range(n_embeddings), (1, 1, n_embeddings)),
110
        (batch_size, padded_size, 1))
111
112
113
    one_hot = tf.where(
        tf.equal(emb_idcs, col_idcs), emb_weights,
        tf.zeros((batch_size, padded_size, n_embeddings)))
114

115
    return tf.tensordot(one_hot, embedding_table, 1)