Commit 91495d13 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 305530702
parent f41682df
......@@ -121,7 +121,8 @@ def get_transformer_encoder(bert_config,
attention_dropout_rate=bert_config.attention_probs_dropout_prob,
)
kwargs = dict(embedding_cfg=embedding_cfg, hidden_cfg=hidden_cfg,
num_hidden_instances=bert_config.num_hidden_layers,)
num_hidden_instances=bert_config.num_hidden_layers,
num_output_classes=bert_config.hidden_size)
# Relies on gin configuration to define the Transformer encoder arguments.
return transformer_encoder_cls(**kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment