Commit 6d458bcc authored by Scott Zhu's avatar Scott Zhu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 446219773
parent 1fd7aaaf
...@@ -272,3 +272,14 @@ def cross_replica_concat(value, axis, name="cross_replica_concat"): ...@@ -272,3 +272,14 @@ def cross_replica_concat(value, axis, name="cross_replica_concat"):
if value.shape.as_list()[0] is None: if value.shape.as_list()[0] is None:
raise RuntimeError(f"{value} has unknown batch.") raise RuntimeError(f"{value} has unknown batch.")
return context.all_gather(value, axis=axis) return context.all_gather(value, axis=axis)
def clone_initializer(initializer):
# Keras initializer is going to be stateless, which mean reusing the same
# initializer will produce same init value when the shapes are the same.
if isinstance(initializer, tf.keras.initializers.Initializer):
return initializer.__class__.from_config(initializer.get_config())
# When the input is string/dict or other serialized configs, caller will
# create a new keras Initializer instance based on that, and we don't need to
# do anything
return initializer
...@@ -18,6 +18,8 @@ from typing import Optional ...@@ -18,6 +18,8 @@ from typing import Optional
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
class BlockDiagFeedforward(tf.keras.layers.Layer): class BlockDiagFeedforward(tf.keras.layers.Layer):
"""Block diagonal feedforward layer. """Block diagonal feedforward layer.
...@@ -80,8 +82,6 @@ class BlockDiagFeedforward(tf.keras.layers.Layer): ...@@ -80,8 +82,6 @@ class BlockDiagFeedforward(tf.keras.layers.Layer):
hidden_size = input_shape.as_list()[-1] hidden_size = input_shape.as_list()[-1]
common_kwargs = dict( common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
...@@ -94,6 +94,8 @@ class BlockDiagFeedforward(tf.keras.layers.Layer): ...@@ -94,6 +94,8 @@ class BlockDiagFeedforward(tf.keras.layers.Layer):
self._intermediate_size // self._num_blocks), self._intermediate_size // self._num_blocks),
bias_axes="de", bias_axes="de",
name="intermediate", name="intermediate",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
policy = tf.keras.mixed_precision.global_policy() policy = tf.keras.mixed_precision.global_policy()
...@@ -110,6 +112,8 @@ class BlockDiagFeedforward(tf.keras.layers.Layer): ...@@ -110,6 +112,8 @@ class BlockDiagFeedforward(tf.keras.layers.Layer):
hidden_size // self._num_blocks), hidden_size // self._num_blocks),
bias_axes="do", bias_axes="do",
name="output", name="output",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
if self._apply_mixing: if self._apply_mixing:
...@@ -118,6 +122,9 @@ class BlockDiagFeedforward(tf.keras.layers.Layer): ...@@ -118,6 +122,9 @@ class BlockDiagFeedforward(tf.keras.layers.Layer):
output_shape=(None, self._num_blocks, output_shape=(None, self._num_blocks,
hidden_size // self._num_blocks), hidden_size // self._num_blocks),
name="output_mixing", name="output_mixing",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
self._output_reshape = tf.keras.layers.Reshape((-1, hidden_size)) self._output_reshape = tf.keras.layers.Reshape((-1, hidden_size))
......
...@@ -57,12 +57,14 @@ class ClassificationHead(tf.keras.layers.Layer): ...@@ -57,12 +57,14 @@ class ClassificationHead(tf.keras.layers.Layer):
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
units=self.inner_dim, units=self.inner_dim,
activation=self.activation, activation=self.activation,
kernel_initializer=self.initializer, kernel_initializer=tf_utils.clone_initializer(self.initializer),
name="pooler_dense") name="pooler_dense")
self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate) self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.out_proj = tf.keras.layers.Dense( self.out_proj = tf.keras.layers.Dense(
units=num_classes, kernel_initializer=self.initializer, name="logits") units=num_classes,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name="logits")
def call(self, features: tf.Tensor, only_project: bool = False): def call(self, features: tf.Tensor, only_project: bool = False):
"""Implements call(). """Implements call().
...@@ -146,14 +148,15 @@ class MultiClsHeads(tf.keras.layers.Layer): ...@@ -146,14 +148,15 @@ class MultiClsHeads(tf.keras.layers.Layer):
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
units=inner_dim, units=inner_dim,
activation=self.activation, activation=self.activation,
kernel_initializer=self.initializer, kernel_initializer=tf_utils.clone_initializer(self.initializer),
name="pooler_dense") name="pooler_dense")
self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate) self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.out_projs = [] self.out_projs = []
for name, num_classes in cls_list: for name, num_classes in cls_list:
self.out_projs.append( self.out_projs.append(
tf.keras.layers.Dense( tf.keras.layers.Dense(
units=num_classes, kernel_initializer=self.initializer, units=num_classes,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name=name)) name=name))
def call(self, features: tf.Tensor, only_project: bool = False): def call(self, features: tf.Tensor, only_project: bool = False):
...@@ -277,7 +280,7 @@ class GaussianProcessClassificationHead(ClassificationHead): ...@@ -277,7 +280,7 @@ class GaussianProcessClassificationHead(ClassificationHead):
if use_gp_layer: if use_gp_layer:
self.out_proj = gaussian_process.RandomFeatureGaussianProcess( self.out_proj = gaussian_process.RandomFeatureGaussianProcess(
self.num_classes, self.num_classes,
kernel_initializer=self.initializer, kernel_initializer=tf_utils.clone_initializer(self.initializer),
name="logits", name="logits",
**self.gp_layer_kwargs) **self.gp_layer_kwargs)
......
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package="Text") @tf.keras.utils.register_keras_serializable(package="Text")
@gin.configurable @gin.configurable
...@@ -95,8 +97,6 @@ class GatedFeedforward(tf.keras.layers.Layer): ...@@ -95,8 +97,6 @@ class GatedFeedforward(tf.keras.layers.Layer):
hidden_size = input_shape.as_list()[-1] hidden_size = input_shape.as_list()[-1]
common_kwargs = dict( common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
...@@ -121,6 +121,10 @@ class GatedFeedforward(tf.keras.layers.Layer): ...@@ -121,6 +121,10 @@ class GatedFeedforward(tf.keras.layers.Layer):
output_shape=(None, self._intermediate_size), output_shape=(None, self._intermediate_size),
bias_axes="d", bias_axes="d",
name="intermediate_%d" % i, name="intermediate_%d" % i,
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs)) **common_kwargs))
self._intermediate_activation_layers.append( self._intermediate_activation_layers.append(
tf.keras.layers.Activation( tf.keras.layers.Activation(
...@@ -132,6 +136,10 @@ class GatedFeedforward(tf.keras.layers.Layer): ...@@ -132,6 +136,10 @@ class GatedFeedforward(tf.keras.layers.Layer):
output_shape=(None, self._intermediate_size), output_shape=(None, self._intermediate_size),
bias_axes="d", bias_axes="d",
name="gate_%d" % i, name="gate_%d" % i,
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs)) **common_kwargs))
self._output_dense.append( self._output_dense.append(
tf.keras.layers.experimental.EinsumDense( tf.keras.layers.experimental.EinsumDense(
...@@ -139,6 +147,10 @@ class GatedFeedforward(tf.keras.layers.Layer): ...@@ -139,6 +147,10 @@ class GatedFeedforward(tf.keras.layers.Layer):
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
name="output_%d" % i, name="output_%d" % i,
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs)) **common_kwargs))
self._output_dropout.append(tf.keras.layers.Dropout(rate=self._dropout)) self._output_dropout.append(tf.keras.layers.Dropout(rate=self._dropout))
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
"""MobileBERT embedding and transformer layers.""" """MobileBERT embedding and transformer layers."""
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import on_device_embedding from official.nlp.modeling.layers import on_device_embedding
from official.nlp.modeling.layers import position_embedding from official.nlp.modeling.layers import position_embedding
...@@ -109,21 +111,21 @@ class MobileBertEmbedding(tf.keras.layers.Layer): ...@@ -109,21 +111,21 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
self.word_embedding = on_device_embedding.OnDeviceEmbedding( self.word_embedding = on_device_embedding.OnDeviceEmbedding(
self.word_vocab_size, self.word_vocab_size,
self.word_embed_size, self.word_embed_size,
initializer=initializer, initializer=tf_utils.clone_initializer(self.initializer),
name='word_embedding') name='word_embedding')
self.type_embedding = on_device_embedding.OnDeviceEmbedding( self.type_embedding = on_device_embedding.OnDeviceEmbedding(
self.type_vocab_size, self.type_vocab_size,
self.output_embed_size, self.output_embed_size,
initializer=initializer, initializer=tf_utils.clone_initializer(self.initializer),
name='type_embedding') name='type_embedding')
self.pos_embedding = position_embedding.PositionEmbedding( self.pos_embedding = position_embedding.PositionEmbedding(
max_length=max_sequence_length, max_length=max_sequence_length,
initializer=initializer, initializer=tf_utils.clone_initializer(self.initializer),
name='position_embedding') name='position_embedding')
self.word_embedding_proj = tf.keras.layers.experimental.EinsumDense( self.word_embedding_proj = tf.keras.layers.experimental.EinsumDense(
'abc,cd->abd', 'abc,cd->abd',
output_shape=[None, self.output_embed_size], output_shape=[None, self.output_embed_size],
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(self.initializer),
bias_axes='d', bias_axes='d',
name='embedding_projection') name='embedding_projection')
self.layer_norm = _get_norm_layer(normalization_type, 'embedding_norm') self.layer_norm = _get_norm_layer(normalization_type, 'embedding_norm')
...@@ -246,7 +248,7 @@ class MobileBertTransformer(tf.keras.layers.Layer): ...@@ -246,7 +248,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
'abc,cd->abd', 'abc,cd->abd',
output_shape=[None, self.intra_bottleneck_size], output_shape=[None, self.intra_bottleneck_size],
bias_axes='d', bias_axes='d',
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='bottleneck_input/dense') name='bottleneck_input/dense')
layer_norm = _get_norm_layer(self.normalization_type, layer_norm = _get_norm_layer(self.normalization_type,
name='bottleneck_input/norm') name='bottleneck_input/norm')
...@@ -258,7 +260,7 @@ class MobileBertTransformer(tf.keras.layers.Layer): ...@@ -258,7 +260,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
'abc,cd->abd', 'abc,cd->abd',
output_shape=[None, self.intra_bottleneck_size], output_shape=[None, self.intra_bottleneck_size],
bias_axes='d', bias_axes='d',
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='kq_shared_bottleneck/dense') name='kq_shared_bottleneck/dense')
layer_norm = _get_norm_layer(self.normalization_type, layer_norm = _get_norm_layer(self.normalization_type,
name='kq_shared_bottleneck/norm') name='kq_shared_bottleneck/norm')
...@@ -272,7 +274,7 @@ class MobileBertTransformer(tf.keras.layers.Layer): ...@@ -272,7 +274,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
value_dim=attention_head_size, value_dim=attention_head_size,
dropout=self.attention_probs_dropout_prob, dropout=self.attention_probs_dropout_prob,
output_shape=self.intra_bottleneck_size, output_shape=self.intra_bottleneck_size,
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='attention') name='attention')
layer_norm = _get_norm_layer(self.normalization_type, layer_norm = _get_norm_layer(self.normalization_type,
name='attention/norm') name='attention/norm')
...@@ -289,14 +291,14 @@ class MobileBertTransformer(tf.keras.layers.Layer): ...@@ -289,14 +291,14 @@ class MobileBertTransformer(tf.keras.layers.Layer):
activation=self.intermediate_act_fn, activation=self.intermediate_act_fn,
output_shape=[None, self.intermediate_size], output_shape=[None, self.intermediate_size],
bias_axes='d', bias_axes='d',
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(self.initializer),
name=layer_name) name=layer_name)
layer_name = layer_prefix + '/output_dense' layer_name = layer_prefix + '/output_dense'
output_layer = tf.keras.layers.experimental.EinsumDense( output_layer = tf.keras.layers.experimental.EinsumDense(
'abc,cd->abd', 'abc,cd->abd',
output_shape=[None, self.intra_bottleneck_size], output_shape=[None, self.intra_bottleneck_size],
bias_axes='d', bias_axes='d',
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(self.initializer),
name=layer_name) name=layer_name)
layer_name = layer_prefix + '/norm' layer_name = layer_prefix + '/norm'
layer_norm = _get_norm_layer(self.normalization_type, layer_norm = _get_norm_layer(self.normalization_type,
...@@ -311,7 +313,7 @@ class MobileBertTransformer(tf.keras.layers.Layer): ...@@ -311,7 +313,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
output_shape=[None, self.hidden_size], output_shape=[None, self.hidden_size],
activation=None, activation=None,
bias_axes='d', bias_axes='d',
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='bottleneck_output/dense') name='bottleneck_output/dense')
dropout_layer = tf.keras.layers.Dropout( dropout_layer = tf.keras.layers.Dropout(
self.hidden_dropout_prob, self.hidden_dropout_prob,
...@@ -474,14 +476,14 @@ class MobileBertMaskedLM(tf.keras.layers.Layer): ...@@ -474,14 +476,14 @@ class MobileBertMaskedLM(tf.keras.layers.Layer):
self.dense = tf.keras.layers.Dense( self.dense = tf.keras.layers.Dense(
hidden_size, hidden_size,
activation=self.activation, activation=self.activation,
kernel_initializer=self.initializer, kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='transform/dense') name='transform/dense')
if hidden_size > embedding_width: if hidden_size > embedding_width:
self.extra_output_weights = self.add_weight( self.extra_output_weights = self.add_weight(
'extra_output_weights', 'extra_output_weights',
shape=(self._vocab_size, hidden_size - embedding_width), shape=(self._vocab_size, hidden_size - embedding_width),
initializer=self.initializer, initializer=tf_utils.clone_initializer(self.initializer),
trainable=True) trainable=True)
elif hidden_size == embedding_width: elif hidden_size == embedding_width:
self.extra_output_weights = None self.extra_output_weights = None
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import math import math
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp.modeling.layers import masked_softmax from official.nlp.modeling.layers import masked_softmax
...@@ -60,8 +61,6 @@ class VotingAttention(tf.keras.layers.Layer): ...@@ -60,8 +61,6 @@ class VotingAttention(tf.keras.layers.Layer):
def build(self, unused_input_shapes): def build(self, unused_input_shapes):
common_kwargs = dict( common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
...@@ -72,12 +71,16 @@ class VotingAttention(tf.keras.layers.Layer): ...@@ -72,12 +71,16 @@ class VotingAttention(tf.keras.layers.Layer):
output_shape=(None, self._num_heads, self._head_size), output_shape=(None, self._num_heads, self._head_size),
bias_axes="NH", bias_axes="NH",
name="query", name="query",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
self._key_dense = tf.keras.layers.experimental.EinsumDense( self._key_dense = tf.keras.layers.experimental.EinsumDense(
"BAE,ENH->BANH", "BAE,ENH->BANH",
output_shape=(None, self._num_heads, self._head_size), output_shape=(None, self._num_heads, self._head_size),
bias_axes="NH", bias_axes="NH",
name="key", name="key",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
super(VotingAttention, self).build(unused_input_shapes) super(VotingAttention, self).build(unused_input_shapes)
......
...@@ -22,6 +22,8 @@ import string ...@@ -22,6 +22,8 @@ import string
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
_CHR_IDX = string.ascii_lowercase _CHR_IDX = string.ascii_lowercase
...@@ -347,8 +349,6 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer): ...@@ -347,8 +349,6 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
self._key_shape = tf.TensorShape(key) self._key_shape = tf.TensorShape(key)
common_kwargs = dict( common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
...@@ -368,6 +368,10 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer): ...@@ -368,6 +368,10 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
self._num_heads - self._reuse_heads, self._key_dim]), self._num_heads - self._reuse_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="query", name="query",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs) **common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
self._key_shape.rank - 1, bound_dims=1, output_dims=2) self._key_shape.rank - 1, bound_dims=1, output_dims=2)
...@@ -377,6 +381,10 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer): ...@@ -377,6 +381,10 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
self._num_heads - self._reuse_heads, self._key_dim]), self._num_heads - self._reuse_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="key", name="key",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs) **common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation( einsum_equation, bias_axes, output_rank = _build_proj_equation(
self._value_shape.rank - 1, bound_dims=1, output_dims=2) self._value_shape.rank - 1, bound_dims=1, output_dims=2)
...@@ -389,6 +397,10 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer): ...@@ -389,6 +397,10 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
output_rank - 1, [self._reuse_heads, self._value_dim]), output_rank - 1, [self._reuse_heads, self._value_dim]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="value_reuse", name="value_reuse",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs)) **common_kwargs))
if self._reuse_heads < self._num_heads: if self._reuse_heads < self._num_heads:
self._value_dense.append(tf.keras.layers.experimental.EinsumDense( self._value_dense.append(tf.keras.layers.experimental.EinsumDense(
...@@ -397,6 +409,10 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer): ...@@ -397,6 +409,10 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
self._num_heads - self._reuse_heads, self._value_dim]), self._num_heads - self._reuse_heads, self._value_dim]),
bias_axes=bias_axes if self._use_bias else None, bias_axes=bias_axes if self._use_bias else None,
name="value_new", name="value_new",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs)) **common_kwargs))
# Builds the attention computations for multi-head dot product attention. # Builds the attention computations for multi-head dot product attention.
...@@ -439,6 +455,10 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer): ...@@ -439,6 +455,10 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
output_shape=_get_output_shape(output_rank - 1, output_shape), output_shape=_get_output_shape(output_rank - 1, output_shape),
bias_axes=bias_axes if (use_bias and self._use_bias) else None, bias_axes=bias_axes if (use_bias and self._use_bias) else None,
name=name, name=name,
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs) **common_kwargs)
def _build_attention(self, rank): def _build_attention(self, rank):
......
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
"""Keras-based TransformerEncoder block layer.""" """Keras-based TransformerEncoder block layer."""
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import reuse_attention as attention from official.nlp.modeling.layers import reuse_attention as attention
...@@ -131,7 +133,8 @@ class ReuseTransformer(tf.keras.layers.Layer): ...@@ -131,7 +133,8 @@ class ReuseTransformer(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get( self._attention_initializer = tf.keras.initializers.get(
attention_initializer) attention_initializer)
else: else:
self._attention_initializer = self._kernel_initializer self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
self._attention_axes = attention_axes self._attention_axes = attention_axes
def build(self, input_shape): def build(self, input_shape):
...@@ -156,7 +159,6 @@ class ReuseTransformer(tf.keras.layers.Layer): ...@@ -156,7 +159,6 @@ class ReuseTransformer(tf.keras.layers.Layer):
else: else:
self._attention_head_size = self._head_size self._attention_head_size = self._head_size
common_kwargs = dict( common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
...@@ -168,6 +170,7 @@ class ReuseTransformer(tf.keras.layers.Layer): ...@@ -168,6 +170,7 @@ class ReuseTransformer(tf.keras.layers.Layer):
dropout=self._attention_dropout, dropout=self._attention_dropout,
use_bias=self._use_bias, use_bias=self._use_bias,
kernel_initializer=self._attention_initializer, kernel_initializer=self._attention_initializer,
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
attention_axes=self._attention_axes, attention_axes=self._attention_axes,
reuse_attention=self._reuse_attention, reuse_attention=self._reuse_attention,
use_relative_pe=self._use_relative_pe, use_relative_pe=self._use_relative_pe,
...@@ -188,7 +191,8 @@ class ReuseTransformer(tf.keras.layers.Layer): ...@@ -188,7 +191,8 @@ class ReuseTransformer(tf.keras.layers.Layer):
einsum_equation, einsum_equation,
output_shape=(None, self._inner_dim), output_shape=(None, self._inner_dim),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="intermediate", name="intermediate",
**common_kwargs) **common_kwargs)
policy = tf.keras.mixed_precision.global_policy() policy = tf.keras.mixed_precision.global_policy()
...@@ -206,7 +210,8 @@ class ReuseTransformer(tf.keras.layers.Layer): ...@@ -206,7 +210,8 @@ class ReuseTransformer(tf.keras.layers.Layer):
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
name="output", name="output",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout) self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import util from official.nlp.modeling.layers import util
...@@ -121,8 +122,6 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -121,8 +122,6 @@ class ReZeroTransformer(tf.keras.layers.Layer):
"heads (%d)" % (hidden_size, self._num_heads)) "heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads) self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict( common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
...@@ -133,6 +132,8 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -133,6 +132,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
key_dim=self._attention_head_size, key_dim=self._attention_head_size,
dropout=self._attention_dropout_rate, dropout=self._attention_dropout_rate,
name="self_attention", name="self_attention",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm: if self._use_layer_norm:
...@@ -149,6 +150,8 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -149,6 +150,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
output_shape=(None, self._intermediate_size), output_shape=(None, self._intermediate_size),
bias_axes="d", bias_axes="d",
name="intermediate", name="intermediate",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
policy = tf.keras.mixed_precision.global_policy() policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16": if policy.name == "mixed_bfloat16":
...@@ -163,6 +166,8 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -163,6 +166,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
name="output", name="output",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm: if self._use_layer_norm:
......
...@@ -84,11 +84,10 @@ class SpectralNormalization(tf.keras.layers.Wrapper): ...@@ -84,11 +84,10 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
self.w = self.layer.kernel self.w = self.layer.kernel
self.w_shape = self.w.shape.as_list() self.w_shape = self.w.shape.as_list()
self.uv_initializer = tf.initializers.random_normal()
self.v = self.add_weight( self.v = self.add_weight(
shape=(1, np.prod(self.w_shape[:-1])), shape=(1, np.prod(self.w_shape[:-1])),
initializer=self.uv_initializer, initializer=tf.initializers.random_normal(),
trainable=False, trainable=False,
name='v', name='v',
dtype=self.dtype, dtype=self.dtype,
...@@ -96,7 +95,7 @@ class SpectralNormalization(tf.keras.layers.Wrapper): ...@@ -96,7 +95,7 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
self.u = self.add_weight( self.u = self.add_weight(
shape=(1, self.w_shape[-1]), shape=(1, self.w_shape[-1]),
initializer=self.uv_initializer, initializer=tf.initializers.random_normal(),
trainable=False, trainable=False,
name='u', name='u',
dtype=self.dtype, dtype=self.dtype,
...@@ -197,7 +196,8 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper): ...@@ -197,7 +196,8 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
super(SpectralNormalizationConv2D, self).__init__(layer, **kwargs) super(SpectralNormalizationConv2D, self).__init__(layer, **kwargs)
def build(self, input_shape): def build(self, input_shape):
self.layer.build(input_shape) if not self.layer.built:
self.layer.build(input_shape)
self.layer.kernel._aggregation = self.aggregation # pylint: disable=protected-access self.layer.kernel._aggregation = self.aggregation # pylint: disable=protected-access
self._dtype = self.layer.kernel.dtype self._dtype = self.layer.kernel.dtype
...@@ -221,11 +221,10 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper): ...@@ -221,11 +221,10 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
self.in_shape = (uv_dim, in_height, in_width, in_channel) self.in_shape = (uv_dim, in_height, in_width, in_channel)
self.out_shape = (uv_dim, out_height, out_width, out_channel) self.out_shape = (uv_dim, out_height, out_width, out_channel)
self.uv_initializer = tf.initializers.random_normal()
self.v = self.add_weight( self.v = self.add_weight(
shape=self.in_shape, shape=self.in_shape,
initializer=self.uv_initializer, initializer=tf.initializers.random_normal(),
trainable=False, trainable=False,
name='v', name='v',
dtype=self.dtype, dtype=self.dtype,
...@@ -233,7 +232,7 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper): ...@@ -233,7 +232,7 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
self.u = self.add_weight( self.u = self.add_weight(
shape=self.out_shape, shape=self.out_shape,
initializer=self.uv_initializer, initializer=tf.initializers.random_normal(),
trainable=False, trainable=False,
name='u', name='u',
dtype=self.dtype, dtype=self.dtype,
......
...@@ -66,7 +66,7 @@ class NormalizationTest(tf.test.TestCase, parameterized.TestCase): ...@@ -66,7 +66,7 @@ class NormalizationTest(tf.test.TestCase, parameterized.TestCase):
spectral_norm_computed = _compute_spectral_norm(normalized_kernel) spectral_norm_computed = _compute_spectral_norm(normalized_kernel)
spectral_norm_expected = self.norm_multiplier spectral_norm_expected = self.norm_multiplier
self.assertAllClose( self.assertAllClose(
spectral_norm_computed, spectral_norm_expected, atol=5e-2) spectral_norm_computed, spectral_norm_expected, atol=1e-1)
# Test that the normalized layer is K-Lipschitz. In particular, if the layer # Test that the normalized layer is K-Lipschitz. In particular, if the layer
# is a function f, then ||f(x1) - f(x2)||_2 <= K * ||(x1 - x2)||_2, where K # is a function f, then ||f(x1) - f(x2)||_2 <= K * ||(x1 - x2)||_2, where K
......
...@@ -20,6 +20,8 @@ import string ...@@ -20,6 +20,8 @@ import string
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
_CHR_IDX = string.ascii_lowercase _CHR_IDX = string.ascii_lowercase
...@@ -87,7 +89,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention): ...@@ -87,7 +89,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention):
self._pre_softmax_weight = self.add_weight( self._pre_softmax_weight = self.add_weight(
"pre_softmax_weight", "pre_softmax_weight",
shape=(self._num_heads, self._num_heads), shape=(self._num_heads, self._num_heads),
initializer=self._kernel_initializer, initializer=tf_utils.clone_initializer(self._kernel_initializer),
regularizer=self._kernel_regularizer, regularizer=self._kernel_regularizer,
constraint=self._kernel_constraint, constraint=self._kernel_constraint,
dtype=self.dtype, dtype=self.dtype,
...@@ -95,7 +97,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention): ...@@ -95,7 +97,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention):
self._post_softmax_weight = self.add_weight( self._post_softmax_weight = self.add_weight(
"post_softmax_weight", "post_softmax_weight",
shape=(self._num_heads, self._num_heads), shape=(self._num_heads, self._num_heads),
initializer=self._kernel_initializer, initializer=tf_utils.clone_initializer(self._kernel_initializer),
regularizer=self._kernel_regularizer, regularizer=self._kernel_regularizer,
constraint=self._kernel_constraint, constraint=self._kernel_constraint,
dtype=self.dtype, dtype=self.dtype,
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
from typing import List, Optional, Text, Any, Dict from typing import List, Optional, Text, Any, Dict
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
Layer = tf.keras.layers.Layer Layer = tf.keras.layers.Layer
activations = tf.keras.activations activations = tf.keras.activations
initializers = tf.keras.initializers initializers = tf.keras.initializers
...@@ -98,24 +100,24 @@ class TNExpandCondense(Layer): ...@@ -98,24 +100,24 @@ class TNExpandCondense(Layer):
name='w1', name='w1',
shape=(input_shape[-1], input_shape[-1]), shape=(input_shape[-1], input_shape[-1]),
trainable=True, trainable=True,
initializer=self.kernel_initializer) initializer=tf_utils.clone_initializer(self.kernel_initializer))
self.w2 = self.add_weight( self.w2 = self.add_weight(
name='w2', name='w2',
shape=(128, (128 * (self.proj_size // input_shape[-1]))), shape=(128, (128 * (self.proj_size // input_shape[-1]))),
trainable=True, trainable=True,
initializer=self.kernel_initializer) initializer=tf_utils.clone_initializer(self.kernel_initializer))
self.w3 = self.add_weight( self.w3 = self.add_weight(
name='w3', name='w3',
shape=(128 * (self.proj_size // input_shape[-1]), 128), shape=(128 * (self.proj_size // input_shape[-1]), 128),
trainable=True, trainable=True,
initializer=self.kernel_initializer) initializer=tf_utils.clone_initializer(self.kernel_initializer))
self.w4 = self.add_weight( self.w4 = self.add_weight(
name='w4', name='w4',
shape=(input_shape[-1] // 128, 128, input_shape[-1]), shape=(input_shape[-1] // 128, 128, input_shape[-1]),
trainable=True, trainable=True,
initializer=self.kernel_initializer) initializer=tf_utils.clone_initializer(self.kernel_initializer))
if self.use_bias: if self.use_bias:
self.bias = self.add_weight( self.bias = self.add_weight(
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers.tn_expand_condense import TNExpandCondense from official.nlp.modeling.layers.tn_expand_condense import TNExpandCondense
...@@ -100,7 +101,8 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer): ...@@ -100,7 +101,8 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get( self._attention_initializer = tf.keras.initializers.get(
attention_initializer) attention_initializer)
else: else:
self._attention_initializer = self._kernel_initializer self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
def build(self, input_shape): def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
...@@ -128,7 +130,6 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer): ...@@ -128,7 +130,6 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer):
"heads (%d)" % (hidden_size, self._num_heads)) "heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads) self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict( common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
...@@ -140,6 +141,7 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer): ...@@ -140,6 +141,7 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer):
dropout=self._attention_dropout_rate, dropout=self._attention_dropout_rate,
use_bias=self._use_bias, use_bias=self._use_bias,
kernel_initializer=self._attention_initializer, kernel_initializer=self._attention_initializer,
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="self_attention", name="self_attention",
**common_kwargs) **common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import attention from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import multi_channel_attention from official.nlp.modeling.layers import multi_channel_attention
from official.nlp.modeling.layers import transformer_encoder_block from official.nlp.modeling.layers import transformer_encoder_block
...@@ -226,7 +227,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -226,7 +227,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get( self._attention_initializer = tf.keras.initializers.get(
attention_initializer) attention_initializer)
else: else:
self._attention_initializer = self._kernel_initializer self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
if self.multi_channel_cross_attention: if self.multi_channel_cross_attention:
self._cross_attention_cls = multi_channel_attention.MultiChannelAttention self._cross_attention_cls = multi_channel_attention.MultiChannelAttention
else: else:
...@@ -244,7 +246,6 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -244,7 +246,6 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
"heads (%d)" % (hidden_size, self.num_attention_heads)) "heads (%d)" % (hidden_size, self.num_attention_heads))
self.attention_head_size = int(hidden_size) // self.num_attention_heads self.attention_head_size = int(hidden_size) // self.num_attention_heads
common_kwargs = dict( common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
...@@ -256,14 +257,17 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -256,14 +257,17 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
key_dim=self.attention_head_size, key_dim=self.attention_head_size,
dropout=self.attention_dropout_rate, dropout=self.attention_dropout_rate,
use_bias=self._use_bias, use_bias=self._use_bias,
kernel_initializer=self._attention_initializer, kernel_initializer=tf_utils.clone_initializer(
self._attention_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="self_attention", name="self_attention",
**common_kwargs) **common_kwargs)
self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense( self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="output", name="output",
**common_kwargs) **common_kwargs)
self.self_attention_dropout = tf.keras.layers.Dropout( self.self_attention_dropout = tf.keras.layers.Dropout(
...@@ -281,7 +285,9 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -281,7 +285,9 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
dropout=self.attention_dropout_rate, dropout=self.attention_dropout_rate,
output_shape=hidden_size, output_shape=hidden_size,
use_bias=self._use_bias, use_bias=self._use_bias,
kernel_initializer=self._attention_initializer, kernel_initializer=tf_utils.clone_initializer(
self._attention_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="attention/encdec", name="attention/encdec",
**common_kwargs) **common_kwargs)
...@@ -299,7 +305,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -299,7 +305,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, self.intermediate_size), output_shape=(None, self.intermediate_size),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="intermediate", name="intermediate",
**common_kwargs) **common_kwargs)
self.intermediate_activation_layer = tf.keras.layers.Activation( self.intermediate_activation_layer = tf.keras.layers.Activation(
...@@ -310,7 +317,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -310,7 +317,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="output", name="output",
**common_kwargs) **common_kwargs)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate) self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import util from official.nlp.modeling.layers import util
...@@ -156,7 +157,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -156,7 +157,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get( self._attention_initializer = tf.keras.initializers.get(
attention_initializer) attention_initializer)
else: else:
self._attention_initializer = self._kernel_initializer self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
self._attention_axes = attention_axes self._attention_axes = attention_axes
if self._diff_q_kv_att_layer_norm and not self._norm_first: if self._diff_q_kv_att_layer_norm and not self._norm_first:
...@@ -188,8 +190,6 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -188,8 +190,6 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
last_output_shape = self._output_last_dim last_output_shape = self._output_last_dim
common_kwargs = dict( common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
...@@ -201,6 +201,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -201,6 +201,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
dropout=self._attention_dropout, dropout=self._attention_dropout,
use_bias=self._use_bias, use_bias=self._use_bias,
kernel_initializer=self._attention_initializer, kernel_initializer=self._attention_initializer,
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
attention_axes=self._attention_axes, attention_axes=self._attention_axes,
output_shape=self._output_last_dim, output_shape=self._output_last_dim,
name="self_attention", name="self_attention",
...@@ -227,7 +228,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -227,7 +228,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
einsum_equation, einsum_equation,
output_shape=(None, self._inner_dim), output_shape=(None, self._inner_dim),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="intermediate", name="intermediate",
**common_kwargs) **common_kwargs)
policy = tf.keras.mixed_precision.global_policy() policy = tf.keras.mixed_precision.global_policy()
...@@ -245,7 +247,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -245,7 +247,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
output_shape=(None, last_output_shape), output_shape=(None, last_output_shape),
bias_axes="d", bias_axes="d",
name="output", name="output",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout) self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
......
...@@ -19,6 +19,7 @@ from absl import logging ...@@ -19,6 +19,7 @@ from absl import logging
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import attention from official.nlp.modeling.layers import attention
...@@ -127,8 +128,6 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -127,8 +128,6 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._attention_head_size = int(hidden_size // self._num_heads) self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict( common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
...@@ -145,6 +144,9 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -145,6 +144,9 @@ class TransformerScaffold(tf.keras.layers.Layer):
return instance_or_cls(**config) return instance_or_cls(**config)
default_attention_cfg = { default_attention_cfg = {
"kernel_initializer": tf_utils.clone_initializer(
self._kernel_initializer),
"bias_initializer": tf_utils.clone_initializer(self._bias_initializer),
"num_heads": self._num_heads, "num_heads": self._num_heads,
"key_dim": self._attention_head_size, "key_dim": self._attention_head_size,
"dropout": self._attention_dropout_rate, "dropout": self._attention_dropout_rate,
...@@ -158,6 +160,10 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -158,6 +160,10 @@ class TransformerScaffold(tf.keras.layers.Layer):
if self._feedforward_cls is not None: if self._feedforward_cls is not None:
default_feedforward_cfg = { default_feedforward_cfg = {
"kernel_initializer": tf_utils.clone_initializer(
self._kernel_initializer),
"bias_initializer": tf_utils.clone_initializer(
self._bias_initializer),
"intermediate_size": self._intermediate_size, "intermediate_size": self._intermediate_size,
"intermediate_activation": self._intermediate_activation, "intermediate_activation": self._intermediate_activation,
"dropout": self._dropout_rate, "dropout": self._dropout_rate,
...@@ -189,6 +195,9 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -189,6 +195,9 @@ class TransformerScaffold(tf.keras.layers.Layer):
output_shape=(None, self._intermediate_size), output_shape=(None, self._intermediate_size),
bias_axes="d", bias_axes="d",
name="intermediate", name="intermediate",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
policy = tf.keras.mixed_precision.global_policy() policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16": if policy.name == "mixed_bfloat16":
...@@ -203,6 +212,9 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -203,6 +212,9 @@ class TransformerScaffold(tf.keras.layers.Layer):
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
name="output", name="output",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
......
...@@ -18,6 +18,7 @@ from absl import logging ...@@ -18,6 +18,7 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import relative_attention from official.nlp.modeling.layers import relative_attention
...@@ -148,7 +149,7 @@ class TransformerXLBlock(tf.keras.layers.Layer): ...@@ -148,7 +149,7 @@ class TransformerXLBlock(tf.keras.layers.Layer):
value_dim=self._head_size, value_dim=self._head_size,
dropout=self._attention_dropout_rate, dropout=self._attention_dropout_rate,
use_bias=False, use_bias=False,
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="rel_attn") name="rel_attn")
self._attention_dropout = tf.keras.layers.Dropout( self._attention_dropout = tf.keras.layers.Dropout(
rate=self._attention_dropout_rate) rate=self._attention_dropout_rate)
...@@ -161,7 +162,7 @@ class TransformerXLBlock(tf.keras.layers.Layer): ...@@ -161,7 +162,7 @@ class TransformerXLBlock(tf.keras.layers.Layer):
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, self._inner_size), output_shape=(None, self._inner_size),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="inner") name="inner")
self._inner_activation_layer = tf.keras.layers.Activation( self._inner_activation_layer = tf.keras.layers.Activation(
...@@ -173,7 +174,7 @@ class TransformerXLBlock(tf.keras.layers.Layer): ...@@ -173,7 +174,7 @@ class TransformerXLBlock(tf.keras.layers.Layer):
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
name="output", name="output",
kernel_initializer=self._kernel_initializer) kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer))
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
self._output_layer_norm = tf.keras.layers.LayerNormalization( self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", name="output_layer_norm",
...@@ -398,17 +399,17 @@ class TransformerXL(tf.keras.layers.Layer): ...@@ -398,17 +399,17 @@ class TransformerXL(tf.keras.layers.Layer):
"content_attention_bias", "content_attention_bias",
shape=attention_bias_shape, shape=attention_bias_shape,
dtype=tf.float32, dtype=tf.float32,
initializer=self._initializer) initializer=tf_utils.clone_initializer(self._initializer))
self.positional_attention_bias = self.add_weight( self.positional_attention_bias = self.add_weight(
"positional_attention_bias", "positional_attention_bias",
shape=attention_bias_shape, shape=attention_bias_shape,
dtype=tf.float32, dtype=tf.float32,
initializer=self._initializer) initializer=tf_utils.clone_initializer(self._initializer))
self.segment_attention_bias = self.add_weight( self.segment_attention_bias = self.add_weight(
"segment_attention_bias", "segment_attention_bias",
shape=attention_bias_shape, shape=attention_bias_shape,
dtype=tf.float32, dtype=tf.float32,
initializer=self._initializer) initializer=tf_utils.clone_initializer(self._initializer))
self.transformer_xl_layers = [] self.transformer_xl_layers = []
for i in range(self._num_layers): for i in range(self._num_layers):
......
...@@ -22,6 +22,7 @@ from absl import logging ...@@ -22,6 +22,7 @@ from absl import logging
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling import networks from official.nlp.modeling import networks
...@@ -102,7 +103,7 @@ class BertPretrainer(tf.keras.Model): ...@@ -102,7 +103,7 @@ class BertPretrainer(tf.keras.Model):
masked_lm = layers.MaskedLM( masked_lm = layers.MaskedLM(
embedding_table=embedding_table, embedding_table=embedding_table,
activation=activation, activation=activation,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
output=output, output=output,
name='cls/predictions') name='cls/predictions')
lm_outputs = masked_lm( lm_outputs = masked_lm(
...@@ -111,7 +112,7 @@ class BertPretrainer(tf.keras.Model): ...@@ -111,7 +112,7 @@ class BertPretrainer(tf.keras.Model):
classification = networks.Classification( classification = networks.Classification(
input_width=cls_output.shape[-1], input_width=cls_output.shape[-1],
num_classes=num_classes, num_classes=num_classes,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
output=output, output=output,
name='classification') name='classification')
sentence_outputs = classification(cls_output) sentence_outputs = classification(cls_output)
......
...@@ -96,21 +96,22 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -96,21 +96,22 @@ class ElectraPretrainer(tf.keras.Model):
self.masked_lm = layers.MaskedLM( self.masked_lm = layers.MaskedLM(
embedding_table=generator_network.get_embedding_table(), embedding_table=generator_network.get_embedding_table(),
activation=mlm_activation, activation=mlm_activation,
initializer=mlm_initializer, initializer=tf_utils.clone_initializer(mlm_initializer),
output=output_type, output=output_type,
name='generator_masked_lm') name='generator_masked_lm')
self.classification = layers.ClassificationHead( self.classification = layers.ClassificationHead(
inner_dim=generator_network.get_config()['hidden_size'], inner_dim=generator_network.get_config()['hidden_size'],
num_classes=num_classes, num_classes=num_classes,
initializer=mlm_initializer, initializer=tf_utils.clone_initializer(mlm_initializer),
name='generator_classification_head') name='generator_classification_head')
self.discriminator_projection = tf.keras.layers.Dense( self.discriminator_projection = tf.keras.layers.Dense(
units=discriminator_network.get_config()['hidden_size'], units=discriminator_network.get_config()['hidden_size'],
activation=mlm_activation, activation=mlm_activation,
kernel_initializer=mlm_initializer, kernel_initializer=tf_utils.clone_initializer(mlm_initializer),
name='discriminator_projection_head') name='discriminator_projection_head')
self.discriminator_head = tf.keras.layers.Dense( self.discriminator_head = tf.keras.layers.Dense(
units=1, kernel_initializer=mlm_initializer) units=1,
kernel_initializer=tf_utils.clone_initializer(mlm_initializer))
def call(self, inputs): def call(self, inputs):
"""ELECTRA forward pass. """ELECTRA forward pass.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment