electra_pretrainer.py 12.7 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frederick Liu's avatar
Frederick Liu committed
14

15
16
17
18
"""Trainer network for ELECTRA models."""
# pylint: disable=g-classes-have-attributes

import copy
Hongkun Yu's avatar
Hongkun Yu committed
19

20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import tensorflow as tf

from official.modeling import tf_utils
from official.nlp.modeling import layers


@tf.keras.utils.register_keras_serializable(package='Text')
class ElectraPretrainer(tf.keras.Model):
  """ELECTRA network training model.

  This is an implementation of the network structure described in "ELECTRA:
  Pre-training Text Encoders as Discriminators Rather Than Generators" (
  https://arxiv.org/abs/2003.10555).

  The ElectraPretrainer allows a user to pass in two transformer models, one for
  generator, the other for discriminator, and instantiates the masked language
  model (at generator side) and classification networks (at discriminator side)
  that are used to create the training objectives.

39
  *Note* that the model is constructed by Keras Subclass API, where layers are
40
  defined inside `__init__` and `call()` implements the computation.
41

42
  Args:
43
44
45
46
47
48
49
50
51
52
53
54
    generator_network: A transformer network for generator, this network should
      output a sequence output and an optional classification output.
    discriminator_network: A transformer network for discriminator, this network
      should output a sequence output
    vocab_size: Size of generator output vocabulary
    num_classes: Number of classes to predict from the classification network
      for the generator network (not used now)
    num_token_predictions: Number of tokens to predict from the masked LM.
    mlm_activation: The activation (if any) to use in the masked LM and
      classification networks. If None, no activation will be used.
    mlm_initializer: The initializer (if any) to use in the masked LM and
      classification networks. Defaults to a Glorot uniform initializer.
55
56
    output_type: The output style for this network. Can be either `logits` or
      `predictions`.
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
    disallow_correct: Whether to disallow the generator to generate the exact
      same token in the original sentence
  """

  def __init__(self,
               generator_network,
               discriminator_network,
               vocab_size,
               num_classes,
               num_token_predictions,
               mlm_activation=None,
               mlm_initializer='glorot_uniform',
               output_type='logits',
               disallow_correct=False,
               **kwargs):
    super(ElectraPretrainer, self).__init__()
    self._config = {
        'generator_network': generator_network,
        'discriminator_network': discriminator_network,
        'vocab_size': vocab_size,
        'num_classes': num_classes,
        'num_token_predictions': num_token_predictions,
        'mlm_activation': mlm_activation,
        'mlm_initializer': mlm_initializer,
        'output_type': output_type,
        'disallow_correct': disallow_correct,
    }
    for k, v in kwargs.items():
      self._config[k] = v

    self.generator_network = generator_network
    self.discriminator_network = discriminator_network
    self.vocab_size = vocab_size
    self.num_classes = num_classes
    self.num_token_predictions = num_token_predictions
    self.mlm_activation = mlm_activation
    self.mlm_initializer = mlm_initializer
    self.output_type = output_type
    self.disallow_correct = disallow_correct
    self.masked_lm = layers.MaskedLM(
        embedding_table=generator_network.get_embedding_table(),
        activation=mlm_activation,
Scott Zhu's avatar
Scott Zhu committed
99
        initializer=tf_utils.clone_initializer(mlm_initializer),
100
101
102
        output=output_type,
        name='generator_masked_lm')
    self.classification = layers.ClassificationHead(
103
        inner_dim=generator_network.get_config()['hidden_size'],
104
        num_classes=num_classes,
Scott Zhu's avatar
Scott Zhu committed
105
        initializer=tf_utils.clone_initializer(mlm_initializer),
106
        name='generator_classification_head')
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
107
    self.discriminator_projection = tf.keras.layers.Dense(
108
        units=discriminator_network.get_config()['hidden_size'],
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
109
        activation=mlm_activation,
Scott Zhu's avatar
Scott Zhu committed
110
        kernel_initializer=tf_utils.clone_initializer(mlm_initializer),
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
111
        name='discriminator_projection_head')
112
    self.discriminator_head = tf.keras.layers.Dense(
Scott Zhu's avatar
Scott Zhu committed
113
114
        units=1,
        kernel_initializer=tf_utils.clone_initializer(mlm_initializer))
115
116

  def call(self, inputs):
117
118
119
120
121
122
123
    """ELECTRA forward pass.

    Args:
      inputs: A dict of all inputs, same as the standard BERT model.

    Returns:
      outputs: A dict of pretrainer model outputs, including
124
125
126
        (1) lm_outputs: A `[batch_size, num_token_predictions, vocab_size]`
        tensor indicating logits on masked positions.
        (2) sentence_outputs: A `[batch_size, num_classes]` tensor indicating
127
        logits for nsp task.
128
        (3) disc_logits: A `[batch_size, sequence_length]` tensor indicating
129
        logits for discriminator replaced token detection task.
130
        (4) disc_label: A `[batch_size, sequence_length]` tensor indicating
131
132
        target labels for discriminator replaced token detection task.
    """
133
134
135
136
137
138
    input_word_ids = inputs['input_word_ids']
    input_mask = inputs['input_mask']
    input_type_ids = inputs['input_type_ids']
    masked_lm_positions = inputs['masked_lm_positions']

    ### Generator ###
139
140
    sequence_output = self.generator_network(
        [input_word_ids, input_mask, input_type_ids])['sequence_output']
141
142
143
144
145
146
147
148
149
150
151
152
153
    # The generator encoder network may get outputs from all layers.
    if isinstance(sequence_output, list):
      sequence_output = sequence_output[-1]

    lm_outputs = self.masked_lm(sequence_output, masked_lm_positions)
    sentence_outputs = self.classification(sequence_output)

    ### Sampling from generator ###
    fake_data = self._get_fake_data(inputs, lm_outputs, duplicate=True)

    ### Discriminator ###
    disc_input = fake_data['inputs']
    disc_label = fake_data['is_fake_tokens']
154
    disc_sequence_output = self.discriminator_network([
155
156
        disc_input['input_word_ids'], disc_input['input_mask'],
        disc_input['input_type_ids']
157
    ])['sequence_output']
158
159
160
161
162

    # The discriminator encoder network may get outputs from all layers.
    if isinstance(disc_sequence_output, list):
      disc_sequence_output = disc_sequence_output[-1]

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
163
164
    disc_logits = self.discriminator_head(
        self.discriminator_projection(disc_sequence_output))
165
166
    disc_logits = tf.squeeze(disc_logits, axis=-1)

167
168
169
170
171
172
173
174
    outputs = {
        'lm_outputs': lm_outputs,
        'sentence_outputs': sentence_outputs,
        'disc_logits': disc_logits,
        'disc_label': disc_label,
    }

    return outputs
175
176
177
178
179

  def _get_fake_data(self, inputs, mlm_logits, duplicate=True):
    """Generate corrupted data for discriminator.

    Args:
180
      inputs: A dict of all inputs, same as the input of `call()` function
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
      mlm_logits: The generator's output logits
      duplicate: Whether to copy the original inputs dict during modifications

    Returns:
      A dict of generated fake data
    """
    inputs = unmask(inputs, duplicate)

    if self.disallow_correct:
      disallow = tf.one_hot(
          inputs['masked_lm_ids'], depth=self.vocab_size, dtype=tf.float32)
    else:
      disallow = None

    sampled_tokens = tf.stop_gradient(
        sample_from_softmax(mlm_logits, disallow=disallow))
    sampled_tokids = tf.argmax(sampled_tokens, -1, output_type=tf.int32)
    updated_input_ids, masked = scatter_update(inputs['input_word_ids'],
                                               sampled_tokids,
                                               inputs['masked_lm_positions'])
    labels = masked * (1 - tf.cast(
        tf.equal(updated_input_ids, inputs['input_word_ids']), tf.int32))

    updated_inputs = get_updated_inputs(
        inputs, duplicate, input_word_ids=updated_input_ids)

    return {
        'inputs': updated_inputs,
        'is_fake_tokens': labels,
        'sampled_tokens': sampled_tokens
    }

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
213
214
215
216
217
218
  @property
  def checkpoint_items(self):
    """Returns a dictionary of items to be additionally checkpointed."""
    items = dict(encoder=self.discriminator_network)
    return items

219
220
221
222
223
224
225
226
227
228
229
230
  def get_config(self):
    return self._config

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)


def scatter_update(sequence, updates, positions):
  """Scatter-update a sequence.

  Args:
231
232
233
234
    sequence: A `[batch_size, seq_len]` or `[batch_size, seq_len, depth]`
      tensor.
    updates: A tensor of size `batch_size*seq_len(*depth)`.
    positions: A `[batch_size, n_positions]` tensor.
235
236

  Returns:
237
238
239
240
241
242
    updated_sequence: A `[batch_size, seq_len]` or
      `[batch_size, seq_len, depth]` tensor of "sequence" with elements at
      "positions" replaced by the values at "updates". Updates to index 0 are
      ignored. If there are duplicated positions the update is only
      applied once.
    updates_mask: A `[batch_size, seq_len]` mask tensor of which inputs were
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
      updated.
  """
  shape = tf_utils.get_shape_list(sequence, expected_rank=[2, 3])
  depth_dimension = (len(shape) == 3)
  if depth_dimension:
    batch_size, seq_len, depth = shape
  else:
    batch_size, seq_len = shape
    depth = 1
    sequence = tf.expand_dims(sequence, -1)
  n_positions = tf_utils.get_shape_list(positions)[1]

  shift = tf.expand_dims(seq_len * tf.range(batch_size), -1)
  flat_positions = tf.reshape(positions + shift, [-1, 1])
  flat_updates = tf.reshape(updates, [-1, depth])
  updates = tf.scatter_nd(flat_positions, flat_updates,
                          [batch_size * seq_len, depth])
  updates = tf.reshape(updates, [batch_size, seq_len, depth])

  flat_updates_mask = tf.ones([batch_size * n_positions], tf.int32)
  updates_mask = tf.scatter_nd(flat_positions, flat_updates_mask,
                               [batch_size * seq_len])
  updates_mask = tf.reshape(updates_mask, [batch_size, seq_len])
  not_first_token = tf.concat([
      tf.zeros((batch_size, 1), tf.int32),
      tf.ones((batch_size, seq_len - 1), tf.int32)
  ], -1)
  updates_mask *= not_first_token
  updates_mask_3d = tf.expand_dims(updates_mask, -1)

  # account for duplicate positions
  if sequence.dtype == tf.float32:
    updates_mask_3d = tf.cast(updates_mask_3d, tf.float32)
    updates /= tf.maximum(1.0, updates_mask_3d)
  else:
    assert sequence.dtype == tf.int32
    updates = tf.math.floordiv(updates, tf.maximum(1, updates_mask_3d))
  updates_mask = tf.minimum(updates_mask, 1)
  updates_mask_3d = tf.minimum(updates_mask_3d, 1)

  updated_sequence = (((1 - updates_mask_3d) * sequence) +
                      (updates_mask_3d * updates))
  if not depth_dimension:
    updated_sequence = tf.squeeze(updated_sequence, -1)

  return updated_sequence, updates_mask


def sample_from_softmax(logits, disallow=None):
  """Implement softmax sampling using gumbel softmax trick.

  Args:
295
296
    logits: A `[batch_size, num_token_predictions, vocab_size]` tensor
      indicating the generator output logits for each masked position.
297
    disallow: If `None`, we directly sample tokens from the logits. Otherwise,
298
      this is a tensor of size `[batch_size, num_token_predictions, vocab_size]`
299
300
301
      indicating the true word id in each masked position.

  Returns:
302
    sampled_tokens: A `[batch_size, num_token_predictions, vocab_size]` one hot
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
      tensor indicating the sampled word id in each masked position.
  """
  if disallow is not None:
    logits -= 1000.0 * disallow
  uniform_noise = tf.random.uniform(
      tf_utils.get_shape_list(logits), minval=0, maxval=1)
  gumbel_noise = -tf.math.log(-tf.math.log(uniform_noise + 1e-9) + 1e-9)

  # Here we essentially follow the original paper and use temperature 1.0 for
  # generator output logits.
  sampled_tokens = tf.one_hot(
      tf.argmax(tf.nn.softmax(logits + gumbel_noise), -1, output_type=tf.int32),
      logits.shape[-1])
  return sampled_tokens


def unmask(inputs, duplicate):
  unmasked_input_word_ids, _ = scatter_update(inputs['input_word_ids'],
                                              inputs['masked_lm_ids'],
                                              inputs['masked_lm_positions'])
  return get_updated_inputs(
      inputs, duplicate, input_word_ids=unmasked_input_word_ids)


def get_updated_inputs(inputs, duplicate, **kwargs):
  if duplicate:
    new_inputs = copy.copy(inputs)
  else:
    new_inputs = inputs
  for k, v in kwargs.items():
    new_inputs[k] = v
  return new_inputs