bert_pretrainer.py 8.57 KB
Newer Older
Hongkun Yu's avatar
Hongkun Yu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Trainer network for BERT-style models."""
16
# pylint: disable=g-classes-have-attributes
Hongkun Yu's avatar
Hongkun Yu committed
17
18
19
20
21
22
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function

import copy
23
24
25
from typing import List, Optional

import gin
Hongkun Yu's avatar
Hongkun Yu committed
26
import tensorflow as tf
Hongkun Yu's avatar
Hongkun Yu committed
27
28
29
30
31
32
33
34
35
36
37
38

from official.nlp.modeling import networks


@tf.keras.utils.register_keras_serializable(package='Text')
class BertPretrainer(tf.keras.Model):
  """BERT network training model.

  This is an implementation of the network structure surrounding a transformer
  encoder as described in "BERT: Pre-training of Deep Bidirectional Transformers
  for Language Understanding" (https://arxiv.org/abs/1810.04805).

39
40
41
  The BertPretrainer allows a user to pass in a transformer stack, and
  instantiates the masked language model and classification networks that are
  used to create the training objectives.
Hongkun Yu's avatar
Hongkun Yu committed
42

43
  Arguments:
Hongkun Yu's avatar
Hongkun Yu committed
44
    network: A transformer network. This network should output a sequence output
45
      and a classification output.
Hongkun Yu's avatar
Hongkun Yu committed
46
47
    num_classes: Number of classes to predict from the classification network.
    num_token_predictions: Number of tokens to predict from the masked LM.
48
49
    embedding_table: Embedding table of a network. If None, the
      "network.get_embedding_table()" is used.
Hongkun Yu's avatar
Hongkun Yu committed
50
51
    activation: The activation (if any) to use in the masked LM network.
      If None, no activation will be used.
Hongkun Yu's avatar
Hongkun Yu committed
52
53
54
55
56
57
58
59
60
61
    initializer: The initializer (if any) to use in the masked LM and
      classification networks. Defaults to a Glorot uniform initializer.
    output: The output style for this network. Can be either 'logits' or
      'predictions'.
  """

  def __init__(self,
               network,
               num_classes,
               num_token_predictions,
62
               embedding_table=None,
Hongkun Yu's avatar
Hongkun Yu committed
63
64
65
66
67
68
69
70
71
72
73
74
75
               activation=None,
               initializer='glorot_uniform',
               output='logits',
               **kwargs):
    self._self_setattr_tracking = False
    self._config = {
        'network': network,
        'num_classes': num_classes,
        'num_token_predictions': num_token_predictions,
        'activation': activation,
        'initializer': initializer,
        'output': output,
    }
Hongkun Yu's avatar
Hongkun Yu committed
76
    self.encoder = network
Hongkun Yu's avatar
Hongkun Yu committed
77
78
79
80
    # We want to use the inputs of the passed network as the inputs to this
    # Model. To do this, we need to keep a copy of the network inputs for use
    # when we construct the Model object at the end of init. (We keep a copy
    # because we'll be adding another tensor to the copy later.)
Hongkun Yu's avatar
Hongkun Yu committed
81
    network_inputs = self.encoder.inputs
Hongkun Yu's avatar
Hongkun Yu committed
82
83
84
85
86
87
88
    inputs = copy.copy(network_inputs)

    # Because we have a copy of inputs to create this Model object, we can
    # invoke the Network object with its own input tensors to start the Model.
    # Note that, because of how deferred construction happens, we can't use
    # the copy of the list here - by the time the network is invoked, the list
    # object contains the additional input added below.
Hongkun Yu's avatar
Hongkun Yu committed
89
    sequence_output, cls_output = self.encoder(network_inputs)
Hongkun Yu's avatar
Hongkun Yu committed
90

Hongkun Yu's avatar
Hongkun Yu committed
91
92
93
94
95
    # The encoder network may get outputs from all layers.
    if isinstance(sequence_output, list):
      sequence_output = sequence_output[-1]
    if isinstance(cls_output, list):
      cls_output = cls_output[-1]
Hongkun Yu's avatar
Hongkun Yu committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
    sequence_output_length = sequence_output.shape.as_list()[1]
    if sequence_output_length < num_token_predictions:
      raise ValueError(
          "The passed network's output length is %s, which is less than the "
          'requested num_token_predictions %s.' %
          (sequence_output_length, num_token_predictions))

    masked_lm_positions = tf.keras.layers.Input(
        shape=(num_token_predictions,),
        name='masked_lm_positions',
        dtype=tf.int32)
    inputs.append(masked_lm_positions)

    self.masked_lm = networks.MaskedLM(
        num_predictions=num_token_predictions,
        input_width=sequence_output.shape[-1],
        source_network=network,
113
        embedding_table=embedding_table,
Hongkun Yu's avatar
Hongkun Yu committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        activation=activation,
        initializer=initializer,
        output=output,
        name='masked_lm')
    lm_outputs = self.masked_lm([sequence_output, masked_lm_positions])

    self.classification = networks.Classification(
        input_width=cls_output.shape[-1],
        num_classes=num_classes,
        initializer=initializer,
        output=output,
        name='classification')
    sentence_outputs = self.classification(cls_output)

    super(BertPretrainer, self).__init__(
        inputs=inputs, outputs=[lm_outputs, sentence_outputs], **kwargs)

  def get_config(self):
    return self._config

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153


# TODO(hongkuny): Migrate to BertPretrainerV2 for all usages.
@tf.keras.utils.register_keras_serializable(package='Text')
@gin.configurable
class BertPretrainerV2(tf.keras.Model):
  """BERT pretraining model V2.

  (Experimental).
  Adds the masked language model head and optional classification heads upon the
  transformer encoder. When num_masked_tokens == 0, there won't be MaskedLM
  head.

  Arguments:
    num_masked_tokens: Number of tokens to predict from the masked LM.
    encoder_network: A transformer network. This network should output a
      sequence output and a classification output.
Hongkun Yu's avatar
Hongkun Yu committed
154
155
    mlm_activation: The activation (if any) to use in the masked LM network.
      If None, no activation will be used.
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
    mlm_initializer: The initializer (if any) to use in the masked LM. Default
      to a Glorot uniform initializer.
    classification_heads: A list of optional head layers to transform on encoder
      sequence outputs.
    name: The name of the model.
  Inputs: Inputs defined by the encoder network, plus `masked_lm_positions` as a
    dictionary.
  Outputs: A dictionary of `lm_output` and classification head outputs keyed by
    head names.
  """

  def __init__(
      self,
      num_masked_tokens: int,
      encoder_network: tf.keras.Model,
Hongkun Yu's avatar
Hongkun Yu committed
171
      mlm_activation=None,
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
      mlm_initializer='glorot_uniform',
      classification_heads: Optional[List[tf.keras.layers.Layer]] = None,
      name: str = 'bert',
      **kwargs):
    self._self_setattr_tracking = False
    self._config = {
        'encoder_network': encoder_network,
        'num_masked_tokens': num_masked_tokens,
        'mlm_initializer': mlm_initializer,
        'classification_heads': classification_heads,
        'name': name,
    }

    self.encoder_network = encoder_network
    inputs = copy.copy(self.encoder_network.inputs)
    sequence_output, _ = self.encoder_network(inputs)

    self.classification_heads = classification_heads or []
    if len(set([cls.name for cls in self.classification_heads])) != len(
        self.classification_heads):
      raise ValueError('Classification heads should have unique names.')

    outputs = dict()
    if num_masked_tokens > 0:
      self.masked_lm = networks.MaskedLM(
          num_predictions=num_masked_tokens,
          input_width=sequence_output.shape[-1],
          source_network=self.encoder_network,
Hongkun Yu's avatar
Hongkun Yu committed
200
          activation=mlm_activation,
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
          initializer=mlm_initializer,
          name='masked_lm')
      masked_lm_positions = copy.copy(self.masked_lm.inputs[-1])
      inputs.append(masked_lm_positions)
      outputs['lm_output'] = self.masked_lm(
          [sequence_output, masked_lm_positions])
    for cls_head in self.classification_heads:
      outputs[cls_head.name] = cls_head(sequence_output)

    super(BertPretrainerV2, self).__init__(
        inputs=inputs, outputs=outputs, name=name, **kwargs)

  @property
  def checkpoint_items(self):
    """Returns a dictionary of items to be additionally checkpointed."""
    items = dict(encoder=self.encoder_network)
    for head in self.classification_heads:
      for key, item in head.checkpoint_items.items():
        items['.'.join([head.name, key])] = item
    return items

  def get_config(self):
    return self._config

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)