dual_encoder.py 6.51 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer network for dual encoder style models."""
# pylint: disable=g-classes-have-attributes
17
import collections
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import tensorflow as tf

from official.nlp.modeling import layers


@tf.keras.utils.register_keras_serializable(package='Text')
class DualEncoder(tf.keras.Model):
  """A dual encoder model based on a transformer-based encoder.

  This is an implementation of the dual encoder network structure based on the
  transfomer stack, as described in ["Language-agnostic BERT Sentence
  Embedding"](https://arxiv.org/abs/2007.01852)

  The DualEncoder allows a user to pass in a transformer stack, and build a dual
  encoder model based on the transformer stack.

34
  Args:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    network: A transformer network which should output an encoding output.
    max_seq_length: The maximum allowed sequence length for transformer.
    normalize: If set to True, normalize the encoding produced by transfomer.
    logit_scale: The scaling factor of dot products when doing training.
    logit_margin: The margin between positive and negative when doing training.
    output: The output style for this network. Can be either 'logits' or
      'predictions'. If set to 'predictions', it will output the embedding
      producted by transformer network.
  """

  def __init__(self,
               network: tf.keras.Model,
               max_seq_length: int = 32,
               normalize: bool = True,
               logit_scale: float = 1.0,
               logit_margin: float = 0.0,
               output: str = 'logits',
               **kwargs) -> None:

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
    if output == 'logits':
      left_word_ids = tf.keras.layers.Input(
          shape=(max_seq_length,), dtype=tf.int32, name='left_word_ids')
      left_mask = tf.keras.layers.Input(
          shape=(max_seq_length,), dtype=tf.int32, name='left_mask')
      left_type_ids = tf.keras.layers.Input(
          shape=(max_seq_length,), dtype=tf.int32, name='left_type_ids')
    else:
      # Keep the consistant with legacy BERT hub module input names.
      left_word_ids = tf.keras.layers.Input(
          shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')
      left_mask = tf.keras.layers.Input(
          shape=(max_seq_length,), dtype=tf.int32, name='input_mask')
      left_type_ids = tf.keras.layers.Input(
          shape=(max_seq_length,), dtype=tf.int32, name='input_type_ids')
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
69
70

    left_inputs = [left_word_ids, left_mask, left_type_ids]
71
72
73
74
75
76
    left_outputs = network(left_inputs)
    if isinstance(left_outputs, list):
      left_sequence_output, left_encoded = left_outputs
    else:
      left_sequence_output = left_outputs['sequence_output']
      left_encoded = left_outputs['pooled_output']
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
77
78
    if normalize:
      left_encoded = tf.keras.layers.Lambda(
79
80
          lambda x: tf.nn.l2_normalize(x, axis=1))(
              left_encoded)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
81
82
83
84
85
86
87
88
89
90

    if output == 'logits':
      right_word_ids = tf.keras.layers.Input(
          shape=(max_seq_length,), dtype=tf.int32, name='right_word_ids')
      right_mask = tf.keras.layers.Input(
          shape=(max_seq_length,), dtype=tf.int32, name='right_mask')
      right_type_ids = tf.keras.layers.Input(
          shape=(max_seq_length,), dtype=tf.int32, name='right_type_ids')

      right_inputs = [right_word_ids, right_mask, right_type_ids]
91
92
93
94
95
      right_outputs = network(right_inputs)
      if isinstance(right_outputs, list):
        _, right_encoded = right_outputs
      else:
        right_encoded = right_outputs['pooled_output']
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
96
97
      if normalize:
        right_encoded = tf.keras.layers.Lambda(
98
99
100
101
102
103
104
105
106
107
108
109
            lambda x: tf.nn.l2_normalize(x, axis=1))(
                right_encoded)

      dot_products = layers.MatMulWithMargin(
          logit_scale=logit_scale,
          logit_margin=logit_margin,
          name='dot_product')

      inputs = [
          left_word_ids, left_mask, left_type_ids, right_word_ids, right_mask,
          right_type_ids
      ]
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
110
111
      left_logits, right_logits = dot_products(left_encoded, right_encoded)

112
      outputs = dict(left_logits=left_logits, right_logits=right_logits)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
113
114
115

    elif output == 'predictions':
      inputs = [left_word_ids, left_mask, left_type_ids]
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
116
117
118

      # To keep consistent with legacy BERT hub modules, the outputs are
      # "pooled_output" and "sequence_output".
119
120
      outputs = dict(
          sequence_output=left_sequence_output, pooled_output=left_encoded)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
121
122
123
    else:
      raise ValueError('output type %s is not supported' % output)

124
125
126
127
128
129
130
    # b/164516224
    # Once we've created the network using the Functional API, we call
    # super().__init__ as though we were invoking the Functional API Model
    # constructor, resulting in this object having all the properties of a model
    # created using the Functional API. Once super().__init__ is called, we
    # can assign attributes to `self` - note that all `self` assignments are
    # below this line.
131
    super(DualEncoder, self).__init__(inputs=inputs, outputs=outputs, **kwargs)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
132

133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    config_dict = {
        'network': network,
        'max_seq_length': max_seq_length,
        'normalize': normalize,
        'logit_scale': logit_scale,
        'logit_margin': logit_margin,
        'output': output,
    }
    # We are storing the config dict as a namedtuple here to ensure checkpoint
    # compatibility with an earlier version of this model which did not track
    # the config dict attribute. TF does not track immutable attrs which
    # do not contain Trackables, so by creating a config namedtuple instead of
    # a dict we avoid tracking it.
    config_cls = collections.namedtuple('Config', config_dict.keys())
    self._config = config_cls(**config_dict)

    self.network = network
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
150

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
151
  def get_config(self):
152
    return dict(self._config._asdict())
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
153
154
155
156
157
158
159
160
161
162

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)

  @property
  def checkpoint_items(self):
    """Returns a dictionary of items to be additionally checkpointed."""
    items = dict(encoder=self.network)
    return items