bert_pretrainer.py 11.1 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Hongkun Yu's avatar
Hongkun Yu committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frederick Liu's avatar
Frederick Liu committed
14

15
"""BERT Pre-training model."""
16
# pylint: disable=g-classes-have-attributes
17
import collections
Hongkun Yu's avatar
Hongkun Yu committed
18
import copy
19
20
from typing import List, Optional

Hongkun Yu's avatar
Hongkun Yu committed
21
from absl import logging
22
import gin
Hongkun Yu's avatar
Hongkun Yu committed
23
import tensorflow as tf
Hongkun Yu's avatar
Hongkun Yu committed
24

Scott Zhu's avatar
Scott Zhu committed
25
from official.modeling import tf_utils
Hongkun Yu's avatar
Hongkun Yu committed
26
from official.nlp.modeling import layers
Hongkun Yu's avatar
Hongkun Yu committed
27
28
29
30
31
from official.nlp.modeling import networks


@tf.keras.utils.register_keras_serializable(package='Text')
class BertPretrainer(tf.keras.Model):
32
  """BERT pretraining model.
Hongkun Yu's avatar
Hongkun Yu committed
33

34
  [Note] Please use the new `BertPretrainerV2` for your projects.
Hongkun Yu's avatar
Hongkun Yu committed
35

36
37
38
  The BertPretrainer allows a user to pass in a transformer stack, and
  instantiates the masked language model and classification networks that are
  used to create the training objectives.
Hongkun Yu's avatar
Hongkun Yu committed
39

40
41
42
  *Note* that the model is constructed by
  [Keras Functional API](https://keras.io/guides/functional_api/).

43
  Args:
Hongkun Yu's avatar
Hongkun Yu committed
44
    network: A transformer network. This network should output a sequence output
45
      and a classification output.
Hongkun Yu's avatar
Hongkun Yu committed
46
47
    num_classes: Number of classes to predict from the classification network.
    num_token_predictions: Number of tokens to predict from the masked LM.
48
49
    embedding_table: Embedding table of a network. If None, the
      "network.get_embedding_table()" is used.
Hongkun Yu's avatar
Hongkun Yu committed
50
51
    activation: The activation (if any) to use in the masked LM network. If
      None, no activation will be used.
Hongkun Yu's avatar
Hongkun Yu committed
52
53
    initializer: The initializer (if any) to use in the masked LM and
      classification networks. Defaults to a Glorot uniform initializer.
54
55
    output: The output style for this network. Can be either `logits` or
      `predictions`.
Hongkun Yu's avatar
Hongkun Yu committed
56
57
58
59
60
61
  """

  def __init__(self,
               network,
               num_classes,
               num_token_predictions,
62
               embedding_table=None,
Hongkun Yu's avatar
Hongkun Yu committed
63
64
65
66
               activation=None,
               initializer='glorot_uniform',
               output='logits',
               **kwargs):
67

Hongkun Yu's avatar
Hongkun Yu committed
68
69
70
71
    # We want to use the inputs of the passed network as the inputs to this
    # Model. To do this, we need to keep a copy of the network inputs for use
    # when we construct the Model object at the end of init. (We keep a copy
    # because we'll be adding another tensor to the copy later.)
72
    network_inputs = network.inputs
Hongkun Yu's avatar
Hongkun Yu committed
73
74
75
76
77
78
79
    inputs = copy.copy(network_inputs)

    # Because we have a copy of inputs to create this Model object, we can
    # invoke the Network object with its own input tensors to start the Model.
    # Note that, because of how deferred construction happens, we can't use
    # the copy of the list here - by the time the network is invoked, the list
    # object contains the additional input added below.
80
    sequence_output, cls_output = network(network_inputs)
Hongkun Yu's avatar
Hongkun Yu committed
81

Hongkun Yu's avatar
Hongkun Yu committed
82
83
84
85
86
    # The encoder network may get outputs from all layers.
    if isinstance(sequence_output, list):
      sequence_output = sequence_output[-1]
    if isinstance(cls_output, list):
      cls_output = cls_output[-1]
Hongkun Yu's avatar
Hongkun Yu committed
87
    sequence_output_length = sequence_output.shape.as_list()[1]
Hongkun Yu's avatar
Hongkun Yu committed
88
89
    if sequence_output_length is not None and (sequence_output_length <
                                               num_token_predictions):
Hongkun Yu's avatar
Hongkun Yu committed
90
91
92
93
94
95
96
97
98
99
100
      raise ValueError(
          "The passed network's output length is %s, which is less than the "
          'requested num_token_predictions %s.' %
          (sequence_output_length, num_token_predictions))

    masked_lm_positions = tf.keras.layers.Input(
        shape=(num_token_predictions,),
        name='masked_lm_positions',
        dtype=tf.int32)
    inputs.append(masked_lm_positions)

Hongkun Yu's avatar
Hongkun Yu committed
101
    if embedding_table is None:
102
103
      embedding_table = network.get_embedding_table()
    masked_lm = layers.MaskedLM(
104
        embedding_table=embedding_table,
Hongkun Yu's avatar
Hongkun Yu committed
105
        activation=activation,
Scott Zhu's avatar
Scott Zhu committed
106
        initializer=tf_utils.clone_initializer(initializer),
Hongkun Yu's avatar
Hongkun Yu committed
107
        output=output,
Hongkun Yu's avatar
Hongkun Yu committed
108
        name='cls/predictions')
109
    lm_outputs = masked_lm(
Hongkun Yu's avatar
Hongkun Yu committed
110
        sequence_output, masked_positions=masked_lm_positions)
Hongkun Yu's avatar
Hongkun Yu committed
111

112
    classification = networks.Classification(
Hongkun Yu's avatar
Hongkun Yu committed
113
114
        input_width=cls_output.shape[-1],
        num_classes=num_classes,
Scott Zhu's avatar
Scott Zhu committed
115
        initializer=tf_utils.clone_initializer(initializer),
Hongkun Yu's avatar
Hongkun Yu committed
116
117
        output=output,
        name='classification')
118
    sentence_outputs = classification(cls_output)
Hongkun Yu's avatar
Hongkun Yu committed
119
120

    super(BertPretrainer, self).__init__(
Hongkun Yu's avatar
Hongkun Yu committed
121
122
123
        inputs=inputs,
        outputs=dict(masked_lm=lm_outputs, classification=sentence_outputs),
        **kwargs)
Hongkun Yu's avatar
Hongkun Yu committed
124

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    # b/164516224
    # Once we've created the network using the Functional API, we call
    # super().__init__ as though we were invoking the Functional API Model
    # constructor, resulting in this object having all the properties of a model
    # created using the Functional API. Once super().__init__ is called, we
    # can assign attributes to `self` - note that all `self` assignments are
    # below this line.
    config_dict = {
        'network': network,
        'num_classes': num_classes,
        'num_token_predictions': num_token_predictions,
        'activation': activation,
        'initializer': initializer,
        'output': output,
    }

    # We are storing the config dict as a namedtuple here to ensure checkpoint
    # compatibility with an earlier version of this model which did not track
    # the config dict attribute. TF does not track immutable attrs which
    # do not contain Trackables, so by creating a config namedtuple instead of
    # a dict we avoid tracking it.
    config_cls = collections.namedtuple('Config', config_dict.keys())
    self._config = config_cls(**config_dict)

    self.encoder = network
    self.classification = classification
    self.masked_lm = masked_lm

Hongkun Yu's avatar
Hongkun Yu committed
153
  def get_config(self):
154
    return dict(self._config._asdict())
Hongkun Yu's avatar
Hongkun Yu committed
155
156
157
158

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)
159
160
161
162
163
164
165
166


@tf.keras.utils.register_keras_serializable(package='Text')
@gin.configurable
class BertPretrainerV2(tf.keras.Model):
  """BERT pretraining model V2.

  Adds the masked language model head and optional classification heads upon the
Hongkun Yu's avatar
Hongkun Yu committed
167
  transformer encoder.
168

169
  Args:
170
171
    encoder_network: A transformer network. This network should output a
      sequence output and a classification output.
Hongkun Yu's avatar
Hongkun Yu committed
172
173
    mlm_activation: The activation (if any) to use in the masked LM network. If
      None, no activation will be used.
174
175
176
177
    mlm_initializer: The initializer (if any) to use in the masked LM. Default
      to a Glorot uniform initializer.
    classification_heads: A list of optional head layers to transform on encoder
      sequence outputs.
178
179
180
181
    customized_masked_lm: A customized masked_lm layer. If None, will create
      a standard layer from `layers.MaskedLM`; if not None, will use the
      specified masked_lm layer. Above arguments `mlm_activation` and
      `mlm_initializer` will be ignored.
182
183
184
    name: The name of the model.
  Inputs: Inputs defined by the encoder network, plus `masked_lm_positions` as a
    dictionary.
Chen Chen's avatar
Chen Chen committed
185
186
  Outputs: A dictionary of `lm_output`, classification head outputs keyed by
    head names, and also outputs from `encoder_network`, keyed by
Hongkun Yu's avatar
Hongkun Yu committed
187
    `sequence_output` and `encoder_outputs` (if any).
188
189
190
191
192
  """

  def __init__(
      self,
      encoder_network: tf.keras.Model,
Hongkun Yu's avatar
Hongkun Yu committed
193
      mlm_activation=None,
194
195
      mlm_initializer='glorot_uniform',
      classification_heads: Optional[List[tf.keras.layers.Layer]] = None,
196
      customized_masked_lm: Optional[tf.keras.layers.Layer] = None,
197
198
      name: str = 'bert',
      **kwargs):
Hongkun Yu's avatar
Hongkun Yu committed
199
    super().__init__(self, name=name, **kwargs)
200
201
202
    self._config = {
        'encoder_network': encoder_network,
        'mlm_initializer': mlm_initializer,
203
        'mlm_activation': mlm_activation,
204
205
206
207
        'classification_heads': classification_heads,
        'name': name,
    }
    self.encoder_network = encoder_network
Frederick Liu's avatar
Frederick Liu committed
208
209
    # Makes sure the weights are built.
    _ = self.encoder_network(self.encoder_network.inputs)
210
    inputs = copy.copy(self.encoder_network.inputs)
Hongkun Yu's avatar
Hongkun Yu committed
211
212
213
214
215
216
217
218
219
220
221
222
    self.classification_heads = classification_heads or []
    if len(set([cls.name for cls in self.classification_heads])) != len(
        self.classification_heads):
      raise ValueError('Classification heads should have unique names.')

    self.masked_lm = customized_masked_lm or layers.MaskedLM(
        embedding_table=self.encoder_network.get_embedding_table(),
        activation=mlm_activation,
        initializer=mlm_initializer,
        name='cls/predictions')
    masked_lm_positions = tf.keras.layers.Input(
        shape=(None,), name='masked_lm_positions', dtype=tf.int32)
Frederick Liu's avatar
Frederick Liu committed
223
224
225
226
    if isinstance(inputs, dict):
      inputs['masked_lm_positions'] = masked_lm_positions
    else:
      inputs.append(masked_lm_positions)
Hongkun Yu's avatar
Hongkun Yu committed
227
228
229
230
231
232
233
234
235
    self.inputs = inputs

  def call(self, inputs):
    if isinstance(inputs, list):
      logging.warning('List inputs to BertPretrainer are discouraged.')
      inputs = dict([
          (ref.name, tensor) for ref, tensor in zip(self.inputs, inputs)
      ])

Chen Chen's avatar
Chen Chen committed
236
237
238
239
240
241
242
243
244
245
246
247
248
249
    outputs = dict()
    encoder_network_outputs = self.encoder_network(inputs)
    if isinstance(encoder_network_outputs, list):
      outputs['pooled_output'] = encoder_network_outputs[1]
      # When `encoder_network` was instantiated with return_all_encoder_outputs
      # set to True, `encoder_network_outputs[0]` is a list containing
      # all transformer layers' output.
      if isinstance(encoder_network_outputs[0], list):
        outputs['encoder_outputs'] = encoder_network_outputs[0]
        outputs['sequence_output'] = encoder_network_outputs[0][-1]
      else:
        outputs['sequence_output'] = encoder_network_outputs[0]
    elif isinstance(encoder_network_outputs, dict):
      outputs = encoder_network_outputs
250
    else:
Chen Chen's avatar
Chen Chen committed
251
252
253
      raise ValueError('encoder_network\'s output should be either a list '
                       'or a dict, but got %s' % encoder_network_outputs)
    sequence_output = outputs['sequence_output']
254
255
256
257
258
    # Inference may not have masked_lm_positions and mlm_logits is not needed.
    if 'masked_lm_positions' in inputs:
      masked_lm_positions = inputs['masked_lm_positions']
      outputs['mlm_logits'] = self.masked_lm(
          sequence_output, masked_positions=masked_lm_positions)
259
    for cls_head in self.classification_heads:
Hongkun Yu's avatar
Hongkun Yu committed
260
261
262
263
264
      cls_outputs = cls_head(sequence_output)
      if isinstance(cls_outputs, dict):
        outputs.update(cls_outputs)
      else:
        outputs[cls_head.name] = cls_outputs
Hongkun Yu's avatar
Hongkun Yu committed
265
    return outputs
266
267
268
269

  @property
  def checkpoint_items(self):
    """Returns a dictionary of items to be additionally checkpointed."""
Hongkun Yu's avatar
Hongkun Yu committed
270
    items = dict(encoder=self.encoder_network, masked_lm=self.masked_lm)
271
272
273
274
275
276
277
278
279
280
281
    for head in self.classification_heads:
      for key, item in head.checkpoint_items.items():
        items['.'.join([head.name, key])] = item
    return items

  def get_config(self):
    return self._config

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)