Commit 8d340ab0 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 430772605
parent a0b548e2
......@@ -22,10 +22,15 @@ from official.vision.beta.modeling.layers.nn_layers import StochasticDepth
class TransformerEncoderBlock(modeling.layers.TransformerEncoderBlock):
"""TransformerEncoderBlock layer with stochastic depth."""
def __init__(self, *args, stochastic_depth_drop_rate=0.0, **kwargs):
def __init__(self,
*args,
stochastic_depth_drop_rate=0.0,
return_attention=False,
**kwargs):
"""Initializes TransformerEncoderBlock."""
super().__init__(*args, **kwargs)
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._return_attention = return_attention
def build(self, input_shape):
if self._stochastic_depth_drop_rate:
......@@ -73,8 +78,9 @@ class TransformerEncoderBlock(modeling.layers.TransformerEncoderBlock):
if key_value is None:
key_value = input_tensor
attention_output = self._attention_layer(
query=target_tensor, value=key_value, attention_mask=attention_mask)
attention_output, attention_scores = self._attention_layer(
query=target_tensor, value=key_value, attention_mask=attention_mask,
return_attention_scores=True)
attention_output = self._attention_dropout(attention_output)
if self._norm_first:
......@@ -95,12 +101,19 @@ class TransformerEncoderBlock(modeling.layers.TransformerEncoderBlock):
layer_output = self._output_dropout(layer_output)
if self._norm_first:
return source_attention_output + self._stochastic_depth(
layer_output, training=training)
if self._return_attention:
return source_attention_output + self._stochastic_depth(
layer_output, training=training), attention_scores
else:
return source_attention_output + self._stochastic_depth(
layer_output, training=training)
# During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32)
return self._output_layer_norm(
layer_output +
self._stochastic_depth(attention_output, training=training))
if self._return_attention:
return self._output_layer_norm(layer_output + self._stochastic_depth(
attention_output, training=training)), attention_scores
else:
return self._output_layer_norm(layer_output + self._stochastic_depth(
attention_output, training=training))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment