Commit 2cbcddb1 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 345761218
parent 277ea3cf
...@@ -63,6 +63,11 @@ class BertSpanLabeler(tf.keras.Model): ...@@ -63,6 +63,11 @@ class BertSpanLabeler(tf.keras.Model):
else: else:
sequence_output = outputs['sequence_output'] sequence_output = outputs['sequence_output']
# The input network (typically a transformer model) may get outputs from all
# layers. When this case happens, we retrieve the last layer output.
if isinstance(sequence_output, list):
sequence_output = sequence_output[-1]
# This is an instance variable for ease of access to the underlying task # This is an instance variable for ease of access to the underlying task
# network. # network.
span_labeling = networks.SpanLabeling( span_labeling = networks.SpanLabeling(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment