Commit 840a493a authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 311773503
parent bce4604a
...@@ -70,12 +70,12 @@ class BertPretrainer(tf.keras.Model): ...@@ -70,12 +70,12 @@ class BertPretrainer(tf.keras.Model):
'initializer': initializer, 'initializer': initializer,
'output': output, 'output': output,
} }
self.encoder = network
# We want to use the inputs of the passed network as the inputs to this # We want to use the inputs of the passed network as the inputs to this
# Model. To do this, we need to keep a copy of the network inputs for use # Model. To do this, we need to keep a copy of the network inputs for use
# when we construct the Model object at the end of init. (We keep a copy # when we construct the Model object at the end of init. (We keep a copy
# because we'll be adding another tensor to the copy later.) # because we'll be adding another tensor to the copy later.)
network_inputs = network.inputs network_inputs = self.encoder.inputs
inputs = copy.copy(network_inputs) inputs = copy.copy(network_inputs)
# Because we have a copy of inputs to create this Model object, we can # Because we have a copy of inputs to create this Model object, we can
...@@ -83,8 +83,13 @@ class BertPretrainer(tf.keras.Model): ...@@ -83,8 +83,13 @@ class BertPretrainer(tf.keras.Model):
# Note that, because of how deferred construction happens, we can't use # Note that, because of how deferred construction happens, we can't use
# the copy of the list here - by the time the network is invoked, the list # the copy of the list here - by the time the network is invoked, the list
# object contains the additional input added below. # object contains the additional input added below.
sequence_output, cls_output = network(network_inputs) sequence_output, cls_output = self.encoder(network_inputs)
# The encoder network may get outputs from all layers.
if isinstance(sequence_output, list):
sequence_output = sequence_output[-1]
if isinstance(cls_output, list):
cls_output = cls_output[-1]
sequence_output_length = sequence_output.shape.as_list()[1] sequence_output_length = sequence_output.shape.as_list()[1]
if sequence_output_length < num_token_predictions: if sequence_output_length < num_token_predictions:
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment