masked_lm.py 7.19 KB
Newer Older
Hongkun Yu's avatar
Hongkun Yu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Masked language task."""
import dataclasses
import tensorflow as tf

from official.core import base_task
Abdullah Rashwan's avatar
Abdullah Rashwan committed
21
from official.core import task_factory
Hongkun Yu's avatar
Hongkun Yu committed
22
from official.modeling import tf_utils
Hongkun Yu's avatar
Hongkun Yu committed
23
24
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert
Hongkun Yu's avatar
Hongkun Yu committed
25
from official.nlp.configs import encoders
Chen Chen's avatar
Chen Chen committed
26
from official.nlp.data import data_loader_factory
Hongkun Yu's avatar
Hongkun Yu committed
27
28
from official.nlp.modeling import layers
from official.nlp.modeling import models
Hongkun Yu's avatar
Hongkun Yu committed
29
30
31
32
33


@dataclasses.dataclass
class MaskedLMConfig(cfg.TaskConfig):
  """The model config."""
Hongkun Yu's avatar
Hongkun Yu committed
34
  model: bert.PretrainerConfig = bert.PretrainerConfig(cls_heads=[
Hongkun Yu's avatar
Hongkun Yu committed
35
36
37
38
39
40
41
      bert.ClsHeadConfig(
          inner_dim=768, num_classes=2, dropout_rate=0.1, name='next_sentence')
  ])
  train_data: cfg.DataConfig = cfg.DataConfig()
  validation_data: cfg.DataConfig = cfg.DataConfig()


Abdullah Rashwan's avatar
Abdullah Rashwan committed
42
@task_factory.register_task_cls(MaskedLMConfig)
Hongkun Yu's avatar
Hongkun Yu committed
43
class MaskedLMTask(base_task.Task):
Hongkun Yu's avatar
Hongkun Yu committed
44
  """Task object for Mask language modeling."""
Hongkun Yu's avatar
Hongkun Yu committed
45

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
46
  def build_model(self, params=None):
Hongkun Yu's avatar
Hongkun Yu committed
47
48
49
50
51
52
53
54
55
56
57
58
    config = params or self.task_config.model
    encoder_cfg = config.encoder
    encoder_network = encoders.build_encoder(encoder_cfg)
    cls_heads = [
        layers.ClassificationHead(**cfg.as_dict()) for cfg in config.cls_heads
    ] if config.cls_heads else []
    return models.BertPretrainerV2(
        mlm_activation=tf_utils.get_activation(config.mlm_activation),
        mlm_initializer=tf.keras.initializers.TruncatedNormal(
            stddev=config.mlm_initializer_range),
        encoder_network=encoder_network,
        classification_heads=cls_heads)
Hongkun Yu's avatar
Hongkun Yu committed
59
60

  def build_losses(self,
61
                   labels,
Hongkun Yu's avatar
Hongkun Yu committed
62
63
64
65
                   model_outputs,
                   metrics,
                   aux_losses=None) -> tf.Tensor:
    metrics = dict([(metric.name, metric) for metric in metrics])
66
67
68
69
70
71
72
73
    lm_prediction_losses = tf.keras.losses.sparse_categorical_crossentropy(
        labels['masked_lm_ids'],
        tf.cast(model_outputs['lm_output'], tf.float32),
        from_logits=True)
    lm_label_weights = labels['masked_lm_weights']
    lm_numerator_loss = tf.reduce_sum(lm_prediction_losses * lm_label_weights)
    lm_denominator_loss = tf.reduce_sum(lm_label_weights)
    mlm_loss = tf.math.divide_no_nan(lm_numerator_loss, lm_denominator_loss)
Hongkun Yu's avatar
Hongkun Yu committed
74
    metrics['lm_example_loss'].update_state(mlm_loss)
75
76
    if 'next_sentence_labels' in labels:
      sentence_labels = labels['next_sentence_labels']
Hongkun Yu's avatar
Hongkun Yu committed
77
78
      sentence_outputs = tf.cast(
          model_outputs['next_sentence'], dtype=tf.float32)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
79
      sentence_loss = tf.reduce_mean(
Hongkun Yu's avatar
Hongkun Yu committed
80
81
          tf.keras.losses.sparse_categorical_crossentropy(
              sentence_labels, sentence_outputs, from_logits=True))
Hongkun Yu's avatar
Hongkun Yu committed
82
83
84
85
86
87
88
89
90
91
92
93
      metrics['next_sentence_loss'].update_state(sentence_loss)
      total_loss = mlm_loss + sentence_loss
    else:
      total_loss = mlm_loss

    if aux_losses:
      total_loss += tf.add_n(aux_losses)
    return total_loss

  def build_inputs(self, params, input_context=None):
    """Returns tf.data.Dataset for pretraining."""
    if params.input_path == 'dummy':
94

Hongkun Yu's avatar
Hongkun Yu committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
      def dummy_data(_):
        dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
        dummy_lm = tf.zeros((1, params.max_predictions_per_seq), dtype=tf.int32)
        return dict(
            input_word_ids=dummy_ids,
            input_mask=dummy_ids,
            input_type_ids=dummy_ids,
            masked_lm_positions=dummy_lm,
            masked_lm_ids=dummy_lm,
            masked_lm_weights=tf.cast(dummy_lm, dtype=tf.float32),
            next_sentence_labels=tf.zeros((1, 1), dtype=tf.int32))

      dataset = tf.data.Dataset.range(1)
      dataset = dataset.repeat()
      dataset = dataset.map(
          dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
      return dataset

Chen Chen's avatar
Chen Chen committed
113
    return data_loader_factory.get_data_loader(params).load(input_context)
Hongkun Yu's avatar
Hongkun Yu committed
114
115
116
117

  def build_metrics(self, training=None):
    del training
    metrics = [
118
        tf.keras.metrics.SparseCategoricalAccuracy(name='masked_lm_accuracy'),
Hongkun Yu's avatar
Hongkun Yu committed
119
120
121
122
123
124
125
126
127
128
        tf.keras.metrics.Mean(name='lm_example_loss')
    ]
    # TODO(hongkuny): rethink how to manage metrics creation with heads.
    if self.task_config.train_data.use_next_sentence_label:
      metrics.append(
          tf.keras.metrics.SparseCategoricalAccuracy(
              name='next_sentence_accuracy'))
      metrics.append(tf.keras.metrics.Mean(name='next_sentence_loss'))
    return metrics

129
  def process_metrics(self, metrics, labels, model_outputs):
Hongkun Yu's avatar
Hongkun Yu committed
130
131
    metrics = dict([(metric.name, metric) for metric in metrics])
    if 'masked_lm_accuracy' in metrics:
132
133
134
      metrics['masked_lm_accuracy'].update_state(labels['masked_lm_ids'],
                                                 model_outputs['lm_output'],
                                                 labels['masked_lm_weights'])
Hongkun Yu's avatar
Hongkun Yu committed
135
136
    if 'next_sentence_accuracy' in metrics:
      metrics['next_sentence_accuracy'].update_state(
137
          labels['next_sentence_labels'], model_outputs['next_sentence'])
Hongkun Yu's avatar
Hongkun Yu committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

  def train_step(self, inputs, model: tf.keras.Model,
                 optimizer: tf.keras.optimizers.Optimizer, metrics):
    """Does forward and backward.

    Args:
      inputs: a dictionary of input tensors.
      model: the model, forward pass definition.
      optimizer: the optimizer for this training step.
      metrics: a nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    with tf.GradientTape() as tape:
      outputs = model(inputs, training=True)
      # Computes per-replica loss.
      loss = self.build_losses(
156
          labels=inputs,
Hongkun Yu's avatar
Hongkun Yu committed
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
          model_outputs=outputs,
          metrics=metrics,
          aux_losses=model.losses)
      # Scales loss as the default gradients allreduce performs sum inside the
      # optimizer.
      # TODO(b/154564893): enable loss scaling.
      # scaled_loss = loss / tf.distribute.get_strategy().num_replicas_in_sync
    tvars = model.trainable_variables
    grads = tape.gradient(loss, tvars)
    optimizer.apply_gradients(list(zip(grads, tvars)))
    self.process_metrics(metrics, inputs, outputs)
    return {self.loss: loss}

  def validation_step(self, inputs, model: tf.keras.Model, metrics):
    """Validatation step.

    Args:
      inputs: a dictionary of input tensors.
      model: the keras.Model.
      metrics: a nested structure of metrics objects.

    Returns:
      A dictionary of logs.
    """
    outputs = self.inference_step(inputs, model)
    loss = self.build_losses(
183
        labels=inputs,
Hongkun Yu's avatar
Hongkun Yu committed
184
185
186
187
188
        model_outputs=outputs,
        metrics=metrics,
        aux_losses=model.losses)
    self.process_metrics(metrics, inputs, outputs)
    return {self.loss: loss}