"ios/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "11bd2eaa6d6976129836b329b01d1300babddcc9"
Commit b1aa44d9 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 368260712
parent 782a0299
...@@ -249,7 +249,7 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -249,7 +249,7 @@ class TransformerScaffold(tf.keras.layers.Layer):
base_config = super(TransformerScaffold, self).get_config() base_config = super(TransformerScaffold, self).get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs): def call(self, inputs, training=None):
if isinstance(inputs, (list, tuple)) and len(inputs) == 2: if isinstance(inputs, (list, tuple)) and len(inputs) == 2:
input_tensor, attention_mask = inputs input_tensor, attention_mask = inputs
else: else:
...@@ -257,27 +257,31 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -257,27 +257,31 @@ class TransformerScaffold(tf.keras.layers.Layer):
if self._norm_first: if self._norm_first:
source_tensor = input_tensor source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor) input_tensor = self._attention_layer_norm(input_tensor, training=training)
attention_output = self._attention_layer( attention_output = self._attention_layer(
query=input_tensor, value=input_tensor, attention_mask=attention_mask) query=input_tensor, value=input_tensor, attention_mask=attention_mask,
attention_output = self._attention_dropout(attention_output) training=training)
attention_output = self._attention_dropout(attention_output,
training=training)
if self._norm_first: if self._norm_first:
attention_output = source_tensor + attention_output attention_output = source_tensor + attention_output
else: else:
attention_output = self._attention_layer_norm(input_tensor + attention_output = self._attention_layer_norm(input_tensor +
attention_output) attention_output,
training=training)
if self._norm_first: if self._norm_first:
source_attention_output = attention_output source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output) attention_output = self._output_layer_norm(attention_output,
training=training)
if self._feedforward_block is None: if self._feedforward_block is None:
intermediate_output = self._intermediate_dense(attention_output) intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._intermediate_activation_layer( intermediate_output = self._intermediate_activation_layer(
intermediate_output) intermediate_output)
layer_output = self._output_dense(intermediate_output) layer_output = self._output_dense(intermediate_output, training=training)
layer_output = self._output_dropout(layer_output) layer_output = self._output_dropout(layer_output, training=training)
# During mixed precision training, attention_output is from layer norm # During mixed precision training, attention_output is from layer norm
# and is always fp32 for now. Cast layer_output to fp32 for the subsequent # and is always fp32 for now. Cast layer_output to fp32 for the subsequent
# add. # add.
...@@ -285,14 +289,17 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -285,14 +289,17 @@ class TransformerScaffold(tf.keras.layers.Layer):
if self._norm_first: if self._norm_first:
layer_output = source_attention_output + layer_output layer_output = source_attention_output + layer_output
else: else:
layer_output = self._output_layer_norm(layer_output + attention_output) layer_output = self._output_layer_norm(layer_output + attention_output,
training=training)
else: else:
if self._norm_first: if self._norm_first:
# if norm_first, assume the feedforward block will not apply layer norm # if norm_first, assume the feedforward block will not apply layer norm
layer_output = self._feedforward_block(attention_output) layer_output = self._feedforward_block(attention_output,
training=training)
layer_output += source_attention_output layer_output += source_attention_output
else: else:
# if not norm_first, assume that the feedforwad does apply layer norm # if not norm_first, assume that the feedforwad does apply layer norm
layer_output = self._feedforward_block(attention_output) layer_output = self._feedforward_block(attention_output,
training=training)
return layer_output return layer_output
...@@ -28,7 +28,7 @@ from official.nlp.projects.bigbird import recomputing_dropout ...@@ -28,7 +28,7 @@ from official.nlp.projects.bigbird import recomputing_dropout
class RecomputeTransformerLayer(layers.TransformerScaffold): class RecomputeTransformerLayer(layers.TransformerScaffold):
"""Transformer layer that recomputes the forward pass during backpropagation.""" """Transformer layer that recomputes the forward pass during backpropagation."""
def call(self, inputs): def call(self, inputs, training=None):
emb, mask = inputs emb, mask = inputs
def f(*args): def f(*args):
# recompute_grad can only handle tensor inputs. so we enumerate the # recompute_grad can only handle tensor inputs. so we enumerate the
...@@ -39,7 +39,8 @@ class RecomputeTransformerLayer(layers.TransformerScaffold): ...@@ -39,7 +39,8 @@ class RecomputeTransformerLayer(layers.TransformerScaffold):
# args[3]: mask[2] = encoder_to_mask # args[3]: mask[2] = encoder_to_mask
# args[4]: mask[3] = blocked_encoder_mask # args[4]: mask[3] = blocked_encoder_mask
x = super(RecomputeTransformerLayer, x = super(RecomputeTransformerLayer,
self).call([args[0], [args[1], args[2], args[3], args[4]]]) self).call([args[0], [args[1], args[2], args[3], args[4]]],
training=training)
return x return x
f = recompute_grad.recompute_grad(f) f = recompute_grad.recompute_grad(f)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment