Commit 1707efda authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal Change

PiperOrigin-RevId: 330822629
parent b47eb2a2
...@@ -285,5 +285,22 @@ def create_retrieval_dataset(file_path, ...@@ -285,5 +285,22 @@ def create_retrieval_dataset(file_path,
_select_data_from_record, _select_data_from_record,
num_parallel_calls=tf.data.experimental.AUTOTUNE) num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.batch(batch_size, drop_remainder=False) dataset = dataset.batch(batch_size, drop_remainder=False)
def _pad_to_batch(x, y):
cur_size = tf.shape(y)[0]
pad_size = batch_size - cur_size
pad_ids = tf.zeros(shape=[pad_size, seq_length], dtype=tf.int32)
for key in ('input_word_ids', 'input_mask', 'input_type_ids'):
x[key] = tf.concat([x[key], pad_ids], axis=0)
pad_labels = -tf.ones(shape=[pad_size, 1], dtype=tf.int32)
y = tf.concat([y, pad_labels], axis=0)
return x, y
dataset = dataset.map(
_pad_to_batch,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset return dataset
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment