"tutorials/models/vscode:/vscode.git/clone" did not exist on "724aa0caf0f63a0887f2d3bd2addfd9bca7ef890"
Commit a26e4649 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 430772605
parent 4585ff77
...@@ -22,10 +22,15 @@ from official.vision.beta.modeling.layers.nn_layers import StochasticDepth ...@@ -22,10 +22,15 @@ from official.vision.beta.modeling.layers.nn_layers import StochasticDepth
class TransformerEncoderBlock(modeling.layers.TransformerEncoderBlock): class TransformerEncoderBlock(modeling.layers.TransformerEncoderBlock):
"""TransformerEncoderBlock layer with stochastic depth.""" """TransformerEncoderBlock layer with stochastic depth."""
def __init__(self, *args, stochastic_depth_drop_rate=0.0, **kwargs): def __init__(self,
*args,
stochastic_depth_drop_rate=0.0,
return_attention=False,
**kwargs):
"""Initializes TransformerEncoderBlock.""" """Initializes TransformerEncoderBlock."""
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._return_attention = return_attention
def build(self, input_shape): def build(self, input_shape):
if self._stochastic_depth_drop_rate: if self._stochastic_depth_drop_rate:
...@@ -73,8 +78,9 @@ class TransformerEncoderBlock(modeling.layers.TransformerEncoderBlock): ...@@ -73,8 +78,9 @@ class TransformerEncoderBlock(modeling.layers.TransformerEncoderBlock):
if key_value is None: if key_value is None:
key_value = input_tensor key_value = input_tensor
attention_output = self._attention_layer( attention_output, attention_scores = self._attention_layer(
query=target_tensor, value=key_value, attention_mask=attention_mask) query=target_tensor, value=key_value, attention_mask=attention_mask,
return_attention_scores=True)
attention_output = self._attention_dropout(attention_output) attention_output = self._attention_dropout(attention_output)
if self._norm_first: if self._norm_first:
...@@ -95,12 +101,19 @@ class TransformerEncoderBlock(modeling.layers.TransformerEncoderBlock): ...@@ -95,12 +101,19 @@ class TransformerEncoderBlock(modeling.layers.TransformerEncoderBlock):
layer_output = self._output_dropout(layer_output) layer_output = self._output_dropout(layer_output)
if self._norm_first: if self._norm_first:
return source_attention_output + self._stochastic_depth( if self._return_attention:
layer_output, training=training) return source_attention_output + self._stochastic_depth(
layer_output, training=training), attention_scores
else:
return source_attention_output + self._stochastic_depth(
layer_output, training=training)
# During mixed precision training, layer norm output is always fp32 for now. # During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add. # Casts fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32) layer_output = tf.cast(layer_output, tf.float32)
return self._output_layer_norm( if self._return_attention:
layer_output + return self._output_layer_norm(layer_output + self._stochastic_depth(
self._stochastic_depth(attention_output, training=training)) attention_output, training=training)), attention_scores
else:
return self._output_layer_norm(layer_output + self._stochastic_depth(
attention_output, training=training))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment