"sgl-kernel/python/vscode:/vscode.git/clone" did not exist on "f226d3da2ae0101cb92764fe2f31f518fe41bb70"
Commit 6c54d37d authored by Jialu Liu's avatar Jialu Liu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 404247019
parent 5ed72458
...@@ -51,9 +51,7 @@ class TeamsPretrainerConfig(base_config.Config): ...@@ -51,9 +51,7 @@ class TeamsPretrainerConfig(base_config.Config):
@gin.configurable @gin.configurable
def get_encoder(bert_config, def get_encoder(bert_config, embedding_network=None, hidden_layers=None):
embedding_network=None,
hidden_layers=layers.Transformer):
"""Gets a 'EncoderScaffold' object. """Gets a 'EncoderScaffold' object.
Args: Args:
...@@ -85,7 +83,9 @@ def get_encoder(bert_config, ...@@ -85,7 +83,9 @@ def get_encoder(bert_config,
stddev=bert_config.initializer_range), stddev=bert_config.initializer_range),
) )
if embedding_network is None: if embedding_network is None:
embedding_network = networks.PackedSequenceEmbedding(**embedding_cfg) embedding_network = networks.PackedSequenceEmbedding
if hidden_layers is None:
hidden_layers = layers.Transformer
kwargs = dict( kwargs = dict(
embedding_cfg=embedding_cfg, embedding_cfg=embedding_cfg,
embedding_cls=embedding_network, embedding_cls=embedding_network,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment