masked_softmax.py 3.03 KB
Newer Older
Hongkun Yu's avatar
Hongkun Yu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras-based softmax layer with optional masking."""
16
# pylint: disable=g-classes-have-attributes
Hongkun Yu's avatar
Hongkun Yu committed
17
18
19
20

import tensorflow as tf


Hongkun Yu's avatar
Hongkun Yu committed
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def _large_compatible_negative(tensor_type):
  """Large negative number as Tensor.

  This function is necessary because the standard value for epsilon
  in this module (-1e9) cannot be represented using tf.float16

  Args:
    tensor_type: a dtype to determine the type.

  Returns:
    a large negative number.
  """
  if tensor_type == tf.float16:
    return tf.float16.min
  return -1e9


Hongkun Yu's avatar
Hongkun Yu committed
38
39
40
41
@tf.keras.utils.register_keras_serializable(package='Text')
class MaskedSoftmax(tf.keras.layers.Layer):
  """Performs a softmax with optional masking on a tensor.

42
  Args:
Hongkun Yu's avatar
Hongkun Yu committed
43
    mask_expansion_axes: Any axes that should be padded on the mask tensor.
44
    normalization_axes: On which axes the softmax should perform.
Hongkun Yu's avatar
Hongkun Yu committed
45
46
  """

47
48
49
50
  def __init__(self,
               mask_expansion_axes=None,
               normalization_axes=None,
               **kwargs):
Hongkun Yu's avatar
Hongkun Yu committed
51
    self._mask_expansion_axes = mask_expansion_axes
52
53
54
55
    if normalization_axes is None:
      self._normalization_axes = (-1,)
    else:
      self._normalization_axes = normalization_axes
Hongkun Yu's avatar
Hongkun Yu committed
56
57
    super(MaskedSoftmax, self).__init__(**kwargs)

58
  def call(self, scores, mask=None):
Hongkun Yu's avatar
Hongkun Yu committed
59
60

    if mask is not None:
61
      for _ in range(len(scores.shape) - len(mask.shape)):
Hongkun Yu's avatar
Hongkun Yu committed
62
63
64
65
        mask = tf.expand_dims(mask, axis=self._mask_expansion_axes)

      # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
      # masked positions, this operation will create a tensor which is 0.0 for
Hongkun Yu's avatar
Hongkun Yu committed
66
67
68
      # positions we want to attend and -1.e9 for masked positions.
      adder = (1.0 - tf.cast(mask, scores.dtype)) * _large_compatible_negative(
          scores.dtype)
Hongkun Yu's avatar
Hongkun Yu committed
69
70
71
72
      # Since we are adding it to the raw scores before the softmax, this is
      # effectively the same as removing these entirely.
      scores += adder

73
74
75
76
77
    if len(self._normalization_axes) == 1:
      return tf.nn.softmax(scores, axis=self._normalization_axes[0])
    else:
      return tf.math.exp(scores - tf.math.reduce_logsumexp(
          scores, axis=self._normalization_axes, keepdims=True))
Hongkun Yu's avatar
Hongkun Yu committed
78
79

  def get_config(self):
Chen Chen's avatar
Chen Chen committed
80
81
82
83
    config = {
        'mask_expansion_axes': self._mask_expansion_axes,
        'normalization_axes': self._normalization_axes
    }
Hongkun Yu's avatar
Hongkun Yu committed
84
85
    base_config = super(MaskedSoftmax, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))