Commit 840a493a authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 311773503
parent bce4604a
......@@ -70,12 +70,12 @@ class BertPretrainer(tf.keras.Model):
'initializer': initializer,
'output': output,
}
self.encoder = network
# We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a copy of the network inputs for use
# when we construct the Model object at the end of init. (We keep a copy
# because we'll be adding another tensor to the copy later.)
network_inputs = network.inputs
network_inputs = self.encoder.inputs
inputs = copy.copy(network_inputs)
# Because we have a copy of inputs to create this Model object, we can
......@@ -83,8 +83,13 @@ class BertPretrainer(tf.keras.Model):
# Note that, because of how deferred construction happens, we can't use
# the copy of the list here - by the time the network is invoked, the list
# object contains the additional input added below.
sequence_output, cls_output = network(network_inputs)
sequence_output, cls_output = self.encoder(network_inputs)
# The encoder network may get outputs from all layers.
if isinstance(sequence_output, list):
sequence_output = sequence_output[-1]
if isinstance(cls_output, list):
cls_output = cls_output[-1]
sequence_output_length = sequence_output.shape.as_list()[1]
if sequence_output_length < num_token_predictions:
raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment