nn_blocks.py 4.73 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Keras-based TransformerEncoder block layer."""
import tensorflow as tf

Frederick Liu's avatar
Frederick Liu committed
18
from official.nlp import modeling
19
20
21
from official.vision.beta.modeling.layers.nn_layers import StochasticDepth


Frederick Liu's avatar
Frederick Liu committed
22
class TransformerEncoderBlock(modeling.layers.TransformerEncoderBlock):
23
  """TransformerEncoderBlock layer with stochastic depth."""
24

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
25
26
27
28
29
  def __init__(self,
               *args,
               stochastic_depth_drop_rate=0.0,
               return_attention=False,
               **kwargs):
30
31
    """Initializes TransformerEncoderBlock."""
    super().__init__(*args, **kwargs)
32
    self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
33
    self._return_attention = return_attention
34
35
36

  def build(self, input_shape):
    if self._stochastic_depth_drop_rate:
37
      self._stochastic_depth = StochasticDepth(self._stochastic_depth_drop_rate)
38
39
40
    else:
      self._stochastic_depth = lambda x, *args, **kwargs: tf.identity(x)

41
    super().build(input_shape)
42
43

  def get_config(self):
44
45
    config = {"stochastic_depth_drop_rate": self._stochastic_depth_drop_rate}
    base_config = super().get_config()
46
47
48
    return dict(list(base_config.items()) + list(config.items()))

  def call(self, inputs, training=None):
49
    """Transformer self-attention encoder block call."""
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
    if isinstance(inputs, (list, tuple)):
      if len(inputs) == 2:
        input_tensor, attention_mask = inputs
        key_value = None
      elif len(inputs) == 3:
        input_tensor, key_value, attention_mask = inputs
      else:
        raise ValueError("Unexpected inputs to %s with length at %d" %
                         (self.__class__, len(inputs)))
    else:
      input_tensor, key_value, attention_mask = (inputs, None, None)

    if self._output_range:
      if self._norm_first:
        source_tensor = input_tensor[:, 0:self._output_range, :]
        input_tensor = self._attention_layer_norm(input_tensor)
        if key_value is not None:
          key_value = self._attention_layer_norm(key_value)
      target_tensor = input_tensor[:, 0:self._output_range, :]
      if attention_mask is not None:
        attention_mask = attention_mask[:, 0:self._output_range, :]
    else:
      if self._norm_first:
        source_tensor = input_tensor
        input_tensor = self._attention_layer_norm(input_tensor)
        if key_value is not None:
          key_value = self._attention_layer_norm(key_value)
      target_tensor = input_tensor

    if key_value is None:
      key_value = input_tensor
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
81
82
83
    attention_output, attention_scores = self._attention_layer(
        query=target_tensor, value=key_value, attention_mask=attention_mask,
        return_attention_scores=True)
84
85
86
87
88
89
90
    attention_output = self._attention_dropout(attention_output)

    if self._norm_first:
      attention_output = source_tensor + self._stochastic_depth(
          attention_output, training=training)
    else:
      attention_output = self._attention_layer_norm(
91
92
93
          target_tensor +
          self._stochastic_depth(attention_output, training=training))

94
95
96
97
98
99
100
101
102
103
    if self._norm_first:
      source_attention_output = attention_output
      attention_output = self._output_layer_norm(attention_output)
    inner_output = self._intermediate_dense(attention_output)
    inner_output = self._intermediate_activation_layer(inner_output)
    inner_output = self._inner_dropout_layer(inner_output)
    layer_output = self._output_dense(inner_output)
    layer_output = self._output_dropout(layer_output)

    if self._norm_first:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
104
105
106
107
108
109
      if self._return_attention:
        return source_attention_output + self._stochastic_depth(
            layer_output, training=training), attention_scores
      else:
        return source_attention_output + self._stochastic_depth(
            layer_output, training=training)
110
111
112
113

    # During mixed precision training, layer norm output is always fp32 for now.
    # Casts fp32 for the subsequent add.
    layer_output = tf.cast(layer_output, tf.float32)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
114
115
116
117
118
119
    if self._return_attention:
      return self._output_layer_norm(layer_output + self._stochastic_depth(
          attention_output, training=training)), attention_scores
    else:
      return self._output_layer_norm(layer_output + self._stochastic_depth(
          attention_output, training=training))