Commit ff8ee1ae authored by Hongkun Yu's avatar Hongkun Yu Committed by saberkun
Browse files

Bring back gather_indexes

PiperOrigin-RevId: 282103499
parent 83f0a576
......@@ -29,6 +29,38 @@ from official.nlp.modeling.networks import bert_pretrainer
from official.nlp.modeling.networks import bert_span_labeler
def gather_indexes(sequence_tensor, positions):
"""Gathers the vectors at the specific positions.
Args:
sequence_tensor: Sequence output of `BertModel` layer of shape
(`batch_size`, `seq_length`, num_hidden) where num_hidden is number of
hidden units of `BertModel` layer.
positions: Positions ids of tokens in sequence to mask for pretraining of
with dimension (batch_size, max_predictions_per_seq) where
`max_predictions_per_seq` is maximum number of tokens to mask out and
predict per each sequence.
Returns:
Masked out sequence tensor of shape (batch_size * max_predictions_per_seq,
num_hidden).
"""
sequence_shape = tf_utils.get_shape_list(
sequence_tensor, name='sequence_output_tensor')
batch_size = sequence_shape[0]
seq_length = sequence_shape[1]
width = sequence_shape[2]
flat_offsets = tf.keras.backend.reshape(
tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
flat_positions = tf.keras.backend.reshape(positions + flat_offsets, [-1])
flat_sequence_tensor = tf.keras.backend.reshape(
sequence_tensor, [batch_size * seq_length, width])
output_tensor = tf.gather(flat_sequence_tensor, flat_positions)
return output_tensor
class BertPretrainLossAndMetricLayer(tf.keras.layers.Layer):
"""Returns layer that computes custom loss and metrics for pretraining."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment