Commit 6d458bcc authored by Scott Zhu's avatar Scott Zhu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 446219773
parent 1fd7aaaf
......@@ -55,6 +55,7 @@ class Module(tf.Module):
initializer: Initializer,
dtype: tf.DType = tf.float32,
**kwargs):
initializer = tf_utils.clone_initializer(initializer)
return tf.Variable(initializer(shape, dtype=dtype, **kwargs), name=name)
def read_variable(self,
......@@ -588,7 +589,8 @@ class MultiHeadAttention(Module):
init_std_rescaling = tf.math.sqrt(tf.cast(self.d_kv, dtype=self.dtype))
query_w_init = (
lambda *args, **kwargs: ( # pylint: disable=g-long-lambda
weight_initializer(*args, **kwargs) / init_std_rescaling))
tf_utils.clone_initializer(weight_initializer)(
*args, **kwargs) / init_std_rescaling))
self.q = Linear3D(
self.d_model,
self.d_kv,
......
......@@ -18,6 +18,7 @@ import collections
import tensorflow as tf
from official.modeling import activations
from official.modeling import tf_utils
from official.nlp.modeling import layers
......@@ -92,13 +93,13 @@ class AlbertEncoder(tf.keras.Model):
embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings')
word_embeddings = embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity.
position_embedding_layer = layers.PositionEmbedding(
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
max_length=max_sequence_length,
name='position_embedding')
position_embeddings = position_embedding_layer(word_embeddings)
......@@ -107,7 +108,7 @@ class AlbertEncoder(tf.keras.Model):
layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True,
name='type_embeddings')(type_ids))
......@@ -127,7 +128,7 @@ class AlbertEncoder(tf.keras.Model):
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection')(
embeddings)
......@@ -139,7 +140,7 @@ class AlbertEncoder(tf.keras.Model):
inner_activation=activation,
output_dropout=dropout_rate,
attention_dropout=attention_dropout_rate,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='transformer')
encoder_outputs = []
for _ in range(num_layers):
......@@ -153,7 +154,7 @@ class AlbertEncoder(tf.keras.Model):
cls_output = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='pooler_transform')(
first_token_tensor)
if dict_outputs:
......
......@@ -19,6 +19,7 @@ from typing import Any, Callable, Optional, Union
from absl import logging
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
......@@ -122,20 +123,20 @@ class BertEncoderV2(tf.keras.layers.Layer):
self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings')
else:
self._embedding_layer = embedding_layer
self._position_embedding_layer = layers.PositionEmbedding(
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
max_length=max_sequence_length,
name='position_embedding')
self._type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True,
name='type_embeddings')
......@@ -153,7 +154,7 @@ class BertEncoderV2(tf.keras.layers.Layer):
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection')
self._transformer_layers = []
......@@ -168,14 +169,14 @@ class BertEncoderV2(tf.keras.layers.Layer):
attention_dropout=attention_dropout,
norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='transformer/layer_%d' % i)
self._transformer_layers.append(layer)
self._pooler_layer = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='pooler_transform')
self._config = {
......@@ -409,7 +410,7 @@ class BertEncoder(tf.keras.Model):
embedding_layer_inst = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings')
else:
embedding_layer_inst = embedding_layer
......@@ -417,14 +418,14 @@ class BertEncoder(tf.keras.Model):
# Always uses dynamic slicing for simplicity.
position_embedding_layer = layers.PositionEmbedding(
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
max_length=max_sequence_length,
name='position_embedding')
position_embeddings = position_embedding_layer(word_embeddings)
type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True,
name='type_embeddings')
type_embeddings = type_embedding_layer(type_ids)
......@@ -445,7 +446,7 @@ class BertEncoder(tf.keras.Model):
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection')
embeddings = embedding_projection(embeddings)
else:
......@@ -468,7 +469,7 @@ class BertEncoder(tf.keras.Model):
attention_dropout=attention_dropout,
norm_first=norm_first,
output_range=transformer_output_range,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='transformer/layer_%d' % i)
transformer_layers.append(layer)
data = layer([data, attention_mask])
......@@ -482,7 +483,7 @@ class BertEncoder(tf.keras.Model):
pooler_layer = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='pooler_transform')
cls_output = pooler_layer(first_token_tensor)
......
......@@ -21,6 +21,7 @@ from absl import logging
import gin
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
......@@ -153,14 +154,14 @@ class EncoderScaffold(tf.keras.Model):
embedding_layer = layers.OnDeviceEmbedding(
vocab_size=embedding_cfg['vocab_size'],
embedding_width=embedding_cfg['hidden_size'],
initializer=embedding_cfg['initializer'],
initializer=tf_utils.clone_initializer(embedding_cfg['initializer']),
name='word_embeddings')
word_embeddings = embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity.
position_embedding_layer = layers.PositionEmbedding(
initializer=embedding_cfg['initializer'],
initializer=tf_utils.clone_initializer(embedding_cfg['initializer']),
max_length=embedding_cfg['max_seq_length'],
name='position_embedding')
position_embeddings = position_embedding_layer(word_embeddings)
......@@ -168,7 +169,7 @@ class EncoderScaffold(tf.keras.Model):
type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=embedding_cfg['type_vocab_size'],
embedding_width=embedding_cfg['hidden_size'],
initializer=embedding_cfg['initializer'],
initializer=tf_utils.clone_initializer(embedding_cfg['initializer']),
use_one_hot=True,
name='type_embeddings')
type_embeddings = type_embedding_layer(type_ids)
......@@ -243,6 +244,8 @@ class EncoderScaffold(tf.keras.Model):
# like this will create a SliceOpLambda layer. This is better than a Lambda
# layer with Python code, because that is fundamentally less portable.
first_token_tensor = last_layer_output[:, 0, :]
pooler_layer_initializer = tf.keras.initializers.get(
pooler_layer_initializer)
pooler_layer = tf.keras.layers.Dense(
units=pooled_output_dim,
activation='tanh',
......@@ -303,7 +306,8 @@ class EncoderScaffold(tf.keras.Model):
config_dict = {
'num_hidden_instances': self._num_hidden_instances,
'pooled_output_dim': self._pooled_output_dim,
'pooler_layer_initializer': self._pooler_layer_initializer,
'pooler_layer_initializer': tf.keras.initializers.serialize(
self._pooler_layer_initializer),
'embedding_cls': self._embedding_network,
'embedding_cfg': self._embedding_cfg,
'layer_norm_before_pooling': self._layer_norm_before_pooling,
......
......@@ -20,6 +20,7 @@ from absl import logging
import numpy as np
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
_Initializer = Union[str, tf.keras.initializers.Initializer]
......@@ -265,20 +266,20 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings')
else:
self._embedding_layer = embedding_layer
self._position_embedding_layer = layers.PositionEmbedding(
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
max_length=max_sequence_length,
name='position_embedding')
self._type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True,
name='type_embeddings')
......@@ -296,7 +297,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
'...x,xy->...y',
output_shape=hidden_size,
bias_axes='y',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection')
self._transformer_layers = []
......@@ -316,7 +317,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
attention_dropout=attention_dropout,
norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
share_rezero=share_rezero,
name='transformer/layer_%d' % i)
self._transformer_layers.append(layer)
......@@ -324,7 +325,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
self._pooler_layer = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='pooler_transform')
if isinstance(pool_stride, int):
# TODO(b/197133196): Pooling layer can be shared.
......
......@@ -97,13 +97,13 @@ class PackedSequenceEmbedding(tf.keras.Model):
embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size,
embedding_width=embedding_width,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings')
word_embeddings = embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity.
position_embedding_layer = PositionEmbeddingWithSubSeqMask(
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
use_dynamic_slicing=True,
max_sequence_length=max_seq_length,
name='position_embedding')
......@@ -114,7 +114,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
layers.OnDeviceEmbedding(
vocab_size=type_vocab_size,
embedding_width=embedding_width,
initializer=initializer,
initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True,
name='type_embeddings')(type_ids))
......@@ -132,7 +132,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
'...x,xy->...y',
output_shape=hidden_size,
bias_axes=None,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection')(
embeddings)
......
......@@ -17,6 +17,8 @@
import collections
import tensorflow as tf
from official.modeling import tf_utils
def _apply_paragraph_mask(logits, paragraph_mask):
"""Applies a position mask to calculated logits."""
......@@ -156,12 +158,12 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
self._end_n_top = end_n_top
self.start_logits_dense = tf.keras.layers.Dense(
units=1,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='predictions/transform/start_logits')
self.end_logits_inner_dense = tf.keras.layers.Dense(
units=input_width,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
activation=activation,
name='predictions/transform/end_logits/inner')
self.end_logits_layer_norm = tf.keras.layers.LayerNormalization(
......@@ -169,18 +171,18 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
name='predictions/transform/end_logits/layernorm')
self.end_logits_output_dense = tf.keras.layers.Dense(
units=1,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
name='predictions/transform/end_logits/output')
self.answer_logits_inner = tf.keras.layers.Dense(
units=input_width,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
activation=activation,
name='predictions/transform/answer_logits/inner')
self.answer_logits_dropout = tf.keras.layers.Dropout(rate=dropout_rate)
self.answer_logits_output = tf.keras.layers.Dense(
units=1,
kernel_initializer=initializer,
kernel_initializer=tf_utils.clone_initializer(initializer),
use_bias=False,
name='predictions/transform/answer_logits/output')
......
......@@ -18,6 +18,7 @@ from absl import logging
import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers
from official.nlp.modeling.layers import transformer_xl
......@@ -507,7 +508,7 @@ class XLNetBase(tf.keras.layers.Layer):
self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=self._vocab_size,
embedding_width=embedding_width,
initializer=self._initializer,
initializer=tf_utils.clone_initializer(self._initializer),
dtype=tf.float32,
name="word_embedding")
self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
......@@ -666,7 +667,7 @@ class XLNetBase(tf.keras.layers.Layer):
shape=[self._num_layers, 2, self._num_attention_heads,
self._head_size],
dtype=tf.float32,
initializer=self._initializer)
initializer=tf_utils.clone_initializer(self._initializer))
segment_embedding = self._segment_embedding
segment_matrix = _compute_segment_matrix(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment