Commit 6d458bcc authored by Scott Zhu's avatar Scott Zhu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 446219773
parent 1fd7aaaf
...@@ -55,6 +55,7 @@ class Module(tf.Module): ...@@ -55,6 +55,7 @@ class Module(tf.Module):
initializer: Initializer, initializer: Initializer,
dtype: tf.DType = tf.float32, dtype: tf.DType = tf.float32,
**kwargs): **kwargs):
initializer = tf_utils.clone_initializer(initializer)
return tf.Variable(initializer(shape, dtype=dtype, **kwargs), name=name) return tf.Variable(initializer(shape, dtype=dtype, **kwargs), name=name)
def read_variable(self, def read_variable(self,
...@@ -588,7 +589,8 @@ class MultiHeadAttention(Module): ...@@ -588,7 +589,8 @@ class MultiHeadAttention(Module):
init_std_rescaling = tf.math.sqrt(tf.cast(self.d_kv, dtype=self.dtype)) init_std_rescaling = tf.math.sqrt(tf.cast(self.d_kv, dtype=self.dtype))
query_w_init = ( query_w_init = (
lambda *args, **kwargs: ( # pylint: disable=g-long-lambda lambda *args, **kwargs: ( # pylint: disable=g-long-lambda
weight_initializer(*args, **kwargs) / init_std_rescaling)) tf_utils.clone_initializer(weight_initializer)(
*args, **kwargs) / init_std_rescaling))
self.q = Linear3D( self.q = Linear3D(
self.d_model, self.d_model,
self.d_kv, self.d_kv,
......
...@@ -18,6 +18,7 @@ import collections ...@@ -18,6 +18,7 @@ import collections
import tensorflow as tf import tensorflow as tf
from official.modeling import activations from official.modeling import activations
from official.modeling import tf_utils
from official.nlp.modeling import layers from official.nlp.modeling import layers
...@@ -92,13 +93,13 @@ class AlbertEncoder(tf.keras.Model): ...@@ -92,13 +93,13 @@ class AlbertEncoder(tf.keras.Model):
embedding_layer = layers.OnDeviceEmbedding( embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings') name='word_embeddings')
word_embeddings = embedding_layer(word_ids) word_embeddings = embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity. # Always uses dynamic slicing for simplicity.
position_embedding_layer = layers.PositionEmbedding( position_embedding_layer = layers.PositionEmbedding(
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
max_length=max_sequence_length, max_length=max_sequence_length,
name='position_embedding') name='position_embedding')
position_embeddings = position_embedding_layer(word_embeddings) position_embeddings = position_embedding_layer(word_embeddings)
...@@ -107,7 +108,7 @@ class AlbertEncoder(tf.keras.Model): ...@@ -107,7 +108,7 @@ class AlbertEncoder(tf.keras.Model):
layers.OnDeviceEmbedding( layers.OnDeviceEmbedding(
vocab_size=type_vocab_size, vocab_size=type_vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True, use_one_hot=True,
name='type_embeddings')(type_ids)) name='type_embeddings')(type_ids))
...@@ -127,7 +128,7 @@ class AlbertEncoder(tf.keras.Model): ...@@ -127,7 +128,7 @@ class AlbertEncoder(tf.keras.Model):
'...x,xy->...y', '...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes='y', bias_axes='y',
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection')( name='embedding_projection')(
embeddings) embeddings)
...@@ -139,7 +140,7 @@ class AlbertEncoder(tf.keras.Model): ...@@ -139,7 +140,7 @@ class AlbertEncoder(tf.keras.Model):
inner_activation=activation, inner_activation=activation,
output_dropout=dropout_rate, output_dropout=dropout_rate,
attention_dropout=attention_dropout_rate, attention_dropout=attention_dropout_rate,
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='transformer') name='transformer')
encoder_outputs = [] encoder_outputs = []
for _ in range(num_layers): for _ in range(num_layers):
...@@ -153,7 +154,7 @@ class AlbertEncoder(tf.keras.Model): ...@@ -153,7 +154,7 @@ class AlbertEncoder(tf.keras.Model):
cls_output = tf.keras.layers.Dense( cls_output = tf.keras.layers.Dense(
units=hidden_size, units=hidden_size,
activation='tanh', activation='tanh',
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='pooler_transform')( name='pooler_transform')(
first_token_tensor) first_token_tensor)
if dict_outputs: if dict_outputs:
......
...@@ -19,6 +19,7 @@ from typing import Any, Callable, Optional, Union ...@@ -19,6 +19,7 @@ from typing import Any, Callable, Optional, Union
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers from official.nlp.modeling import layers
...@@ -122,20 +123,20 @@ class BertEncoderV2(tf.keras.layers.Layer): ...@@ -122,20 +123,20 @@ class BertEncoderV2(tf.keras.layers.Layer):
self._embedding_layer = layers.OnDeviceEmbedding( self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings') name='word_embeddings')
else: else:
self._embedding_layer = embedding_layer self._embedding_layer = embedding_layer
self._position_embedding_layer = layers.PositionEmbedding( self._position_embedding_layer = layers.PositionEmbedding(
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
max_length=max_sequence_length, max_length=max_sequence_length,
name='position_embedding') name='position_embedding')
self._type_embedding_layer = layers.OnDeviceEmbedding( self._type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size, vocab_size=type_vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True, use_one_hot=True,
name='type_embeddings') name='type_embeddings')
...@@ -153,7 +154,7 @@ class BertEncoderV2(tf.keras.layers.Layer): ...@@ -153,7 +154,7 @@ class BertEncoderV2(tf.keras.layers.Layer):
'...x,xy->...y', '...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes='y', bias_axes='y',
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection') name='embedding_projection')
self._transformer_layers = [] self._transformer_layers = []
...@@ -168,14 +169,14 @@ class BertEncoderV2(tf.keras.layers.Layer): ...@@ -168,14 +169,14 @@ class BertEncoderV2(tf.keras.layers.Layer):
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
norm_first=norm_first, norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None, output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='transformer/layer_%d' % i) name='transformer/layer_%d' % i)
self._transformer_layers.append(layer) self._transformer_layers.append(layer)
self._pooler_layer = tf.keras.layers.Dense( self._pooler_layer = tf.keras.layers.Dense(
units=hidden_size, units=hidden_size,
activation='tanh', activation='tanh',
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='pooler_transform') name='pooler_transform')
self._config = { self._config = {
...@@ -409,7 +410,7 @@ class BertEncoder(tf.keras.Model): ...@@ -409,7 +410,7 @@ class BertEncoder(tf.keras.Model):
embedding_layer_inst = layers.OnDeviceEmbedding( embedding_layer_inst = layers.OnDeviceEmbedding(
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings') name='word_embeddings')
else: else:
embedding_layer_inst = embedding_layer embedding_layer_inst = embedding_layer
...@@ -417,14 +418,14 @@ class BertEncoder(tf.keras.Model): ...@@ -417,14 +418,14 @@ class BertEncoder(tf.keras.Model):
# Always uses dynamic slicing for simplicity. # Always uses dynamic slicing for simplicity.
position_embedding_layer = layers.PositionEmbedding( position_embedding_layer = layers.PositionEmbedding(
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
max_length=max_sequence_length, max_length=max_sequence_length,
name='position_embedding') name='position_embedding')
position_embeddings = position_embedding_layer(word_embeddings) position_embeddings = position_embedding_layer(word_embeddings)
type_embedding_layer = layers.OnDeviceEmbedding( type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size, vocab_size=type_vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True, use_one_hot=True,
name='type_embeddings') name='type_embeddings')
type_embeddings = type_embedding_layer(type_ids) type_embeddings = type_embedding_layer(type_ids)
...@@ -445,7 +446,7 @@ class BertEncoder(tf.keras.Model): ...@@ -445,7 +446,7 @@ class BertEncoder(tf.keras.Model):
'...x,xy->...y', '...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes='y', bias_axes='y',
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection') name='embedding_projection')
embeddings = embedding_projection(embeddings) embeddings = embedding_projection(embeddings)
else: else:
...@@ -468,7 +469,7 @@ class BertEncoder(tf.keras.Model): ...@@ -468,7 +469,7 @@ class BertEncoder(tf.keras.Model):
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
norm_first=norm_first, norm_first=norm_first,
output_range=transformer_output_range, output_range=transformer_output_range,
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='transformer/layer_%d' % i) name='transformer/layer_%d' % i)
transformer_layers.append(layer) transformer_layers.append(layer)
data = layer([data, attention_mask]) data = layer([data, attention_mask])
...@@ -482,7 +483,7 @@ class BertEncoder(tf.keras.Model): ...@@ -482,7 +483,7 @@ class BertEncoder(tf.keras.Model):
pooler_layer = tf.keras.layers.Dense( pooler_layer = tf.keras.layers.Dense(
units=hidden_size, units=hidden_size,
activation='tanh', activation='tanh',
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='pooler_transform') name='pooler_transform')
cls_output = pooler_layer(first_token_tensor) cls_output = pooler_layer(first_token_tensor)
......
...@@ -21,6 +21,7 @@ from absl import logging ...@@ -21,6 +21,7 @@ from absl import logging
import gin import gin
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers from official.nlp.modeling import layers
...@@ -153,14 +154,14 @@ class EncoderScaffold(tf.keras.Model): ...@@ -153,14 +154,14 @@ class EncoderScaffold(tf.keras.Model):
embedding_layer = layers.OnDeviceEmbedding( embedding_layer = layers.OnDeviceEmbedding(
vocab_size=embedding_cfg['vocab_size'], vocab_size=embedding_cfg['vocab_size'],
embedding_width=embedding_cfg['hidden_size'], embedding_width=embedding_cfg['hidden_size'],
initializer=embedding_cfg['initializer'], initializer=tf_utils.clone_initializer(embedding_cfg['initializer']),
name='word_embeddings') name='word_embeddings')
word_embeddings = embedding_layer(word_ids) word_embeddings = embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity. # Always uses dynamic slicing for simplicity.
position_embedding_layer = layers.PositionEmbedding( position_embedding_layer = layers.PositionEmbedding(
initializer=embedding_cfg['initializer'], initializer=tf_utils.clone_initializer(embedding_cfg['initializer']),
max_length=embedding_cfg['max_seq_length'], max_length=embedding_cfg['max_seq_length'],
name='position_embedding') name='position_embedding')
position_embeddings = position_embedding_layer(word_embeddings) position_embeddings = position_embedding_layer(word_embeddings)
...@@ -168,7 +169,7 @@ class EncoderScaffold(tf.keras.Model): ...@@ -168,7 +169,7 @@ class EncoderScaffold(tf.keras.Model):
type_embedding_layer = layers.OnDeviceEmbedding( type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=embedding_cfg['type_vocab_size'], vocab_size=embedding_cfg['type_vocab_size'],
embedding_width=embedding_cfg['hidden_size'], embedding_width=embedding_cfg['hidden_size'],
initializer=embedding_cfg['initializer'], initializer=tf_utils.clone_initializer(embedding_cfg['initializer']),
use_one_hot=True, use_one_hot=True,
name='type_embeddings') name='type_embeddings')
type_embeddings = type_embedding_layer(type_ids) type_embeddings = type_embedding_layer(type_ids)
...@@ -243,6 +244,8 @@ class EncoderScaffold(tf.keras.Model): ...@@ -243,6 +244,8 @@ class EncoderScaffold(tf.keras.Model):
# like this will create a SliceOpLambda layer. This is better than a Lambda # like this will create a SliceOpLambda layer. This is better than a Lambda
# layer with Python code, because that is fundamentally less portable. # layer with Python code, because that is fundamentally less portable.
first_token_tensor = last_layer_output[:, 0, :] first_token_tensor = last_layer_output[:, 0, :]
pooler_layer_initializer = tf.keras.initializers.get(
pooler_layer_initializer)
pooler_layer = tf.keras.layers.Dense( pooler_layer = tf.keras.layers.Dense(
units=pooled_output_dim, units=pooled_output_dim,
activation='tanh', activation='tanh',
...@@ -303,7 +306,8 @@ class EncoderScaffold(tf.keras.Model): ...@@ -303,7 +306,8 @@ class EncoderScaffold(tf.keras.Model):
config_dict = { config_dict = {
'num_hidden_instances': self._num_hidden_instances, 'num_hidden_instances': self._num_hidden_instances,
'pooled_output_dim': self._pooled_output_dim, 'pooled_output_dim': self._pooled_output_dim,
'pooler_layer_initializer': self._pooler_layer_initializer, 'pooler_layer_initializer': tf.keras.initializers.serialize(
self._pooler_layer_initializer),
'embedding_cls': self._embedding_network, 'embedding_cls': self._embedding_network,
'embedding_cfg': self._embedding_cfg, 'embedding_cfg': self._embedding_cfg,
'layer_norm_before_pooling': self._layer_norm_before_pooling, 'layer_norm_before_pooling': self._layer_norm_before_pooling,
......
...@@ -20,6 +20,7 @@ from absl import logging ...@@ -20,6 +20,7 @@ from absl import logging
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers from official.nlp.modeling import layers
_Initializer = Union[str, tf.keras.initializers.Initializer] _Initializer = Union[str, tf.keras.initializers.Initializer]
...@@ -265,20 +266,20 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -265,20 +266,20 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
self._embedding_layer = layers.OnDeviceEmbedding( self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings') name='word_embeddings')
else: else:
self._embedding_layer = embedding_layer self._embedding_layer = embedding_layer
self._position_embedding_layer = layers.PositionEmbedding( self._position_embedding_layer = layers.PositionEmbedding(
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
max_length=max_sequence_length, max_length=max_sequence_length,
name='position_embedding') name='position_embedding')
self._type_embedding_layer = layers.OnDeviceEmbedding( self._type_embedding_layer = layers.OnDeviceEmbedding(
vocab_size=type_vocab_size, vocab_size=type_vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True, use_one_hot=True,
name='type_embeddings') name='type_embeddings')
...@@ -296,7 +297,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -296,7 +297,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
'...x,xy->...y', '...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes='y', bias_axes='y',
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection') name='embedding_projection')
self._transformer_layers = [] self._transformer_layers = []
...@@ -316,7 +317,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -316,7 +317,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
attention_dropout=attention_dropout, attention_dropout=attention_dropout,
norm_first=norm_first, norm_first=norm_first,
output_range=output_range if i == num_layers - 1 else None, output_range=output_range if i == num_layers - 1 else None,
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
share_rezero=share_rezero, share_rezero=share_rezero,
name='transformer/layer_%d' % i) name='transformer/layer_%d' % i)
self._transformer_layers.append(layer) self._transformer_layers.append(layer)
...@@ -324,7 +325,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer): ...@@ -324,7 +325,7 @@ class FunnelTransformerEncoder(tf.keras.layers.Layer):
self._pooler_layer = tf.keras.layers.Dense( self._pooler_layer = tf.keras.layers.Dense(
units=hidden_size, units=hidden_size,
activation='tanh', activation='tanh',
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='pooler_transform') name='pooler_transform')
if isinstance(pool_stride, int): if isinstance(pool_stride, int):
# TODO(b/197133196): Pooling layer can be shared. # TODO(b/197133196): Pooling layer can be shared.
......
...@@ -97,13 +97,13 @@ class PackedSequenceEmbedding(tf.keras.Model): ...@@ -97,13 +97,13 @@ class PackedSequenceEmbedding(tf.keras.Model):
embedding_layer = layers.OnDeviceEmbedding( embedding_layer = layers.OnDeviceEmbedding(
vocab_size=vocab_size, vocab_size=vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
name='word_embeddings') name='word_embeddings')
word_embeddings = embedding_layer(word_ids) word_embeddings = embedding_layer(word_ids)
# Always uses dynamic slicing for simplicity. # Always uses dynamic slicing for simplicity.
position_embedding_layer = PositionEmbeddingWithSubSeqMask( position_embedding_layer = PositionEmbeddingWithSubSeqMask(
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
use_dynamic_slicing=True, use_dynamic_slicing=True,
max_sequence_length=max_seq_length, max_sequence_length=max_seq_length,
name='position_embedding') name='position_embedding')
...@@ -114,7 +114,7 @@ class PackedSequenceEmbedding(tf.keras.Model): ...@@ -114,7 +114,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
layers.OnDeviceEmbedding( layers.OnDeviceEmbedding(
vocab_size=type_vocab_size, vocab_size=type_vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=initializer, initializer=tf_utils.clone_initializer(initializer),
use_one_hot=True, use_one_hot=True,
name='type_embeddings')(type_ids)) name='type_embeddings')(type_ids))
...@@ -132,7 +132,7 @@ class PackedSequenceEmbedding(tf.keras.Model): ...@@ -132,7 +132,7 @@ class PackedSequenceEmbedding(tf.keras.Model):
'...x,xy->...y', '...x,xy->...y',
output_shape=hidden_size, output_shape=hidden_size,
bias_axes=None, bias_axes=None,
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='embedding_projection')( name='embedding_projection')(
embeddings) embeddings)
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
import collections import collections
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
def _apply_paragraph_mask(logits, paragraph_mask): def _apply_paragraph_mask(logits, paragraph_mask):
"""Applies a position mask to calculated logits.""" """Applies a position mask to calculated logits."""
...@@ -156,12 +158,12 @@ class XLNetSpanLabeling(tf.keras.layers.Layer): ...@@ -156,12 +158,12 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
self._end_n_top = end_n_top self._end_n_top = end_n_top
self.start_logits_dense = tf.keras.layers.Dense( self.start_logits_dense = tf.keras.layers.Dense(
units=1, units=1,
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='predictions/transform/start_logits') name='predictions/transform/start_logits')
self.end_logits_inner_dense = tf.keras.layers.Dense( self.end_logits_inner_dense = tf.keras.layers.Dense(
units=input_width, units=input_width,
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
activation=activation, activation=activation,
name='predictions/transform/end_logits/inner') name='predictions/transform/end_logits/inner')
self.end_logits_layer_norm = tf.keras.layers.LayerNormalization( self.end_logits_layer_norm = tf.keras.layers.LayerNormalization(
...@@ -169,18 +171,18 @@ class XLNetSpanLabeling(tf.keras.layers.Layer): ...@@ -169,18 +171,18 @@ class XLNetSpanLabeling(tf.keras.layers.Layer):
name='predictions/transform/end_logits/layernorm') name='predictions/transform/end_logits/layernorm')
self.end_logits_output_dense = tf.keras.layers.Dense( self.end_logits_output_dense = tf.keras.layers.Dense(
units=1, units=1,
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
name='predictions/transform/end_logits/output') name='predictions/transform/end_logits/output')
self.answer_logits_inner = tf.keras.layers.Dense( self.answer_logits_inner = tf.keras.layers.Dense(
units=input_width, units=input_width,
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
activation=activation, activation=activation,
name='predictions/transform/answer_logits/inner') name='predictions/transform/answer_logits/inner')
self.answer_logits_dropout = tf.keras.layers.Dropout(rate=dropout_rate) self.answer_logits_dropout = tf.keras.layers.Dropout(rate=dropout_rate)
self.answer_logits_output = tf.keras.layers.Dense( self.answer_logits_output = tf.keras.layers.Dense(
units=1, units=1,
kernel_initializer=initializer, kernel_initializer=tf_utils.clone_initializer(initializer),
use_bias=False, use_bias=False,
name='predictions/transform/answer_logits/output') name='predictions/transform/answer_logits/output')
......
...@@ -18,6 +18,7 @@ from absl import logging ...@@ -18,6 +18,7 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling import layers from official.nlp.modeling import layers
from official.nlp.modeling.layers import transformer_xl from official.nlp.modeling.layers import transformer_xl
...@@ -507,7 +508,7 @@ class XLNetBase(tf.keras.layers.Layer): ...@@ -507,7 +508,7 @@ class XLNetBase(tf.keras.layers.Layer):
self._embedding_layer = layers.OnDeviceEmbedding( self._embedding_layer = layers.OnDeviceEmbedding(
vocab_size=self._vocab_size, vocab_size=self._vocab_size,
embedding_width=embedding_width, embedding_width=embedding_width,
initializer=self._initializer, initializer=tf_utils.clone_initializer(self._initializer),
dtype=tf.float32, dtype=tf.float32,
name="word_embedding") name="word_embedding")
self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate) self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
...@@ -666,7 +667,7 @@ class XLNetBase(tf.keras.layers.Layer): ...@@ -666,7 +667,7 @@ class XLNetBase(tf.keras.layers.Layer):
shape=[self._num_layers, 2, self._num_attention_heads, shape=[self._num_layers, 2, self._num_attention_heads,
self._head_size], self._head_size],
dtype=tf.float32, dtype=tf.float32,
initializer=self._initializer) initializer=tf_utils.clone_initializer(self._initializer))
segment_embedding = self._segment_embedding segment_embedding = self._segment_embedding
segment_matrix = _compute_segment_matrix( segment_matrix = _compute_segment_matrix(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment