dual_encoder.py 4.83 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer network for dual encoder style models."""
# pylint: disable=g-classes-have-attributes
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function

# Import libraries
import tensorflow as tf

from official.nlp.modeling import layers


@tf.keras.utils.register_keras_serializable(package='Text')
class DualEncoder(tf.keras.Model):
  """A dual encoder model based on a transformer-based encoder.

  This is an implementation of the dual encoder network structure based on the
  transfomer stack, as described in ["Language-agnostic BERT Sentence
  Embedding"](https://arxiv.org/abs/2007.01852)

  The DualEncoder allows a user to pass in a transformer stack, and build a dual
  encoder model based on the transformer stack.

  Arguments:
    network: A transformer network which should output an encoding output.
    max_seq_length: The maximum allowed sequence length for transformer.
    normalize: If set to True, normalize the encoding produced by transfomer.
    logit_scale: The scaling factor of dot products when doing training.
    logit_margin: The margin between positive and negative when doing training.
    output: The output style for this network. Can be either 'logits' or
      'predictions'. If set to 'predictions', it will output the embedding
      producted by transformer network.
  """

  def __init__(self,
               network: tf.keras.Model,
               max_seq_length: int = 32,
               normalize: bool = True,
               logit_scale: float = 1.0,
               logit_margin: float = 0.0,
               output: str = 'logits',
               **kwargs) -> None:
    self._self_setattr_tracking = False
    self._config = {
        'network': network,
        'max_seq_length': max_seq_length,
        'normalize': normalize,
        'logit_scale': logit_scale,
        'logit_margin': logit_margin,
        'output': output,
    }

    self.network = network

    left_word_ids = tf.keras.layers.Input(
        shape=(max_seq_length,), dtype=tf.int32, name='left_word_ids')
    left_mask = tf.keras.layers.Input(
        shape=(max_seq_length,), dtype=tf.int32, name='left_mask')
    left_type_ids = tf.keras.layers.Input(
        shape=(max_seq_length,), dtype=tf.int32, name='left_type_ids')

    left_inputs = [left_word_ids, left_mask, left_type_ids]
    _, left_encoded = network(left_inputs)

    if normalize:
      left_encoded = tf.keras.layers.Lambda(
          lambda x: tf.nn.l2_normalize(x, axis=1))(left_encoded)

    if output == 'logits':
      right_word_ids = tf.keras.layers.Input(
          shape=(max_seq_length,), dtype=tf.int32, name='right_word_ids')
      right_mask = tf.keras.layers.Input(
          shape=(max_seq_length,), dtype=tf.int32, name='right_mask')
      right_type_ids = tf.keras.layers.Input(
          shape=(max_seq_length,), dtype=tf.int32, name='right_type_ids')

      right_inputs = [right_word_ids, right_mask, right_type_ids]
      _, right_encoded = network(right_inputs)

      if normalize:
        right_encoded = tf.keras.layers.Lambda(
            lambda x: tf.nn.l2_normalize(x, axis=1))(right_encoded)

      dot_products = layers.MatMulWithMargin(logit_scale=logit_scale,
                                             logit_margin=logit_margin,
                                             name='dot_product')

      inputs = [left_word_ids, left_mask, left_type_ids, right_word_ids,
                right_mask, right_type_ids]
      left_logits, right_logits = dot_products(left_encoded, right_encoded)

      outputs = [left_logits, right_logits]

    elif output == 'predictions':
      inputs = [left_word_ids, left_mask, left_type_ids]
      outputs = left_encoded
    else:
      raise ValueError('output type %s is not supported' % output)

    super(DualEncoder, self).__init__(
        inputs=inputs, outputs=outputs, **kwargs)

  def get_config(self):
    return self._config

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)

  @property
  def checkpoint_items(self):
    """Returns a dictionary of items to be additionally checkpointed."""
    items = dict(encoder=self.network)
    return items