cls_head.py 15.8 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Hongkun Yu's avatar
Hongkun Yu committed
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Frederick Liu's avatar
Frederick Liu committed
14

Hongkun Yu's avatar
Hongkun Yu committed
15
16
17
18
19
20
"""A Classification head layer which is common used with sequence encoders."""

import tensorflow as tf

from official.modeling import tf_utils

21
22
23
from official.nlp.modeling.layers import gaussian_process
from official.nlp.modeling.layers import spectral_normalization

Hongkun Yu's avatar
Hongkun Yu committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

class ClassificationHead(tf.keras.layers.Layer):
  """Pooling head for sentence-level classification tasks."""

  def __init__(self,
               inner_dim,
               num_classes,
               cls_token_idx=0,
               activation="tanh",
               dropout_rate=0.0,
               initializer="glorot_uniform",
               **kwargs):
    """Initializes the `ClassificationHead`.

    Args:
39
40
      inner_dim: The dimensionality of inner projection layer. If 0 or `None`
        then only the output projection layer is created.
Hongkun Yu's avatar
Hongkun Yu committed
41
42
43
44
45
46
47
      num_classes: Number of output classes.
      cls_token_idx: The index inside the sequence to pool.
      activation: Dense layer activation.
      dropout_rate: Dropout probability.
      initializer: Initializer for dense layer kernels.
      **kwargs: Keyword arguments.
    """
Hongkun Yu's avatar
Hongkun Yu committed
48
    super().__init__(**kwargs)
Hongkun Yu's avatar
Hongkun Yu committed
49
50
51
52
53
54
55
    self.dropout_rate = dropout_rate
    self.inner_dim = inner_dim
    self.num_classes = num_classes
    self.activation = tf_utils.get_activation(activation)
    self.initializer = tf.keras.initializers.get(initializer)
    self.cls_token_idx = cls_token_idx

56
57
58
59
    if self.inner_dim:
      self.dense = tf.keras.layers.Dense(
          units=self.inner_dim,
          activation=self.activation,
Scott Zhu's avatar
Scott Zhu committed
60
          kernel_initializer=tf_utils.clone_initializer(self.initializer),
61
          name="pooler_dense")
Hongkun Yu's avatar
Hongkun Yu committed
62
    self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
63

Hongkun Yu's avatar
Hongkun Yu committed
64
    self.out_proj = tf.keras.layers.Dense(
Scott Zhu's avatar
Scott Zhu committed
65
66
67
        units=num_classes,
        kernel_initializer=tf_utils.clone_initializer(self.initializer),
        name="logits")
Hongkun Yu's avatar
Hongkun Yu committed
68

Hongkun Yu's avatar
Hongkun Yu committed
69
70
71
72
73
74
75
76
77
78
79
80
81
  def call(self, features: tf.Tensor, only_project: bool = False):
    """Implements call().

    Args:
      features: a rank-3 Tensor when self.inner_dim is specified, otherwise
        it is a rank-2 Tensor.
      only_project: a boolean. If True, we return the intermediate Tensor
        before projecting to class logits.

    Returns:
      a Tensor, if only_project is True, shape= [batch size, hidden size].
      If only_project is False, shape= [batch size, num classes].
    """
82
83
84
85
86
87
    if not self.inner_dim:
      x = features
    else:
      x = features[:, self.cls_token_idx, :]  # take <CLS> token.
      x = self.dense(x)

Hongkun Yu's avatar
Hongkun Yu committed
88
89
90
    if only_project:
      return x
    x = self.dropout(x)
Hongkun Yu's avatar
Hongkun Yu committed
91
92
93
94
95
    x = self.out_proj(x)
    return x

  def get_config(self):
    config = {
Hongkun Yu's avatar
Hongkun Yu committed
96
        "cls_token_idx": self.cls_token_idx,
Hongkun Yu's avatar
Hongkun Yu committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
        "dropout_rate": self.dropout_rate,
        "num_classes": self.num_classes,
        "inner_dim": self.inner_dim,
        "activation": tf.keras.activations.serialize(self.activation),
        "initializer": tf.keras.initializers.serialize(self.initializer),
    }
    config.update(super(ClassificationHead, self).get_config())
    return config

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)

  @property
  def checkpoint_items(self):
    return {self.dense.name: self.dense}
Hongkun Yu's avatar
Hongkun Yu committed
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128


class MultiClsHeads(tf.keras.layers.Layer):
  """Pooling heads sharing the same pooling stem."""

  def __init__(self,
               inner_dim,
               cls_list,
               cls_token_idx=0,
               activation="tanh",
               dropout_rate=0.0,
               initializer="glorot_uniform",
               **kwargs):
    """Initializes the `MultiClsHeads`.

    Args:
129
130
      inner_dim: The dimensionality of inner projection layer. If 0 or `None`
        then only the output projection layer is created.
Hongkun Yu's avatar
Hongkun Yu committed
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
      cls_list: a list of pairs of (classification problem name and the numbers
        of classes.
      cls_token_idx: The index inside the sequence to pool.
      activation: Dense layer activation.
      dropout_rate: Dropout probability.
      initializer: Initializer for dense layer kernels.
      **kwargs: Keyword arguments.
    """
    super().__init__(**kwargs)
    self.dropout_rate = dropout_rate
    self.inner_dim = inner_dim
    self.cls_list = cls_list
    self.activation = tf_utils.get_activation(activation)
    self.initializer = tf.keras.initializers.get(initializer)
    self.cls_token_idx = cls_token_idx

147
148
149
150
    if self.inner_dim:
      self.dense = tf.keras.layers.Dense(
          units=inner_dim,
          activation=self.activation,
Scott Zhu's avatar
Scott Zhu committed
151
          kernel_initializer=tf_utils.clone_initializer(self.initializer),
152
          name="pooler_dense")
Hongkun Yu's avatar
Hongkun Yu committed
153
    self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
Hongkun Yu's avatar
Hongkun Yu committed
154
155
156
157
    self.out_projs = []
    for name, num_classes in cls_list:
      self.out_projs.append(
          tf.keras.layers.Dense(
Scott Zhu's avatar
Scott Zhu committed
158
159
              units=num_classes,
              kernel_initializer=tf_utils.clone_initializer(self.initializer),
Hongkun Yu's avatar
Hongkun Yu committed
160
161
              name=name))

Hongkun Yu's avatar
Hongkun Yu committed
162
163
164
165
166
167
168
169
170
171
172
173
174
  def call(self, features: tf.Tensor, only_project: bool = False):
    """Implements call().

    Args:
      features: a rank-3 Tensor when self.inner_dim is specified, otherwise
        it is a rank-2 Tensor.
      only_project: a boolean. If True, we return the intermediate Tensor
        before projecting to class logits.

    Returns:
      If only_project is True, a Tensor with shape= [batch size, hidden size].
      If only_project is False, a dictionary of Tensors.
    """
175
176
177
178
179
    if not self.inner_dim:
      x = features
    else:
      x = features[:, self.cls_token_idx, :]  # take <CLS> token.
      x = self.dense(x)
Hongkun Yu's avatar
Hongkun Yu committed
180
181
182
183

    if only_project:
      return x
    x = self.dropout(x)
184

Hongkun Yu's avatar
Hongkun Yu committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    outputs = {}
    for proj_layer in self.out_projs:
      outputs[proj_layer.name] = proj_layer(x)
    return outputs

  def get_config(self):
    config = {
        "dropout_rate": self.dropout_rate,
        "cls_token_idx": self.cls_token_idx,
        "cls_list": self.cls_list,
        "inner_dim": self.inner_dim,
        "activation": tf.keras.activations.serialize(self.activation),
        "initializer": tf.keras.initializers.serialize(self.initializer),
    }
    config.update(super().get_config())
    return config

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)

  @property
  def checkpoint_items(self):
208
209
210
    items = {self.dense.name: self.dense}
    items.update({v.name: v for v in self.out_projs})
    return items
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238


class GaussianProcessClassificationHead(ClassificationHead):
  """Gaussian process-based pooling head for sentence classification.

  This class implements a classifier head for BERT encoder that is based on the
  spectral-normalized neural Gaussian process (SNGP) [1]. SNGP is a simple
  method to improve a neural network's uncertainty quantification ability
  without sacrificing accuracy or lantency. It applies spectral normalization to
  the hidden pooler layer, and then replaces the dense output layer with a
  Gaussian process.


  [1]: Jeremiah Liu et al. Simple and Principled Uncertainty Estimation with
       Deterministic Deep Learning via Distance Awareness.
       In _Neural Information Processing Systems_, 2020.
       https://arxiv.org/abs/2006.10108
  """

  def __init__(self,
               inner_dim,
               num_classes,
               cls_token_idx=0,
               activation="tanh",
               dropout_rate=0.0,
               initializer="glorot_uniform",
               use_spec_norm=True,
               use_gp_layer=True,
239
               temperature=None,
240
241
242
243
               **kwargs):
    """Initializes the `GaussianProcessClassificationHead`.

    Args:
244
245
      inner_dim: The dimensionality of inner projection layer. If 0 or `None`
        then only the output projection layer is created.
246
247
248
249
250
251
252
      num_classes: Number of output classes.
      cls_token_idx: The index inside the sequence to pool.
      activation: Dense layer activation.
      dropout_rate: Dropout probability.
      initializer: Initializer for dense layer kernels.
      use_spec_norm: Whether to apply spectral normalization to pooler layer.
      use_gp_layer: Whether to use Gaussian process as the output layer.
253
254
255
      temperature: The temperature parameter to be used for mean-field
        approximation during inference. If None then no mean-field adjustment is
        applied.
256
257
258
259
260
261
262
      **kwargs: Additional keyword arguments.
    """
    # Collects spectral normalization and Gaussian process args from kwargs.
    self.use_spec_norm = use_spec_norm
    self.use_gp_layer = use_gp_layer
    self.spec_norm_kwargs = extract_spec_norm_kwargs(kwargs)
    self.gp_layer_kwargs = extract_gp_layer_kwargs(kwargs)
263
    self.temperature = temperature
264
265
266
267
268
269
270
271
272
273

    super().__init__(
        inner_dim=inner_dim,
        num_classes=num_classes,
        cls_token_idx=cls_token_idx,
        activation=activation,
        dropout_rate=dropout_rate,
        initializer=initializer,
        **kwargs)

274
275
    # Applies spectral normalization to the dense pooler layer.
    if self.use_spec_norm and hasattr(self, "dense"):
276
277
278
279
280
281
282
      self.dense = spectral_normalization.SpectralNormalization(
          self.dense, inhere_layer_name=True, **self.spec_norm_kwargs)

    # Replace Dense output layer with the Gaussian process layer.
    if use_gp_layer:
      self.out_proj = gaussian_process.RandomFeatureGaussianProcess(
          self.num_classes,
Scott Zhu's avatar
Scott Zhu committed
283
          kernel_initializer=tf_utils.clone_initializer(self.initializer),
284
285
286
          name="logits",
          **self.gp_layer_kwargs)

287
  def call(self, features, training=False, return_covmat=False):
288
289
    """Returns model output.

290
291
292
    Dring training, the model returns raw logits. During evaluation, the model
    returns uncertainty adjusted logits, and (optionally) the covariance matrix.

293
294
    Arguments:
      features: A tensor of input features, shape (batch_size, feature_dim).
295
      training: Whether the model is in training mode.
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
      return_covmat: Whether the model should also return covariance matrix if
        `use_gp_layer=True`. During training, it is recommended to set
        `return_covmat=False` to be compatible with the standard Keras pipelines
        (e.g., `model.fit()`).

    Returns:
      logits: Uncertainty-adjusted predictive logits, shape
        (batch_size, num_classes).
      covmat: (Optional) Covariance matrix, shape (batch_size, batch_size).
        Returned only when return_covmat=True.
    """
    logits = super().call(features)

    # Extracts logits and covariance matrix from model output.
    if self.use_gp_layer:
      logits, covmat = logits
    else:
      covmat = None

315
316
317
318
    # Computes the uncertainty-adjusted logits during evaluation.
    if not training:
      logits = gaussian_process.mean_field_logits(
          logits, covmat, mean_field_factor=self.temperature)
319
320
321
322
323
324
325
326
327
328

    if return_covmat and covmat is not None:
      return logits, covmat
    return logits

  def reset_covariance_matrix(self):
    """Resets covariance matrix of the Gaussian process layer."""
    if hasattr(self.out_proj, "reset_covariance_matrix"):
      self.out_proj.reset_covariance_matrix()

329
330
331
332
333
334
  def get_config(self):
    config = dict(
        use_spec_norm=self.use_spec_norm, use_gp_layer=self.use_gp_layer)

    config.update(self.spec_norm_kwargs)
    config.update(self.gp_layer_kwargs)
335
    config["temperature"] = self.temperature
336
337
338
339
340
341
342
343
344
345
346
347

    config.update(super(GaussianProcessClassificationHead, self).get_config())
    return config


def extract_gp_layer_kwargs(kwargs):
  """Extracts Gaussian process layer configs from a given kwarg."""

  return dict(
      num_inducing=kwargs.pop("num_inducing", 1024),
      normalize_input=kwargs.pop("normalize_input", True),
      gp_cov_momentum=kwargs.pop("gp_cov_momentum", 0.999),
348
      gp_cov_ridge_penalty=kwargs.pop("gp_cov_ridge_penalty", 1.),
349
      scale_random_features=kwargs.pop("scale_random_features", False),
350
      l2_regularization=kwargs.pop("l2_regularization", 1e-6),
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
      gp_cov_likelihood=kwargs.pop("gp_cov_likelihood", "gaussian"),
      return_gp_cov=kwargs.pop("return_gp_cov", True),
      return_random_features=kwargs.pop("return_random_features", False),
      use_custom_random_features=kwargs.pop("use_custom_random_features", True),
      custom_random_features_initializer=kwargs.pop(
          "custom_random_features_initializer", "random_normal"),
      custom_random_features_activation=kwargs.pop(
          "custom_random_features_activation", None))


def extract_spec_norm_kwargs(kwargs):
  """Extracts spectral normalization configs from a given kwarg."""

  return dict(
      iteration=kwargs.pop("iteration", 1),
      norm_multiplier=kwargs.pop("norm_multiplier", .99))
Jiayu Ye's avatar
Jiayu Ye committed
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460


class PerQueryDenseHead(tf.keras.layers.Layer):
  """Pooling head used for EncT5 style models.

    This module projects each query to use a different projection.

    For a input shape= [bs, num_queries, hidden_size], it projects each query to
    (features). Ending up with shape= [bs, num_queries, features].

    For example, for classification with a few classes, one may use num_queries
    as 1 and features as number of classes. For multilabel classification, one
    may use num_queries as number of classes and features as 2. So each query
    represents a binary classification of one label.
  """

  def __init__(self,
               num_queries: int,
               features: int,
               use_bias: bool = False,
               kernel_initializer: str = "glorot_uniform",
               **kwargs):
    """Initializes the `PerQueryDenseHead`.

    Args:
      num_queries: number of queries (the learnable embeddings in the input
        sequences) from the decoder.
      features: int with numbers of output features. Each query with be
        projected to this number with a different projection.
      use_bias: whether to add a bias to the output.
      kernel_initializer: Initializer for dense layer kernels.
      **kwargs: Keyword arguments.
    """
    super().__init__(**kwargs)
    self.num_queries = num_queries
    self.features = features

    self.use_bias = use_bias
    self.kernel_initializer = tf.keras.initializers.get(kernel_initializer)

  def build(self, input_shape):
    input_shape = tf.TensorShape(input_shape)
    # Hidden size.
    last_dim = tf.compat.dimension_value(input_shape[-1])

    self.hidden_size = last_dim
    self.kernel = self.add_weight(
        "kernel",
        shape=[self.num_queries, last_dim, self.features],
        initializer=self.kernel_initializer,
        dtype=self.dtype,
        trainable=True)
    if self.use_bias:
      self.bias = self.add_weight(
          "bias",
          shape=[
              self.num_queries,
              self.features,
          ],
          dtype=self.dtype,
          trainable=True)
    else:
      self.bias = None

  def call(self, inputs: tf.Tensor) -> tf.Tensor:
    """Implements call().

    Args:
      inputs: a rank-3 Tensor of shape= [bs, num_queries, hidden_size].

    Returns:
      A Tensor, shape= [batch size, num_queries, features].
    """

    outputs = tf.einsum("bqh,qhf->bqf", inputs, self.kernel)
    if self.use_bias:
      outputs += self.bias
    return outputs

  def get_config(self):
    config = {
        "num_queries":
            self.num_queries,
        "features":
            self.features,
        "kernel_initializer":
            tf.keras.activations.serialize(self.kernel_initializer),
    }
    config.update(super(PerQueryDenseHead, self).get_config())
    return config

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)