Unverified Commit aef943ed authored by SunJong Park's avatar SunJong Park Committed by GitHub
Browse files

Merge branch 'tensorflow:master' into master

parents 67ad909d 930abe21
......@@ -14,6 +14,8 @@
"""Keras-based TransformerEncoder block layer."""
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import reuse_attention as attention
......@@ -131,7 +133,8 @@ class ReuseTransformer(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
self._attention_axes = attention_axes
def build(self, input_shape):
......@@ -156,7 +159,6 @@ class ReuseTransformer(tf.keras.layers.Layer):
else:
self._attention_head_size = self._head_size
common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
......@@ -168,6 +170,7 @@ class ReuseTransformer(tf.keras.layers.Layer):
dropout=self._attention_dropout,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
attention_axes=self._attention_axes,
reuse_attention=self._reuse_attention,
use_relative_pe=self._use_relative_pe,
......@@ -188,7 +191,8 @@ class ReuseTransformer(tf.keras.layers.Layer):
einsum_equation,
output_shape=(None, self._inner_dim),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="intermediate",
**common_kwargs)
policy = tf.keras.mixed_precision.global_policy()
......@@ -206,7 +210,8 @@ class ReuseTransformer(tf.keras.layers.Layer):
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability.
......
......@@ -18,6 +18,7 @@
import gin
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import util
......@@ -121,8 +122,6 @@ class ReZeroTransformer(tf.keras.layers.Layer):
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
......@@ -133,6 +132,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
key_dim=self._attention_head_size,
dropout=self._attention_dropout_rate,
name="self_attention",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm:
......@@ -149,6 +150,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
output_shape=(None, self._intermediate_size),
bias_axes="d",
name="intermediate",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16":
......@@ -163,6 +166,8 @@ class ReZeroTransformer(tf.keras.layers.Layer):
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
if self._use_layer_norm:
......
......@@ -84,11 +84,10 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
self.w = self.layer.kernel
self.w_shape = self.w.shape.as_list()
self.uv_initializer = tf.initializers.random_normal()
self.v = self.add_weight(
shape=(1, np.prod(self.w_shape[:-1])),
initializer=self.uv_initializer,
initializer=tf.initializers.random_normal(),
trainable=False,
name='v',
dtype=self.dtype,
......@@ -96,7 +95,7 @@ class SpectralNormalization(tf.keras.layers.Wrapper):
self.u = self.add_weight(
shape=(1, self.w_shape[-1]),
initializer=self.uv_initializer,
initializer=tf.initializers.random_normal(),
trainable=False,
name='u',
dtype=self.dtype,
......@@ -197,7 +196,8 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
super(SpectralNormalizationConv2D, self).__init__(layer, **kwargs)
def build(self, input_shape):
self.layer.build(input_shape)
if not self.layer.built:
self.layer.build(input_shape)
self.layer.kernel._aggregation = self.aggregation # pylint: disable=protected-access
self._dtype = self.layer.kernel.dtype
......@@ -221,11 +221,10 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
self.in_shape = (uv_dim, in_height, in_width, in_channel)
self.out_shape = (uv_dim, out_height, out_width, out_channel)
self.uv_initializer = tf.initializers.random_normal()
self.v = self.add_weight(
shape=self.in_shape,
initializer=self.uv_initializer,
initializer=tf.initializers.random_normal(),
trainable=False,
name='v',
dtype=self.dtype,
......@@ -233,7 +232,7 @@ class SpectralNormalizationConv2D(tf.keras.layers.Wrapper):
self.u = self.add_weight(
shape=self.out_shape,
initializer=self.uv_initializer,
initializer=tf.initializers.random_normal(),
trainable=False,
name='u',
dtype=self.dtype,
......
......@@ -66,7 +66,7 @@ class NormalizationTest(tf.test.TestCase, parameterized.TestCase):
spectral_norm_computed = _compute_spectral_norm(normalized_kernel)
spectral_norm_expected = self.norm_multiplier
self.assertAllClose(
spectral_norm_computed, spectral_norm_expected, atol=5e-2)
spectral_norm_computed, spectral_norm_expected, atol=1e-1)
# Test that the normalized layer is K-Lipschitz. In particular, if the layer
# is a function f, then ||f(x1) - f(x2)||_2 <= K * ||(x1 - x2)||_2, where K
......
......@@ -20,6 +20,8 @@ import string
import gin
import tensorflow as tf
from official.modeling import tf_utils
_CHR_IDX = string.ascii_lowercase
......@@ -87,7 +89,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention):
self._pre_softmax_weight = self.add_weight(
"pre_softmax_weight",
shape=(self._num_heads, self._num_heads),
initializer=self._kernel_initializer,
initializer=tf_utils.clone_initializer(self._kernel_initializer),
regularizer=self._kernel_regularizer,
constraint=self._kernel_constraint,
dtype=self.dtype,
......@@ -95,7 +97,7 @@ class TalkingHeadsAttention(tf.keras.layers.MultiHeadAttention):
self._post_softmax_weight = self.add_weight(
"post_softmax_weight",
shape=(self._num_heads, self._num_heads),
initializer=self._kernel_initializer,
initializer=tf_utils.clone_initializer(self._kernel_initializer),
regularizer=self._kernel_regularizer,
constraint=self._kernel_constraint,
dtype=self.dtype,
......
......@@ -17,6 +17,8 @@
from typing import List, Optional, Text, Any, Dict
import tensorflow as tf
from official.modeling import tf_utils
Layer = tf.keras.layers.Layer
activations = tf.keras.activations
initializers = tf.keras.initializers
......@@ -98,24 +100,24 @@ class TNExpandCondense(Layer):
name='w1',
shape=(input_shape[-1], input_shape[-1]),
trainable=True,
initializer=self.kernel_initializer)
initializer=tf_utils.clone_initializer(self.kernel_initializer))
self.w2 = self.add_weight(
name='w2',
shape=(128, (128 * (self.proj_size // input_shape[-1]))),
trainable=True,
initializer=self.kernel_initializer)
initializer=tf_utils.clone_initializer(self.kernel_initializer))
self.w3 = self.add_weight(
name='w3',
shape=(128 * (self.proj_size // input_shape[-1]), 128),
trainable=True,
initializer=self.kernel_initializer)
initializer=tf_utils.clone_initializer(self.kernel_initializer))
self.w4 = self.add_weight(
name='w4',
shape=(input_shape[-1] // 128, 128, input_shape[-1]),
trainable=True,
initializer=self.kernel_initializer)
initializer=tf_utils.clone_initializer(self.kernel_initializer))
if self.use_bias:
self.bias = self.add_weight(
......
......@@ -19,6 +19,7 @@
import gin
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers.tn_expand_condense import TNExpandCondense
......@@ -100,7 +101,8 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
def build(self, input_shape):
input_tensor = input_shape[0] if len(input_shape) == 2 else input_shape
......@@ -128,7 +130,6 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer):
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
......@@ -140,6 +141,7 @@ class TNTransformerExpandCondense(tf.keras.layers.Layer):
dropout=self._attention_dropout_rate,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="self_attention",
**common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
......
......@@ -18,6 +18,7 @@
import gin
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import attention
from official.nlp.modeling.layers import multi_channel_attention
from official.nlp.modeling.layers import transformer_encoder_block
......@@ -226,7 +227,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
if self.multi_channel_cross_attention:
self._cross_attention_cls = multi_channel_attention.MultiChannelAttention
else:
......@@ -244,7 +246,6 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
"heads (%d)" % (hidden_size, self.num_attention_heads))
self.attention_head_size = int(hidden_size) // self.num_attention_heads
common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
......@@ -256,14 +257,17 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
key_dim=self.attention_head_size,
dropout=self.attention_dropout_rate,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
kernel_initializer=tf_utils.clone_initializer(
self._attention_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="self_attention",
**common_kwargs)
self.self_attention_output_dense = tf.keras.layers.experimental.EinsumDense(
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="output",
**common_kwargs)
self.self_attention_dropout = tf.keras.layers.Dropout(
......@@ -281,7 +285,9 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
dropout=self.attention_dropout_rate,
output_shape=hidden_size,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
kernel_initializer=tf_utils.clone_initializer(
self._attention_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="attention/encdec",
**common_kwargs)
......@@ -299,7 +305,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
"abc,cd->abd",
output_shape=(None, self.intermediate_size),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="intermediate",
**common_kwargs)
self.intermediate_activation_layer = tf.keras.layers.Activation(
......@@ -310,7 +317,8 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
"abc,cd->abd",
output_shape=(None, hidden_size),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="output",
**common_kwargs)
self.output_dropout = tf.keras.layers.Dropout(rate=self.dropout_rate)
......
......@@ -17,6 +17,7 @@
from absl import logging
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import util
......@@ -156,7 +157,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
self._attention_initializer = tf_utils.clone_initializer(
self._kernel_initializer)
self._attention_axes = attention_axes
if self._diff_q_kv_att_layer_norm and not self._norm_first:
......@@ -188,8 +190,6 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
last_output_shape = self._output_last_dim
common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
......@@ -201,6 +201,7 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
dropout=self._attention_dropout,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
attention_axes=self._attention_axes,
output_shape=self._output_last_dim,
name="self_attention",
......@@ -227,7 +228,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
einsum_equation,
output_shape=(None, self._inner_dim),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
name="intermediate",
**common_kwargs)
policy = tf.keras.mixed_precision.global_policy()
......@@ -245,7 +247,8 @@ class TransformerEncoderBlock(tf.keras.layers.Layer):
output_shape=(None, last_output_shape),
bias_axes="d",
name="output",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability.
......
......@@ -19,6 +19,7 @@ from absl import logging
import gin
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import attention
......@@ -127,8 +128,6 @@ class TransformerScaffold(tf.keras.layers.Layer):
self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
......@@ -145,6 +144,9 @@ class TransformerScaffold(tf.keras.layers.Layer):
return instance_or_cls(**config)
default_attention_cfg = {
"kernel_initializer": tf_utils.clone_initializer(
self._kernel_initializer),
"bias_initializer": tf_utils.clone_initializer(self._bias_initializer),
"num_heads": self._num_heads,
"key_dim": self._attention_head_size,
"dropout": self._attention_dropout_rate,
......@@ -158,6 +160,10 @@ class TransformerScaffold(tf.keras.layers.Layer):
if self._feedforward_cls is not None:
default_feedforward_cfg = {
"kernel_initializer": tf_utils.clone_initializer(
self._kernel_initializer),
"bias_initializer": tf_utils.clone_initializer(
self._bias_initializer),
"intermediate_size": self._intermediate_size,
"intermediate_activation": self._intermediate_activation,
"dropout": self._dropout_rate,
......@@ -189,6 +195,9 @@ class TransformerScaffold(tf.keras.layers.Layer):
output_shape=(None, self._intermediate_size),
bias_axes="d",
name="intermediate",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16":
......@@ -203,6 +212,9 @@ class TransformerScaffold(tf.keras.layers.Layer):
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
kernel_initializer=tf_utils.clone_initializer(
self._kernel_initializer),
bias_initializer=tf_utils.clone_initializer(self._bias_initializer),
**common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
......
......@@ -18,6 +18,7 @@ from absl import logging
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import relative_attention
......@@ -148,7 +149,7 @@ class TransformerXLBlock(tf.keras.layers.Layer):
value_dim=self._head_size,
dropout=self._attention_dropout_rate,
use_bias=False,
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="rel_attn")
self._attention_dropout = tf.keras.layers.Dropout(
rate=self._attention_dropout_rate)
......@@ -161,7 +162,7 @@ class TransformerXLBlock(tf.keras.layers.Layer):
"abc,cd->abd",
output_shape=(None, self._inner_size),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer),
name="inner")
self._inner_activation_layer = tf.keras.layers.Activation(
......@@ -173,7 +174,7 @@ class TransformerXLBlock(tf.keras.layers.Layer):
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
kernel_initializer=self._kernel_initializer)
kernel_initializer=tf_utils.clone_initializer(self._kernel_initializer))
self._output_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm",
......@@ -398,17 +399,17 @@ class TransformerXL(tf.keras.layers.Layer):
"content_attention_bias",
shape=attention_bias_shape,
dtype=tf.float32,
initializer=self._initializer)
initializer=tf_utils.clone_initializer(self._initializer))
self.positional_attention_bias = self.add_weight(
"positional_attention_bias",
shape=attention_bias_shape,
dtype=tf.float32,
initializer=self._initializer)
initializer=tf_utils.clone_initializer(self._initializer))
self.segment_attention_bias = self.add_weight(
"segment_attention_bias",
shape=attention_bias_shape,
dtype=tf.float32,
initializer=self._initializer)
initializer=tf_utils.clone_initializer(self._initializer))
self.transformer_xl_layers = []
for i in range(self._num_layers):
......
......@@ -22,6 +22,7 @@ from absl import logging
import gin
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp.modeling import networks
......@@ -102,7 +103,7 @@ class BertPretrainer(tf.keras.Model):
masked_lm = layers.MaskedLM(
embedding_table=embedding_table,
activation=activation,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
output=output,
name='cls/predictions')
lm_outputs = masked_lm(
......@@ -111,7 +112,7 @@ class BertPretrainer(tf.keras.Model):
classification = networks.Classification(
input_width=cls_output.shape[-1],
num_classes=num_classes,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
output=output,
name='classification')
sentence_outputs = classification(cls_output)
......
......@@ -96,21 +96,22 @@ class ElectraPretrainer(tf.keras.Model):
self.masked_lm = layers.MaskedLM(
embedding_table=generator_network.get_embedding_table(),
activation=mlm_activation,
initializer=mlm_initializer,
initializer=tf_utils.clone_initializer(mlm_initializer),
output=output_type,
name='generator_masked_lm')
self.classification = layers.ClassificationHead(
inner_dim=generator_network.get_config()['hidden_size'],
num_classes=num_classes,
initializer=mlm_initializer,
initializer=tf_utils.clone_initializer(mlm_initializer),
name='generator_classification_head')
self.discriminator_projection = tf.keras.layers.Dense(
units=discriminator_network.get_config()['hidden_size'],
activation=mlm_activation,
kernel_initializer=mlm_initializer,
kernel_initializer=tf_utils.clone_initializer(mlm_initializer),
name='discriminator_projection_head')
self.discriminator_head = tf.keras.layers.Dense(
units=1, kernel_initializer=mlm_initializer)
units=1,
kernel_initializer=tf_utils.clone_initializer(mlm_initializer))
def call(self, inputs):
"""ELECTRA forward pass.
......
......@@ -55,6 +55,7 @@ class Module(tf.Module):
initializer: Initializer,
dtype: tf.DType = tf.float32,
**kwargs):
initializer = tf_utils.clone_initializer(initializer)
return tf.Variable(initializer(shape, dtype=dtype, **kwargs), name=name)
def read_variable(self,
......@@ -588,7 +589,8 @@ class MultiHeadAttention(Module):
init_std_rescaling = tf.math.sqrt(tf.cast(self.d_kv, dtype=self.dtype))
query_w_init = (
lambda *args, **kwargs: ( # pylint: disable=g-long-lambda
weight_initializer(*args, **kwargs) / init_std_rescaling))
tf_utils.clone_initializer(weight_initializer)(
*args, **kwargs) / init_std_rescaling))
self.q = Linear3D(
self.d_model,
self.d_kv,
......
......@@ -18,6 +18,7 @@ import collections
import tensorflow as tf
from official.modeling import activations
from official.modeling import tf_utils
from official.nlp.modeling import layers
......@@ -92,13 +93,13 @@ class AlbertEncoder(tf.keras.Model):
embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings')
word_embeddings = embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity.
position_embedding_layer = layers.PositionEmbedding(
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
max_length=max_sequence_length,
name='position_embedding')
position_embeddings = position_embedding_layer(word_embeddings)
......@@ -107,7 +108,7 @@ class AlbertEncoder(tf.keras.Model):
layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True,
name='type_embeddings')(type_ids))
......@@ -127,7 +128,7 @@ class AlbertEncoder(tf.keras.Model):
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection')(
embeddings)
......@@ -139,7 +140,7 @@ class AlbertEncoder(tf.keras.Model):
inner_activation=activation,
output_dropout=dropout_rate,
attention_dropout=attention_dropout_rate,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='transformer')
encoder_outputs = []
for _ in range(num_layers):
......@@ -153,7 +154,7 @@ class AlbertEncoder(tf.keras.Model):
cls_output = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='pooler_transform')(
first_token_tensor)
if dict_outputs:
......
......@@ -19,6 +19,7 @@ from typing import Any, Callable, Optional, Union
from absl import logging
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
......@@ -122,20 +123,20 @@ class BertEncoderV2(tf.keras.layers.Layer):
self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings')
else:
self._embedding_layer = embedding_layer
self._position_embedding_layer = layers.PositionEmbedding(
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
max_length=max_sequence_length,
name='position_embedding')
self._type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True,
name='type_embeddings')
......@@ -153,7 +154,7 @@ class BertEncoderV2(tf.keras.layers.Layer):
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection')
self._transformer_layers = []
......@@ -168,14 +169,14 @@ class BertEncoderV2(tf.keras.layers.Layer):
attention_dropout=attention_dropout,
norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='transformer/layer_%d' % i)
self._transformer_layers.append(layer)
self._pooler_layer = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='pooler_transform')
self._config = {
......@@ -409,7 +410,7 @@ class BertEncoder(tf.keras.Model):
embedding_layer_inst = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings')
else:
embedding_layer_inst = embedding_layer
......@@ -417,14 +418,14 @@ class BertEncoder(tf.keras.Model):
# Always uses dynamic slicing for simplicity.
position_embedding_layer = layers.PositionEmbedding(
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
max_length=max_sequence_length,
name='position_embedding')
position_embeddings = position_embedding_layer(word_embeddings)
type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True,
name='type_embeddings')
type_embeddings = type_embedding_layer(type_ids)
......@@ -445,7 +446,7 @@ class BertEncoder(tf.keras.Model):
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection')
embeddings = embedding_projection(embeddings)
else:
......@@ -468,7 +469,7 @@ class BertEncoder(tf.keras.Model):
attention_dropout=attention_dropout,
norm_first=norm_first,
output_range=transformer_output_range,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='transformer/layer_%d' % i)
transformer_layers.append(layer)
data = layer([data, attention_mask])
......@@ -482,7 +483,7 @@ class BertEncoder(tf.keras.Model):
pooler_layer = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='pooler_transform')
cls_output = pooler_layer(first_token_tensor)
......
......@@ -21,6 +21,7 @@ from absl import logging
import gin
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
......@@ -153,14 +154,14 @@ class EncoderScaffold(tf.keras.Model):
embedding_layer = layers.OnDeviceEmbedding(
vocab_size=embedding_cfg['vocab_size'],
embedding_width=embedding_cfg['hidden_size'],
initializer=embedding_cfg['initializer'],
initializer=tf_utils.clone_initializer(embedding_cfg['initializer']),
name='word_embeddings')
word_embeddings = embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity.
position_embedding_layer = layers.PositionEmbedding(
initializer=embedding_cfg['initializer'],
initializer=tf_utils.clone_initializer(embedding_cfg['initializer']),
max_length=embedding_cfg['max_seq_length'],
name='position_embedding')
position_embeddings = position_embedding_layer(word_embeddings)
......@@ -168,7 +169,7 @@ class EncoderScaffold(tf.keras.Model):
type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=embedding_cfg['type_vocab_size'],
embedding_width=embedding_cfg['hidden_size'],
initializer=embedding_cfg['initializer'],
initializer=tf_utils.clone_initializer(embedding_cfg['initializer']),
use_one_hot=True,
name='type_embeddings')
type_embeddings = type_embedding_layer(type_ids)
......@@ -243,6 +244,8 @@ class EncoderScaffold(tf.keras.Model):
# like this will create a SliceOpLambda layer. This is better than a Lambda
# layer with Python code, because that is fundamentally less portable.
first_token_tensor = last_layer_output[:, 0, :]
pooler_layer_initializer = tf.keras.initializers.get(
pooler_layer_initializer)
pooler_layer = tf.keras.layers.Dense(
units=pooled_output_dim,
activation='tanh',
......@@ -303,7 +306,8 @@ class EncoderScaffold(tf.keras.Model):
config_dict = {
'num_hidden_instances': self._num_hidden_instances,
'pooled_output_dim': self._pooled_output_dim,
'pooler_layer_initializer': self._pooler_layer_initializer,
'pooler_layer_initializer': tf.keras.initializers.serialize(
self._pooler_layer_initializer),
'embedding_cls': self._embedding_network,
'embedding_cfg': self._embedding_cfg,
'layer_norm_before_pooling': self._layer_norm_before_pooling,
......
......@@ -20,6 +20,7 @@ from absl import logging
import numpy as np
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
_Initializer = Union[str, tf.keras.initializers.Initializer]
......@@ -265,20 +266,20 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings')
else:
self._embedding_layer = embedding_layer
self._position_embedding_layer = layers.PositionEmbedding(
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
max_length=max_sequence_length,
name='position_embedding')
self._type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True,
name='type_embeddings')
......@@ -296,7 +297,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection')
self._transformer_layers = []
......@@ -316,7 +317,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
attention_dropout=attention_dropout,
norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
share_rezero=share_rezero,
name='transformer/layer_%d' % i)
self._transformer_layers.append(layer)
......@@ -324,7 +325,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
self._pooler_layer = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='pooler_transform')
if isinstance(pool_stride, int):
# TODO(b/197133196): Pooling layer can be shared.
......@@ -342,9 +343,6 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
# TODO(b/203665205): unpool_length should be implemented.
if unpool_length != 0:
raise ValueError('unpool_length is not supported by truncated_avg now.')
# Compute the attention masks and pooling transforms.
self._pooling_transforms = _create_truncated_avg_transforms(
max_sequence_length, pool_strides)
else:
raise ValueError('pool_type not supported.')
......@@ -358,6 +356,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
name='att_input_pool_layer')
self._att_input_pool_layers.append(att_input_pool_layer)
self._max_sequence_length = max_sequence_length
self._pool_strides = pool_strides # This is a list here.
self._unpool_length = unpool_length
self._pool_type = pool_type
......@@ -488,8 +487,13 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
axes=[1, 2])
encoder_outputs.append(x)
elif self._pool_type == _TRUNCATED_AVG:
# Compute the attention masks and pooling transforms.
# Note we do not compute this in __init__ due to inference converter issue
# b/215659399.
pooling_transforms = _create_truncated_avg_transforms(
self._max_sequence_length, self._pool_strides)
attention_masks = _create_truncated_avg_masks(mask, self._pool_strides,
self._pooling_transforms)
pooling_transforms)
for i, layer in enumerate(self._transformer_layers):
attention_mask = attention_masks[i]
# Bypass no pooling cases.
......@@ -500,7 +504,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
'BFD,FT->BTD',
tf.cast(x[:, self._unpool_length:, :], _get_policy_dtype()
), # extra casting for faster mixed computation.
self._pooling_transforms[i])
pooling_transforms[i])
query_inputs = tf.concat(
values=(tf.cast(
x[:, :self._unpool_length, :],
......
......@@ -97,13 +97,13 @@ class PackedSequenceEmbedding(tf.keras.Model):
embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings')
word_embeddings = embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity.
position_embedding_layer = PositionEmbeddingWithSubSeqMask(
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
use_dynamic_slicing=True,
max_sequence_length=max_seq_length,
name='position_embedding')
......@@ -114,7 +114,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True,
name='type_embeddings')(type_ids))
......@@ -132,7 +132,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
'...x,xy->...y',
output_shape=hidden_size,
bias_axes=None,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection')(
embeddings)
......
......@@ -17,6 +17,8 @@
import collections
import tensorflow as tf
from official.modeling import tf_utils
def _apply_paragraph_mask(logits, paragraph_mask):
"""Applies a position mask to calculated logits."""
......@@ -156,12 +158,12 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
self._end_n_top = end_n_top
self.start_logits_dense = tf.keras.layers.Dense(
units=1,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='predictions/transform/start_logits')
self.end_logits_inner_dense = tf.keras.layers.Dense(
units=input_width,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
activation=activation,
name='predictions/transform/end_logits/inner')
self.end_logits_layer_norm = tf.keras.layers.LayerNormalization(
......@@ -169,18 +171,18 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
name='predictions/transform/end_logits/layernorm')
self.end_logits_output_dense = tf.keras.layers.Dense(
units=1,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='predictions/transform/end_logits/output')
self.answer_logits_inner = tf.keras.layers.Dense(
units=input_width,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
activation=activation,
name='predictions/transform/answer_logits/inner')
self.answer_logits_dropout = tf.keras.layers.Dropout(rate=dropout_rate)
self.answer_logits_output = tf.keras.layers.Dense(
units=1,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
use_bias=False,
name='predictions/transform/answer_logits/output')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment