Commit 6b58625d authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Remove TimeDistributed.

PiperOrigin-RevId: 288554071
parent f3250a1d
...@@ -57,14 +57,12 @@ class SpanLabeling(network.Network): ...@@ -57,14 +57,12 @@ class SpanLabeling(network.Network):
sequence_data = tf.keras.layers.Input( sequence_data = tf.keras.layers.Input(
shape=(None, input_width), name='sequence_data', dtype=tf.float32) shape=(None, input_width), name='sequence_data', dtype=tf.float32)
time_distributed_dense = tf.keras.layers.TimeDistributed( intermediate_logits = tf.keras.layers.Dense(
tf.keras.layers.Dense( 2, # This layer predicts start location and end location.
2, # This layer predicts start location and end location. activation=activation,
activation=activation, kernel_initializer=initializer,
kernel_initializer=initializer, name='predictions/transform/logits')(
name='predictions/transform/logits')) sequence_data)
intermediate_logits = time_distributed_dense(sequence_data)
self.start_logits, self.end_logits = ( self.start_logits, self.end_logits = (
tf.keras.layers.Lambda(self._split_output_tensor)(intermediate_logits)) tf.keras.layers.Lambda(self._split_output_tensor)(intermediate_logits))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment