"examples/mxnet/vscode:/vscode.git/clone" did not exist on "12d706300cba4d9ec25cfa1075ab4d2703dd89f0"
Commit f7852565 authored by Sergey Mironov's avatar Sergey Mironov
Browse files

Make BertPretrainer to accept embedding_table explicitly

parent 31da2245
...@@ -212,6 +212,7 @@ def pretrain_model(bert_config, ...@@ -212,6 +212,7 @@ def pretrain_model(bert_config,
stddev=bert_config.initializer_range) stddev=bert_config.initializer_range)
pretrainer_model = models.BertPretrainer( pretrainer_model = models.BertPretrainer(
network=transformer_encoder, network=transformer_encoder,
embedding_table=transformer_encoder.get_embedding_table(),
num_classes=2, # The next sentence prediction label has two classes. num_classes=2, # The next sentence prediction label has two classes.
num_token_predictions=max_predictions_per_seq, num_token_predictions=max_predictions_per_seq,
initializer=initializer, initializer=initializer,
......
...@@ -39,14 +39,15 @@ class BertPretrainer(tf.keras.Model): ...@@ -39,14 +39,15 @@ class BertPretrainer(tf.keras.Model):
Arguments: Arguments:
network: A transformer network. This network should output a sequence output network: A transformer network. This network should output a sequence output
and a classification output. Furthermore, it should expose its embedding and a classification output.
table via a "get_embedding_table" method.
num_classes: Number of classes to predict from the classification network. num_classes: Number of classes to predict from the classification network.
num_token_predictions: Number of tokens to predict from the masked LM. num_token_predictions: Number of tokens to predict from the masked LM.
activation: The activation (if any) to use in the masked LM and activation: The activation (if any) to use in the masked LM and
classification networks. If None, no activation will be used. classification networks. If None, no activation will be used.
initializer: The initializer (if any) to use in the masked LM and initializer: The initializer (if any) to use in the masked LM and
classification networks. Defaults to a Glorot uniform initializer. classification networks. Defaults to a Glorot uniform initializer.
embedding_table: Embedding table of a network. If None, the
"network.get_embedding_table()" is used.
output: The output style for this network. Can be either 'logits' or output: The output style for this network. Can be either 'logits' or
'predictions'. 'predictions'.
""" """
...@@ -58,6 +59,7 @@ class BertPretrainer(tf.keras.Model): ...@@ -58,6 +59,7 @@ class BertPretrainer(tf.keras.Model):
activation=None, activation=None,
initializer='glorot_uniform', initializer='glorot_uniform',
output='logits', output='logits',
embedding_table=None,
**kwargs): **kwargs):
self._self_setattr_tracking = False self._self_setattr_tracking = False
self._config = { self._config = {
...@@ -100,6 +102,7 @@ class BertPretrainer(tf.keras.Model): ...@@ -100,6 +102,7 @@ class BertPretrainer(tf.keras.Model):
num_predictions=num_token_predictions, num_predictions=num_token_predictions,
input_width=sequence_output.shape[-1], input_width=sequence_output.shape[-1],
source_network=network, source_network=network,
embedding_table=embedding_table,
activation=activation, activation=activation,
initializer=initializer, initializer=initializer,
output=output, output=output,
......
...@@ -37,6 +37,8 @@ class MaskedLM(network.Network): ...@@ -37,6 +37,8 @@ class MaskedLM(network.Network):
num_predictions: The number of predictions to make per sequence. num_predictions: The number of predictions to make per sequence.
source_network: The network with the embedding layer to use for the source_network: The network with the embedding layer to use for the
embedding layer. embedding layer.
embedding_table: The embedding table of a source network, If None, the
`source_network.get_embedding_table()` method is used.
activation: The activation, if any, for the dense layer in this network. activation: The activation, if any, for the dense layer in this network.
initializer: The intializer for the dense layer in this network. Defaults to initializer: The intializer for the dense layer in this network. Defaults to
a Glorot uniform initializer. a Glorot uniform initializer.
...@@ -48,12 +50,16 @@ class MaskedLM(network.Network): ...@@ -48,12 +50,16 @@ class MaskedLM(network.Network):
input_width, input_width,
num_predictions, num_predictions,
source_network, source_network,
embedding_table=None,
activation=None, activation=None,
initializer='glorot_uniform', initializer='glorot_uniform',
output='logits', output='logits',
**kwargs): **kwargs):
embedding_table = source_network.get_embedding_table()
if embedding_table is None:
embedding_table = source_network.get_embedding_table()
vocab_size, hidden_size = embedding_table.shape vocab_size, hidden_size = embedding_table.shape
sequence_data = tf.keras.layers.Input( sequence_data = tf.keras.layers.Input(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment