sentence_prediction.py 10 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sentence prediction (classification) task."""
17
18
from typing import List, Union

19
from absl import logging
20
import dataclasses
21
import numpy as np
22
import orbit
23
24
from scipy import stats
from sklearn import metrics as sklearn_metrics
25
26
27
import tensorflow as tf

from official.core import base_task
28
from official.core import config_definitions as cfg
Abdullah Rashwan's avatar
Abdullah Rashwan committed
29
from official.core import task_factory
Chen Chen's avatar
Chen Chen committed
30
from official.modeling import tf_utils
Hongkun Yu's avatar
Hongkun Yu committed
31
32
from official.modeling.hyperparams import base_config
from official.nlp.configs import encoders
Chen Chen's avatar
Chen Chen committed
33
from official.nlp.data import data_loader_factory
Hongkun Yu's avatar
Hongkun Yu committed
34
from official.nlp.modeling import models
Chen Chen's avatar
Chen Chen committed
35
from official.nlp.tasks import utils
36

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
37
38
39
40
METRIC_TYPES = frozenset(
    ['accuracy', 'matthews_corrcoef', 'pearson_spearman_corr'])


Hongkun Yu's avatar
Hongkun Yu committed
41
42
43
44
45
@dataclasses.dataclass
class ModelConfig(base_config.Config):
  """A classifier/regressor configuration."""
  num_classes: int = 0
  use_encoder_pooler: bool = False
Hongkun Yu's avatar
Hongkun Yu committed
46
  encoder: encoders.EncoderConfig = encoders.EncoderConfig()
Hongkun Yu's avatar
Hongkun Yu committed
47
48


49
50
51
@dataclasses.dataclass
class SentencePredictionConfig(cfg.TaskConfig):
  """The model config."""
Hongkun Yu's avatar
Hongkun Yu committed
52
  # At most one of `init_checkpoint` and `hub_module_url` can
53
  # be specified.
Hongkun Yu's avatar
Hongkun Yu committed
54
  init_checkpoint: str = ''
Hongkun Yu's avatar
Hongkun Yu committed
55
  init_cls_pooler: bool = False
56
  hub_module_url: str = ''
57
  metric_type: str = 'accuracy'
Hongkun Yu's avatar
Hongkun Yu committed
58
59
  # Defines the concrete model config at instantiation time.
  model: ModelConfig = ModelConfig()
60
61
62
63
  train_data: cfg.DataConfig = cfg.DataConfig()
  validation_data: cfg.DataConfig = cfg.DataConfig()


Abdullah Rashwan's avatar
Abdullah Rashwan committed
64
@task_factory.register_task_cls(SentencePredictionConfig)
65
66
67
class SentencePredictionTask(base_task.Task):
  """Task object for sentence_prediction."""

Hongkun Yu's avatar
Hongkun Yu committed
68
69
  def __init__(self, params: cfg.TaskConfig, logging_dir=None, name=None):
    super().__init__(params, logging_dir, name=name)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
70
71
    if params.metric_type not in METRIC_TYPES:
      raise ValueError('Invalid metric_type: {}'.format(params.metric_type))
72
    self.metric_type = params.metric_type
73
74

  def build_model(self):
Hongkun Yu's avatar
Hongkun Yu committed
75
76
77
78
    if self.task_config.hub_module_url and self.task_config.init_checkpoint:
      raise ValueError('At most one of `hub_module_url` and '
                       '`init_checkpoint` can be specified.')
    if self.task_config.hub_module_url:
Chen Chen's avatar
Chen Chen committed
79
80
      encoder_network = utils.get_encoder_from_hub(
          self.task_config.hub_module_url)
81
    else:
Hongkun Yu's avatar
Hongkun Yu committed
82
83
      encoder_network = encoders.build_encoder(self.task_config.model.encoder)
    encoder_cfg = self.task_config.model.encoder.get()
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
84
    # Currently, we only support bert-style sentence prediction finetuning.
Hongkun Yu's avatar
Hongkun Yu committed
85
86
87
88
    return models.BertClassifier(
        network=encoder_network,
        num_classes=self.task_config.model.num_classes,
        initializer=tf.keras.initializers.TruncatedNormal(
Hongkun Yu's avatar
Hongkun Yu committed
89
            stddev=encoder_cfg.initializer_range),
Hongkun Yu's avatar
Hongkun Yu committed
90
        use_encoder_pooler=self.task_config.model.use_encoder_pooler)
91

92
  def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
93
94
95
96
97
    if self.task_config.model.num_classes == 1:
      loss = tf.keras.losses.mean_squared_error(labels, model_outputs)
    else:
      loss = tf.keras.losses.sparse_categorical_crossentropy(
          labels, tf.cast(model_outputs, tf.float32), from_logits=True)
98
99
100

    if aux_losses:
      loss += tf.add_n(aux_losses)
Chen Chen's avatar
Chen Chen committed
101
    return tf_utils.safe_mean(loss)
102
103
104
105

  def build_inputs(self, params, input_context=None):
    """Returns tf.data.Dataset for sentence_prediction task."""
    if params.input_path == 'dummy':
Hongkun Yu's avatar
Hongkun Yu committed
106

107
108
109
110
111
112
      def dummy_data(_):
        dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
        x = dict(
            input_word_ids=dummy_ids,
            input_mask=dummy_ids,
            input_type_ids=dummy_ids)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
113
114
115
116
117
118

        if self.task_config.model.num_classes == 1:
          y = tf.zeros((1,), dtype=tf.float32)
        else:
          y = tf.zeros((1, 1), dtype=tf.int32)
        return x, y
119
120
121
122
123
124
125

      dataset = tf.data.Dataset.range(1)
      dataset = dataset.repeat()
      dataset = dataset.map(
          dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
      return dataset

Chen Chen's avatar
Chen Chen committed
126
    return data_loader_factory.get_data_loader(params).load(input_context)
127
128
129

  def build_metrics(self, training=None):
    del training
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
130
131
132
133
    if self.task_config.model.num_classes == 1:
      metrics = [tf.keras.metrics.MeanSquaredError()]
    else:
      metrics = [
Hongkun Yu's avatar
Hongkun Yu committed
134
135
          tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy')
      ]
136
137
    return metrics

138
  def process_metrics(self, metrics, labels, model_outputs):
139
    for metric in metrics:
Hongkun Yu's avatar
Hongkun Yu committed
140
      metric.update_state(labels, model_outputs)
141

142
  def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
Hongkun Yu's avatar
Hongkun Yu committed
143
    compiled_metrics.update_state(labels, model_outputs)
144

145
146
147
148
149
150
151
152
  def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
    if self.metric_type == 'accuracy':
      return super(SentencePredictionTask,
                   self).validation_step(inputs, model, metrics)
    features, labels = inputs
    outputs = self.inference_step(features, model)
    loss = self.build_losses(
        labels=labels, model_outputs=outputs, aux_losses=model.losses)
Hongkun Yu's avatar
Hongkun Yu committed
153
    logs = {self.loss: loss}
154
    if self.metric_type == 'matthews_corrcoef':
Hongkun Yu's avatar
Hongkun Yu committed
155
      logs.update({
156
          'sentence_prediction':
157
158
              # Ensure one prediction along batch dimension.
              tf.expand_dims(tf.math.argmax(outputs, axis=1), axis=1),
159
160
          'labels':
              labels,
Hongkun Yu's avatar
Hongkun Yu committed
161
      })
162
    if self.metric_type == 'pearson_spearman_corr':
Hongkun Yu's avatar
Hongkun Yu committed
163
      logs.update({
Hongkun Yu's avatar
Hongkun Yu committed
164
          'sentence_prediction': outputs,
165
          'labels': labels,
Hongkun Yu's avatar
Hongkun Yu committed
166
167
      })
    return logs
168
169

  def aggregate_logs(self, state=None, step_outputs=None):
Hongkun Yu's avatar
Hongkun Yu committed
170
171
    if self.metric_type == 'accuracy':
      return None
172
173
174
175
176
177
178
179
180
181
    if state is None:
      state = {'sentence_prediction': [], 'labels': []}
    state['sentence_prediction'].append(
        np.concatenate([v.numpy() for v in step_outputs['sentence_prediction']],
                       axis=0))
    state['labels'].append(
        np.concatenate([v.numpy() for v in step_outputs['labels']], axis=0))
    return state

  def reduce_aggregated_logs(self, aggregated_logs):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
182
183
184
    if self.metric_type == 'accuracy':
      return None
    elif self.metric_type == 'matthews_corrcoef':
185
      preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
186
      preds = np.reshape(preds, -1)
187
      labels = np.concatenate(aggregated_logs['labels'], axis=0)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
188
      labels = np.reshape(labels, -1)
189
190
191
      return {
          self.metric_type: sklearn_metrics.matthews_corrcoef(preds, labels)
      }
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
192
    elif self.metric_type == 'pearson_spearman_corr':
193
      preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
194
      preds = np.reshape(preds, -1)
195
      labels = np.concatenate(aggregated_logs['labels'], axis=0)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
196
      labels = np.reshape(labels, -1)
197
198
199
200
201
      pearson_corr = stats.pearsonr(preds, labels)[0]
      spearman_corr = stats.spearmanr(preds, labels)[0]
      corr_metric = (pearson_corr + spearman_corr) / 2
      return {self.metric_type: corr_metric}

202
203
  def initialize(self, model):
    """Load a pretrained checkpoint (if exists) and then train from iter 0."""
Hongkun Yu's avatar
Hongkun Yu committed
204
205
206
207
    ckpt_dir_or_file = self.task_config.init_checkpoint
    if tf.io.gfile.isdir(ckpt_dir_or_file):
      ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
    if not ckpt_dir_or_file:
208
209
210
      return

    pretrain2finetune_mapping = {
Hongkun Yu's avatar
Hongkun Yu committed
211
        'encoder': model.checkpoint_items['encoder'],
212
    }
Hongkun Yu's avatar
Hongkun Yu committed
213
    if self.task_config.init_cls_pooler:
Hongkun Yu's avatar
Hongkun Yu committed
214
      # This option is valid when use_encoder_pooler is false.
Hongkun Yu's avatar
Hongkun Yu committed
215
216
217
      pretrain2finetune_mapping[
          'next_sentence.pooler_dense'] = model.checkpoint_items[
              'sentence_prediction.pooler_dense']
218
    ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping)
Hongkun Yu's avatar
Hongkun Yu committed
219
    status = ckpt.read(ckpt_dir_or_file)
220
    status.expect_partial().assert_existing_objects_matched()
Hongkun Yu's avatar
Hongkun Yu committed
221
    logging.info('Finished loading pretrained checkpoint from %s',
Hongkun Yu's avatar
Hongkun Yu committed
222
                 ckpt_dir_or_file)
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240


def predict(task: SentencePredictionTask, params: cfg.DataConfig,
            model: tf.keras.Model) -> List[Union[int, float]]:
  """Predicts on the input data.

  Args:
    task: A `SentencePredictionTask` object.
    params: A `cfg.DataConfig` object.
    model: A keras.Model.

  Returns:
    A list of predictions with length of `num_examples`. For regression task,
      each element in the list is the predicted score; for classification task,
      each element is the predicted class id.
  """
  is_regression = task.task_config.model.num_classes == 1

241
242
243
  def predict_step(inputs):
    """Replicated prediction calculation."""
    x, _ = inputs
Chen Chen's avatar
Chen Chen committed
244
    example_id = x.pop('example_id')
245
246
    outputs = task.inference_step(x, model)
    if is_regression:
Chen Chen's avatar
Chen Chen committed
247
      return dict(example_id=example_id, predictions=outputs)
248
    else:
Chen Chen's avatar
Chen Chen committed
249
250
      return dict(
          example_id=example_id, predictions=tf.argmax(outputs, axis=-1))
251
252

  def aggregate_fn(state, outputs):
253
    """Concatenates model's outputs."""
254
    if state is None:
Chen Chen's avatar
Chen Chen committed
255
      state = []
256

Chen Chen's avatar
Chen Chen committed
257
258
259
    for per_replica_example_id, per_replica_batch_predictions in zip(
        outputs['example_id'], outputs['predictions']):
      state.extend(zip(per_replica_example_id, per_replica_batch_predictions))
260
261
262
263
    return state

  dataset = orbit.utils.make_distributed_dataset(tf.distribute.get_strategy(),
                                                 task.build_inputs, params)
264
  outputs = utils.predict(predict_step, aggregate_fn, dataset)
Chen Chen's avatar
Chen Chen committed
265
266
267
268
269

  # When running on TPU POD, the order of output cannot be maintained,
  # so we need to sort by example_id.
  outputs = sorted(outputs, key=lambda x: x[0])
  return [x[1] for x in outputs]