Commit 4ec2ee97 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 312313738
parent 1e72c1f2
...@@ -134,13 +134,16 @@ def get_transformer_encoder(bert_config, ...@@ -134,13 +134,16 @@ def get_transformer_encoder(bert_config,
intermediate_activation=tf_utils.get_activation(bert_config.hidden_act), intermediate_activation=tf_utils.get_activation(bert_config.hidden_act),
dropout_rate=bert_config.hidden_dropout_prob, dropout_rate=bert_config.hidden_dropout_prob,
attention_dropout_rate=bert_config.attention_probs_dropout_prob, attention_dropout_rate=bert_config.attention_probs_dropout_prob,
kernel_initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range),
) )
kwargs = dict( kwargs = dict(
embedding_cfg=embedding_cfg, embedding_cfg=embedding_cfg,
hidden_cfg=hidden_cfg, hidden_cfg=hidden_cfg,
num_hidden_instances=bert_config.num_hidden_layers, num_hidden_instances=bert_config.num_hidden_layers,
pooled_output_dim=bert_config.hidden_size, pooled_output_dim=bert_config.hidden_size,
) pooler_layer_initializer=tf.keras.initializers.TruncatedNormal(
stddev=bert_config.initializer_range))
# Relies on gin configuration to define the Transformer encoder arguments. # Relies on gin configuration to define the Transformer encoder arguments.
return transformer_encoder_cls(**kwargs) return transformer_encoder_cls(**kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment