Commit 71a6b899 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 432311587
parent b136d845
......@@ -115,7 +115,8 @@ class MaskedLM(tf.keras.layers.Layer):
flat_offsets = tf.reshape(
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
flat_positions = tf.reshape(positions + flat_offsets, [-1])
flat_positions = tf.reshape(
positions + tf.cast(flat_offsets, positions.dtype), [-1])
flat_sequence_tensor = tf.reshape(sequence_tensor,
[batch_size * seq_length, width])
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment