sentence_prediction.py 11.2 KB
Newer Older
Frederick Liu's avatar
Frederick Liu committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frederick Liu's avatar
Frederick Liu committed
14

15
"""Sentence prediction (classification) task."""
16
from typing import List, Union, Optional
17

18
from absl import logging
19
import dataclasses
20
import numpy as np
21
import orbit
22
23
from scipy import stats
from sklearn import metrics as sklearn_metrics
24
25
26
import tensorflow as tf

from official.core import base_task
27
from official.core import config_definitions as cfg
Abdullah Rashwan's avatar
Abdullah Rashwan committed
28
from official.core import task_factory
Chen Chen's avatar
Chen Chen committed
29
from official.modeling import tf_utils
Hongkun Yu's avatar
Hongkun Yu committed
30
31
from official.modeling.hyperparams import base_config
from official.nlp.configs import encoders
Chen Chen's avatar
Chen Chen committed
32
from official.nlp.data import data_loader_factory
Hongkun Yu's avatar
Hongkun Yu committed
33
from official.nlp.modeling import models
Chen Chen's avatar
Chen Chen committed
34
from official.nlp.tasks import utils
35

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
36
37
38
39
METRIC_TYPES = frozenset(
    ['accuracy', 'matthews_corrcoef', 'pearson_spearman_corr'])


Hongkun Yu's avatar
Hongkun Yu committed
40
41
42
43
44
@dataclasses.dataclass
class ModelConfig(base_config.Config):
  """A classifier/regressor configuration."""
  num_classes: int = 0
  use_encoder_pooler: bool = False
Hongkun Yu's avatar
Hongkun Yu committed
45
  encoder: encoders.EncoderConfig = encoders.EncoderConfig()
Hongkun Yu's avatar
Hongkun Yu committed
46
47


48
49
50
@dataclasses.dataclass
class SentencePredictionConfig(cfg.TaskConfig):
  """The model config."""
Hongkun Yu's avatar
Hongkun Yu committed
51
  # At most one of `init_checkpoint` and `hub_module_url` can
52
  # be specified.
Hongkun Yu's avatar
Hongkun Yu committed
53
  init_checkpoint: str = ''
Hongkun Yu's avatar
Hongkun Yu committed
54
  init_cls_pooler: bool = False
55
  hub_module_url: str = ''
56
  metric_type: str = 'accuracy'
Hongkun Yu's avatar
Hongkun Yu committed
57
58
  # Defines the concrete model config at instantiation time.
  model: ModelConfig = ModelConfig()
59
60
61
62
  train_data: cfg.DataConfig = cfg.DataConfig()
  validation_data: cfg.DataConfig = cfg.DataConfig()


Abdullah Rashwan's avatar
Abdullah Rashwan committed
63
@task_factory.register_task_cls(SentencePredictionConfig)
64
65
66
class SentencePredictionTask(base_task.Task):
  """Task object for sentence_prediction."""

Hongkun Yu's avatar
Hongkun Yu committed
67
68
  def __init__(self, params: cfg.TaskConfig, logging_dir=None, name=None):
    super().__init__(params, logging_dir, name=name)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
69
70
    if params.metric_type not in METRIC_TYPES:
      raise ValueError('Invalid metric_type: {}'.format(params.metric_type))
71
    self.metric_type = params.metric_type
72
73

  def build_model(self):
Hongkun Yu's avatar
Hongkun Yu committed
74
75
76
77
    if self.task_config.hub_module_url and self.task_config.init_checkpoint:
      raise ValueError('At most one of `hub_module_url` and '
                       '`init_checkpoint` can be specified.')
    if self.task_config.hub_module_url:
Chen Chen's avatar
Chen Chen committed
78
79
      encoder_network = utils.get_encoder_from_hub(
          self.task_config.hub_module_url)
80
    else:
Hongkun Yu's avatar
Hongkun Yu committed
81
82
      encoder_network = encoders.build_encoder(self.task_config.model.encoder)
    encoder_cfg = self.task_config.model.encoder.get()
Allen Wang's avatar
Allen Wang committed
83
84
85
86
87
88
89
90
91
92
93
94
95
    if self.task_config.model.encoder.type == 'xlnet':
      return models.XLNetClassifier(
          network=encoder_network,
          num_classes=self.task_config.model.num_classes,
          initializer=tf.keras.initializers.RandomNormal(
              stddev=encoder_cfg.initializer_range))
    else:
      return models.BertClassifier(
          network=encoder_network,
          num_classes=self.task_config.model.num_classes,
          initializer=tf.keras.initializers.TruncatedNormal(
              stddev=encoder_cfg.initializer_range),
          use_encoder_pooler=self.task_config.model.use_encoder_pooler)
96

97
  def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
98
    label_ids = labels['label_ids']
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
99
    if self.task_config.model.num_classes == 1:
100
      loss = tf.keras.losses.mean_squared_error(label_ids, model_outputs)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
101
102
    else:
      loss = tf.keras.losses.sparse_categorical_crossentropy(
103
          label_ids, tf.cast(model_outputs, tf.float32), from_logits=True)
104
105
106

    if aux_losses:
      loss += tf.add_n(aux_losses)
Chen Chen's avatar
Chen Chen committed
107
    return tf_utils.safe_mean(loss)
108
109
110
111

  def build_inputs(self, params, input_context=None):
    """Returns tf.data.Dataset for sentence_prediction task."""
    if params.input_path == 'dummy':
Hongkun Yu's avatar
Hongkun Yu committed
112

113
114
115
116
117
118
      def dummy_data(_):
        dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
        x = dict(
            input_word_ids=dummy_ids,
            input_mask=dummy_ids,
            input_type_ids=dummy_ids)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
119
120
121
122
123

        if self.task_config.model.num_classes == 1:
          y = tf.zeros((1,), dtype=tf.float32)
        else:
          y = tf.zeros((1, 1), dtype=tf.int32)
124
125
        x['label_ids'] = y
        return x
126
127
128
129
130
131
132

      dataset = tf.data.Dataset.range(1)
      dataset = dataset.repeat()
      dataset = dataset.map(
          dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
      return dataset

Chen Chen's avatar
Chen Chen committed
133
    return data_loader_factory.get_data_loader(params).load(input_context)
134
135
136

  def build_metrics(self, training=None):
    del training
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
137
138
139
140
    if self.task_config.model.num_classes == 1:
      metrics = [tf.keras.metrics.MeanSquaredError()]
    else:
      metrics = [
Hongkun Yu's avatar
Hongkun Yu committed
141
142
          tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy')
      ]
143
144
    return metrics

145
  def process_metrics(self, metrics, labels, model_outputs):
146
    for metric in metrics:
147
      metric.update_state(labels['label_ids'], model_outputs)
148

149
  def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
Hongkun Yu's avatar
Hongkun Yu committed
150
    compiled_metrics.update_state(labels, model_outputs)
151

152
153
154
155
  def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
    if self.metric_type == 'accuracy':
      return super(SentencePredictionTask,
                   self).validation_step(inputs, model, metrics)
156
    features, labels = inputs, inputs
157
158
159
    outputs = self.inference_step(features, model)
    loss = self.build_losses(
        labels=labels, model_outputs=outputs, aux_losses=model.losses)
Hongkun Yu's avatar
Hongkun Yu committed
160
    logs = {self.loss: loss}
161
    if self.metric_type == 'matthews_corrcoef':
Hongkun Yu's avatar
Hongkun Yu committed
162
      logs.update({
163
          'sentence_prediction':  # Ensure one prediction along batch dimension.
164
              tf.expand_dims(tf.math.argmax(outputs, axis=1), axis=1),
165
          'labels':
166
              labels['label_ids'],
Hongkun Yu's avatar
Hongkun Yu committed
167
      })
168
    if self.metric_type == 'pearson_spearman_corr':
Hongkun Yu's avatar
Hongkun Yu committed
169
      logs.update({
Hongkun Yu's avatar
Hongkun Yu committed
170
          'sentence_prediction': outputs,
171
          'labels': labels['label_ids'],
Hongkun Yu's avatar
Hongkun Yu committed
172
173
      })
    return logs
174
175

  def aggregate_logs(self, state=None, step_outputs=None):
Hongkun Yu's avatar
Hongkun Yu committed
176
177
    if self.metric_type == 'accuracy':
      return None
178
179
180
181
182
183
184
185
186
    if state is None:
      state = {'sentence_prediction': [], 'labels': []}
    state['sentence_prediction'].append(
        np.concatenate([v.numpy() for v in step_outputs['sentence_prediction']],
                       axis=0))
    state['labels'].append(
        np.concatenate([v.numpy() for v in step_outputs['labels']], axis=0))
    return state

187
  def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
188
189
190
    if self.metric_type == 'accuracy':
      return None
    elif self.metric_type == 'matthews_corrcoef':
191
      preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
192
      preds = np.reshape(preds, -1)
193
      labels = np.concatenate(aggregated_logs['labels'], axis=0)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
194
      labels = np.reshape(labels, -1)
195
196
197
      return {
          self.metric_type: sklearn_metrics.matthews_corrcoef(preds, labels)
      }
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
198
    elif self.metric_type == 'pearson_spearman_corr':
199
      preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
200
      preds = np.reshape(preds, -1)
201
      labels = np.concatenate(aggregated_logs['labels'], axis=0)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
202
      labels = np.reshape(labels, -1)
203
204
205
206
207
      pearson_corr = stats.pearsonr(preds, labels)[0]
      spearman_corr = stats.spearmanr(preds, labels)[0]
      corr_metric = (pearson_corr + spearman_corr) / 2
      return {self.metric_type: corr_metric}

208
209
  def initialize(self, model):
    """Load a pretrained checkpoint (if exists) and then train from iter 0."""
Hongkun Yu's avatar
Hongkun Yu committed
210
211
    ckpt_dir_or_file = self.task_config.init_checkpoint
    if not ckpt_dir_or_file:
212
      return
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
213
214
    if tf.io.gfile.isdir(ckpt_dir_or_file):
      ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
215
216

    pretrain2finetune_mapping = {
Hongkun Yu's avatar
Hongkun Yu committed
217
        'encoder': model.checkpoint_items['encoder'],
218
    }
Hongkun Yu's avatar
Hongkun Yu committed
219
    if self.task_config.init_cls_pooler:
Hongkun Yu's avatar
Hongkun Yu committed
220
      # This option is valid when use_encoder_pooler is false.
Hongkun Yu's avatar
Hongkun Yu committed
221
222
223
      pretrain2finetune_mapping[
          'next_sentence.pooler_dense'] = model.checkpoint_items[
              'sentence_prediction.pooler_dense']
224
    ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping)
Hongkun Yu's avatar
Hongkun Yu committed
225
    status = ckpt.read(ckpt_dir_or_file)
226
    status.expect_partial().assert_existing_objects_matched()
Hongkun Yu's avatar
Hongkun Yu committed
227
    logging.info('Finished loading pretrained checkpoint from %s',
Hongkun Yu's avatar
Hongkun Yu committed
228
                 ckpt_dir_or_file)
229
230


231
232
233
234
235
def predict(task: SentencePredictionTask,
            params: cfg.DataConfig,
            model: tf.keras.Model,
            params_aug: Optional[cfg.DataConfig] = None,
            test_time_aug_wgt: float = 0.3) -> List[Union[int, float]]:
236
237
238
239
240
241
  """Predicts on the input data.

  Args:
    task: A `SentencePredictionTask` object.
    params: A `cfg.DataConfig` object.
    model: A keras.Model.
242
243
244
245
    params_aug: A `cfg.DataConfig` object for augmented data.
    test_time_aug_wgt: Test time augmentation weight. The prediction score will
      use (1. - test_time_aug_wgt) original prediction plus test_time_aug_wgt
      augmented prediction.
246
247
248
249
250
251
252

  Returns:
    A list of predictions with length of `num_examples`. For regression task,
      each element in the list is the predicted score; for classification task,
      each element is the predicted class id.
  """

253
254
  def predict_step(inputs):
    """Replicated prediction calculation."""
255
    x = inputs
Chen Chen's avatar
Chen Chen committed
256
    example_id = x.pop('example_id')
257
    outputs = task.inference_step(x, model)
258
    return dict(example_id=example_id, predictions=outputs)
259
260

  def aggregate_fn(state, outputs):
261
    """Concatenates model's outputs."""
262
    if state is None:
Chen Chen's avatar
Chen Chen committed
263
      state = []
264

Chen Chen's avatar
Chen Chen committed
265
266
267
    for per_replica_example_id, per_replica_batch_predictions in zip(
        outputs['example_id'], outputs['predictions']):
      state.extend(zip(per_replica_example_id, per_replica_batch_predictions))
268
269
270
271
    return state

  dataset = orbit.utils.make_distributed_dataset(tf.distribute.get_strategy(),
                                                 task.build_inputs, params)
272
  outputs = utils.predict(predict_step, aggregate_fn, dataset)
Chen Chen's avatar
Chen Chen committed
273
274
275
276

  # When running on TPU POD, the order of output cannot be maintained,
  # so we need to sort by example_id.
  outputs = sorted(outputs, key=lambda x: x[0])
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
  is_regression = task.task_config.model.num_classes == 1
  if params_aug is not None:
    dataset_aug = orbit.utils.make_distributed_dataset(
        tf.distribute.get_strategy(), task.build_inputs, params_aug)
    outputs_aug = utils.predict(predict_step, aggregate_fn, dataset_aug)
    outputs_aug = sorted(outputs_aug, key=lambda x: x[0])
    if is_regression:
      return [(1. - test_time_aug_wgt) * x[1] + test_time_aug_wgt * y[1]
              for x, y in zip(outputs, outputs_aug)]
    else:
      return [
          tf.argmax(
              (1. - test_time_aug_wgt) * x[1] + test_time_aug_wgt * y[1],
              axis=-1) for x, y in zip(outputs, outputs_aug)
      ]
  if is_regression:
    return [x[1] for x in outputs]
  else:
    return [tf.argmax(x[1], axis=-1) for x in outputs]