"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "21d2b0fe695750f57742f950d36d3ebbf4ab4992"
Commit 9cbe7fab authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #8403 from stagedml:bert-pretrain-embedding-table

PiperOrigin-RevId: 308649588
parents 7cc0970b f7852565
......@@ -212,6 +212,7 @@ def pretrain_model(bert_config,
stddev=bert_config.initializer_range)
pretrainer_model = models.BertPretrainer(
network=transformer_encoder,
embedding_table=transformer_encoder.get_embedding_table(),
num_classes=2, # The next sentence prediction label has two classes.
num_token_predictions=max_predictions_per_seq,
initializer=initializer,
......
......@@ -39,10 +39,11 @@ class BertPretrainer(tf.keras.Model):
Arguments:
network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding
table via a "get_embedding_table" method.
and a classification output.
num_classes: Number of classes to predict from the classification network.
num_token_predictions: Number of tokens to predict from the masked LM.
embedding_table: Embedding table of a network. If None, the
"network.get_embedding_table()" is used.
activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used.
initializer: The initializer (if any) to use in the masked LM and
......@@ -55,6 +56,7 @@ class BertPretrainer(tf.keras.Model):
network,
num_classes,
num_token_predictions,
embedding_table=None,
activation=None,
initializer='glorot_uniform',
output='logits',
......@@ -100,6 +102,7 @@ class BertPretrainer(tf.keras.Model):
num_predictions=num_token_predictions,
input_width=sequence_output.shape[-1],
source_network=network,
embedding_table=embedding_table,
activation=activation,
initializer=initializer,
output=output,
......
......@@ -37,6 +37,8 @@ class MaskedLM(network.Network):
num_predictions: The number of predictions to make per sequence.
source_network: The network with the embedding layer to use for the
embedding layer.
embedding_table: The embedding table of a source network, If None, the
`source_network.get_embedding_table()` method is used.
activation: The activation, if any, for the dense layer in this network.
initializer: The intializer for the dense layer in this network. Defaults to
a Glorot uniform initializer.
......@@ -48,12 +50,14 @@ class MaskedLM(network.Network):
input_width,
num_predictions,
source_network,
embedding_table=None,
activation=None,
initializer='glorot_uniform',
output='logits',
**kwargs):
embedding_table = source_network.get_embedding_table()
if embedding_table is None:
embedding_table = source_network.get_embedding_table()
vocab_size, hidden_size = embedding_table.shape
sequence_data = tf.keras.layers.Input(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment