Commit a2800201 authored by Simon Geisler's avatar Simon Geisler
Browse files

subclass nlp_keras transformer encoder for deit

parent c0e1e3f9
from official.vision.beta.projects.vit.modeling.layers.vit_transformer_encoder_block import TransformerEncoderBlock from official.vision.beta.projects.vit.modeling.layers.transformer_encoder_block import StochsticDepthTransformerEncoderBlock
\ No newline at end of file \ No newline at end of file
...@@ -16,17 +16,18 @@ ...@@ -16,17 +16,18 @@
import tensorflow as tf import tensorflow as tf
from official.nlp.keras_nlp.layers import TransformerEncoderBlock
from official.vision.beta.modeling.layers.nn_layers import StochasticDepth from official.vision.beta.modeling.layers.nn_layers import StochasticDepth
@tf.keras.utils.register_keras_serializable(package="Vision") @tf.keras.utils.register_keras_serializable(package="Vision")
class TransformerEncoderBlock(tf.keras.layers.Layer): class StochsticDepthTransformerEncoderBlock(TransformerEncoderBlock):
"""TransformerEncoderBlock layer. """TransformerEncoderBlock layer.
This layer implements the Transformer Encoder from This layer implements the Transformer Encoder from
"Attention Is All You Need". (https://arxiv.org/abs/1706.03762), "Attention Is All You Need". (https://arxiv.org/abs/1706.03762),
which combines a `tf.keras.layers.MultiHeadAttention` layer with a which combines a `tf.keras.layers.MultiHeadAttention` layer with a
two-layer feedforward network. Here we ass support for stochastic depth. two-layer feedforward network. Here we add support for stochastic depth.
References: References:
[Attention Is All You Need](https://arxiv.org/abs/1706.03762) [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
...@@ -35,218 +36,35 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -35,218 +36,35 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
""" """
def __init__(self, def __init__(self,
num_attention_heads, *args,
inner_dim,
inner_activation,
output_range=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
output_dropout=0.0,
attention_dropout=0.0,
inner_dropout=0.0,
stochastic_depth_drop_rate=0.0, stochastic_depth_drop_rate=0.0,
attention_initializer=None,
attention_axes=None,
**kwargs): **kwargs):
"""Initializes `TransformerEncoderBlock`. """Initializes `TransformerEncoderBlock`.
Args: Args:
num_attention_heads: Number of attention heads. *args: positional arguments/
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network.
output_range: the sequence output range, [0, output_range) for slicing the
target sequence. `None` means the target sequence is not sliced.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: Dropout probability for within the attention layer.
inner_dropout: Dropout probability for the first Dense layer in a
two-layer feedforward network.
stochastic_depth_drop_rate: Dropout propobability for the stochastic depth stochastic_depth_drop_rate: Dropout propobability for the stochastic depth
regularization. regularization.
attention_initializer: Initializer for kernels of attention layers. If set
`None`, attention layers use kernel_initializer as initializer for
kernel.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
**kwargs: keyword arguments/ **kwargs: keyword arguments/
""" """
super().__init__(**kwargs) super().__init__(*args, **kwargs)
self._num_heads = num_attention_heads
self._inner_dim = inner_dim
self._inner_activation = inner_activation
self._attention_dropout = attention_dropout
self._attention_dropout_rate = attention_dropout
self._output_dropout = output_dropout
self._output_dropout_rate = output_dropout
self._output_range = output_range
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
self._inner_dropout = inner_dropout
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
if attention_initializer:
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
self._attention_axes = attention_axes
def build(self, input_shape): def build(self, input_shape):
if isinstance(input_shape, tf.TensorShape):
input_tensor_shape = input_shape
elif isinstance(input_shape, (list, tuple)):
input_tensor_shape = tf.TensorShape(input_shape[0])
else:
raise ValueError(
"The type of input shape argument is not supported, got: %s" %
type(input_shape))
einsum_equation = "abc,cd->abd"
if len(input_tensor_shape.as_list()) > 3:
einsum_equation = "...bc,cd->...bd"
hidden_size = input_tensor_shape[-1]
if hidden_size % self._num_heads != 0:
raise ValueError(
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
self._attention_layer = tf.keras.layers.MultiHeadAttention(
num_heads=self._num_heads,
key_dim=self._attention_head_size,
dropout=self._attention_dropout,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
attention_axes=self._attention_axes,
name="self_attention",
**common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
self._attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=(None, self._inner_dim),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
name="intermediate",
**common_kwargs)
policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge
# as well, so we use float32.
# TODO(b/154538392): Investigate this.
policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._inner_activation, dtype=policy)
self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout)
self._output_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
kernel_initializer=self._kernel_initializer,
**common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32)
if self._stochastic_depth_drop_rate: if self._stochastic_depth_drop_rate:
self._stochastic_depth = StochasticDepth( self._stochastic_depth = StochasticDepth(
self._stochastic_depth_drop_rate) self._stochastic_depth_drop_rate)
else: else:
self._stochastic_depth = None self._stochastic_depth = lambda x, *args, **kwargs: tf.identity(x)
super(TransformerEncoderBlock, self).build(input_shape) super(StochsticDepthTransformerEncoderBlock, self).build(input_shape)
def get_config(self): def get_config(self):
config = { config = {
"num_attention_heads": "stochastic_depth_drop_rate": self._stochastic_depth_drop_rate
self._num_heads,
"inner_dim":
self._inner_dim,
"inner_activation":
self._inner_activation,
"output_dropout":
self._output_dropout_rate,
"attention_dropout":
self._attention_dropout_rate,
"output_range":
self._output_range,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon,
"inner_dropout":
self._inner_dropout,
"stochastic_depth_drop_rate":
self._stochastic_depth_drop_rate,
"attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer),
"attention_axes": self._attention_axes,
} }
base_config = super(TransformerEncoderBlock, self).get_config() base_config = super(StochsticDepthTransformerEncoderBlock, self)\
.get_config()
return dict(list(base_config.items()) + list(config.items())) return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, training=None): def call(self, inputs, training=None):
...@@ -274,8 +92,6 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -274,8 +92,6 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
else: else:
input_tensor, key_value, attention_mask = (inputs, None, None) input_tensor, key_value, attention_mask = (inputs, None, None)
with_stochastic_depth = training and self._stochastic_depth
if self._output_range: if self._output_range:
if self._norm_first: if self._norm_first:
source_tensor = input_tensor[:, 0:self._output_range, :] source_tensor = input_tensor[:, 0:self._output_range, :]
...@@ -301,11 +117,11 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -301,11 +117,11 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
if self._norm_first: if self._norm_first:
attention_output = source_tensor + self._stochastic_depth( attention_output = source_tensor + self._stochastic_depth(
attention_output, training=with_stochastic_depth) attention_output, training=training)
else: else:
attention_output = self._attention_layer_norm( attention_output = self._attention_layer_norm(
target_tensor + target_tensor
self._stochastic_depth(attention_output, training=with_stochastic_depth) + self._stochastic_depth(attention_output, training=training)
) )
if self._norm_first: if self._norm_first:
...@@ -319,13 +135,13 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -319,13 +135,13 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
if self._norm_first: if self._norm_first:
return source_attention_output + self._stochastic_depth( return source_attention_output + self._stochastic_depth(
layer_output, training=with_stochastic_depth) layer_output, training=training)
# During mixed precision training, layer norm output is always fp32 for now. # During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add. # Casts fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32) layer_output = tf.cast(layer_output, tf.float32)
return self._output_layer_norm( return self._output_layer_norm(
layer_output layer_output
+ self._stochastic_depth(attention_output, training=with_stochastic_depth) + self._stochastic_depth(attention_output, training=training)
) )
...@@ -17,10 +17,9 @@ ...@@ -17,10 +17,9 @@
import tensorflow as tf import tensorflow as tf
from official.modeling import activations from official.modeling import activations
from official.nlp import keras_nlp
from official.vision.beta.modeling.backbones import factory from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_layers from official.vision.beta.modeling.layers import nn_layers
from official.vision.beta.projects.vit.modeling.layers import TransformerEncoderBlock from official.vision.beta.projects.vit.modeling.layers import StochsticDepthTransformerEncoderBlock
layers = tf.keras.layers layers = tf.keras.layers
...@@ -150,7 +149,7 @@ class Encoder(tf.keras.layers.Layer): ...@@ -150,7 +149,7 @@ class Encoder(tf.keras.layers.Layer):
# Set layer norm epsilons to 1e-6 to be consistent with JAX implementation. # Set layer norm epsilons to 1e-6 to be consistent with JAX implementation.
# https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.LayerNorm.html # https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.LayerNorm.html
for i in range(self._num_layers): for i in range(self._num_layers):
encoder_layer = TransformerEncoderBlock( encoder_layer = StochsticDepthTransformerEncoderBlock(
inner_activation=activations.gelu, inner_activation=activations.gelu,
num_attention_heads=self._num_heads, num_attention_heads=self._num_heads,
inner_dim=self._mlp_dim, inner_dim=self._mlp_dim,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment