nn_blocks.py 4.2 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Keras-based TransformerEncoder block layer."""
import tensorflow as tf

Frederick Liu's avatar
Frederick Liu committed
18
from official.nlp import modeling
19
20
21
from official.vision.beta.modeling.layers.nn_layers import StochasticDepth


Frederick Liu's avatar
Frederick Liu committed
22
class TransformerEncoderBlock(modeling.layers.TransformerEncoderBlock):
23
  """TransformerEncoderBlock layer with stochastic depth."""
24

25
26
27
  def __init__(self, *args, stochastic_depth_drop_rate=0.0, **kwargs):
    """Initializes TransformerEncoderBlock."""
    super().__init__(*args, **kwargs)
28
29
30
31
    self._stochastic_depth_drop_rate = stochastic_depth_drop_rate

  def build(self, input_shape):
    if self._stochastic_depth_drop_rate:
32
      self._stochastic_depth = StochasticDepth(self._stochastic_depth_drop_rate)
33
34
35
    else:
      self._stochastic_depth = lambda x, *args, **kwargs: tf.identity(x)

36
    super().build(input_shape)
37
38

  def get_config(self):
39
40
    config = {"stochastic_depth_drop_rate": self._stochastic_depth_drop_rate}
    base_config = super().get_config()
41
42
43
    return dict(list(base_config.items()) + list(config.items()))

  def call(self, inputs, training=None):
44
    """Transformer self-attention encoder block call."""
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    if isinstance(inputs, (list, tuple)):
      if len(inputs) == 2:
        input_tensor, attention_mask = inputs
        key_value = None
      elif len(inputs) == 3:
        input_tensor, key_value, attention_mask = inputs
      else:
        raise ValueError("Unexpected inputs to %s with length at %d" %
                         (self.__class__, len(inputs)))
    else:
      input_tensor, key_value, attention_mask = (inputs, None, None)

    if self._output_range:
      if self._norm_first:
        source_tensor = input_tensor[:, 0:self._output_range, :]
        input_tensor = self._attention_layer_norm(input_tensor)
        if key_value is not None:
          key_value = self._attention_layer_norm(key_value)
      target_tensor = input_tensor[:, 0:self._output_range, :]
      if attention_mask is not None:
        attention_mask = attention_mask[:, 0:self._output_range, :]
    else:
      if self._norm_first:
        source_tensor = input_tensor
        input_tensor = self._attention_layer_norm(input_tensor)
        if key_value is not None:
          key_value = self._attention_layer_norm(key_value)
      target_tensor = input_tensor

    if key_value is None:
      key_value = input_tensor
    attention_output = self._attention_layer(
        query=target_tensor, value=key_value, attention_mask=attention_mask)
    attention_output = self._attention_dropout(attention_output)

    if self._norm_first:
      attention_output = source_tensor + self._stochastic_depth(
          attention_output, training=training)
    else:
      attention_output = self._attention_layer_norm(
85
86
87
          target_tensor +
          self._stochastic_depth(attention_output, training=training))

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    if self._norm_first:
      source_attention_output = attention_output
      attention_output = self._output_layer_norm(attention_output)
    inner_output = self._intermediate_dense(attention_output)
    inner_output = self._intermediate_activation_layer(inner_output)
    inner_output = self._inner_dropout_layer(inner_output)
    layer_output = self._output_dense(inner_output)
    layer_output = self._output_dropout(layer_output)

    if self._norm_first:
      return source_attention_output + self._stochastic_depth(
          layer_output, training=training)

    # During mixed precision training, layer norm output is always fp32 for now.
    # Casts fp32 for the subsequent add.
    layer_output = tf.cast(layer_output, tf.float32)
    return self._output_layer_norm(
105
106
        layer_output +
        self._stochastic_depth(attention_output, training=training))