sentence_prediction.py 11.9 KB
Newer Older
Frederick Liu's avatar
Frederick Liu committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frederick Liu's avatar
Frederick Liu committed
14

15
"""Sentence prediction (classification) task."""
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
16
import dataclasses
17
from typing import List, Union, Optional
18

19
20
from absl import logging
import numpy as np
21
import orbit
22
23
from scipy import stats
from sklearn import metrics as sklearn_metrics
24
25
26
import tensorflow as tf

from official.core import base_task
27
from official.core import config_definitions as cfg
Abdullah Rashwan's avatar
Abdullah Rashwan committed
28
from official.core import task_factory
Chen Chen's avatar
Chen Chen committed
29
from official.modeling import tf_utils
Hongkun Yu's avatar
Hongkun Yu committed
30
31
from official.modeling.hyperparams import base_config
from official.nlp.configs import encoders
Chen Chen's avatar
Chen Chen committed
32
from official.nlp.data import data_loader_factory
Hongkun Yu's avatar
Hongkun Yu committed
33
from official.nlp.modeling import models
Chen Chen's avatar
Chen Chen committed
34
from official.nlp.tasks import utils
35

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
36
37
38
39
METRIC_TYPES = frozenset(
    ['accuracy', 'matthews_corrcoef', 'pearson_spearman_corr'])


Hongkun Yu's avatar
Hongkun Yu committed
40
41
42
43
44
@dataclasses.dataclass
class ModelConfig(base_config.Config):
  """A classifier/regressor configuration."""
  num_classes: int = 0
  use_encoder_pooler: bool = False
Hongkun Yu's avatar
Hongkun Yu committed
45
  encoder: encoders.EncoderConfig = encoders.EncoderConfig()
Hongkun Yu's avatar
Hongkun Yu committed
46
47


48
49
50
@dataclasses.dataclass
class SentencePredictionConfig(cfg.TaskConfig):
  """The model config."""
Hongkun Yu's avatar
Hongkun Yu committed
51
  # At most one of `init_checkpoint` and `hub_module_url` can
52
  # be specified.
Hongkun Yu's avatar
Hongkun Yu committed
53
  init_checkpoint: str = ''
Hongkun Yu's avatar
Hongkun Yu committed
54
  init_cls_pooler: bool = False
55
  hub_module_url: str = ''
56
  metric_type: str = 'accuracy'
Hongkun Yu's avatar
Hongkun Yu committed
57
58
  # Defines the concrete model config at instantiation time.
  model: ModelConfig = ModelConfig()
59
60
61
62
  train_data: cfg.DataConfig = cfg.DataConfig()
  validation_data: cfg.DataConfig = cfg.DataConfig()


Abdullah Rashwan's avatar
Abdullah Rashwan committed
63
@task_factory.register_task_cls(SentencePredictionConfig)
64
65
66
class SentencePredictionTask(base_task.Task):
  """Task object for sentence_prediction."""

Hongkun Yu's avatar
Hongkun Yu committed
67
68
  def __init__(self, params: cfg.TaskConfig, logging_dir=None, name=None):
    super().__init__(params, logging_dir, name=name)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
69
70
    if params.metric_type not in METRIC_TYPES:
      raise ValueError('Invalid metric_type: {}'.format(params.metric_type))
71
    self.metric_type = params.metric_type
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
72
73
74
75
    if hasattr(params.train_data, 'label_field'):
      self.label_field = params.train_data.label_field
    else:
      self.label_field = 'label_ids'
76
77

  def build_model(self):
Hongkun Yu's avatar
Hongkun Yu committed
78
79
80
81
    if self.task_config.hub_module_url and self.task_config.init_checkpoint:
      raise ValueError('At most one of `hub_module_url` and '
                       '`init_checkpoint` can be specified.')
    if self.task_config.hub_module_url:
Chen Chen's avatar
Chen Chen committed
82
83
      encoder_network = utils.get_encoder_from_hub(
          self.task_config.hub_module_url)
84
    else:
Hongkun Yu's avatar
Hongkun Yu committed
85
86
      encoder_network = encoders.build_encoder(self.task_config.model.encoder)
    encoder_cfg = self.task_config.model.encoder.get()
Allen Wang's avatar
Allen Wang committed
87
88
89
90
91
92
93
94
95
96
97
98
99
    if self.task_config.model.encoder.type == 'xlnet':
      return models.XLNetClassifier(
          network=encoder_network,
          num_classes=self.task_config.model.num_classes,
          initializer=tf.keras.initializers.RandomNormal(
              stddev=encoder_cfg.initializer_range))
    else:
      return models.BertClassifier(
          network=encoder_network,
          num_classes=self.task_config.model.num_classes,
          initializer=tf.keras.initializers.TruncatedNormal(
              stddev=encoder_cfg.initializer_range),
          use_encoder_pooler=self.task_config.model.use_encoder_pooler)
100

101
  def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
102
    label_ids = labels[self.label_field]
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
103
    if self.task_config.model.num_classes == 1:
104
      loss = tf.keras.losses.mean_squared_error(label_ids, model_outputs)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
105
106
    else:
      loss = tf.keras.losses.sparse_categorical_crossentropy(
107
          label_ids, tf.cast(model_outputs, tf.float32), from_logits=True)
108
109
110

    if aux_losses:
      loss += tf.add_n(aux_losses)
Chen Chen's avatar
Chen Chen committed
111
    return tf_utils.safe_mean(loss)
112
113
114
115

  def build_inputs(self, params, input_context=None):
    """Returns tf.data.Dataset for sentence_prediction task."""
    if params.input_path == 'dummy':
Hongkun Yu's avatar
Hongkun Yu committed
116

117
118
119
120
121
122
      def dummy_data(_):
        dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
        x = dict(
            input_word_ids=dummy_ids,
            input_mask=dummy_ids,
            input_type_ids=dummy_ids)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
123
124
125
126
127

        if self.task_config.model.num_classes == 1:
          y = tf.zeros((1,), dtype=tf.float32)
        else:
          y = tf.zeros((1, 1), dtype=tf.int32)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
128
        x[self.label_field] = y
129
        return x
130
131
132
133
134
135
136

      dataset = tf.data.Dataset.range(1)
      dataset = dataset.repeat()
      dataset = dataset.map(
          dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
      return dataset

Chen Chen's avatar
Chen Chen committed
137
    return data_loader_factory.get_data_loader(params).load(input_context)
138
139
140

  def build_metrics(self, training=None):
    del training
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
141
142
    if self.task_config.model.num_classes == 1:
      metrics = [tf.keras.metrics.MeanSquaredError()]
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
143
144
145
146
147
    elif self.task_config.model.num_classes == 2:
      metrics = [
          tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy'),
          tf.keras.metrics.AUC(name='auc', curve='PR'),
      ]
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
148
149
    else:
      metrics = [
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
150
          tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy'),
Hongkun Yu's avatar
Hongkun Yu committed
151
      ]
152
153
    return metrics

154
  def process_metrics(self, metrics, labels, model_outputs):
155
    for metric in metrics:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
156
157
158
159
160
161
162
      if metric.name == 'auc':
        # Convert the logit to probability and extract the probability of True..
        metric.update_state(
            labels[self.label_field],
            tf.expand_dims(tf.nn.softmax(model_outputs)[:, 1], axis=1))
      if metric.name == 'cls_accuracy':
        metric.update_state(labels[self.label_field], model_outputs)
163

164
  def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
165
    compiled_metrics.update_state(labels[self.label_field], model_outputs)
166

167
168
169
170
  def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
    if self.metric_type == 'accuracy':
      return super(SentencePredictionTask,
                   self).validation_step(inputs, model, metrics)
171
    features, labels = inputs, inputs
172
173
174
    outputs = self.inference_step(features, model)
    loss = self.build_losses(
        labels=labels, model_outputs=outputs, aux_losses=model.losses)
Hongkun Yu's avatar
Hongkun Yu committed
175
    logs = {self.loss: loss}
176
    if self.metric_type == 'matthews_corrcoef':
Hongkun Yu's avatar
Hongkun Yu committed
177
      logs.update({
178
          'sentence_prediction':  # Ensure one prediction along batch dimension.
179
              tf.expand_dims(tf.math.argmax(outputs, axis=1), axis=1),
180
          'labels':
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
181
              labels[self.label_field],
Hongkun Yu's avatar
Hongkun Yu committed
182
      })
183
    if self.metric_type == 'pearson_spearman_corr':
Hongkun Yu's avatar
Hongkun Yu committed
184
      logs.update({
Hongkun Yu's avatar
Hongkun Yu committed
185
          'sentence_prediction': outputs,
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
186
          'labels': labels[self.label_field],
Hongkun Yu's avatar
Hongkun Yu committed
187
188
      })
    return logs
189
190

  def aggregate_logs(self, state=None, step_outputs=None):
Hongkun Yu's avatar
Hongkun Yu committed
191
192
    if self.metric_type == 'accuracy':
      return None
193
194
195
196
197
198
199
200
201
    if state is None:
      state = {'sentence_prediction': [], 'labels': []}
    state['sentence_prediction'].append(
        np.concatenate([v.numpy() for v in step_outputs['sentence_prediction']],
                       axis=0))
    state['labels'].append(
        np.concatenate([v.numpy() for v in step_outputs['labels']], axis=0))
    return state

202
  def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
203
204
205
    if self.metric_type == 'accuracy':
      return None
    elif self.metric_type == 'matthews_corrcoef':
206
      preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
207
      preds = np.reshape(preds, -1)
208
      labels = np.concatenate(aggregated_logs['labels'], axis=0)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
209
      labels = np.reshape(labels, -1)
210
211
212
      return {
          self.metric_type: sklearn_metrics.matthews_corrcoef(preds, labels)
      }
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
213
    elif self.metric_type == 'pearson_spearman_corr':
214
      preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
215
      preds = np.reshape(preds, -1)
216
      labels = np.concatenate(aggregated_logs['labels'], axis=0)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
217
      labels = np.reshape(labels, -1)
218
219
220
221
222
      pearson_corr = stats.pearsonr(preds, labels)[0]
      spearman_corr = stats.spearmanr(preds, labels)[0]
      corr_metric = (pearson_corr + spearman_corr) / 2
      return {self.metric_type: corr_metric}

223
224
  def initialize(self, model):
    """Load a pretrained checkpoint (if exists) and then train from iter 0."""
Hongkun Yu's avatar
Hongkun Yu committed
225
226
    ckpt_dir_or_file = self.task_config.init_checkpoint
    if not ckpt_dir_or_file:
227
      return
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
228
229
    if tf.io.gfile.isdir(ckpt_dir_or_file):
      ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
230
231

    pretrain2finetune_mapping = {
Hongkun Yu's avatar
Hongkun Yu committed
232
        'encoder': model.checkpoint_items['encoder'],
233
    }
Hongkun Yu's avatar
Hongkun Yu committed
234
    if self.task_config.init_cls_pooler:
Hongkun Yu's avatar
Hongkun Yu committed
235
      # This option is valid when use_encoder_pooler is false.
Hongkun Yu's avatar
Hongkun Yu committed
236
237
238
      pretrain2finetune_mapping[
          'next_sentence.pooler_dense'] = model.checkpoint_items[
              'sentence_prediction.pooler_dense']
239
    ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping)
Hongkun Yu's avatar
Hongkun Yu committed
240
    status = ckpt.read(ckpt_dir_or_file)
241
    status.expect_partial().assert_existing_objects_matched()
Hongkun Yu's avatar
Hongkun Yu committed
242
    logging.info('Finished loading pretrained checkpoint from %s',
Hongkun Yu's avatar
Hongkun Yu committed
243
                 ckpt_dir_or_file)
244
245


246
247
248
249
250
def predict(task: SentencePredictionTask,
            params: cfg.DataConfig,
            model: tf.keras.Model,
            params_aug: Optional[cfg.DataConfig] = None,
            test_time_aug_wgt: float = 0.3) -> List[Union[int, float]]:
251
252
253
254
255
256
  """Predicts on the input data.

  Args:
    task: A `SentencePredictionTask` object.
    params: A `cfg.DataConfig` object.
    model: A keras.Model.
257
258
259
260
    params_aug: A `cfg.DataConfig` object for augmented data.
    test_time_aug_wgt: Test time augmentation weight. The prediction score will
      use (1. - test_time_aug_wgt) original prediction plus test_time_aug_wgt
      augmented prediction.
261
262
263
264
265
266
267

  Returns:
    A list of predictions with length of `num_examples`. For regression task,
      each element in the list is the predicted score; for classification task,
      each element is the predicted class id.
  """

268
269
  def predict_step(inputs):
    """Replicated prediction calculation."""
270
    x = inputs
Chen Chen's avatar
Chen Chen committed
271
    example_id = x.pop('example_id')
272
    outputs = task.inference_step(x, model)
273
    return dict(example_id=example_id, predictions=outputs)
274
275

  def aggregate_fn(state, outputs):
276
    """Concatenates model's outputs."""
277
    if state is None:
Chen Chen's avatar
Chen Chen committed
278
      state = []
279

Chen Chen's avatar
Chen Chen committed
280
281
282
    for per_replica_example_id, per_replica_batch_predictions in zip(
        outputs['example_id'], outputs['predictions']):
      state.extend(zip(per_replica_example_id, per_replica_batch_predictions))
283
284
285
286
    return state

  dataset = orbit.utils.make_distributed_dataset(tf.distribute.get_strategy(),
                                                 task.build_inputs, params)
287
  outputs = utils.predict(predict_step, aggregate_fn, dataset)
Chen Chen's avatar
Chen Chen committed
288
289
290
291

  # When running on TPU POD, the order of output cannot be maintained,
  # so we need to sort by example_id.
  outputs = sorted(outputs, key=lambda x: x[0])
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
  is_regression = task.task_config.model.num_classes == 1
  if params_aug is not None:
    dataset_aug = orbit.utils.make_distributed_dataset(
        tf.distribute.get_strategy(), task.build_inputs, params_aug)
    outputs_aug = utils.predict(predict_step, aggregate_fn, dataset_aug)
    outputs_aug = sorted(outputs_aug, key=lambda x: x[0])
    if is_regression:
      return [(1. - test_time_aug_wgt) * x[1] + test_time_aug_wgt * y[1]
              for x, y in zip(outputs, outputs_aug)]
    else:
      return [
          tf.argmax(
              (1. - test_time_aug_wgt) * x[1] + test_time_aug_wgt * y[1],
              axis=-1) for x, y in zip(outputs, outputs_aug)
      ]
  if is_regression:
    return [x[1] for x in outputs]
  else:
    return [tf.argmax(x[1], axis=-1) for x in outputs]