sentence_prediction.py 7.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sentence prediction (classification) task."""
17
from absl import logging
18
import dataclasses
19
20
21
import numpy as np
from scipy import stats
from sklearn import metrics as sklearn_metrics
22
23
24
25
26
27
28
import tensorflow as tf
import tensorflow_hub as hub

from official.core import base_task
from official.modeling.hyperparams import config_definitions as cfg
from official.nlp.configs import bert
from official.nlp.data import sentence_prediction_dataloader
Chen Chen's avatar
Chen Chen committed
29
from official.nlp.tasks import utils
30
31
32
33
34


@dataclasses.dataclass
class SentencePredictionConfig(cfg.TaskConfig):
  """The model config."""
Hongkun Yu's avatar
Hongkun Yu committed
35
  # At most one of `init_checkpoint` and `hub_module_url` can
36
  # be specified.
Hongkun Yu's avatar
Hongkun Yu committed
37
  init_checkpoint: str = ''
Hongkun Yu's avatar
Hongkun Yu committed
38
  init_cls_pooler: bool = False
39
  hub_module_url: str = ''
40
  metric_type: str = 'accuracy'
Pengchong Jin's avatar
Pengchong Jin committed
41
  model: bert.BertPretrainerConfig = bert.BertPretrainerConfig(
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
42
      num_masked_tokens=0,  # No masked language modeling head.
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
      cls_heads=[
          bert.ClsHeadConfig(
              inner_dim=768,
              num_classes=3,
              dropout_rate=0.1,
              name='sentence_prediction')
      ])
  train_data: cfg.DataConfig = cfg.DataConfig()
  validation_data: cfg.DataConfig = cfg.DataConfig()


@base_task.register_task_cls(SentencePredictionConfig)
class SentencePredictionTask(base_task.Task):
  """Task object for sentence_prediction."""

  def __init__(self, params=cfg.TaskConfig):
    super(SentencePredictionTask, self).__init__(params)
Hongkun Yu's avatar
Hongkun Yu committed
60
    if params.hub_module_url and params.init_checkpoint:
61
      raise ValueError('At most one of `hub_module_url` and '
Hongkun Yu's avatar
Hongkun Yu committed
62
                       '`init_checkpoint` can be specified.')
63
64
65
66
    if params.hub_module_url:
      self._hub_module = hub.load(params.hub_module_url)
    else:
      self._hub_module = None
67
    self.metric_type = params.metric_type
68
69
70

  def build_model(self):
    if self._hub_module:
Chen Chen's avatar
Chen Chen committed
71
      encoder_from_hub = utils.get_encoder_from_hub(self._hub_module)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
72
      return bert.instantiate_bertpretrainer_from_cfg(
Pengchong Jin's avatar
Pengchong Jin committed
73
          self.task_config.model, encoder_network=encoder_from_hub)
74
    else:
Pengchong Jin's avatar
Pengchong Jin committed
75
      return bert.instantiate_bertpretrainer_from_cfg(self.task_config.model)
76

77
  def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
78
79
80
81
    loss = tf.keras.losses.sparse_categorical_crossentropy(
        labels,
        tf.cast(model_outputs['sentence_prediction'], tf.float32),
        from_logits=True)
82
83
84
85
86
87
88
89

    if aux_losses:
      loss += tf.add_n(aux_losses)
    return loss

  def build_inputs(self, params, input_context=None):
    """Returns tf.data.Dataset for sentence_prediction task."""
    if params.input_path == 'dummy':
Hongkun Yu's avatar
Hongkun Yu committed
90

91
92
93
94
95
96
      def dummy_data(_):
        dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
        x = dict(
            input_word_ids=dummy_ids,
            input_mask=dummy_ids,
            input_type_ids=dummy_ids)
97
        y = tf.zeros((1, 1), dtype=tf.int32)
98
99
100
101
102
103
104
105
106
107
108
109
110
        return (x, y)

      dataset = tf.data.Dataset.range(1)
      dataset = dataset.repeat()
      dataset = dataset.map(
          dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
      return dataset

    return sentence_prediction_dataloader.SentencePredictionDataLoader(
        params).load(input_context)

  def build_metrics(self, training=None):
    del training
Hongkun Yu's avatar
Hongkun Yu committed
111
    metrics = [tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy')]
112
113
    return metrics

114
  def process_metrics(self, metrics, labels, model_outputs):
115
    for metric in metrics:
116
      metric.update_state(labels, model_outputs['sentence_prediction'])
117

118
119
  def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
    compiled_metrics.update_state(labels, model_outputs['sentence_prediction'])
120

121
122
123
124
125
126
127
128
  def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
    if self.metric_type == 'accuracy':
      return super(SentencePredictionTask,
                   self).validation_step(inputs, model, metrics)
    features, labels = inputs
    outputs = self.inference_step(features, model)
    loss = self.build_losses(
        labels=labels, model_outputs=outputs, aux_losses=model.losses)
Hongkun Yu's avatar
Hongkun Yu committed
129
    logs = {self.loss: loss}
130
    if self.metric_type == 'matthews_corrcoef':
Hongkun Yu's avatar
Hongkun Yu committed
131
      logs.update({
132
133
134
135
136
137
          'sentence_prediction':
              tf.expand_dims(
                  tf.math.argmax(outputs['sentence_prediction'], axis=1),
                  axis=0),
          'labels':
              labels,
Hongkun Yu's avatar
Hongkun Yu committed
138
      })
139
    if self.metric_type == 'pearson_spearman_corr':
Hongkun Yu's avatar
Hongkun Yu committed
140
      logs.update({
141
142
          'sentence_prediction': outputs['sentence_prediction'],
          'labels': labels,
Hongkun Yu's avatar
Hongkun Yu committed
143
144
      })
    return logs
145
146

  def aggregate_logs(self, state=None, step_outputs=None):
Hongkun Yu's avatar
Hongkun Yu committed
147
148
    if self.metric_type == 'accuracy':
      return None
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    if state is None:
      state = {'sentence_prediction': [], 'labels': []}
    state['sentence_prediction'].append(
        np.concatenate([v.numpy() for v in step_outputs['sentence_prediction']],
                       axis=0))
    state['labels'].append(
        np.concatenate([v.numpy() for v in step_outputs['labels']], axis=0))
    return state

  def reduce_aggregated_logs(self, aggregated_logs):
    if self.metric_type == 'matthews_corrcoef':
      preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
      labels = np.concatenate(aggregated_logs['labels'], axis=0)
      return {
          self.metric_type: sklearn_metrics.matthews_corrcoef(preds, labels)
      }
    if self.metric_type == 'pearson_spearman_corr':
      preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
      labels = np.concatenate(aggregated_logs['labels'], axis=0)
      pearson_corr = stats.pearsonr(preds, labels)[0]
      spearman_corr = stats.spearmanr(preds, labels)[0]
      corr_metric = (pearson_corr + spearman_corr) / 2
      return {self.metric_type: corr_metric}

173
174
  def initialize(self, model):
    """Load a pretrained checkpoint (if exists) and then train from iter 0."""
Hongkun Yu's avatar
Hongkun Yu committed
175
176
177
178
    ckpt_dir_or_file = self.task_config.init_checkpoint
    if tf.io.gfile.isdir(ckpt_dir_or_file):
      ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
    if not ckpt_dir_or_file:
179
180
181
      return

    pretrain2finetune_mapping = {
Hongkun Yu's avatar
Hongkun Yu committed
182
        'encoder': model.checkpoint_items['encoder'],
183
    }
Hongkun Yu's avatar
Hongkun Yu committed
184
185
186
187
188
189
    # TODO(b/160251903): Investigate why no pooler dense improves finetuning
    # accuracies.
    if self.task_config.init_cls_pooler:
      pretrain2finetune_mapping[
          'next_sentence.pooler_dense'] = model.checkpoint_items[
              'sentence_prediction.pooler_dense']
190
    ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping)
Hongkun Yu's avatar
Hongkun Yu committed
191
    status = ckpt.read(ckpt_dir_or_file)
192
    status.expect_partial().assert_existing_objects_matched()
Hongkun Yu's avatar
Hongkun Yu committed
193
194
    logging.info('finished loading pretrained checkpoint from %s',
                 ckpt_dir_or_file)