sentence_prediction.py 9.77 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sentence prediction (classification) task."""
17
18
from typing import List, Union

19
from absl import logging
20
import dataclasses
21
import numpy as np
22
import orbit
23
24
from scipy import stats
from sklearn import metrics as sklearn_metrics
25
26
27
28
import tensorflow as tf
import tensorflow_hub as hub

from official.core import base_task
Abdullah Rashwan's avatar
Abdullah Rashwan committed
29
from official.core import task_factory
Hongkun Yu's avatar
Hongkun Yu committed
30
from official.modeling.hyperparams import base_config
31
from official.modeling.hyperparams import config_definitions as cfg
Hongkun Yu's avatar
Hongkun Yu committed
32
from official.nlp.configs import encoders
Chen Chen's avatar
Chen Chen committed
33
from official.nlp.data import data_loader_factory
Hongkun Yu's avatar
Hongkun Yu committed
34
from official.nlp.modeling import models
Chen Chen's avatar
Chen Chen committed
35
from official.nlp.tasks import utils
36
37


A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
38
39
40
41
METRIC_TYPES = frozenset(
    ['accuracy', 'matthews_corrcoef', 'pearson_spearman_corr'])


Hongkun Yu's avatar
Hongkun Yu committed
42
43
44
45
46
@dataclasses.dataclass
class ModelConfig(base_config.Config):
  """A classifier/regressor configuration."""
  num_classes: int = 0
  use_encoder_pooler: bool = False
Hongkun Yu's avatar
Hongkun Yu committed
47
  encoder: encoders.EncoderConfig = encoders.EncoderConfig()
Hongkun Yu's avatar
Hongkun Yu committed
48
49


50
51
52
@dataclasses.dataclass
class SentencePredictionConfig(cfg.TaskConfig):
  """The model config."""
Hongkun Yu's avatar
Hongkun Yu committed
53
  # At most one of `init_checkpoint` and `hub_module_url` can
54
  # be specified.
Hongkun Yu's avatar
Hongkun Yu committed
55
  init_checkpoint: str = ''
Hongkun Yu's avatar
Hongkun Yu committed
56
  init_cls_pooler: bool = False
57
  hub_module_url: str = ''
58
  metric_type: str = 'accuracy'
Hongkun Yu's avatar
Hongkun Yu committed
59
60
  # Defines the concrete model config at instantiation time.
  model: ModelConfig = ModelConfig()
61
62
63
64
  train_data: cfg.DataConfig = cfg.DataConfig()
  validation_data: cfg.DataConfig = cfg.DataConfig()


Abdullah Rashwan's avatar
Abdullah Rashwan committed
65
@task_factory.register_task_cls(SentencePredictionConfig)
66
67
68
class SentencePredictionTask(base_task.Task):
  """Task object for sentence_prediction."""

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
69
70
  def __init__(self, params=cfg.TaskConfig, logging_dir=None):
    super(SentencePredictionTask, self).__init__(params, logging_dir)
Hongkun Yu's avatar
Hongkun Yu committed
71
    if params.hub_module_url and params.init_checkpoint:
72
      raise ValueError('At most one of `hub_module_url` and '
Hongkun Yu's avatar
Hongkun Yu committed
73
                       '`init_checkpoint` can be specified.')
74
75
76
77
    if params.hub_module_url:
      self._hub_module = hub.load(params.hub_module_url)
    else:
      self._hub_module = None
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
78
79
80

    if params.metric_type not in METRIC_TYPES:
      raise ValueError('Invalid metric_type: {}'.format(params.metric_type))
81
    self.metric_type = params.metric_type
82
83
84

  def build_model(self):
    if self._hub_module:
Hongkun Yu's avatar
Hongkun Yu committed
85
      encoder_network = utils.get_encoder_from_hub(self._hub_module)
86
    else:
Hongkun Yu's avatar
Hongkun Yu committed
87
88
      encoder_network = encoders.build_encoder(self.task_config.model.encoder)
    encoder_cfg = self.task_config.model.encoder.get()
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
89
    # Currently, we only support bert-style sentence prediction finetuning.
Hongkun Yu's avatar
Hongkun Yu committed
90
91
92
93
    return models.BertClassifier(
        network=encoder_network,
        num_classes=self.task_config.model.num_classes,
        initializer=tf.keras.initializers.TruncatedNormal(
Hongkun Yu's avatar
Hongkun Yu committed
94
            stddev=encoder_cfg.initializer_range),
Hongkun Yu's avatar
Hongkun Yu committed
95
        use_encoder_pooler=self.task_config.model.use_encoder_pooler)
96

97
  def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
98
99
100
101
102
    if self.task_config.model.num_classes == 1:
      loss = tf.keras.losses.mean_squared_error(labels, model_outputs)
    else:
      loss = tf.keras.losses.sparse_categorical_crossentropy(
          labels, tf.cast(model_outputs, tf.float32), from_logits=True)
103
104
105

    if aux_losses:
      loss += tf.add_n(aux_losses)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
106
    return tf.reduce_mean(loss)
107
108
109
110

  def build_inputs(self, params, input_context=None):
    """Returns tf.data.Dataset for sentence_prediction task."""
    if params.input_path == 'dummy':
Hongkun Yu's avatar
Hongkun Yu committed
111

112
113
114
115
116
117
      def dummy_data(_):
        dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
        x = dict(
            input_word_ids=dummy_ids,
            input_mask=dummy_ids,
            input_type_ids=dummy_ids)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
118
119
120
121
122
123

        if self.task_config.model.num_classes == 1:
          y = tf.zeros((1,), dtype=tf.float32)
        else:
          y = tf.zeros((1, 1), dtype=tf.int32)
        return x, y
124
125
126
127
128
129
130

      dataset = tf.data.Dataset.range(1)
      dataset = dataset.repeat()
      dataset = dataset.map(
          dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
      return dataset

Chen Chen's avatar
Chen Chen committed
131
    return data_loader_factory.get_data_loader(params).load(input_context)
132
133
134

  def build_metrics(self, training=None):
    del training
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
135
136
137
138
139
    if self.task_config.model.num_classes == 1:
      metrics = [tf.keras.metrics.MeanSquaredError()]
    else:
      metrics = [
          tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy')]
140
141
    return metrics

142
  def process_metrics(self, metrics, labels, model_outputs):
143
    for metric in metrics:
Hongkun Yu's avatar
Hongkun Yu committed
144
      metric.update_state(labels, model_outputs)
145

146
  def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
Hongkun Yu's avatar
Hongkun Yu committed
147
    compiled_metrics.update_state(labels, model_outputs)
148

149
150
151
152
153
154
155
156
  def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
    if self.metric_type == 'accuracy':
      return super(SentencePredictionTask,
                   self).validation_step(inputs, model, metrics)
    features, labels = inputs
    outputs = self.inference_step(features, model)
    loss = self.build_losses(
        labels=labels, model_outputs=outputs, aux_losses=model.losses)
Hongkun Yu's avatar
Hongkun Yu committed
157
    logs = {self.loss: loss}
158
    if self.metric_type == 'matthews_corrcoef':
Hongkun Yu's avatar
Hongkun Yu committed
159
      logs.update({
160
          'sentence_prediction':
Hongkun Yu's avatar
Hongkun Yu committed
161
              tf.expand_dims(tf.math.argmax(outputs, axis=1), axis=0),
162
163
          'labels':
              labels,
Hongkun Yu's avatar
Hongkun Yu committed
164
      })
165
    if self.metric_type == 'pearson_spearman_corr':
Hongkun Yu's avatar
Hongkun Yu committed
166
      logs.update({
Hongkun Yu's avatar
Hongkun Yu committed
167
          'sentence_prediction': outputs,
168
          'labels': labels,
Hongkun Yu's avatar
Hongkun Yu committed
169
170
      })
    return logs
171
172

  def aggregate_logs(self, state=None, step_outputs=None):
Hongkun Yu's avatar
Hongkun Yu committed
173
174
    if self.metric_type == 'accuracy':
      return None
175
176
    if state is None:
      state = {'sentence_prediction': [], 'labels': []}
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
177
    # TODO(b/160712818): Add support for concatenating partial batches.
178
179
180
181
182
183
184
185
    state['sentence_prediction'].append(
        np.concatenate([v.numpy() for v in step_outputs['sentence_prediction']],
                       axis=0))
    state['labels'].append(
        np.concatenate([v.numpy() for v in step_outputs['labels']], axis=0))
    return state

  def reduce_aggregated_logs(self, aggregated_logs):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
186
187
188
    if self.metric_type == 'accuracy':
      return None
    elif self.metric_type == 'matthews_corrcoef':
189
      preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
190
      preds = np.reshape(preds, -1)
191
      labels = np.concatenate(aggregated_logs['labels'], axis=0)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
192
      labels = np.reshape(labels, -1)
193
194
195
      return {
          self.metric_type: sklearn_metrics.matthews_corrcoef(preds, labels)
      }
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
196
    elif self.metric_type == 'pearson_spearman_corr':
197
      preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
198
      preds = np.reshape(preds, -1)
199
      labels = np.concatenate(aggregated_logs['labels'], axis=0)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
200
      labels = np.reshape(labels, -1)
201
202
203
204
205
      pearson_corr = stats.pearsonr(preds, labels)[0]
      spearman_corr = stats.spearmanr(preds, labels)[0]
      corr_metric = (pearson_corr + spearman_corr) / 2
      return {self.metric_type: corr_metric}

206
207
  def initialize(self, model):
    """Load a pretrained checkpoint (if exists) and then train from iter 0."""
Hongkun Yu's avatar
Hongkun Yu committed
208
209
210
211
    ckpt_dir_or_file = self.task_config.init_checkpoint
    if tf.io.gfile.isdir(ckpt_dir_or_file):
      ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
    if not ckpt_dir_or_file:
212
213
214
      return

    pretrain2finetune_mapping = {
Hongkun Yu's avatar
Hongkun Yu committed
215
        'encoder': model.checkpoint_items['encoder'],
216
    }
Hongkun Yu's avatar
Hongkun Yu committed
217
218
219
220
221
222
    # TODO(b/160251903): Investigate why no pooler dense improves finetuning
    # accuracies.
    if self.task_config.init_cls_pooler:
      pretrain2finetune_mapping[
          'next_sentence.pooler_dense'] = model.checkpoint_items[
              'sentence_prediction.pooler_dense']
223
    ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping)
Hongkun Yu's avatar
Hongkun Yu committed
224
    status = ckpt.read(ckpt_dir_or_file)
225
    status.expect_partial().assert_existing_objects_matched()
Hongkun Yu's avatar
Hongkun Yu committed
226
    logging.info('Finished loading pretrained checkpoint from %s',
Hongkun Yu's avatar
Hongkun Yu committed
227
                 ckpt_dir_or_file)
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245


def predict(task: SentencePredictionTask, params: cfg.DataConfig,
            model: tf.keras.Model) -> List[Union[int, float]]:
  """Predicts on the input data.

  Args:
    task: A `SentencePredictionTask` object.
    params: A `cfg.DataConfig` object.
    model: A keras.Model.

  Returns:
    A list of predictions with length of `num_examples`. For regression task,
      each element in the list is the predicted score; for classification task,
      each element is the predicted class id.
  """
  is_regression = task.task_config.model.num_classes == 1

246
247
248
249
250
251
252
253
254
255
  def predict_step(inputs):
    """Replicated prediction calculation."""
    x, _ = inputs
    outputs = task.inference_step(x, model)
    if is_regression:
      return outputs
    else:
      return tf.argmax(outputs, axis=-1)

  def aggregate_fn(state, outputs):
256
    """Concatenates model's outputs."""
257
258
259
    if state is None:
      state = {'predictions': []}

260
    for per_replica_batch_predictions in outputs:
261
      state['predictions'].extend(per_replica_batch_predictions)
262
263
264
265
    return state

  dataset = orbit.utils.make_distributed_dataset(tf.distribute.get_strategy(),
                                                 task.build_inputs, params)
266
267
  outputs = utils.predict(predict_step, aggregate_fn, dataset)
  return outputs['predictions']