sentence_prediction.py 7.56 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sentence prediction (classification) task."""
17
from absl import logging
18
import dataclasses
19
20
21
import numpy as np
from scipy import stats
from sklearn import metrics as sklearn_metrics
22
23
24
25
import tensorflow as tf
import tensorflow_hub as hub

from official.core import base_task
Hongkun Yu's avatar
Hongkun Yu committed
26
from official.modeling.hyperparams import base_config
27
from official.modeling.hyperparams import config_definitions as cfg
Hongkun Yu's avatar
Hongkun Yu committed
28
from official.nlp.configs import encoders
Chen Chen's avatar
Chen Chen committed
29
from official.nlp.data import data_loader_factory
Hongkun Yu's avatar
Hongkun Yu committed
30
from official.nlp.modeling import models
Chen Chen's avatar
Chen Chen committed
31
from official.nlp.tasks import utils
32
33


Hongkun Yu's avatar
Hongkun Yu committed
34
35
36
37
38
39
40
41
42
@dataclasses.dataclass
class ModelConfig(base_config.Config):
  """A classifier/regressor configuration."""
  num_classes: int = 0
  use_encoder_pooler: bool = False
  encoder: encoders.TransformerEncoderConfig = (
      encoders.TransformerEncoderConfig())


43
44
45
@dataclasses.dataclass
class SentencePredictionConfig(cfg.TaskConfig):
  """The model config."""
Hongkun Yu's avatar
Hongkun Yu committed
46
  # At most one of `init_checkpoint` and `hub_module_url` can
47
  # be specified.
Hongkun Yu's avatar
Hongkun Yu committed
48
  init_checkpoint: str = ''
Hongkun Yu's avatar
Hongkun Yu committed
49
  init_cls_pooler: bool = False
50
  hub_module_url: str = ''
51
  metric_type: str = 'accuracy'
Hongkun Yu's avatar
Hongkun Yu committed
52
53
  # Defines the concrete model config at instantiation time.
  model: ModelConfig = ModelConfig()
54
55
56
57
58
59
60
61
  train_data: cfg.DataConfig = cfg.DataConfig()
  validation_data: cfg.DataConfig = cfg.DataConfig()


@base_task.register_task_cls(SentencePredictionConfig)
class SentencePredictionTask(base_task.Task):
  """Task object for sentence_prediction."""

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
62
63
  def __init__(self, params=cfg.TaskConfig, logging_dir=None):
    super(SentencePredictionTask, self).__init__(params, logging_dir)
Hongkun Yu's avatar
Hongkun Yu committed
64
    if params.hub_module_url and params.init_checkpoint:
65
      raise ValueError('At most one of `hub_module_url` and '
Hongkun Yu's avatar
Hongkun Yu committed
66
                       '`init_checkpoint` can be specified.')
67
68
69
70
    if params.hub_module_url:
      self._hub_module = hub.load(params.hub_module_url)
    else:
      self._hub_module = None
71
    self.metric_type = params.metric_type
72
73
74

  def build_model(self):
    if self._hub_module:
Hongkun Yu's avatar
Hongkun Yu committed
75
      encoder_network = utils.get_encoder_from_hub(self._hub_module)
76
    else:
Hongkun Yu's avatar
Hongkun Yu committed
77
78
79
80
81
82
83
84
85
86
      encoder_network = encoders.instantiate_encoder_from_cfg(
          self.task_config.model.encoder)

    # Currently, we only supports bert-style sentence prediction finetuning.
    return models.BertClassifier(
        network=encoder_network,
        num_classes=self.task_config.model.num_classes,
        initializer=tf.keras.initializers.TruncatedNormal(
            stddev=self.task_config.model.encoder.initializer_range),
        use_encoder_pooler=self.task_config.model.use_encoder_pooler)
87

88
  def build_losses(self, labels, model_outputs, aux_losses=None) -> tf.Tensor:
89
    loss = tf.keras.losses.sparse_categorical_crossentropy(
Hongkun Yu's avatar
Hongkun Yu committed
90
        labels, tf.cast(model_outputs, tf.float32), from_logits=True)
91
92
93
94
95
96
97
98

    if aux_losses:
      loss += tf.add_n(aux_losses)
    return loss

  def build_inputs(self, params, input_context=None):
    """Returns tf.data.Dataset for sentence_prediction task."""
    if params.input_path == 'dummy':
Hongkun Yu's avatar
Hongkun Yu committed
99

100
101
102
103
104
105
      def dummy_data(_):
        dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
        x = dict(
            input_word_ids=dummy_ids,
            input_mask=dummy_ids,
            input_type_ids=dummy_ids)
106
        y = tf.zeros((1, 1), dtype=tf.int32)
107
108
109
110
111
112
113
114
        return (x, y)

      dataset = tf.data.Dataset.range(1)
      dataset = dataset.repeat()
      dataset = dataset.map(
          dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
      return dataset

Chen Chen's avatar
Chen Chen committed
115
    return data_loader_factory.get_data_loader(params).load(input_context)
116
117
118

  def build_metrics(self, training=None):
    del training
Hongkun Yu's avatar
Hongkun Yu committed
119
    metrics = [tf.keras.metrics.SparseCategoricalAccuracy(name='cls_accuracy')]
120
121
    return metrics

122
  def process_metrics(self, metrics, labels, model_outputs):
123
    for metric in metrics:
Hongkun Yu's avatar
Hongkun Yu committed
124
      metric.update_state(labels, model_outputs)
125

126
  def process_compiled_metrics(self, compiled_metrics, labels, model_outputs):
Hongkun Yu's avatar
Hongkun Yu committed
127
    compiled_metrics.update_state(labels, model_outputs)
128

129
130
131
132
133
134
135
136
  def validation_step(self, inputs, model: tf.keras.Model, metrics=None):
    if self.metric_type == 'accuracy':
      return super(SentencePredictionTask,
                   self).validation_step(inputs, model, metrics)
    features, labels = inputs
    outputs = self.inference_step(features, model)
    loss = self.build_losses(
        labels=labels, model_outputs=outputs, aux_losses=model.losses)
Hongkun Yu's avatar
Hongkun Yu committed
137
    logs = {self.loss: loss}
138
    if self.metric_type == 'matthews_corrcoef':
Hongkun Yu's avatar
Hongkun Yu committed
139
      logs.update({
140
          'sentence_prediction':
Hongkun Yu's avatar
Hongkun Yu committed
141
              tf.expand_dims(tf.math.argmax(outputs, axis=1), axis=0),
142
143
          'labels':
              labels,
Hongkun Yu's avatar
Hongkun Yu committed
144
      })
145
    if self.metric_type == 'pearson_spearman_corr':
Hongkun Yu's avatar
Hongkun Yu committed
146
      logs.update({
Hongkun Yu's avatar
Hongkun Yu committed
147
          'sentence_prediction': outputs,
148
          'labels': labels,
Hongkun Yu's avatar
Hongkun Yu committed
149
150
      })
    return logs
151
152

  def aggregate_logs(self, state=None, step_outputs=None):
Hongkun Yu's avatar
Hongkun Yu committed
153
154
    if self.metric_type == 'accuracy':
      return None
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    if state is None:
      state = {'sentence_prediction': [], 'labels': []}
    state['sentence_prediction'].append(
        np.concatenate([v.numpy() for v in step_outputs['sentence_prediction']],
                       axis=0))
    state['labels'].append(
        np.concatenate([v.numpy() for v in step_outputs['labels']], axis=0))
    return state

  def reduce_aggregated_logs(self, aggregated_logs):
    if self.metric_type == 'matthews_corrcoef':
      preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
      labels = np.concatenate(aggregated_logs['labels'], axis=0)
      return {
          self.metric_type: sklearn_metrics.matthews_corrcoef(preds, labels)
      }
    if self.metric_type == 'pearson_spearman_corr':
      preds = np.concatenate(aggregated_logs['sentence_prediction'], axis=0)
      labels = np.concatenate(aggregated_logs['labels'], axis=0)
      pearson_corr = stats.pearsonr(preds, labels)[0]
      spearman_corr = stats.spearmanr(preds, labels)[0]
      corr_metric = (pearson_corr + spearman_corr) / 2
      return {self.metric_type: corr_metric}

179
180
  def initialize(self, model):
    """Load a pretrained checkpoint (if exists) and then train from iter 0."""
Hongkun Yu's avatar
Hongkun Yu committed
181
182
183
184
    ckpt_dir_or_file = self.task_config.init_checkpoint
    if tf.io.gfile.isdir(ckpt_dir_or_file):
      ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
    if not ckpt_dir_or_file:
185
186
187
      return

    pretrain2finetune_mapping = {
Hongkun Yu's avatar
Hongkun Yu committed
188
        'encoder': model.checkpoint_items['encoder'],
189
    }
Hongkun Yu's avatar
Hongkun Yu committed
190
191
192
193
194
195
    # TODO(b/160251903): Investigate why no pooler dense improves finetuning
    # accuracies.
    if self.task_config.init_cls_pooler:
      pretrain2finetune_mapping[
          'next_sentence.pooler_dense'] = model.checkpoint_items[
              'sentence_prediction.pooler_dense']
196
    ckpt = tf.train.Checkpoint(**pretrain2finetune_mapping)
Hongkun Yu's avatar
Hongkun Yu committed
197
    status = ckpt.read(ckpt_dir_or_file)
198
    status.expect_partial().assert_existing_objects_matched()
Hongkun Yu's avatar
Hongkun Yu committed
199
    logging.info('Finished loading pretrained checkpoint from %s',
Hongkun Yu's avatar
Hongkun Yu committed
200
                 ckpt_dir_or_file)