Commit 6d458bcc authored by Scott Zhu's avatar Scott Zhu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 446219773
parent 1fd7aaaf
......@@ -272,3 +272,14 @@ def cross_replica_concat(value, axis, name="cross_replica_concat"):
if value.shape.as_list()[0] is None:
raise RuntimeError(f"{value} has unknown batch.")
return context.all_gather(value, axis=axis)
def clone_initializer(initializer):
# Keras initializer is going to be stateless, which mean reusing the same
# initializer will produce same init value when the shapes are the same.
if isinstance(initializer, tf.keras.initializers.Initializer):
return initializer.__class__.from_config(initializer.get_config())
# When the input is string/dict or other serialized configs, caller will
# create a new keras Initializer instance based on that, and we don't need to
# do anything
return initializer
......@@ -18,6 +18,8 @@ from typing import Optional
import tensorflow as tf
from official.modeling import tf_utils
class BlockDiagFeedforward(tf.keras.layers.Layer):
"""Block diagonal feedforward layer.
......@@ -80,8 +82,6 @@ class BlockDiagFeedforward(tf.keras.layers.Layer):
hidden_size = input_shape.as_list()[-1]
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
......@@ -94,6 +94,8 @@ class BlockDiagFeedforward(tf.keras.layers.Layer):
self._intermediate_size // self._num_blocks),
bias_axes="de",
name="intermediate",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
policy = tf.keras.mixed_precision.global_policy()
......@@ -110,6 +112,8 @@ class BlockDiagFeedforward(tf.keras.layers.Layer):
hidden_size // self._num_blocks),
bias_axes="do",
name="output",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
if self._apply_mixing:
......@@ -118,6 +122,9 @@ class BlockDiagFeedforward(tf.keras.layers.Layer):
output_shape=(None, self._num_blocks,
hidden_size // self._num_blocks),
name="output_mixing",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
self._output_reshape = tf.keras.layers.Reshape((-1, hidden_size))
......
......@@ -57,12 +57,14 @@ class ClassificationHead(tf.keras.layers.Layer):
self.dense = tf.keras.layers.Dense(
units=self.inner_dim,
activation=self.activation,
kernel_initializer=self.initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name="pooler_dense")
self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.out_proj = tf.keras.layers.Dense(
units=num_classes, kernel_initializer=self.initializer, name="logits")
units=num_classes,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name="logits")
def call(self, features: tf.Tensor, only_project: bool = False):
"""Implements call().
......@@ -146,14 +148,15 @@ class MultiClsHeads(tf.keras.layers.Layer):
self.dense = tf.keras.layers.Dense(
units=inner_dim,
activation=self.activation,
kernel_initializer=self.initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name="pooler_dense")
self.dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
self.out_projs = []
for name, num_classes in cls_list:
self.out_projs.append(
tf.keras.layers.Dense(
units=num_classes, kernel_initializer=self.initializer,
units=num_classes,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name=name))
def call(self, features: tf.Tensor, only_project: bool = False):
......@@ -277,7 +280,7 @@ class GaussianProcessClassificationHead(ClassificationHead):
if use_gp_layer:
self.out_proj = gaussian_process.RandomFeatureGaussianProcess(
self.num_classes,
kernel_initializer=self.initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name="logits",
**self.gp_layer_kwargs)
......
......@@ -18,6 +18,8 @@
import gin
import tensorflow as tf
from official.modeling import tf_utils
@tf.keras.utils.register_keras_serializable(package="Text")
@gin.configurable
......@@ -95,8 +97,6 @@ class GatedFeedforward(tf.keras.layers.Layer):
hidden_size = input_shape.as_list()[-1]
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
......@@ -121,6 +121,10 @@ class GatedFeedforward(tf.keras.layers.Layer):
output_shape=(None, self._intermediate_size),
bias_axes="d",
name="intermediate_%d" % i,
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs))
self._intermediate_activation_layers.append(
tf.keras.layers.Activation(
......@@ -132,6 +136,10 @@ class GatedFeedforward(tf.keras.layers.Layer):
output_shape=(None, self._intermediate_size),
bias_axes="d",
name="gate_%d" % i,
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs))
self._output_dense.append(
tf.keras.layers.experimental.EinsumDense(
......@@ -139,6 +147,10 @@ class GatedFeedforward(tf.keras.layers.Layer):
output_shape=(None, hidden_size),
bias_axes="d",
name="output_%d" % i,
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs))
self._output_dropout.append(tf.keras.layers.Dropout(rate=self._dropout))
# Use float32 in layernorm for numeric stability.
......
......@@ -15,6 +15,8 @@
"""MobileBERT embedding and transformer layers."""
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import on_device_embedding
from official.nlp.modeling.layers import position_embedding
......@@ -109,21 +111,21 @@ class MobileBertEmbedding(tf.keras.layers.Layer):
self.word_embedding = on_device_embedding.OnDeviceEmbedding(
self.word_vocab_size,
self.word_embed_size,
initializer=initializer,
initializer=tf_utils.clone_initializer(self.initializer),
name='word_embedding')
self.type_embedding = on_device_embedding.OnDeviceEmbedding(
self.type_vocab_size,
self.output_embed_size,
initializer=initializer,
initializer=tf_utils.clone_initializer(self.initializer),
name='type_embedding')
self.pos_embedding = position_embedding.PositionEmbedding(
max_length=max_sequence_length,
initializer=initializer,
initializer=tf_utils.clone_initializer(self.initializer),
name='position_embedding')
self.word_embedding_proj = tf.keras.layers.experimental.EinsumDense(
'abc,cd->abd',
output_shape=[None, self.output_embed_size],
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
bias_axes='d',
name='embedding_projection')
self.layer_norm = _get_norm_layer(normalization_type, 'embedding_norm')
......@@ -246,7 +248,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
'abc,cd->abd',
output_shape=[None, self.intra_bottleneck_size],
bias_axes='d',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='bottleneck_input/dense')
layer_norm = _get_norm_layer(self.normalization_type,
name='bottleneck_input/norm')
......@@ -258,7 +260,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
'abc,cd->abd',
output_shape=[None, self.intra_bottleneck_size],
bias_axes='d',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='kq_shared_bottleneck/dense')
layer_norm = _get_norm_layer(self.normalization_type,
name='kq_shared_bottleneck/norm')
......@@ -272,7 +274,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
value_dim=attention_head_size,
dropout=self.attention_probs_dropout_prob,
output_shape=self.intra_bottleneck_size,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='attention')
layer_norm = _get_norm_layer(self.normalization_type,
name='attention/norm')
......@@ -289,14 +291,14 @@ class MobileBertTransformer(tf.keras.layers.Layer):
activation=self.intermediate_act_fn,
output_shape=[None, self.intermediate_size],
bias_axes='d',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name=layer_name)
layer_name = layer_prefix + '/output_dense'
output_layer = tf.keras.layers.experimental.EinsumDense(
'abc,cd->abd',
output_shape=[None, self.intra_bottleneck_size],
bias_axes='d',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name=layer_name)
layer_name = layer_prefix + '/norm'
layer_norm = _get_norm_layer(self.normalization_type,
......@@ -311,7 +313,7 @@ class MobileBertTransformer(tf.keras.layers.Layer):
output_shape=[None, self.hidden_size],
activation=None,
bias_axes='d',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='bottleneck_output/dense')
dropout_layer = tf.keras.layers.Dropout(
self.hidden_dropout_prob,
......@@ -474,14 +476,14 @@ class MobileBertMaskedLM(tf.keras.layers.Layer):
self.dense = tf.keras.layers.Dense(
hidden_size,
activation=self.activation,
kernel_initializer=self.initializer,
kernel_initializer=tf_utils.clone_initializer(self.initializer),
name='transform/dense')
if hidden_size > embedding_width:
self.extra_output_weights = self.add_weight(
'extra_output_weights',
shape=(self._vocab_size, hidden_size - embedding_width),
initializer=self.initializer,
initializer=tf_utils.clone_initializer(self.initializer),
trainable=True)
elif hidden_size == embedding_width:
self.extra_output_weights = None
......
......@@ -18,6 +18,7 @@
import math
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import masked_softmax
......@@ -60,8 +61,6 @@ class VotingAttention(tf.keras.layers.Layer):
def build(self, unused_input_shapes):
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
......@@ -72,12 +71,16 @@ class VotingAttention(tf.keras.layers.Layer):
output_shape=(None, self._num_heads, self._head_size),
bias_axes="NH",
name="query",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
self._key_dense = tf.keras.layers.experimental.EinsumDense(
"BAE,ENH->BANH",
output_shape=(None, self._num_heads, self._head_size),
bias_axes="NH",
name="key",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
super(VotingAttention, self).build(unused_input_shapes)
......
......@@ -22,6 +22,8 @@ import string
import numpy as np
import tensorflow as tf
from official.modeling import tf_utils
_CHR_IDX = string.ascii_lowercase
......@@ -347,8 +349,6 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
self._key_shape = tf.TensorShape(key)
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
......@@ -368,6 +368,10 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
self._num_heads - self._reuse_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="query",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
self._key_shape.rank - 1, bound_dims=1, output_dims=2)
......@@ -377,6 +381,10 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
self._num_heads - self._reuse_heads, self._key_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="key",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
self._value_shape.rank - 1, bound_dims=1, output_dims=2)
......@@ -389,6 +397,10 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
output_rank - 1, [self._reuse_heads, self._value_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="value_reuse",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs))
if self._reuse_heads < self._num_heads:
self._value_dense.append(tf.keras.layers.experimental.EinsumDense(
......@@ -397,6 +409,10 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
self._num_heads - self._reuse_heads, self._value_dim]),
bias_axes=bias_axes if self._use_bias else None,
name="value_new",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs))
# Builds the attention computations for multi-head dot product attention.
......@@ -439,6 +455,10 @@ class ReuseMultiHeadAttention(tf.keras.layers.Layer):
output_shape=_get_output_shape(output_rank - 1, output_shape),
bias_axes=bias_axes if (use_bias and self._use_bias) else None,
name=name,
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(
self._bias_initializer),
**common_kwargs)
def _build_attention(self, rank):
......
......@@ -14,6 +14,8 @@
"""Keras-based TransformerEncoder block layer."""
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import reuse_attention as attention
......@@ -131,7 +133,8 @@ class ReuseTransformer(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
self._attention_axes = attention_axes
def build(self, input_shape):
......@@ -156,7 +159,6 @@ class ReuseTransformer(tf.keras.layers.Layer):
else:
self._attention_head_size = self._head_size
common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
......@@ -168,6 +170,7 @@ class ReuseTransformer(tf.keras.layers.Layer):
dropout=self._attention_dropout,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
attention_axes=self._attention_axes,
reuse_attention=self._reuse_attention,
use_relative_pe=self._use_relative_pe,
......@@ -188,7 +191,8 @@ class ReuseTransformer(tf.keras.layers.Layer):
einsum_equation,
output_shape=(None, self._inner_dim),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="intermediate",
**common_kwargs)
policy = tf.keras.mixed_precision.global_policy()
......@@ -206,7 +210,8 @@ class ReuseTransformer(tf.keras.layers.Layer):
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability.
......
......@@ -18,6 +18,7 @@
import gin
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import util
......@@ -121,8 +122,6 @@ class ReZeroTransformer(tf.keras.layers.Layer):
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
......@@ -133,6 +132,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
key_dim=self._attention_head_size,
dropout=self._attention_dropout_rate,
name="self_attention",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm:
......@@ -149,6 +150,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
output_shape=(None, self._intermediate_size),
bias_axes="d",
name="intermediate",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16":
......@@ -163,6 +166,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm:
......
......@@ -84,11 +84,10 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
self.w = self.layer.kernel
self.w_shape = self.w.shape.as_list()
self.uv_initializer = tf.initializers.random_normal()
self.v = self.add_weight(
shape=(1, np.prod(self.w_shape[:-1])),
initializer=self.uv_initializer,
initializer=tf.initializers.random_normal(),
trainable=False,
name='v',
dtype=self.dtype,
......@@ -96,7 +95,7 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
self.u = self.add_weight(
shape=(1, self.w_shape[-1]),
initializer=self.uv_initializer,
initializer=tf.initializers.random_normal(),
trainable=False,
name='u',
dtype=self.dtype,
......@@ -197,7 +196,8 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
super(SpectralNormalizationConv2D, self).__init__(layer, **kwargs)
def build(self, input_shape):
self.layer.build(input_shape)
if not self.layer.built:
self.layer.build(input_shape)
self.layer.kernel._aggregation = self.aggregation # pylint: disable=protected-access
self._dtype = self.layer.kernel.dtype
......@@ -221,11 +221,10 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
self.in_shape = (uv_dim, in_height, in_width, in_channel)
self.out_shape = (uv_dim, out_height, out_width, out_channel)
self.uv_initializer = tf.initializers.random_normal()
self.v = self.add_weight(
shape=self.in_shape,
initializer=self.uv_initializer,
initializer=tf.initializers.random_normal(),
trainable=False,
name='v',
dtype=self.dtype,
......@@ -233,7 +232,7 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
self.u = self.add_weight(
shape=self.out_shape,
initializer=self.uv_initializer,
initializer=tf.initializers.random_normal(),
trainable=False,
name='u',
dtype=self.dtype,
......
......@@ -66,7 +66,7 @@ class NormalizationTest(tf.test.TestCase, parameterized.TestCase):
spectral_norm_computed = _compute_spectral_norm(normalized_kernel)
spectral_norm_expected = self.norm_multiplier
self.assertAllClose(
spectral_norm_computed, spectral_norm_expected, atol=5e-2)
spectral_norm_computed, spectral_norm_expected, atol=1e-1)
# Test that the normalized layer is K-Lipschitz. In particular, if the layer
# is a function f, then ||f(x1) - f(x2)||_2 <= K * ||(x1 - x2)||_2, where K
......
......@@ -20,6 +20,8 @@ import string
import gin
import tensorflow as tf
from official.modeling import tf_utils
_CHR_IDX = string.ascii_lowercase
......@@ -87,7 +89,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention):
self._pre_softmax_weight = self.add_weight(
"pre_softmax_weight",
shape=(self._num_heads, self._num_heads),
initializer=self._kernel_initializer,
initializer=tf_utils.clone_initializer(self._kernel_initializer),
regularizer=self._kernel_regularizer,
constraint=self._kernel_constraint,
dtype=self.dtype,
......@@ -95,7 +97,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention):
self._post_softmax_weight = self.add_weight(
"post_softmax_weight",
shape=(self._num_heads, self._num_heads),
initializer=self._kernel_initializer,
initializer=tf_utils.clone_initializer(self._kernel_initializer),
regularizer=self._kernel_regularizer,
constraint=self._kernel_constraint,
dtype=self.dtype,
......
......@@ -17,6 +17,8 @@
from typing import List, Optional, Text, Any, Dict
import tensorflow as tf
from official.modeling import tf_utils
Layer = tf.keras.layers.Layer
activations = tf.keras.activations
initializers = tf.keras.initializers
......@@ -98,24 +100,24 @@ class TNExpandCondense(Layer):
name='w1',
shape=(input_shape[-1], input_shape[-1]),
trainable=True,
initializer=self.kernel_initializer)
initializer=tf_utils.clone_initializer(self.kernel_initializer))
self.w2 = self.add_weight(
name='w2',
shape=(128, (128 * (self.proj_size // input_shape[-1]))),
trainable=True,
initializer=self.kernel_initializer)
initializer=tf_utils.clone_initializer(self.kernel_initializer))
self.w3 = self.add_weight(
name='w3',
shape=(128 * (self.proj_size // input_shape[-1]), 128),
trainable=True,
initializer=self.kernel_initializer)
initializer=tf_utils.clone_initializer(self.kernel_initializer))
self.w4 = self.add_weight(
name='w4',
shape=(input_shape[-1] // 128, 128, input_shape[-1]),
trainable=True,
initializer=self.kernel_initializer)
initializer=tf_utils.clone_initializer(self.kernel_initializer))
if self.use_bias:
self.bias = self.add_weight(
......
......@@ -19,6 +19,7 @@
import gin
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers.tn_expand_condense import TNExpandCondense
......@@ -100,7 +101,8 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
......@@ -128,7 +130,6 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer):
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
......@@ -140,6 +141,7 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer):
dropout=self._attention_dropout_rate,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="self_attention",
**common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
......
......@@ -18,6 +18,7 @@
import gin
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import multi_channel_attention
from official.nlp.modeling.layers import transformer_encoder_block
......@@ -226,7 +227,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
if self.multi_channel_cross_attention:
self._cross_attention_cls = multi_channel_attention.MultiChannelAttention
else:
......@@ -244,7 +246,6 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
"heads (%d)" % (hidden_size, self.num_attention_heads))
self.attention_head_size = int(hidden_size) // self.num_attention_heads
common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
......@@ -256,14 +257,17 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
key_dim=self.attention_head_size,
dropout=self.attention_dropout_rate,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
kernel_initializer=tf_utils.clone_initializer(
self._attention_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="self_attention",
**common_kwargs)
self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="output",
**common_kwargs)
self.self_attention_dropout = tf.keras.layers.Dropout(
......@@ -281,7 +285,9 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
dropout=self.attention_dropout_rate,
output_shape=hidden_size,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
kernel_initializer=tf_utils.clone_initializer(
self._attention_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="attention/encdec",
**common_kwargs)
......@@ -299,7 +305,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
"abc,cd->abd",
output_shape=(None, self.intermediate_size),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="intermediate",
**common_kwargs)
self.intermediate_activation_layer = tf.keras.layers.Activation(
......@@ -310,7 +317,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="output",
**common_kwargs)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
......
......@@ -17,6 +17,7 @@
from absl import logging
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import util
......@@ -156,7 +157,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
self._attention_axes = attention_axes
if self._diff_q_kv_att_layer_norm and not self._norm_first:
......@@ -188,8 +190,6 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
last_output_shape = self._output_last_dim
common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
......@@ -201,6 +201,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
dropout=self._attention_dropout,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
attention_axes=self._attention_axes,
output_shape=self._output_last_dim,
name="self_attention",
......@@ -227,7 +228,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
einsum_equation,
output_shape=(None, self._inner_dim),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="intermediate",
**common_kwargs)
policy = tf.keras.mixed_precision.global_policy()
......@@ -245,7 +247,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
output_shape=(None, last_output_shape),
bias_axes="d",
name="output",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability.
......
......@@ -19,6 +19,7 @@ from absl import logging
import gin
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import attention
......@@ -127,8 +128,6 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
......@@ -145,6 +144,9 @@ class TransformerScaffold(tf.keras.layers.Layer):
return instance_or_cls(**config)
default_attention_cfg = {
"kernel_initializer": tf_utils.clone_initializer(
self._kernel_initializer),
"bias_initializer": tf_utils.clone_initializer(self._bias_initializer),
"num_heads": self._num_heads,
"key_dim": self._attention_head_size,
"dropout": self._attention_dropout_rate,
......@@ -158,6 +160,10 @@ class TransformerScaffold(tf.keras.layers.Layer):
if self._feedforward_cls is not None:
default_feedforward_cfg = {
"kernel_initializer": tf_utils.clone_initializer(
self._kernel_initializer),
"bias_initializer": tf_utils.clone_initializer(
self._bias_initializer),
"intermediate_size": self._intermediate_size,
"intermediate_activation": self._intermediate_activation,
"dropout": self._dropout_rate,
......@@ -189,6 +195,9 @@ class TransformerScaffold(tf.keras.layers.Layer):
output_shape=(None, self._intermediate_size),
bias_axes="d",
name="intermediate",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16":
......@@ -203,6 +212,9 @@ class TransformerScaffold(tf.keras.layers.Layer):
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
......
......@@ -18,6 +18,7 @@ from absl import logging
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import relative_attention
......@@ -148,7 +149,7 @@ class TransformerXLBlock(tf.keras.layers.Layer):
value_dim=self._head_size,
dropout=self._attention_dropout_rate,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="rel_attn")
self._attention_dropout = tf.keras.layers.Dropout(
rate=self._attention_dropout_rate)
......@@ -161,7 +162,7 @@ class TransformerXLBlock(tf.keras.layers.Layer):
"abc,cd->abd",
output_shape=(None, self._inner_size),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="inner")
self._inner_activation_layer = tf.keras.layers.Activation(
......@@ -173,7 +174,7 @@ class TransformerXLBlock(tf.keras.layers.Layer):
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
kernel_initializer=self._kernel_initializer)
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer))
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm",
......@@ -398,17 +399,17 @@ class TransformerXL(tf.keras.layers.Layer):
"content_attention_bias",
shape=attention_bias_shape,
dtype=tf.float32,
initializer=self._initializer)
initializer=tf_utils.clone_initializer(self._initializer))
self.positional_attention_bias = self.add_weight(
"positional_attention_bias",
shape=attention_bias_shape,
dtype=tf.float32,
initializer=self._initializer)
initializer=tf_utils.clone_initializer(self._initializer))
self.segment_attention_bias = self.add_weight(
"segment_attention_bias",
shape=attention_bias_shape,
dtype=tf.float32,
initializer=self._initializer)
initializer=tf_utils.clone_initializer(self._initializer))
self.transformer_xl_layers = []
for i in range(self._num_layers):
......
......@@ -22,6 +22,7 @@ from absl import logging
import gin
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp.modeling import networks
......@@ -102,7 +103,7 @@ class BertPretrainer(tf.keras.Model):
masked_lm = layers.MaskedLM(
embedding_table=embedding_table,
activation=activation,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
output=output,
name='cls/predictions')
lm_outputs = masked_lm(
......@@ -111,7 +112,7 @@ class BertPretrainer(tf.keras.Model):
classification = networks.Classification(
input_width=cls_output.shape[-1],
num_classes=num_classes,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
output=output,
name='classification')
sentence_outputs = classification(cls_output)
......
......@@ -96,21 +96,22 @@ class ElectraPretrainer(tf.keras.Model):
self.masked_lm = layers.MaskedLM(
embedding_table=generator_network.get_embedding_table(),
activation=mlm_activation,
initializer=mlm_initializer,
initializer=tf_utils.clone_initializer(mlm_initializer),
output=output_type,
name='generator_masked_lm')
self.classification = layers.ClassificationHead(
inner_dim=generator_network.get_config()['hidden_size'],
num_classes=num_classes,
initializer=mlm_initializer,
initializer=tf_utils.clone_initializer(mlm_initializer),
name='generator_classification_head')
self.discriminator_projection = tf.keras.layers.Dense(
units=discriminator_network.get_config()['hidden_size'],
activation=mlm_activation,
kernel_initializer=mlm_initializer,
kernel_initializer=tf_utils.clone_initializer(mlm_initializer),
name='discriminator_projection_head')
self.discriminator_head = tf.keras.layers.Dense(
units=1, kernel_initializer=mlm_initializer)
units=1,
kernel_initializer=tf_utils.clone_initializer(mlm_initializer))
def call(self, inputs):
"""ELECTRA forward pass.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment