Unverified Commit aef943ed authored by SunJong Park's avatar SunJong Park Committed by GitHub
Browse files

Merge branch 'tensorflow:master' into master

parents 67ad909d 930abe21
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
"""Keras-based TransformerEncoder block layer.""" """Keras-based TransformerEncoder block layer."""
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import reuse_attention as attention from official.nlp.modeling.layers import reuse_attention as attention
...@@ -131,7 +133,8 @@ class ReuseTransformer(tf.keras.layers.Layer): ...@@ -131,7 +133,8 @@ class ReuseTransformer(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get( self._attention_initializer = tf.keras.initializers.get(
attention_initializer) attention_initializer)
else: else:
self._attention_initializer = self._kernel_initializer self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
self._attention_axes = attention_axes self._attention_axes = attention_axes
def build(self, input_shape): def build(self, input_shape):
...@@ -156,7 +159,6 @@ class ReuseTransformer(tf.keras.layers.Layer): ...@@ -156,7 +159,6 @@ class ReuseTransformer(tf.keras.layers.Layer):
else: else:
self._attention_head_size = self._head_size self._attention_head_size = self._head_size
common_kwargs = dict( common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
...@@ -168,6 +170,7 @@ class ReuseTransformer(tf.keras.layers.Layer): ...@@ -168,6 +170,7 @@ class ReuseTransformer(tf.keras.layers.Layer):
dropout=self._attention_dropout, dropout=self._attention_dropout,
use_bias=self._use_bias, use_bias=self._use_bias,
kernel_initializer=self._attention_initializer, kernel_initializer=self._attention_initializer,
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
attention_axes=self._attention_axes, attention_axes=self._attention_axes,
reuse_attention=self._reuse_attention, reuse_attention=self._reuse_attention,
use_relative_pe=self._use_relative_pe, use_relative_pe=self._use_relative_pe,
...@@ -188,7 +191,8 @@ class ReuseTransformer(tf.keras.layers.Layer): ...@@ -188,7 +191,8 @@ class ReuseTransformer(tf.keras.layers.Layer):
einsum_equation, einsum_equation,
output_shape=(None, self._inner_dim), output_shape=(None, self._inner_dim),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="intermediate", name="intermediate",
**common_kwargs) **common_kwargs)
policy = tf.keras.mixed_precision.global_policy() policy = tf.keras.mixed_precision.global_policy()
...@@ -206,7 +210,8 @@ class ReuseTransformer(tf.keras.layers.Layer): ...@@ -206,7 +210,8 @@ class ReuseTransformer(tf.keras.layers.Layer):
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
name="output", name="output",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout) self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import util from official.nlp.modeling.layers import util
...@@ -121,8 +122,6 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -121,8 +122,6 @@ class ReZeroTransformer(tf.keras.layers.Layer):
"heads (%d)" % (hidden_size, self._num_heads)) "heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads) self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict( common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
...@@ -133,6 +132,8 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -133,6 +132,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
key_dim=self._attention_head_size, key_dim=self._attention_head_size,
dropout=self._attention_dropout_rate, dropout=self._attention_dropout_rate,
name="self_attention", name="self_attention",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm: if self._use_layer_norm:
...@@ -149,6 +150,8 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -149,6 +150,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
output_shape=(None, self._intermediate_size), output_shape=(None, self._intermediate_size),
bias_axes="d", bias_axes="d",
name="intermediate", name="intermediate",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
policy = tf.keras.mixed_precision.global_policy() policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16": if policy.name == "mixed_bfloat16":
...@@ -163,6 +166,8 @@ class ReZeroTransformer(tf.keras.layers.Layer): ...@@ -163,6 +166,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
name="output", name="output",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm: if self._use_layer_norm:
......
...@@ -84,11 +84,10 @@ class SpectralNormalization(tf.keras.layers.Wrapper): ...@@ -84,11 +84,10 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
self.w = self.layer.kernel self.w = self.layer.kernel
self.w_shape = self.w.shape.as_list() self.w_shape = self.w.shape.as_list()
self.uv_initializer = tf.initializers.random_normal()
self.v = self.add_weight( self.v = self.add_weight(
shape=(1, np.prod(self.w_shape[:-1])), shape=(1, np.prod(self.w_shape[:-1])),
initializer=self.uv_initializer, initializer=tf.initializers.random_normal(),
trainable=False, trainable=False,
name='v', name='v',
dtype=self.dtype, dtype=self.dtype,
...@@ -96,7 +95,7 @@ class SpectralNormalization(tf.keras.layers.Wrapper): ...@@ -96,7 +95,7 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
self.u = self.add_weight( self.u = self.add_weight(
shape=(1, self.w_shape[-1]), shape=(1, self.w_shape[-1]),
initializer=self.uv_initializer, initializer=tf.initializers.random_normal(),
trainable=False, trainable=False,
name='u', name='u',
dtype=self.dtype, dtype=self.dtype,
...@@ -197,7 +196,8 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper): ...@@ -197,7 +196,8 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
super(SpectralNormalizationConv2D, self).__init__(layer, **kwargs) super(SpectralNormalizationConv2D, self).__init__(layer, **kwargs)
def build(self, input_shape): def build(self, input_shape):
self.layer.build(input_shape) if not self.layer.built:
self.layer.build(input_shape)
self.layer.kernel._aggregation = self.aggregation # pylint: disable=protected-access self.layer.kernel._aggregation = self.aggregation # pylint: disable=protected-access
self._dtype = self.layer.kernel.dtype self._dtype = self.layer.kernel.dtype
...@@ -221,11 +221,10 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper): ...@@ -221,11 +221,10 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
self.in_shape = (uv_dim, in_height, in_width, in_channel) self.in_shape = (uv_dim, in_height, in_width, in_channel)
self.out_shape = (uv_dim, out_height, out_width, out_channel) self.out_shape = (uv_dim, out_height, out_width, out_channel)
self.uv_initializer = tf.initializers.random_normal()
self.v = self.add_weight( self.v = self.add_weight(
shape=self.in_shape, shape=self.in_shape,
initializer=self.uv_initializer, initializer=tf.initializers.random_normal(),
trainable=False, trainable=False,
name='v', name='v',
dtype=self.dtype, dtype=self.dtype,
...@@ -233,7 +232,7 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper): ...@@ -233,7 +232,7 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
self.u = self.add_weight( self.u = self.add_weight(
shape=self.out_shape, shape=self.out_shape,
initializer=self.uv_initializer, initializer=tf.initializers.random_normal(),
trainable=False, trainable=False,
name='u', name='u',
dtype=self.dtype, dtype=self.dtype,
......
...@@ -66,7 +66,7 @@ class NormalizationTest(tf.test.TestCase, parameterized.TestCase): ...@@ -66,7 +66,7 @@ class NormalizationTest(tf.test.TestCase, parameterized.TestCase):
spectral_norm_computed = _compute_spectral_norm(normalized_kernel) spectral_norm_computed = _compute_spectral_norm(normalized_kernel)
spectral_norm_expected = self.norm_multiplier spectral_norm_expected = self.norm_multiplier
self.assertAllClose( self.assertAllClose(
spectral_norm_computed, spectral_norm_expected, atol=5e-2) spectral_norm_computed, spectral_norm_expected, atol=1e-1)
# Test that the normalized layer is K-Lipschitz. In particular, if the layer # Test that the normalized layer is K-Lipschitz. In particular, if the layer
# is a function f, then ||f(x1) - f(x2)||_2 <= K * ||(x1 - x2)||_2, where K # is a function f, then ||f(x1) - f(x2)||_2 <= K * ||(x1 - x2)||_2, where K
......
...@@ -20,6 +20,8 @@ import string ...@@ -20,6 +20,8 @@ import string
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
_CHR_IDX = string.ascii_lowercase _CHR_IDX = string.ascii_lowercase
...@@ -87,7 +89,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention): ...@@ -87,7 +89,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention):
self._pre_softmax_weight = self.add_weight( self._pre_softmax_weight = self.add_weight(
"pre_softmax_weight", "pre_softmax_weight",
shape=(self._num_heads, self._num_heads), shape=(self._num_heads, self._num_heads),
initializer=self._kernel_initializer, initializer=tf_utils.clone_initializer(self._kernel_initializer),
regularizer=self._kernel_regularizer, regularizer=self._kernel_regularizer,
constraint=self._kernel_constraint, constraint=self._kernel_constraint,
dtype=self.dtype, dtype=self.dtype,
...@@ -95,7 +97,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention): ...@@ -95,7 +97,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention):
self._post_softmax_weight = self.add_weight( self._post_softmax_weight = self.add_weight(
"post_softmax_weight", "post_softmax_weight",
shape=(self._num_heads, self._num_heads), shape=(self._num_heads, self._num_heads),
initializer=self._kernel_initializer, initializer=tf_utils.clone_initializer(self._kernel_initializer),
regularizer=self._kernel_regularizer, regularizer=self._kernel_regularizer,
constraint=self._kernel_constraint, constraint=self._kernel_constraint,
dtype=self.dtype, dtype=self.dtype,
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
from typing import List, Optional, Text, Any, Dict from typing import List, Optional, Text, Any, Dict
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
Layer = tf.keras.layers.Layer Layer = tf.keras.layers.Layer
activations = tf.keras.activations activations = tf.keras.activations
initializers = tf.keras.initializers initializers = tf.keras.initializers
...@@ -98,24 +100,24 @@ class TNExpandCondense(Layer): ...@@ -98,24 +100,24 @@ class TNExpandCondense(Layer):
name='w1', name='w1',
shape=(input_shape[-1], input_shape[-1]), shape=(input_shape[-1], input_shape[-1]),
trainable=True, trainable=True,
initializer=self.kernel_initializer) initializer=tf_utils.clone_initializer(self.kernel_initializer))
self.w2 = self.add_weight( self.w2 = self.add_weight(
name='w2', name='w2',
shape=(128, (128 * (self.proj_size // input_shape[-1]))), shape=(128, (128 * (self.proj_size // input_shape[-1]))),
trainable=True, trainable=True,
initializer=self.kernel_initializer) initializer=tf_utils.clone_initializer(self.kernel_initializer))
self.w3 = self.add_weight( self.w3 = self.add_weight(
name='w3', name='w3',
shape=(128 * (self.proj_size // input_shape[-1]), 128), shape=(128 * (self.proj_size // input_shape[-1]), 128),
trainable=True, trainable=True,
initializer=self.kernel_initializer) initializer=tf_utils.clone_initializer(self.kernel_initializer))
self.w4 = self.add_weight( self.w4 = self.add_weight(
name='w4', name='w4',
shape=(input_shape[-1] // 128, 128, input_shape[-1]), shape=(input_shape[-1] // 128, 128, input_shape[-1]),
trainable=True, trainable=True,
initializer=self.kernel_initializer) initializer=tf_utils.clone_initializer(self.kernel_initializer))
if self.use_bias: if self.use_bias:
self.bias = self.add_weight( self.bias = self.add_weight(
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers.tn_expand_condense import TNExpandCondense from official.nlp.modeling.layers.tn_expand_condense import TNExpandCondense
...@@ -100,7 +101,8 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer): ...@@ -100,7 +101,8 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get( self._attention_initializer = tf.keras.initializers.get(
attention_initializer) attention_initializer)
else: else:
self._attention_initializer = self._kernel_initializer self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
def build(self, input_shape): def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
...@@ -128,7 +130,6 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer): ...@@ -128,7 +130,6 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer):
"heads (%d)" % (hidden_size, self._num_heads)) "heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads) self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict( common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
...@@ -140,6 +141,7 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer): ...@@ -140,6 +141,7 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer):
dropout=self._attention_dropout_rate, dropout=self._attention_dropout_rate,
use_bias=self._use_bias, use_bias=self._use_bias,
kernel_initializer=self._attention_initializer, kernel_initializer=self._attention_initializer,
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="self_attention", name="self_attention",
**common_kwargs) **common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import attention from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import multi_channel_attention from official.nlp.modeling.layers import multi_channel_attention
from official.nlp.modeling.layers import transformer_encoder_block from official.nlp.modeling.layers import transformer_encoder_block
...@@ -226,7 +227,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -226,7 +227,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get( self._attention_initializer = tf.keras.initializers.get(
attention_initializer) attention_initializer)
else: else:
self._attention_initializer = self._kernel_initializer self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
if self.multi_channel_cross_attention: if self.multi_channel_cross_attention:
self._cross_attention_cls = multi_channel_attention.MultiChannelAttention self._cross_attention_cls = multi_channel_attention.MultiChannelAttention
else: else:
...@@ -244,7 +246,6 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -244,7 +246,6 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
"heads (%d)" % (hidden_size, self.num_attention_heads)) "heads (%d)" % (hidden_size, self.num_attention_heads))
self.attention_head_size = int(hidden_size) // self.num_attention_heads self.attention_head_size = int(hidden_size) // self.num_attention_heads
common_kwargs = dict( common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
...@@ -256,14 +257,17 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -256,14 +257,17 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
key_dim=self.attention_head_size, key_dim=self.attention_head_size,
dropout=self.attention_dropout_rate, dropout=self.attention_dropout_rate,
use_bias=self._use_bias, use_bias=self._use_bias,
kernel_initializer=self._attention_initializer, kernel_initializer=tf_utils.clone_initializer(
self._attention_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="self_attention", name="self_attention",
**common_kwargs) **common_kwargs)
self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense( self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="output", name="output",
**common_kwargs) **common_kwargs)
self.self_attention_dropout = tf.keras.layers.Dropout( self.self_attention_dropout = tf.keras.layers.Dropout(
...@@ -281,7 +285,9 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -281,7 +285,9 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
dropout=self.attention_dropout_rate, dropout=self.attention_dropout_rate,
output_shape=hidden_size, output_shape=hidden_size,
use_bias=self._use_bias, use_bias=self._use_bias,
kernel_initializer=self._attention_initializer, kernel_initializer=tf_utils.clone_initializer(
self._attention_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="attention/encdec", name="attention/encdec",
**common_kwargs) **common_kwargs)
...@@ -299,7 +305,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -299,7 +305,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, self.intermediate_size), output_shape=(None, self.intermediate_size),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="intermediate", name="intermediate",
**common_kwargs) **common_kwargs)
self.intermediate_activation_layer = tf.keras.layers.Activation( self.intermediate_activation_layer = tf.keras.layers.Activation(
...@@ -310,7 +317,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer): ...@@ -310,7 +317,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="output", name="output",
**common_kwargs) **common_kwargs)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate) self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import util from official.nlp.modeling.layers import util
...@@ -156,7 +157,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -156,7 +157,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get( self._attention_initializer = tf.keras.initializers.get(
attention_initializer) attention_initializer)
else: else:
self._attention_initializer = self._kernel_initializer self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
self._attention_axes = attention_axes self._attention_axes = attention_axes
if self._diff_q_kv_att_layer_norm and not self._norm_first: if self._diff_q_kv_att_layer_norm and not self._norm_first:
...@@ -188,8 +190,6 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -188,8 +190,6 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
last_output_shape = self._output_last_dim last_output_shape = self._output_last_dim
common_kwargs = dict( common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint, kernel_constraint=self._kernel_constraint,
...@@ -201,6 +201,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -201,6 +201,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
dropout=self._attention_dropout, dropout=self._attention_dropout,
use_bias=self._use_bias, use_bias=self._use_bias,
kernel_initializer=self._attention_initializer, kernel_initializer=self._attention_initializer,
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
attention_axes=self._attention_axes, attention_axes=self._attention_axes,
output_shape=self._output_last_dim, output_shape=self._output_last_dim,
name="self_attention", name="self_attention",
...@@ -227,7 +228,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -227,7 +228,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
einsum_equation, einsum_equation,
output_shape=(None, self._inner_dim), output_shape=(None, self._inner_dim),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="intermediate", name="intermediate",
**common_kwargs) **common_kwargs)
policy = tf.keras.mixed_precision.global_policy() policy = tf.keras.mixed_precision.global_policy()
...@@ -245,7 +247,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer): ...@@ -245,7 +247,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
output_shape=(None, last_output_shape), output_shape=(None, last_output_shape),
bias_axes="d", bias_axes="d",
name="output", name="output",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout) self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability. # Use float32 in layernorm for numeric stability.
......
...@@ -19,6 +19,7 @@ from absl import logging ...@@ -19,6 +19,7 @@ from absl import logging
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import attention from official.nlp.modeling.layers import attention
...@@ -127,8 +128,6 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -127,8 +128,6 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._attention_head_size = int(hidden_size // self._num_heads) self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict( common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer, bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer, activity_regularizer=self._activity_regularizer,
...@@ -145,6 +144,9 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -145,6 +144,9 @@ class TransformerScaffold(tf.keras.layers.Layer):
return instance_or_cls(**config) return instance_or_cls(**config)
default_attention_cfg = { default_attention_cfg = {
"kernel_initializer": tf_utils.clone_initializer(
self._kernel_initializer),
"bias_initializer": tf_utils.clone_initializer(self._bias_initializer),
"num_heads": self._num_heads, "num_heads": self._num_heads,
"key_dim": self._attention_head_size, "key_dim": self._attention_head_size,
"dropout": self._attention_dropout_rate, "dropout": self._attention_dropout_rate,
...@@ -158,6 +160,10 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -158,6 +160,10 @@ class TransformerScaffold(tf.keras.layers.Layer):
if self._feedforward_cls is not None: if self._feedforward_cls is not None:
default_feedforward_cfg = { default_feedforward_cfg = {
"kernel_initializer": tf_utils.clone_initializer(
self._kernel_initializer),
"bias_initializer": tf_utils.clone_initializer(
self._bias_initializer),
"intermediate_size": self._intermediate_size, "intermediate_size": self._intermediate_size,
"intermediate_activation": self._intermediate_activation, "intermediate_activation": self._intermediate_activation,
"dropout": self._dropout_rate, "dropout": self._dropout_rate,
...@@ -189,6 +195,9 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -189,6 +195,9 @@ class TransformerScaffold(tf.keras.layers.Layer):
output_shape=(None, self._intermediate_size), output_shape=(None, self._intermediate_size),
bias_axes="d", bias_axes="d",
name="intermediate", name="intermediate",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
policy = tf.keras.mixed_precision.global_policy() policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16": if policy.name == "mixed_bfloat16":
...@@ -203,6 +212,9 @@ class TransformerScaffold(tf.keras.layers.Layer): ...@@ -203,6 +212,9 @@ class TransformerScaffold(tf.keras.layers.Layer):
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
name="output", name="output",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs) **common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
......
...@@ -18,6 +18,7 @@ from absl import logging ...@@ -18,6 +18,7 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import relative_attention from official.nlp.modeling.layers import relative_attention
...@@ -148,7 +149,7 @@ class TransformerXLBlock(tf.keras.layers.Layer): ...@@ -148,7 +149,7 @@ class TransformerXLBlock(tf.keras.layers.Layer):
value_dim=self._head_size, value_dim=self._head_size,
dropout=self._attention_dropout_rate, dropout=self._attention_dropout_rate,
use_bias=False, use_bias=False,
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="rel_attn") name="rel_attn")
self._attention_dropout = tf.keras.layers.Dropout( self._attention_dropout = tf.keras.layers.Dropout(
rate=self._attention_dropout_rate) rate=self._attention_dropout_rate)
...@@ -161,7 +162,7 @@ class TransformerXLBlock(tf.keras.layers.Layer): ...@@ -161,7 +162,7 @@ class TransformerXLBlock(tf.keras.layers.Layer):
"abc,cd->abd", "abc,cd->abd",
output_shape=(None, self._inner_size), output_shape=(None, self._inner_size),
bias_axes="d", bias_axes="d",
kernel_initializer=self._kernel_initializer, kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="inner") name="inner")
self._inner_activation_layer = tf.keras.layers.Activation( self._inner_activation_layer = tf.keras.layers.Activation(
...@@ -173,7 +174,7 @@ class TransformerXLBlock(tf.keras.layers.Layer): ...@@ -173,7 +174,7 @@ class TransformerXLBlock(tf.keras.layers.Layer):
output_shape=(None, hidden_size), output_shape=(None, hidden_size),
bias_axes="d", bias_axes="d",
name="output", name="output",
kernel_initializer=self._kernel_initializer) kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer))
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
self._output_layer_norm = tf.keras.layers.LayerNormalization( self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm", name="output_layer_norm",
...@@ -398,17 +399,17 @@ class TransformerXL(tf.keras.layers.Layer): ...@@ -398,17 +399,17 @@ class TransformerXL(tf.keras.layers.Layer):
"content_attention_bias", "content_attention_bias",
shape=attention_bias_shape, shape=attention_bias_shape,
dtype=tf.float32, dtype=tf.float32,
initializer=self._initializer) initializer=tf_utils.clone_initializer(self._initializer))
self.positional_attention_bias = self.add_weight( self.positional_attention_bias = self.add_weight(
"positional_attention_bias", "positional_attention_bias",
shape=attention_bias_shape, shape=attention_bias_shape,
dtype=tf.float32, dtype=tf.float32,
initializer=self._initializer) initializer=tf_utils.clone_initializer(self._initializer))
self.segment_attention_bias = self.add_weight( self.segment_attention_bias = self.add_weight(
"segment_attention_bias", "segment_attention_bias",
shape=attention_bias_shape, shape=attention_bias_shape,
dtype=tf.float32, dtype=tf.float32,
initializer=self._initializer) initializer=tf_utils.clone_initializer(self._initializer))
self.transformer_xl_layers = [] self.transformer_xl_layers = []
for i in range(self._num_layers): for i in range(self._num_layers):
......
...@@ -22,6 +22,7 @@ from absl import logging ...@@ -22,6 +22,7 @@ from absl import logging
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling import networks from official.nlp.modeling import networks
...@@ -102,7 +103,7 @@ class BertPretrainer(tf.keras.Model): ...@@ -102,7 +103,7 @@ class BertPretrainer(tf.keras.Model):
masked_lm = layers.MaskedLM( masked_lm = layers.MaskedLM(
embedding_table=embedding_table, embedding_table=embedding_table,
activation=activation, activation=activation,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
output=output, output=output,
name='cls/predictions') name='cls/predictions')
lm_outputs = masked_lm( lm_outputs = masked_lm(
...@@ -111,7 +112,7 @@ class BertPretrainer(tf.keras.Model): ...@@ -111,7 +112,7 @@ class BertPretrainer(tf.keras.Model):
classification = networks.Classification( classification = networks.Classification(
input_width=cls_output.shape[-1], input_width=cls_output.shape[-1],
num_classes=num_classes, num_classes=num_classes,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
output=output, output=output,
name='classification') name='classification')
sentence_outputs = classification(cls_output) sentence_outputs = classification(cls_output)
......
...@@ -96,21 +96,22 @@ class ElectraPretrainer(tf.keras.Model): ...@@ -96,21 +96,22 @@ class ElectraPretrainer(tf.keras.Model):
self.masked_lm = layers.MaskedLM( self.masked_lm = layers.MaskedLM(
embedding_table=generator_network.get_embedding_table(), embedding_table=generator_network.get_embedding_table(),
activation=mlm_activation, activation=mlm_activation,
initializer=mlm_initializer, initializer=tf_utils.clone_initializer(mlm_initializer),
output=output_type, output=output_type,
name='generator_masked_lm') name='generator_masked_lm')
self.classification = layers.ClassificationHead( self.classification = layers.ClassificationHead(
inner_dim=generator_network.get_config()['hidden_size'], inner_dim=generator_network.get_config()['hidden_size'],
num_classes=num_classes, num_classes=num_classes,
initializer=mlm_initializer, initializer=tf_utils.clone_initializer(mlm_initializer),
name='generator_classification_head') name='generator_classification_head')
self.discriminator_projection = tf.keras.layers.Dense( self.discriminator_projection = tf.keras.layers.Dense(
units=discriminator_network.get_config()['hidden_size'], units=discriminator_network.get_config()['hidden_size'],
activation=mlm_activation, activation=mlm_activation,
kernel_initializer=mlm_initializer, kernel_initializer=tf_utils.clone_initializer(mlm_initializer),
name='discriminator_projection_head') name='discriminator_projection_head')
self.discriminator_head = tf.keras.layers.Dense( self.discriminator_head = tf.keras.layers.Dense(
units=1, kernel_initializer=mlm_initializer) units=1,
kernel_initializer=tf_utils.clone_initializer(mlm_initializer))
def call(self, inputs): def call(self, inputs):
"""ELECTRA forward pass. """ELECTRA forward pass.
......
...@@ -55,6 +55,7 @@ class Module(tf.Module): ...@@ -55,6 +55,7 @@ class Module(tf.Module):
initializer: Initializer, initializer: Initializer,
dtype: tf.DType = tf.float32, dtype: tf.DType = tf.float32,
**kwargs): **kwargs):
initializer = tf_utils.clone_initializer(initializer)
return tf.Variable(initializer(shape, dtype=dtype, **kwargs), name=name) return tf.Variable(initializer(shape, dtype=dtype, **kwargs), name=name)
def read_variable(self, def read_variable(self,
...@@ -588,7 +589,8 @@ class MultiHeadAttention(Module): ...@@ -588,7 +589,8 @@ class MultiHeadAttention(Module):
init_std_rescaling = tf.math.sqrt(tf.cast(self.d_kv, dtype=self.dtype)) init_std_rescaling = tf.math.sqrt(tf.cast(self.d_kv, dtype=self.dtype))
query_w_init = ( query_w_init = (
lambda *args, **kwargs: ( # pylint: disable=g-long-lambda lambda *args, **kwargs: ( # pylint: disable=g-long-lambda
weight_initializer(*args, **kwargs) / init_std_rescaling)) tf_utils.clone_initializer(weight_initializer)(
*args, **kwargs) / init_std_rescaling))
self.q = Linear3D( self.q = Linear3D(
self.d_model, self.d_model,
self.d_kv, self.d_kv,
......
...@@ -18,6 +18,7 @@ import collections ...@@ -18,6 +18,7 @@ import collections
import tensorflow as tf import tensorflow as tf
from official.modeling import activations from official.modeling import activations
from official.modeling import tf_utils
from official.nlp.modeling import layers from official.nlp.modeling import layers
...@@ -92,13 +93,13 @@ class AlbertEncoder(tf.keras.Model): ...@@ -92,13 +93,13 @@ class AlbertEncoder(tf.keras.Model):
embedding_layer = layers.OnDeviceEmbedding( embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings') name='word_embeddings')
word_embeddings = embedding_layer(word_ids) word_embeddings = embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity. # Always uses dynamic slicing for simplicity.
position_embedding_layer = layers.PositionEmbedding( position_embedding_layer = layers.PositionEmbedding(
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
max_length=max_sequence_length, max_length=max_sequence_length,
name='position_embedding') name='position_embedding')
position_embeddings = position_embedding_layer(word_embeddings) position_embeddings = position_embedding_layer(word_embeddings)
...@@ -107,7 +108,7 @@ class AlbertEncoder(tf.keras.Model): ...@@ -107,7 +108,7 @@ class AlbertEncoder(tf.keras.Model):
layers.OnDeviceEmbedding( layers.OnDeviceEmbedding(
vocab_size=type_vocab_size, vocab_size=type_vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True, use_one_hot=True,
name='type_embeddings')(type_ids)) name='type_embeddings')(type_ids))
...@@ -127,7 +128,7 @@ class AlbertEncoder(tf.keras.Model): ...@@ -127,7 +128,7 @@ class AlbertEncoder(tf.keras.Model):
'...x,xy->...y', '...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes='y', bias_axes='y',
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection')( name='embedding_projection')(
embeddings) embeddings)
...@@ -139,7 +140,7 @@ class AlbertEncoder(tf.keras.Model): ...@@ -139,7 +140,7 @@ class AlbertEncoder(tf.keras.Model):
inner_activation=activation, inner_activation=activation,
output_dropout=dropout_rate, output_dropout=dropout_rate,
attention_dropout=attention_dropout_rate, attention_dropout=attention_dropout_rate,
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='transformer') name='transformer')
encoder_outputs = [] encoder_outputs = []
for _ in range(num_layers): for _ in range(num_layers):
...@@ -153,7 +154,7 @@ class AlbertEncoder(tf.keras.Model): ...@@ -153,7 +154,7 @@ class AlbertEncoder(tf.keras.Model):
cls_output = tf.keras.layers.Dense( cls_output = tf.keras.layers.Dense(
units=hidden_size, units=hidden_size,
activation='tanh', activation='tanh',
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='pooler_transform')( name='pooler_transform')(
first_token_tensor) first_token_tensor)
if dict_outputs: if dict_outputs:
......
...@@ -19,6 +19,7 @@ from typing import Any, Callable, Optional, Union ...@@ -19,6 +19,7 @@ from typing import Any, Callable, Optional, Union
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers from official.nlp.modeling import layers
...@@ -122,20 +123,20 @@ class BertEncoderV2(tf.keras.layers.Layer): ...@@ -122,20 +123,20 @@ class BertEncoderV2(tf.keras.layers.Layer):
self._embedding_layer = layers.OnDeviceEmbedding( self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings') name='word_embeddings')
else: else:
self._embedding_layer = embedding_layer self._embedding_layer = embedding_layer
self._position_embedding_layer = layers.PositionEmbedding( self._position_embedding_layer = layers.PositionEmbedding(
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
max_length=max_sequence_length, max_length=max_sequence_length,
name='position_embedding') name='position_embedding')
self._type_embedding_layer = layers.OnDeviceEmbedding( self._type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size, vocab_size=type_vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True, use_one_hot=True,
name='type_embeddings') name='type_embeddings')
...@@ -153,7 +154,7 @@ class BertEncoderV2(tf.keras.layers.Layer): ...@@ -153,7 +154,7 @@ class BertEncoderV2(tf.keras.layers.Layer):
'...x,xy->...y', '...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes='y', bias_axes='y',
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection') name='embedding_projection')
self._transformer_layers = [] self._transformer_layers = []
...@@ -168,14 +169,14 @@ class BertEncoderV2(tf.keras.layers.Layer): ...@@ -168,14 +169,14 @@ class BertEncoderV2(tf.keras.layers.Layer):
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
norm_first=norm_first, norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None, output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='transformer/layer_%d' % i) name='transformer/layer_%d' % i)
self._transformer_layers.append(layer) self._transformer_layers.append(layer)
self._pooler_layer = tf.keras.layers.Dense( self._pooler_layer = tf.keras.layers.Dense(
units=hidden_size, units=hidden_size,
activation='tanh', activation='tanh',
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='pooler_transform') name='pooler_transform')
self._config = { self._config = {
...@@ -409,7 +410,7 @@ class BertEncoder(tf.keras.Model): ...@@ -409,7 +410,7 @@ class BertEncoder(tf.keras.Model):
embedding_layer_inst = layers.OnDeviceEmbedding( embedding_layer_inst = layers.OnDeviceEmbedding(
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings') name='word_embeddings')
else: else:
embedding_layer_inst = embedding_layer embedding_layer_inst = embedding_layer
...@@ -417,14 +418,14 @@ class BertEncoder(tf.keras.Model): ...@@ -417,14 +418,14 @@ class BertEncoder(tf.keras.Model):
# Always uses dynamic slicing for simplicity. # Always uses dynamic slicing for simplicity.
position_embedding_layer = layers.PositionEmbedding( position_embedding_layer = layers.PositionEmbedding(
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
max_length=max_sequence_length, max_length=max_sequence_length,
name='position_embedding') name='position_embedding')
position_embeddings = position_embedding_layer(word_embeddings) position_embeddings = position_embedding_layer(word_embeddings)
type_embedding_layer = layers.OnDeviceEmbedding( type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size, vocab_size=type_vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True, use_one_hot=True,
name='type_embeddings') name='type_embeddings')
type_embeddings = type_embedding_layer(type_ids) type_embeddings = type_embedding_layer(type_ids)
...@@ -445,7 +446,7 @@ class BertEncoder(tf.keras.Model): ...@@ -445,7 +446,7 @@ class BertEncoder(tf.keras.Model):
'...x,xy->...y', '...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes='y', bias_axes='y',
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection') name='embedding_projection')
embeddings = embedding_projection(embeddings) embeddings = embedding_projection(embeddings)
else: else:
...@@ -468,7 +469,7 @@ class BertEncoder(tf.keras.Model): ...@@ -468,7 +469,7 @@ class BertEncoder(tf.keras.Model):
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
norm_first=norm_first, norm_first=norm_first,
output_range=transformer_output_range, output_range=transformer_output_range,
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='transformer/layer_%d' % i) name='transformer/layer_%d' % i)
transformer_layers.append(layer) transformer_layers.append(layer)
data = layer([data, attention_mask]) data = layer([data, attention_mask])
...@@ -482,7 +483,7 @@ class BertEncoder(tf.keras.Model): ...@@ -482,7 +483,7 @@ class BertEncoder(tf.keras.Model):
pooler_layer = tf.keras.layers.Dense( pooler_layer = tf.keras.layers.Dense(
units=hidden_size, units=hidden_size,
activation='tanh', activation='tanh',
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='pooler_transform') name='pooler_transform')
cls_output = pooler_layer(first_token_tensor) cls_output = pooler_layer(first_token_tensor)
......
...@@ -21,6 +21,7 @@ from absl import logging ...@@ -21,6 +21,7 @@ from absl import logging
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers from official.nlp.modeling import layers
...@@ -153,14 +154,14 @@ class EncoderScaffold(tf.keras.Model): ...@@ -153,14 +154,14 @@ class EncoderScaffold(tf.keras.Model):
embedding_layer = layers.OnDeviceEmbedding( embedding_layer = layers.OnDeviceEmbedding(
vocab_size=embedding_cfg['vocab_size'], vocab_size=embedding_cfg['vocab_size'],
embedding_width=embedding_cfg['hidden_size'], embedding_width=embedding_cfg['hidden_size'],
initializer=embedding_cfg['initializer'], initializer=tf_utils.clone_initializer(embedding_cfg['initializer']),
name='word_embeddings') name='word_embeddings')
word_embeddings = embedding_layer(word_ids) word_embeddings = embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity. # Always uses dynamic slicing for simplicity.
position_embedding_layer = layers.PositionEmbedding( position_embedding_layer = layers.PositionEmbedding(
initializer=embedding_cfg['initializer'], initializer=tf_utils.clone_initializer(embedding_cfg['initializer']),
max_length=embedding_cfg['max_seq_length'], max_length=embedding_cfg['max_seq_length'],
name='position_embedding') name='position_embedding')
position_embeddings = position_embedding_layer(word_embeddings) position_embeddings = position_embedding_layer(word_embeddings)
...@@ -168,7 +169,7 @@ class EncoderScaffold(tf.keras.Model): ...@@ -168,7 +169,7 @@ class EncoderScaffold(tf.keras.Model):
type_embedding_layer = layers.OnDeviceEmbedding( type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=embedding_cfg['type_vocab_size'], vocab_size=embedding_cfg['type_vocab_size'],
embedding_width=embedding_cfg['hidden_size'], embedding_width=embedding_cfg['hidden_size'],
initializer=embedding_cfg['initializer'], initializer=tf_utils.clone_initializer(embedding_cfg['initializer']),
use_one_hot=True, use_one_hot=True,
name='type_embeddings') name='type_embeddings')
type_embeddings = type_embedding_layer(type_ids) type_embeddings = type_embedding_layer(type_ids)
...@@ -243,6 +244,8 @@ class EncoderScaffold(tf.keras.Model): ...@@ -243,6 +244,8 @@ class EncoderScaffold(tf.keras.Model):
# like this will create a SliceOpLambda layer. This is better than a Lambda # like this will create a SliceOpLambda layer. This is better than a Lambda
# layer with Python code, because that is fundamentally less portable. # layer with Python code, because that is fundamentally less portable.
first_token_tensor = last_layer_output[:, 0, :] first_token_tensor = last_layer_output[:, 0, :]
pooler_layer_initializer = tf.keras.initializers.get(
pooler_layer_initializer)
pooler_layer = tf.keras.layers.Dense( pooler_layer = tf.keras.layers.Dense(
units=pooled_output_dim, units=pooled_output_dim,
activation='tanh', activation='tanh',
...@@ -303,7 +306,8 @@ class EncoderScaffold(tf.keras.Model): ...@@ -303,7 +306,8 @@ class EncoderScaffold(tf.keras.Model):
config_dict = { config_dict = {
'num_hidden_instances': self._num_hidden_instances, 'num_hidden_instances': self._num_hidden_instances,
'pooled_output_dim': self._pooled_output_dim, 'pooled_output_dim': self._pooled_output_dim,
'pooler_layer_initializer': self._pooler_layer_initializer, 'pooler_layer_initializer': tf.keras.initializers.serialize(
self._pooler_layer_initializer),
'embedding_cls': self._embedding_network, 'embedding_cls': self._embedding_network,
'embedding_cfg': self._embedding_cfg, 'embedding_cfg': self._embedding_cfg,
'layer_norm_before_pooling': self._layer_norm_before_pooling, 'layer_norm_before_pooling': self._layer_norm_before_pooling,
......
...@@ -20,6 +20,7 @@ from absl import logging ...@@ -20,6 +20,7 @@ from absl import logging
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers from official.nlp.modeling import layers
_Initializer = Union[str, tf.keras.initializers.Initializer] _Initializer = Union[str, tf.keras.initializers.Initializer]
...@@ -265,20 +266,20 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -265,20 +266,20 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
self._embedding_layer = layers.OnDeviceEmbedding( self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings') name='word_embeddings')
else: else:
self._embedding_layer = embedding_layer self._embedding_layer = embedding_layer
self._position_embedding_layer = layers.PositionEmbedding( self._position_embedding_layer = layers.PositionEmbedding(
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
max_length=max_sequence_length, max_length=max_sequence_length,
name='position_embedding') name='position_embedding')
self._type_embedding_layer = layers.OnDeviceEmbedding( self._type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size, vocab_size=type_vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True, use_one_hot=True,
name='type_embeddings') name='type_embeddings')
...@@ -296,7 +297,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -296,7 +297,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
'...x,xy->...y', '...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes='y', bias_axes='y',
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection') name='embedding_projection')
self._transformer_layers = [] self._transformer_layers = []
...@@ -316,7 +317,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -316,7 +317,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
norm_first=norm_first, norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None, output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
share_rezero=share_rezero, share_rezero=share_rezero,
name='transformer/layer_%d' % i) name='transformer/layer_%d' % i)
self._transformer_layers.append(layer) self._transformer_layers.append(layer)
...@@ -324,7 +325,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -324,7 +325,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
self._pooler_layer = tf.keras.layers.Dense( self._pooler_layer = tf.keras.layers.Dense(
units=hidden_size, units=hidden_size,
activation='tanh', activation='tanh',
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='pooler_transform') name='pooler_transform')
if isinstance(pool_stride, int): if isinstance(pool_stride, int):
# TODO(b/197133196): Pooling layer can be shared. # TODO(b/197133196): Pooling layer can be shared.
...@@ -342,9 +343,6 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -342,9 +343,6 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
# TODO(b/203665205): unpool_length should be implemented. # TODO(b/203665205): unpool_length should be implemented.
if unpool_length != 0: if unpool_length != 0:
raise ValueError('unpool_length is not supported by truncated_avg now.') raise ValueError('unpool_length is not supported by truncated_avg now.')
# Compute the attention masks and pooling transforms.
self._pooling_transforms = _create_truncated_avg_transforms(
max_sequence_length, pool_strides)
else: else:
raise ValueError('pool_type not supported.') raise ValueError('pool_type not supported.')
...@@ -358,6 +356,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -358,6 +356,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
name='att_input_pool_layer') name='att_input_pool_layer')
self._att_input_pool_layers.append(att_input_pool_layer) self._att_input_pool_layers.append(att_input_pool_layer)
self._max_sequence_length = max_sequence_length
self._pool_strides = pool_strides # This is a list here. self._pool_strides = pool_strides # This is a list here.
self._unpool_length = unpool_length self._unpool_length = unpool_length
self._pool_type = pool_type self._pool_type = pool_type
...@@ -488,8 +487,13 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -488,8 +487,13 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
axes=[1, 2]) axes=[1, 2])
encoder_outputs.append(x) encoder_outputs.append(x)
elif self._pool_type == _TRUNCATED_AVG: elif self._pool_type == _TRUNCATED_AVG:
# Compute the attention masks and pooling transforms.
# Note we do not compute this in __init__ due to inference converter issue
# b/215659399.
pooling_transforms = _create_truncated_avg_transforms(
self._max_sequence_length, self._pool_strides)
attention_masks = _create_truncated_avg_masks(mask, self._pool_strides, attention_masks = _create_truncated_avg_masks(mask, self._pool_strides,
self._pooling_transforms) pooling_transforms)
for i, layer in enumerate(self._transformer_layers): for i, layer in enumerate(self._transformer_layers):
attention_mask = attention_masks[i] attention_mask = attention_masks[i]
# Bypass no pooling cases. # Bypass no pooling cases.
...@@ -500,7 +504,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -500,7 +504,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
'BFD,FT->BTD', 'BFD,FT->BTD',
tf.cast(x[:, self._unpool_length:, :], _get_policy_dtype() tf.cast(x[:, self._unpool_length:, :], _get_policy_dtype()
), # extra casting for faster mixed computation. ), # extra casting for faster mixed computation.
self._pooling_transforms[i]) pooling_transforms[i])
query_inputs = tf.concat( query_inputs = tf.concat(
values=(tf.cast( values=(tf.cast(
x[:, :self._unpool_length, :], x[:, :self._unpool_length, :],
......
...@@ -97,13 +97,13 @@ class PackedSequenceEmbedding(tf.keras.Model): ...@@ -97,13 +97,13 @@ class PackedSequenceEmbedding(tf.keras.Model):
embedding_layer = layers.OnDeviceEmbedding( embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings') name='word_embeddings')
word_embeddings = embedding_layer(word_ids) word_embeddings = embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity. # Always uses dynamic slicing for simplicity.
position_embedding_layer = PositionEmbeddingWithSubSeqMask( position_embedding_layer = PositionEmbeddingWithSubSeqMask(
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
use_dynamic_slicing=True, use_dynamic_slicing=True,
max_sequence_length=max_seq_length, max_sequence_length=max_seq_length,
name='position_embedding') name='position_embedding')
...@@ -114,7 +114,7 @@ class PackedSequenceEmbedding(tf.keras.Model): ...@@ -114,7 +114,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
layers.OnDeviceEmbedding( layers.OnDeviceEmbedding(
vocab_size=type_vocab_size, vocab_size=type_vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True, use_one_hot=True,
name='type_embeddings')(type_ids)) name='type_embeddings')(type_ids))
...@@ -132,7 +132,7 @@ class PackedSequenceEmbedding(tf.keras.Model): ...@@ -132,7 +132,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
'...x,xy->...y', '...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes=None, bias_axes=None,
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection')( name='embedding_projection')(
embeddings) embeddings)
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
import collections import collections
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
def _apply_paragraph_mask(logits, paragraph_mask): def _apply_paragraph_mask(logits, paragraph_mask):
"""Applies a position mask to calculated logits.""" """Applies a position mask to calculated logits."""
...@@ -156,12 +158,12 @@ class XLNetSpanLabeling(tf.keras.layers.Layer): ...@@ -156,12 +158,12 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
self._end_n_top = end_n_top self._end_n_top = end_n_top
self.start_logits_dense = tf.keras.layers.Dense( self.start_logits_dense = tf.keras.layers.Dense(
units=1, units=1,
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='predictions/transform/start_logits') name='predictions/transform/start_logits')
self.end_logits_inner_dense = tf.keras.layers.Dense( self.end_logits_inner_dense = tf.keras.layers.Dense(
units=input_width, units=input_width,
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
activation=activation, activation=activation,
name='predictions/transform/end_logits/inner') name='predictions/transform/end_logits/inner')
self.end_logits_layer_norm = tf.keras.layers.LayerNormalization( self.end_logits_layer_norm = tf.keras.layers.LayerNormalization(
...@@ -169,18 +171,18 @@ class XLNetSpanLabeling(tf.keras.layers.Layer): ...@@ -169,18 +171,18 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
name='predictions/transform/end_logits/layernorm') name='predictions/transform/end_logits/layernorm')
self.end_logits_output_dense = tf.keras.layers.Dense( self.end_logits_output_dense = tf.keras.layers.Dense(
units=1, units=1,
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='predictions/transform/end_logits/output') name='predictions/transform/end_logits/output')
self.answer_logits_inner = tf.keras.layers.Dense( self.answer_logits_inner = tf.keras.layers.Dense(
units=input_width, units=input_width,
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
activation=activation, activation=activation,
name='predictions/transform/answer_logits/inner') name='predictions/transform/answer_logits/inner')
self.answer_logits_dropout = tf.keras.layers.Dropout(rate=dropout_rate) self.answer_logits_dropout = tf.keras.layers.Dropout(rate=dropout_rate)
self.answer_logits_output = tf.keras.layers.Dense( self.answer_logits_output = tf.keras.layers.Dense(
units=1, units=1,
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
use_bias=False, use_bias=False,
name='predictions/transform/answer_logits/output') name='predictions/transform/answer_logits/output')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment