Commit 6c54d37d authored by Jialu Liu's avatar Jialu Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 404247019
parent 5ed72458
......@@ -51,9 +51,7 @@ class TeamsPretrainerConfig(base_config.Config):
@gin.configurable
def get_encoder(bert_config,
embedding_network=None,
hidden_layers=layers.Transformer):
def get_encoder(bert_config, embedding_network=None, hidden_layers=None):
"""Gets a 'EncoderScaffold' object.
Args:
......@@ -85,7 +83,9 @@ def get_encoder(bert_config,
stddev=bert_config.initializer_range),
)
if embedding_network is None:
embedding_network = networks.PackedSequenceEmbedding(**embedding_cfg)
embedding_network = networks.PackedSequenceEmbedding
if hidden_layers is None:
hidden_layers = layers.Transformer
kwargs = dict(
embedding_cfg=embedding_cfg,
embedding_cls=embedding_network,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment