"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "dec6f622b4f8cb4857d3cffcfb8b6259dda09340"
Commit 46a373af authored by Yuexin Wu's avatar Yuexin Wu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 471177782
parent 14f78b2c
...@@ -14,7 +14,9 @@ ...@@ -14,7 +14,9 @@
"""Keras-based rezero-transformer block layer (Transformer with ReZero).""" """Keras-based rezero-transformer block layer (Transformer with ReZero)."""
# pylint: disable=g-classes-have-attributes # pylint: disable=g-classes-have-attributes
from typing import Optional
from absl import logging
import gin import gin
import tensorflow as tf import tensorflow as tf
...@@ -80,6 +82,11 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -80,6 +82,11 @@ class ReZeroTransformer(tf.keras.layers.Layer):
util.filter_kwargs(kwargs) util.filter_kwargs(kwargs)
super().__init__(**kwargs) super().__init__(**kwargs)
# Deprecation warning.
if output_range is not None:
logging.warning("`output_range` is avaliable as an argument for `call()`."
"The `output_range` as __init__ argument is deprecated.")
self._num_heads = num_attention_heads self._num_heads = num_attention_heads
self._inner_dim = inner_dim self._inner_dim = inner_dim
self._inner_activation = inner_activation self._inner_activation = inner_activation
...@@ -237,7 +244,7 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -237,7 +244,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
if not self._share_rezero: if not self._share_rezero:
self._rezero_a_ffn.assign(0.) self._rezero_a_ffn.assign(0.)
def call(self, inputs): def call(self, inputs, output_range: Optional[tf.Tensor] = None) -> tf.Tensor:
if isinstance(inputs, (list, tuple)): if isinstance(inputs, (list, tuple)):
if len(inputs) == 2: if len(inputs) == 2:
input_tensor, attention_mask = inputs input_tensor, attention_mask = inputs
...@@ -250,10 +257,12 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -250,10 +257,12 @@ class ReZeroTransformer(tf.keras.layers.Layer):
else: else:
input_tensor, key_value, attention_mask = (inputs, None, None) input_tensor, key_value, attention_mask = (inputs, None, None)
if self._output_range: if output_range is None:
target_tensor = input_tensor[:, 0:self._output_range, :] output_range = self._output_range
if output_range:
target_tensor = input_tensor[:, 0:output_range, :]
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask[:, 0:self._output_range, :] attention_mask = attention_mask[:, 0:output_range, :]
else: else:
target_tensor = input_tensor target_tensor = input_tensor
...@@ -270,8 +279,7 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -270,8 +279,7 @@ class ReZeroTransformer(tf.keras.layers.Layer):
attention_output = tf.cast(attention_output, tf.float32) attention_output = tf.cast(attention_output, tf.float32)
intermediate_output = self._intermediate_dense(attention_output) intermediate_output = self._intermediate_dense(attention_output)
intermediate_output = self._inner_activation_layer( intermediate_output = self._inner_activation_layer(intermediate_output)
intermediate_output)
layer_output = self._output_dense(intermediate_output) layer_output = self._output_dense(intermediate_output)
layer_output = self._output_dropout(layer_output) layer_output = self._output_dropout(layer_output)
# During mixed precision training, attention_output is from layer norm and # During mixed precision training, attention_output is from layer norm and
......
...@@ -128,6 +128,9 @@ class TransformerWithReZeroLayerTest(keras_parameterized.TestCase): ...@@ -128,6 +128,9 @@ class TransformerWithReZeroLayerTest(keras_parameterized.TestCase):
new_output_tensor = new_layer([input_data, mask_data]) new_output_tensor = new_layer([input_data, mask_data])
self.assertAllClose(new_output_tensor, output_tensor[:, 0:1, :]) self.assertAllClose(new_output_tensor, output_tensor[:, 0:1, :])
output_tensor = test_layer([input_data, mask_data], output_range=1)
self.assertAllClose(new_output_tensor, output_tensor, atol=5e-5, rtol=0.003)
def test_separate_qkv(self): def test_separate_qkv(self):
test_layer = rezero_transformer.ReZeroTransformer( test_layer = rezero_transformer.ReZeroTransformer(
num_attention_heads=2, num_attention_heads=2,
......
...@@ -349,7 +349,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -349,7 +349,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
else: else:
input_tensor, key_value, attention_mask = (inputs, None, None) input_tensor, key_value, attention_mask = (inputs, None, None)
output_range = output_range or self._output_range if output_range is None:
output_range = self._output_range
if output_range: if output_range:
if self._norm_first: if self._norm_first:
source_tensor = input_tensor[:, 0:output_range, :] source_tensor = input_tensor[:, 0:output_range, :]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment