teams_pretrainer.py 18 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Jialu Liu's avatar
Jialu Liu committed
15
"""Trainer network for TEAMS models."""
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# pylint: disable=g-classes-have-attributes

import tensorflow as tf

from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp.modeling import models


class ReplacedTokenDetectionHead(tf.keras.layers.Layer):
  """Replaced token detection discriminator head.

  Arguments:
    encoder_cfg: Encoder config, used to create hidden layers and head.
    num_task_agnostic_layers: Number of task agnostic layers in the
      discriminator.
    output: The output style for this network. Can be either 'logits' or
      'predictions'.
  """

  def __init__(self,
               encoder_cfg,
               num_task_agnostic_layers,
               output='logits',
               name='rtd',
               **kwargs):
    super(ReplacedTokenDetectionHead, self).__init__(name=name, **kwargs)
    self.num_task_agnostic_layers = num_task_agnostic_layers
    self.hidden_size = encoder_cfg['embedding_cfg']['hidden_size']
    self.num_hidden_instances = encoder_cfg['num_hidden_instances']
    self.hidden_cfg = encoder_cfg['hidden_cfg']
    self.activation = self.hidden_cfg['intermediate_activation']
    self.initializer = self.hidden_cfg['kernel_initializer']

    self.hidden_layers = []
    for i in range(self.num_task_agnostic_layers, self.num_hidden_instances):
      self.hidden_layers.append(
          layers.Transformer(
              num_attention_heads=self.hidden_cfg['num_attention_heads'],
              intermediate_size=self.hidden_cfg['intermediate_size'],
              intermediate_activation=self.activation,
              dropout_rate=self.hidden_cfg['dropout_rate'],
              attention_dropout_rate=self.hidden_cfg['attention_dropout_rate'],
              kernel_initializer=self.initializer,
              name='transformer/layer_%d_rtd' % i))
    self.dense = tf.keras.layers.Dense(
        self.hidden_size,
        activation=self.activation,
        kernel_initializer=self.initializer,
        name='transform/rtd_dense')
    self.rtd_head = tf.keras.layers.Dense(
        units=1, kernel_initializer=self.initializer,
        name='transform/rtd_head')

Jialu Liu's avatar
Jialu Liu committed
70
71
72
73
74
75
    if output not in ('predictions', 'logits'):
      raise ValueError(
          ('Unknown `output` value "%s". `output` can be either "logits" or '
           '"predictions"') % output)
    self._output_type = output

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
  def call(self, sequence_data, input_mask):
    """Compute inner-products of hidden vectors with sampled element embeddings.

    Args:
      sequence_data: A [batch_size, seq_length, num_hidden] tensor.
      input_mask: A [batch_size, seq_length] binary mask to separate the input
        from the padding.

    Returns:
      A [batch_size, seq_length] tensor.
    """
    attention_mask = layers.SelfAttentionMask()([sequence_data, input_mask])
    data = sequence_data
    for hidden_layer in self.hidden_layers:
      data = hidden_layer([sequence_data, attention_mask])
    rtd_logits = self.rtd_head(self.dense(data))
    return tf.squeeze(rtd_logits, axis=-1)


class MultiWordSelectionHead(tf.keras.layers.Layer):
  """Multi-word selection discriminator head.

  Arguments:
    embedding_table: The embedding table.
    activation: The activation, if any, for the dense layer.
    initializer: The intializer for the dense layer. Defaults to a Glorot
      uniform initializer.
    output: The output style for this network. Can be either 'logits' or
      'predictions'.
  """

  def __init__(self,
               embedding_table,
               activation=None,
               initializer='glorot_uniform',
               output='logits',
               name='mws',
               **kwargs):
    super(MultiWordSelectionHead, self).__init__(name=name, **kwargs)
    self.embedding_table = embedding_table
    self.activation = activation
    self.initializer = tf.keras.initializers.get(initializer)

    self._vocab_size, self.embed_size = self.embedding_table.shape
    self.dense = tf.keras.layers.Dense(
        self.embed_size,
        activation=self.activation,
        kernel_initializer=self.initializer,
        name='transform/mws_dense')
    self.layer_norm = tf.keras.layers.LayerNormalization(
        axis=-1, epsilon=1e-12, name='transform/mws_layernorm')

Jialu Liu's avatar
Jialu Liu committed
128
129
130
131
132
    if output not in ('predictions', 'logits'):
      raise ValueError(
          ('Unknown `output` value "%s". `output` can be either "logits" or '
           '"predictions"') % output)
    self._output_type = output
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226

  def call(self, sequence_data, masked_positions, candidate_sets):
    """Compute inner-products of hidden vectors with sampled element embeddings.

    Args:
      sequence_data: A [batch_size, seq_length, num_hidden] tensor.
      masked_positions: A [batch_size, num_prediction] tensor.
      candidate_sets: A [batch_size, num_prediction, k] tensor.

    Returns:
      A [batch_size, num_prediction, k] tensor.
    """
    # Gets shapes for later usage
    candidate_set_shape = tf_utils.get_shape_list(candidate_sets)
    num_prediction = candidate_set_shape[1]

    # Gathers hidden vectors -> (batch_size, num_prediction, 1, embed_size)
    masked_lm_input = self._gather_indexes(sequence_data, masked_positions)
    lm_data = self.dense(masked_lm_input)
    lm_data = self.layer_norm(lm_data)
    lm_data = tf.expand_dims(
        tf.reshape(lm_data, [-1, num_prediction, self.embed_size]), 2)

    # Gathers embeddings -> (batch_size, num_prediction, embed_size, k)
    flat_candidate_sets = tf.reshape(candidate_sets, [-1])
    candidate_embeddings = tf.gather(self.embedding_table, flat_candidate_sets)
    candidate_embeddings = tf.reshape(
        candidate_embeddings,
        tf.concat([tf.shape(candidate_sets), [self.embed_size]], axis=0)
    )
    candidate_embeddings.set_shape(
        candidate_sets.shape.as_list() + [self.embed_size])
    candidate_embeddings = tf.transpose(candidate_embeddings, [0, 1, 3, 2])

    # matrix multiplication + squeeze -> (batch_size, num_prediction, k)
    logits = tf.matmul(lm_data, candidate_embeddings)
    logits = tf.squeeze(logits, 2)

    if self._output_type == 'logits':
      return logits
    return tf.nn.log_softmax(logits)

  def _gather_indexes(self, sequence_tensor, positions):
    """Gathers the vectors at the specific positions.

    Args:
        sequence_tensor: Sequence output of shape
          (`batch_size`, `seq_length`, `num_hidden`) where `num_hidden` is
          number of hidden units.
        positions: Positions ids of tokens in batched sequences.

    Returns:
        Sequence tensor of shape (batch_size * num_predictions,
        num_hidden).
    """
    sequence_shape = tf_utils.get_shape_list(
        sequence_tensor, name='sequence_output_tensor')
    batch_size, seq_length, width = sequence_shape

    flat_offsets = tf.reshape(
        tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
    flat_positions = tf.reshape(positions + flat_offsets, [-1])
    flat_sequence_tensor = tf.reshape(sequence_tensor,
                                      [batch_size * seq_length, width])
    output_tensor = tf.gather(flat_sequence_tensor, flat_positions)

    return output_tensor


@tf.keras.utils.register_keras_serializable(package='Text')
class TeamsPretrainer(tf.keras.Model):
  """TEAMS network training model.

  This is an implementation of the network structure described in "Training
  ELECTRA Augmented with Multi-word Selection"
  (https://arxiv.org/abs/2106.00139).

  The TeamsPretrainer allows a user to pass in two transformer encoders, one
  for generator, the other for discriminator (multi-word selection). The
  pretrainer then instantiates the masked language model (at generator side) and
  classification networks (including both multi-word selection head and replaced
  token detection head) that are used to create the training objectives.

  *Note* that the model is constructed by Keras Subclass API, where layers are
  defined inside `__init__` and `call()` implements the computation.

  Args:
    generator_network: A transformer encoder for generator, this network should
      output a sequence output.
    discriminator_mws_network: A transformer encoder for multi-word selection
      discriminator, this network should output a sequence output.
    num_discriminator_task_agnostic_layers: Number of layers shared between
      multi-word selection and random token detection discriminators.
    vocab_size: Size of generator output vocabulary
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
227
228
    candidate_size: Candidate size for multi-word selection task,
      including the correct word.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
    mlm_activation: The activation (if any) to use in the masked LM and
      classification networks. If None, no activation will be used.
    mlm_initializer: The initializer (if any) to use in the masked LM and
      classification networks. Defaults to a Glorot uniform initializer.
    output_type: The output style for this network. Can be either `logits` or
      `predictions`.
  """

  def __init__(self,
               generator_network,
               discriminator_mws_network,
               num_discriminator_task_agnostic_layers,
               vocab_size,
               candidate_size=5,
               mlm_activation=None,
               mlm_initializer='glorot_uniform',
               output_type='logits',
               **kwargs):
    super().__init__()
    self._config = {
        'generator_network':
            generator_network,
        'discriminator_mws_network':
            discriminator_mws_network,
        'num_discriminator_task_agnostic_layers':
            num_discriminator_task_agnostic_layers,
        'vocab_size':
            vocab_size,
        'candidate_size':
            candidate_size,
        'mlm_activation':
            mlm_activation,
        'mlm_initializer':
            mlm_initializer,
        'output_type':
            output_type,
    }
    for k, v in kwargs.items():
      self._config[k] = v

    self.generator_network = generator_network
    self.discriminator_mws_network = discriminator_mws_network
    self.vocab_size = vocab_size
    self.candidate_size = candidate_size
    self.mlm_activation = mlm_activation
    self.mlm_initializer = mlm_initializer
    self.output_type = output_type
Jialu Liu's avatar
Jialu Liu committed
276
277
    self.embedding_table = (
        self.discriminator_mws_network.embedding_network.get_embedding_table())
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
278
    self.masked_lm = layers.MaskedLM(
Jialu Liu's avatar
Jialu Liu committed
279
        embedding_table=self.embedding_table,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
280
281
282
283
284
        activation=mlm_activation,
        initializer=mlm_initializer,
        output=output_type,
        name='generator_masked_lm')
    discriminator_cfg = self.discriminator_mws_network.get_config()
Jialu Liu's avatar
Jialu Liu committed
285
    self.num_task_agnostic_layers = num_discriminator_task_agnostic_layers
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
286
287
    self.discriminator_rtd_head = ReplacedTokenDetectionHead(
        encoder_cfg=discriminator_cfg,
Jialu Liu's avatar
Jialu Liu committed
288
        num_task_agnostic_layers=self.num_task_agnostic_layers,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
289
290
291
292
        output=output_type,
        name='discriminator_rtd')
    hidden_cfg = discriminator_cfg['hidden_cfg']
    self.discriminator_mws_head = MultiWordSelectionHead(
Jialu Liu's avatar
Jialu Liu committed
293
        embedding_table=self.embedding_table,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
        activation=hidden_cfg['intermediate_activation'],
        initializer=hidden_cfg['kernel_initializer'],
        output=output_type,
        name='discriminator_mws')

  def call(self, inputs):
    """TEAMS forward pass.

    Args:
      inputs: A dict of all inputs, same as the standard BERT model.

    Returns:
      outputs: A dict of pretrainer model outputs, including
        (1) lm_outputs: A `[batch_size, num_token_predictions, vocab_size]`
        tensor indicating logits on masked positions.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
309
        (2) disc_rtd_logits: A `[batch_size, sequence_length]` tensor indicating
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
310
        logits for discriminator replaced token detection task.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
311
        (3) disc_rtd_label: A `[batch_size, sequence_length]` tensor indicating
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
312
        target labels for discriminator replaced token detection task.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
313
314
315
316
317
        (4) disc_mws_logits: A `[batch_size, num_token_predictions,
        candidate_size]` tensor indicating logits for discriminator multi-word
        selection task.
        (5) disc_mws_labels: A `[batch_size, num_token_predictions]` tensor
        indicating target labels for discriminator multi-word selection task.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
    """
    input_word_ids = inputs['input_word_ids']
    input_mask = inputs['input_mask']
    input_type_ids = inputs['input_type_ids']
    masked_lm_positions = inputs['masked_lm_positions']

    # Runs generator.
    sequence_output = self.generator_network(
        [input_word_ids, input_mask, input_type_ids])['sequence_output']

    lm_outputs = self.masked_lm(sequence_output, masked_lm_positions)

    # Samples tokens from generator.
    fake_data = self._get_fake_data(inputs, lm_outputs)

    # Runs discriminator.
    disc_input = fake_data['inputs']
    disc_rtd_label = fake_data['is_fake_tokens']
    disc_mws_candidates = fake_data['candidate_set']
    mws_sequence_outputs = self.discriminator_mws_network([
        disc_input['input_word_ids'], disc_input['input_mask'],
        disc_input['input_type_ids']
    ])['encoder_outputs']

    # Applies replaced token detection with input selected based on
    # self.num_discriminator_task_agnostic_layers
    disc_rtd_logits = self.discriminator_rtd_head(
        mws_sequence_outputs[self.num_task_agnostic_layers - 1], input_mask)

    # Applies multi-word selection.
    disc_mws_logits = self.discriminator_mws_head(mws_sequence_outputs[-1],
                                                  masked_lm_positions,
                                                  disc_mws_candidates)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
351
352
    disc_mws_label = tf.zeros_like(masked_lm_positions, dtype=tf.int32)

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
353
354
355
356
357
    outputs = {
        'lm_outputs': lm_outputs,
        'disc_rtd_logits': disc_rtd_logits,
        'disc_rtd_label': disc_rtd_label,
        'disc_mws_logits': disc_mws_logits,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
358
        'disc_mws_label': disc_mws_label,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
    }

    return outputs

  def _get_fake_data(self, inputs, mlm_logits):
    """Generate corrupted data for discriminator.

    Note it is poosible for sampled token to be the same as the correct one.
    Args:
      inputs: A dict of all inputs, same as the input of `call()` function
      mlm_logits: The generator's output logits

    Returns:
      A dict of generated fake data
    """
    inputs = models.electra_pretrainer.unmask(inputs, duplicate=True)

    # Samples replaced token.
    sampled_tokens = tf.stop_gradient(
        models.electra_pretrainer.sample_from_softmax(
            mlm_logits, disallow=None))
Jialu Liu's avatar
Jialu Liu committed
380
    sampled_tokids = tf.argmax(sampled_tokens, axis=-1, output_type=tf.int32)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408

    # Prepares input and label for replaced token detection task.
    updated_input_ids, masked = models.electra_pretrainer.scatter_update(
        inputs['input_word_ids'], sampled_tokids, inputs['masked_lm_positions'])
    rtd_labels = masked * (1 - tf.cast(
        tf.equal(updated_input_ids, inputs['input_word_ids']), tf.int32))
    updated_inputs = models.electra_pretrainer.get_updated_inputs(
        inputs, duplicate=True, input_word_ids=updated_input_ids)

    # Samples (candidate_size-1) negatives and concat with true tokens
    disallow = tf.one_hot(
        inputs['masked_lm_ids'], depth=self.vocab_size, dtype=tf.float32)
    sampled_candidates = tf.stop_gradient(
        sample_k_from_softmax(mlm_logits, k=self.candidate_size-1,
                              disallow=disallow))
    true_token_id = tf.expand_dims(inputs['masked_lm_ids'], -1)
    candidate_set = tf.concat([true_token_id, sampled_candidates], -1)

    return {
        'inputs': updated_inputs,
        'is_fake_tokens': rtd_labels,
        'sampled_tokens': sampled_tokens,
        'candidate_set': candidate_set
    }

  @property
  def checkpoint_items(self):
    """Returns a dictionary of items to be additionally checkpointed."""
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
409
    items = dict(encoder=self.discriminator_mws_network)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
410
411
412
413
414
415
416
417
418
419
    return items

  def get_config(self):
    return self._config

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)


A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
420
def sample_k_from_softmax(logits, k, disallow=None, use_topk=False):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
421
422
423
424
425
426
427
428
429
  """Implement softmax sampling using gumbel softmax trick to select k items.

  Args:
    logits: A [batch_size, num_token_predictions, vocab_size] tensor indicating
      the generator output logits for each masked position.
    k: Number of samples
    disallow: If `None`, we directly sample tokens from the logits. Otherwise,
      this is a tensor of size [batch_size, num_token_predictions, vocab_size]
      indicating the true word id in each masked position.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
430
431
    use_topk: Whether to use tf.nn.top_k or using iterative approach where the
      latter is empirically faster.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460

  Returns:
    sampled_tokens: A [batch_size, num_token_predictions, k] tensor indicating
    the sampled word id in each masked position.
  """
  if use_topk:
    if disallow is not None:
      logits -= 10000.0 * disallow
    uniform_noise = tf.random.uniform(
        tf_utils.get_shape_list(logits), minval=0, maxval=1)
    gumbel_noise = -tf.math.log(-tf.math.log(uniform_noise + 1e-9) + 1e-9)
    _, sampled_tokens = tf.nn.top_k(logits + gumbel_noise, k=k, sorted=False)
  else:
    sampled_tokens_list = []
    vocab_size = tf_utils.get_shape_list(logits)[-1]
    if disallow is not None:
      logits -= 10000.0 * disallow

    uniform_noise = tf.random.uniform(
        tf_utils.get_shape_list(logits), minval=0, maxval=1)
    gumbel_noise = -tf.math.log(-tf.math.log(uniform_noise + 1e-9) + 1e-9)
    logits += gumbel_noise
    for _ in range(k):
      token_ids = tf.argmax(logits, -1, output_type=tf.int32)
      sampled_tokens_list.append(token_ids)
      logits -= 10000.0 *  tf.one_hot(
          token_ids, depth=vocab_size, dtype=tf.float32)
    sampled_tokens = tf.stack(sampled_tokens_list, -1)
  return sampled_tokens