Commit eb4f8124 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Slice BERT output with a TF Op layer instead of a Lambda layer.

Lambda layers are fundamentally non-portable since they serialize
Python bytecode.

PiperOrigin-RevId: 336910013
parent d6fe2516
......@@ -171,9 +171,11 @@ class BertEncoder(tf.keras.Model):
data = layer([data, attention_mask])
encoder_outputs.append(data)
first_token_tensor = (
tf.keras.layers.Lambda(lambda x: tf.squeeze(x[:, 0:1, :], axis=1))(
encoder_outputs[-1]))
last_enocder_output = encoder_outputs[-1]
# Applying a tf.slice op (through subscript notation) to a Keras tensor
# like this will create a SliceOpLambda layer. This is better than a Lambda
# layer with Python code, because that is fundamentally less portable.
first_token_tensor = last_enocder_output[:, 0, :]
self._pooler_layer = tf.keras.layers.Dense(
units=hidden_size,
activation='tanh',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment