cls_head.py 12.8 KB
Newer Older
Frederick Liu's avatar
Frederick Liu committed
1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Hongkun Yu's avatar
Hongkun Yu committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frederick Liu's avatar
Frederick Liu committed
14

Hongkun Yu's avatar
Hongkun Yu committed
15
16
17
18
19
20
"""A Classification head layer which is common used with sequence encoders."""

import tensorflow as tf

from official.modeling import tf_utils

21
22
23
from official.nlp.modeling.layers import gaussian_process
from official.nlp.modeling.layers import spectral_normalization

Hongkun Yu's avatar
Hongkun Yu committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

class ClassificationHead(tf.keras.layers.Layer):
  """Pooling head for sentence-level classification tasks."""

  def __init__(self,
               inner_dim,
               num_classes,
               cls_token_idx=0,
               activation="tanh",
               dropout_rate=0.0,
               initializer="glorot_uniform",
               **kwargs):
    """Initializes the `ClassificationHead`.

    Args:
39
40
      inner_dim: The dimensionality of inner projection layer. If 0 or `None`
        then only the output projection layer is created.
Hongkun Yu's avatar
Hongkun Yu committed
41
42
43
44
45
46
47
      num_classes: Number of output classes.
      cls_token_idx: The index inside the sequence to pool.
      activation: Dense layer activation.
      dropout_rate: Dropout probability.
      initializer: Initializer for dense layer kernels.
      **kwargs: Keyword arguments.
    """
Hongkun Yu's avatar
Hongkun Yu committed
48
    super().__init__(**kwargs)
Hongkun Yu's avatar
Hongkun Yu committed
49
50
51
52
53
54
55
    self.dropout_rate = dropout_rate
    self.inner_dim = inner_dim
    self.num_classes = num_classes
    self.activation = tf_utils.get_activation(activation)
    self.initializer = tf.keras.initializers.get(initializer)
    self.cls_token_idx = cls_token_idx

56
57
58
59
60
61
    if self.inner_dim:
      self.dense = tf.keras.layers.Dense(
          units=self.inner_dim,
          activation=self.activation,
          kernel_initializer=self.initializer,
          name="pooler_dense")
Hongkun Yu's avatar
Hongkun Yu committed
62
    self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
63

Hongkun Yu's avatar
Hongkun Yu committed
64
65
66
    self.out_proj = tf.keras.layers.Dense(
        units=num_classes, kernel_initializer=self.initializer, name="logits")

Hongkun Yu's avatar
Hongkun Yu committed
67
68
69
70
71
72
73
74
75
76
77
78
79
  def call(self, features: tf.Tensor, only_project: bool = False):
    """Implements call().

    Args:
      features: a rank-3 Tensor when self.inner_dim is specified, otherwise
        it is a rank-2 Tensor.
      only_project: a boolean. If True, we return the intermediate Tensor
        before projecting to class logits.

    Returns:
      a Tensor, if only_project is True, shape= [batch size, hidden size].
      If only_project is False, shape= [batch size, num classes].
    """
80
81
82
83
84
85
    if not self.inner_dim:
      x = features
    else:
      x = features[:, self.cls_token_idx, :]  # take <CLS> token.
      x = self.dense(x)

Hongkun Yu's avatar
Hongkun Yu committed
86
87
88
    if only_project:
      return x
    x = self.dropout(x)
Hongkun Yu's avatar
Hongkun Yu committed
89
90
91
92
93
    x = self.out_proj(x)
    return x

  def get_config(self):
    config = {
Hongkun Yu's avatar
Hongkun Yu committed
94
        "cls_token_idx": self.cls_token_idx,
Hongkun Yu's avatar
Hongkun Yu committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        "dropout_rate": self.dropout_rate,
        "num_classes": self.num_classes,
        "inner_dim": self.inner_dim,
        "activation": tf.keras.activations.serialize(self.activation),
        "initializer": tf.keras.initializers.serialize(self.initializer),
    }
    config.update(super(ClassificationHead, self).get_config())
    return config

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)

  @property
  def checkpoint_items(self):
    return {self.dense.name: self.dense}
Hongkun Yu's avatar
Hongkun Yu committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126


class MultiClsHeads(tf.keras.layers.Layer):
  """Pooling heads sharing the same pooling stem."""

  def __init__(self,
               inner_dim,
               cls_list,
               cls_token_idx=0,
               activation="tanh",
               dropout_rate=0.0,
               initializer="glorot_uniform",
               **kwargs):
    """Initializes the `MultiClsHeads`.

    Args:
127
128
      inner_dim: The dimensionality of inner projection layer. If 0 or `None`
        then only the output projection layer is created.
Hongkun Yu's avatar
Hongkun Yu committed
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
      cls_list: a list of pairs of (classification problem name and the numbers
        of classes.
      cls_token_idx: The index inside the sequence to pool.
      activation: Dense layer activation.
      dropout_rate: Dropout probability.
      initializer: Initializer for dense layer kernels.
      **kwargs: Keyword arguments.
    """
    super().__init__(**kwargs)
    self.dropout_rate = dropout_rate
    self.inner_dim = inner_dim
    self.cls_list = cls_list
    self.activation = tf_utils.get_activation(activation)
    self.initializer = tf.keras.initializers.get(initializer)
    self.cls_token_idx = cls_token_idx

145
146
147
148
149
150
    if self.inner_dim:
      self.dense = tf.keras.layers.Dense(
          units=inner_dim,
          activation=self.activation,
          kernel_initializer=self.initializer,
          name="pooler_dense")
Hongkun Yu's avatar
Hongkun Yu committed
151
    self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
Hongkun Yu's avatar
Hongkun Yu committed
152
153
154
155
156
157
158
    self.out_projs = []
    for name, num_classes in cls_list:
      self.out_projs.append(
          tf.keras.layers.Dense(
              units=num_classes, kernel_initializer=self.initializer,
              name=name))

Hongkun Yu's avatar
Hongkun Yu committed
159
160
161
162
163
164
165
166
167
168
169
170
171
  def call(self, features: tf.Tensor, only_project: bool = False):
    """Implements call().

    Args:
      features: a rank-3 Tensor when self.inner_dim is specified, otherwise
        it is a rank-2 Tensor.
      only_project: a boolean. If True, we return the intermediate Tensor
        before projecting to class logits.

    Returns:
      If only_project is True, a Tensor with shape= [batch size, hidden size].
      If only_project is False, a dictionary of Tensors.
    """
172
173
174
175
176
    if not self.inner_dim:
      x = features
    else:
      x = features[:, self.cls_token_idx, :]  # take <CLS> token.
      x = self.dense(x)
Hongkun Yu's avatar
Hongkun Yu committed
177
178
179
180

    if only_project:
      return x
    x = self.dropout(x)
181

Hongkun Yu's avatar
Hongkun Yu committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    outputs = {}
    for proj_layer in self.out_projs:
      outputs[proj_layer.name] = proj_layer(x)
    return outputs

  def get_config(self):
    config = {
        "dropout_rate": self.dropout_rate,
        "cls_token_idx": self.cls_token_idx,
        "cls_list": self.cls_list,
        "inner_dim": self.inner_dim,
        "activation": tf.keras.activations.serialize(self.activation),
        "initializer": tf.keras.initializers.serialize(self.initializer),
    }
    config.update(super().get_config())
    return config

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)

  @property
  def checkpoint_items(self):
205
206
207
    items = {self.dense.name: self.dense}
    items.update({v.name: v for v in self.out_projs})
    return items
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235


class GaussianProcessClassificationHead(ClassificationHead):
  """Gaussian process-based pooling head for sentence classification.

  This class implements a classifier head for BERT encoder that is based on the
  spectral-normalized neural Gaussian process (SNGP) [1]. SNGP is a simple
  method to improve a neural network's uncertainty quantification ability
  without sacrificing accuracy or lantency. It applies spectral normalization to
  the hidden pooler layer, and then replaces the dense output layer with a
  Gaussian process.


  [1]: Jeremiah Liu et al. Simple and Principled Uncertainty Estimation with
       Deterministic Deep Learning via Distance Awareness.
       In _Neural Information Processing Systems_, 2020.
       https://arxiv.org/abs/2006.10108
  """

  def __init__(self,
               inner_dim,
               num_classes,
               cls_token_idx=0,
               activation="tanh",
               dropout_rate=0.0,
               initializer="glorot_uniform",
               use_spec_norm=True,
               use_gp_layer=True,
236
               temperature=None,
237
238
239
240
               **kwargs):
    """Initializes the `GaussianProcessClassificationHead`.

    Args:
241
242
      inner_dim: The dimensionality of inner projection layer. If 0 or `None`
        then only the output projection layer is created.
243
244
245
246
247
248
249
      num_classes: Number of output classes.
      cls_token_idx: The index inside the sequence to pool.
      activation: Dense layer activation.
      dropout_rate: Dropout probability.
      initializer: Initializer for dense layer kernels.
      use_spec_norm: Whether to apply spectral normalization to pooler layer.
      use_gp_layer: Whether to use Gaussian process as the output layer.
250
251
252
      temperature: The temperature parameter to be used for mean-field
        approximation during inference. If None then no mean-field adjustment is
        applied.
253
254
255
256
257
258
259
      **kwargs: Additional keyword arguments.
    """
    # Collects spectral normalization and Gaussian process args from kwargs.
    self.use_spec_norm = use_spec_norm
    self.use_gp_layer = use_gp_layer
    self.spec_norm_kwargs = extract_spec_norm_kwargs(kwargs)
    self.gp_layer_kwargs = extract_gp_layer_kwargs(kwargs)
260
    self.temperature = temperature
261
262
263
264
265
266
267
268
269
270

    super().__init__(
        inner_dim=inner_dim,
        num_classes=num_classes,
        cls_token_idx=cls_token_idx,
        activation=activation,
        dropout_rate=dropout_rate,
        initializer=initializer,
        **kwargs)

271
272
    # Applies spectral normalization to the dense pooler layer.
    if self.use_spec_norm and hasattr(self, "dense"):
273
274
275
276
277
278
279
280
281
282
283
      self.dense = spectral_normalization.SpectralNormalization(
          self.dense, inhere_layer_name=True, **self.spec_norm_kwargs)

    # Replace Dense output layer with the Gaussian process layer.
    if use_gp_layer:
      self.out_proj = gaussian_process.RandomFeatureGaussianProcess(
          self.num_classes,
          kernel_initializer=self.initializer,
          name="logits",
          **self.gp_layer_kwargs)

284
  def call(self, features, training=False, return_covmat=False):
285
286
    """Returns model output.

287
288
289
    Dring training, the model returns raw logits. During evaluation, the model
    returns uncertainty adjusted logits, and (optionally) the covariance matrix.

290
291
    Arguments:
      features: A tensor of input features, shape (batch_size, feature_dim).
292
      training: Whether the model is in training mode.
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
      return_covmat: Whether the model should also return covariance matrix if
        `use_gp_layer=True`. During training, it is recommended to set
        `return_covmat=False` to be compatible with the standard Keras pipelines
        (e.g., `model.fit()`).

    Returns:
      logits: Uncertainty-adjusted predictive logits, shape
        (batch_size, num_classes).
      covmat: (Optional) Covariance matrix, shape (batch_size, batch_size).
        Returned only when return_covmat=True.
    """
    logits = super().call(features)

    # Extracts logits and covariance matrix from model output.
    if self.use_gp_layer:
      logits, covmat = logits
    else:
      covmat = None

312
313
314
315
    # Computes the uncertainty-adjusted logits during evaluation.
    if not training:
      logits = gaussian_process.mean_field_logits(
          logits, covmat, mean_field_factor=self.temperature)
316
317
318
319
320
321
322
323
324
325

    if return_covmat and covmat is not None:
      return logits, covmat
    return logits

  def reset_covariance_matrix(self):
    """Resets covariance matrix of the Gaussian process layer."""
    if hasattr(self.out_proj, "reset_covariance_matrix"):
      self.out_proj.reset_covariance_matrix()

326
327
328
329
330
331
  def get_config(self):
    config = dict(
        use_spec_norm=self.use_spec_norm, use_gp_layer=self.use_gp_layer)

    config.update(self.spec_norm_kwargs)
    config.update(self.gp_layer_kwargs)
332
    config["temperature"] = self.temperature
333
334
335
336
337
338
339
340
341
342
343
344

    config.update(super(GaussianProcessClassificationHead, self).get_config())
    return config


def extract_gp_layer_kwargs(kwargs):
  """Extracts Gaussian process layer configs from a given kwarg."""

  return dict(
      num_inducing=kwargs.pop("num_inducing", 1024),
      normalize_input=kwargs.pop("normalize_input", True),
      gp_cov_momentum=kwargs.pop("gp_cov_momentum", 0.999),
345
      gp_cov_ridge_penalty=kwargs.pop("gp_cov_ridge_penalty", 1.),
346
      scale_random_features=kwargs.pop("scale_random_features", False),
347
      l2_regularization=kwargs.pop("l2_regularization", 1e-6),
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
      gp_cov_likelihood=kwargs.pop("gp_cov_likelihood", "gaussian"),
      return_gp_cov=kwargs.pop("return_gp_cov", True),
      return_random_features=kwargs.pop("return_random_features", False),
      use_custom_random_features=kwargs.pop("use_custom_random_features", True),
      custom_random_features_initializer=kwargs.pop(
          "custom_random_features_initializer", "random_normal"),
      custom_random_features_activation=kwargs.pop(
          "custom_random_features_activation", None))


def extract_spec_norm_kwargs(kwargs):
  """Extracts spectral normalization configs from a given kwarg."""

  return dict(
      iteration=kwargs.pop("iteration", 1),
      norm_multiplier=kwargs.pop("norm_multiplier", .99))