sentence_prediction.py 10.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sentence prediction (classification) task."""
17
18
from typing import List, Union

19
from absl import logging
20
import dataclasses
21
import numpy as np
22
import orbit
23
24
from scipy import stats
from sklearn import metrics as sklearn_metrics
25
26
27
28
import tensorflow as tf
import tensorflow_hub as hub

from official.core import base_task
Abdullah Rashwan's avatar
Abdullah Rashwan committed
29
from official.core import task_factory
Hongkun Yu's avatar
Hongkun Yu committed
30
from official.modeling.hyperparams import base_config
31
from official.modeling.hyperparams import config_definitions as cfg
Hongkun Yu's avatar
Hongkun Yu committed
32
from official.nlp.configs import encoders
Chen Chen's avatar
Chen Chen committed
33
from official.nlp.data import data_loader_factory
Hongkun Yu's avatar
Hongkun Yu committed
34
from official.nlp.modeling import models
Chen Chen's avatar
Chen Chen committed
35
from official.nlp.tasks import utils
36
37


A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
38
39
40
41
METRIC_TYPES = frozenset(
    ['accuracy', 'matthews_corrcoef', 'pearson_spearman_corr'])


Hongkun Yu's avatar
Hongkun Yu committed
42
43
44
45
46
47
48
49
50
@dataclasses.dataclass
class ModelConfig(base_config.Config):
  """A classifier/regressor configuration."""
  num_classes: int = 0
  use_encoder_pooler: bool = False
  encoder: encoders.TransformerEncoderConfig = (
      encoders.TransformerEncoderConfig())


51
52
53
@dataclasses.dataclass
class SentencePredictionConfig(cfg.TaskConfig):
  """The model config."""
Hongkun Yu's avatar
Hongkun Yu committed
54
  # At most one of `init_checkpoint` and `hub_module_url` can
55
  # be specified.
Hongkun Yu's avatar
Hongkun Yu committed
56
  init_checkpoint: str = ''
Hongkun Yu's avatar
Hongkun Yu committed
57
  init_cls_pooler: bool = False
58
  hub_module_url: str = ''
59
  metric_type: str = 'accuracy'
Hongkun Yu's avatar
Hongkun Yu committed
60
61
  # Defines the concrete model config at instantiation time.
  model: ModelConfig = ModelConfig()
62
63
64
65
  train_data: cfg.DataConfig = cfg.DataConfig()
  validation_data: cfg.DataConfig = cfg.DataConfig()


Abdullah Rashwan's avatar
Abdullah Rashwan committed
66
@task_factory.register_task_cls(SentencePredictionConfig)
67
68
69
class SentencePredictionTask(base_task.Task):
  """Task object for sentence_prediction."""

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
70
71
  def __init__(self, params=cfg.TaskConfig, logging_dir=None):
    super(SentencePredictionTask, self).__init__(params, logging_dir)
Hongkun Yu's avatar
Hongkun Yu committed
72
    if params.hub_module_url and params.init_checkpoint:
73
      raise ValueError('At most one of `hub_module_url` and '
Hongkun Yu's avatar
Hongkun Yu committed
74
                       '`init_checkpoint` can be specified.')
75
76
77
78
    if params.hub_module_url:
      self._hub_module = hub.load(params.hub_module_url)
    else:
      self._hub_module = None
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
79
80
81

    if params.metric_type not in METRIC_TYPES:
      raise ValueError('Invalid metric_type: {}'.format(params.metric_type))
82
    self.metric_type = params.metric_type
83
84
85

  def build_model(self):
    if self._hub_module:
Hongkun Yu's avatar
Hongkun Yu committed
86
      encoder_network = utils.get_encoder_from_hub(self._hub_module)
87
    else:
Hongkun Yu's avatar
Hongkun Yu committed
88
89
90
      encoder_network = encoders.instantiate_encoder_from_cfg(
          self.task_config.model.encoder)

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
91
    # Currently, we only support bert-style sentence prediction finetuning.
Hongkun Yu's avatar
Hongkun Yu committed
92
93
94
95
96
97
    return models.BertClassifier(
        network=encoder_network,
        num_classes=self.task_config.model.num_classes,
        initializer=tf.keras.initializers.TruncatedNormal(
            stddev=self.task_config.model.encoder.initializer_range),
        use_encoder_pooler=self.task_config.model.use_encoder_pooler)
98

99
  def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
100
101
102
103
104
    if self.task_config.model.num_classes == 1:
      loss = tf.keras.losses.mean_squared_error(labels, model_outputs)
    else:
      loss = tf.keras.losses.sparse_categorical_crossentropy(
          labels, tf.cast(model_outputs, tf.float32), from_logits=True)
105
106
107

    if aux_losses:
      loss += tf.add_n(aux_losses)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
108
    return tf.reduce_mean(loss)
109
110
111
112

  def build_inputs(self, params, input_context=None):
    """Returns tf.data.Dataset for sentence_prediction task."""
    if params.input_path == 'dummy':
Hongkun Yu's avatar
Hongkun Yu committed
113

114
115
116
117
118
119
      def dummy_data(_):
        dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
        x = dict(
            input_word_ids=dummy_ids,
            input_mask=dummy_ids,
            input_type_ids=dummy_ids)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
120
121
122
123
124
125

        if self.task_config.model.num_classes == 1:
          y = tf.zeros((1,), dtype=tf.float32)
        else:
          y = tf.zeros((1, 1), dtype=tf.int32)
        return x, y
126
127
128
129
130
131
132

      dataset = tf.data.Dataset.range(1)
      dataset = dataset.repeat()
      dataset = dataset.map(
          dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
      return dataset

Chen Chen's avatar
Chen Chen committed
133
    return data_loader_factory.get_data_loader(params).load(input_context)
134
135
136

  def build_metrics(self, training=None):
    del training
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
137
138
139
140
141
    if self.task_config.model.num_classes == 1:
      metrics = [tf.keras.metrics.MeanSquaredError()]
    else:
      metrics = [
          tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy')]
142
143
    return metrics

144
  def process_metrics(self, metrics, labels, model_outputs):
145
    for metric in metrics:
Hongkun Yu's avatar
Hongkun Yu committed
146
      metric.update_state(labels, model_outputs)
147

148
  def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
Hongkun Yu's avatar
Hongkun Yu committed
149
    compiled_metrics.update_state(labels, model_outputs)
150

151
152
153
154
155
156
157
158
  def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
    if self.metric_type == 'accuracy':
      return super(SentencePredictionTask,
                   self).validation_step(inputs, model, metrics)
    features, labels = inputs
    outputs = self.inference_step(features, model)
    loss = self.build_losses(
        labels=labels, model_outputs=outputs, aux_losses=model.losses)
Hongkun Yu's avatar
Hongkun Yu committed
159
    logs = {self.loss: loss}
160
    if self.metric_type == 'matthews_corrcoef':
Hongkun Yu's avatar
Hongkun Yu committed
161
      logs.update({
162
          'sentence_prediction':
Hongkun Yu's avatar
Hongkun Yu committed
163
              tf.expand_dims(tf.math.argmax(outputs, axis=1), axis=0),
164
165
          'labels':
              labels,
Hongkun Yu's avatar
Hongkun Yu committed
166
      })
167
    if self.metric_type == 'pearson_spearman_corr':
Hongkun Yu's avatar
Hongkun Yu committed
168
      logs.update({
Hongkun Yu's avatar
Hongkun Yu committed
169
          'sentence_prediction': outputs,
170
          'labels': labels,
Hongkun Yu's avatar
Hongkun Yu committed
171
172
      })
    return logs
173
174

  def aggregate_logs(self, state=None, step_outputs=None):
Hongkun Yu's avatar
Hongkun Yu committed
175
176
    if self.metric_type == 'accuracy':
      return None
177
178
    if state is None:
      state = {'sentence_prediction': [], 'labels': []}
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
179
    # TODO(b/160712818): Add support for concatenating partial batches.
180
181
182
183
184
185
186
187
    state['sentence_prediction'].append(
        np.concatenate([v.numpy() for v in step_outputs['sentence_prediction']],
                       axis=0))
    state['labels'].append(
        np.concatenate([v.numpy() for v in step_outputs['labels']], axis=0))
    return state

  def reduce_aggregated_logs(self, aggregated_logs):
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
188
189
190
    if self.metric_type == 'accuracy':
      return None
    elif self.metric_type == 'matthews_corrcoef':
191
      preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
192
      preds = np.reshape(preds, -1)
193
      labels = np.concatenate(aggregated_logs['labels'], axis=0)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
194
      labels = np.reshape(labels, -1)
195
196
197
      return {
          self.metric_type: sklearn_metrics.matthews_corrcoef(preds, labels)
      }
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
198
    elif self.metric_type == 'pearson_spearman_corr':
199
      preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
200
      preds = np.reshape(preds, -1)
201
      labels = np.concatenate(aggregated_logs['labels'], axis=0)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
202
      labels = np.reshape(labels, -1)
203
204
205
206
207
      pearson_corr = stats.pearsonr(preds, labels)[0]
      spearman_corr = stats.spearmanr(preds, labels)[0]
      corr_metric = (pearson_corr + spearman_corr) / 2
      return {self.metric_type: corr_metric}

208
209
  def initialize(self, model):
    """Load a pretrained checkpoint (if exists) and then train from iter 0."""
Hongkun Yu's avatar
Hongkun Yu committed
210
211
212
213
    ckpt_dir_or_file = self.task_config.init_checkpoint
    if tf.io.gfile.isdir(ckpt_dir_or_file):
      ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
    if not ckpt_dir_or_file:
214
215
216
      return

    pretrain2finetune_mapping = {
Hongkun Yu's avatar
Hongkun Yu committed
217
        'encoder': model.checkpoint_items['encoder'],
218
    }
Hongkun Yu's avatar
Hongkun Yu committed
219
220
221
222
223
224
    # TODO(b/160251903): Investigate why no pooler dense improves finetuning
    # accuracies.
    if self.task_config.init_cls_pooler:
      pretrain2finetune_mapping[
          'next_sentence.pooler_dense'] = model.checkpoint_items[
              'sentence_prediction.pooler_dense']
225
    ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping)
Hongkun Yu's avatar
Hongkun Yu committed
226
    status = ckpt.read(ckpt_dir_or_file)
227
    status.expect_partial().assert_existing_objects_matched()
Hongkun Yu's avatar
Hongkun Yu committed
228
    logging.info('Finished loading pretrained checkpoint from %s',
Hongkun Yu's avatar
Hongkun Yu committed
229
                 ckpt_dir_or_file)
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278


def predict(task: SentencePredictionTask, params: cfg.DataConfig,
            model: tf.keras.Model) -> List[Union[int, float]]:
  """Predicts on the input data.

  Args:
    task: A `SentencePredictionTask` object.
    params: A `cfg.DataConfig` object.
    model: A keras.Model.

  Returns:
    A list of predictions with length of `num_examples`. For regression task,
      each element in the list is the predicted score; for classification task,
      each element is the predicted class id.
  """
  is_regression = task.task_config.model.num_classes == 1

  @tf.function
  def predict_step(iterator):
    """Predicts on distributed devices."""

    def _replicated_step(inputs):
      """Replicated prediction calculation."""
      x, _ = inputs
      outputs = task.inference_step(x, model)
      if is_regression:
        return outputs
      else:
        return tf.argmax(outputs, axis=-1)

    outputs = tf.distribute.get_strategy().run(
        _replicated_step, args=(next(iterator),))
    return tf.nest.map_structure(
        tf.distribute.get_strategy().experimental_local_results, outputs)

  def reduce_fn(state, outputs):
    """Concatenates model's outputs."""
    for per_replica_batch_predictions in outputs:
      state.extend(per_replica_batch_predictions)
    return state

  loop_fn = orbit.utils.create_loop_fn(predict_step)
  dataset = orbit.utils.make_distributed_dataset(tf.distribute.get_strategy(),
                                                 task.build_inputs, params)
  # Set `num_steps` to -1 to exhaust the dataset.
  predictions = loop_fn(
      iter(dataset), num_steps=-1, state=[], reduce_fn=reduce_fn)
  return predictions