Commit 0dadbbc8 authored by Frederick Liu's avatar Frederick Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 420497751
parent 993dbf54
...@@ -21,6 +21,7 @@ from official.vision.beta.modeling.backbones import factory ...@@ -21,6 +21,7 @@ from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_layers from official.vision.beta.modeling.layers import nn_layers
from official.vision.beta.projects.vit.modeling import nn_blocks from official.vision.beta.projects.vit.modeling import nn_blocks
layers = tf.keras.layers layers = tf.keras.layers
VIT_SPECS = { VIT_SPECS = {
...@@ -121,6 +122,7 @@ class Encoder(tf.keras.layers.Layer): ...@@ -121,6 +122,7 @@ class Encoder(tf.keras.layers.Layer):
inputs_positions=None, inputs_positions=None,
init_stochastic_depth_rate=0.0, init_stochastic_depth_rate=0.0,
kernel_initializer='glorot_uniform', kernel_initializer='glorot_uniform',
add_pos_embed=True,
**kwargs): **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self._num_layers = num_layers self._num_layers = num_layers
...@@ -132,11 +134,13 @@ class Encoder(tf.keras.layers.Layer): ...@@ -132,11 +134,13 @@ class Encoder(tf.keras.layers.Layer):
self._inputs_positions = inputs_positions self._inputs_positions = inputs_positions
self._init_stochastic_depth_rate = init_stochastic_depth_rate self._init_stochastic_depth_rate = init_stochastic_depth_rate
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
self._add_pos_embed = add_pos_embed
def build(self, input_shape): def build(self, input_shape):
self._pos_embed = AddPositionEmbs( if self._add_pos_embed:
posemb_init=tf.keras.initializers.RandomNormal(stddev=0.02), self._pos_embed = AddPositionEmbs(
name='posembed_input') posemb_init=tf.keras.initializers.RandomNormal(stddev=0.02),
name='posembed_input')
self._dropout = layers.Dropout(rate=self._dropout_rate) self._dropout = layers.Dropout(rate=self._dropout_rate)
self._encoder_layers = [] self._encoder_layers = []
...@@ -160,7 +164,9 @@ class Encoder(tf.keras.layers.Layer): ...@@ -160,7 +164,9 @@ class Encoder(tf.keras.layers.Layer):
super().build(input_shape) super().build(input_shape)
def call(self, inputs, training=None): def call(self, inputs, training=None):
x = self._pos_embed(inputs, inputs_positions=self._inputs_positions) x = inputs
if self._add_pos_embed:
x = self._pos_embed(x, inputs_positions=self._inputs_positions)
x = self._dropout(x, training=training) x = self._dropout(x, training=training)
for encoder_layer in self._encoder_layers: for encoder_layer in self._encoder_layers:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment