Commit 4e434726 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Internal Change

PiperOrigin-RevId: 307233807
parent c3b4ffc5
...@@ -60,7 +60,8 @@ def create_pretrain_dataset(input_patterns, ...@@ -60,7 +60,8 @@ def create_pretrain_dataset(input_patterns,
batch_size, batch_size,
is_training=True, is_training=True,
input_pipeline_context=None, input_pipeline_context=None,
use_next_sentence_label=True): use_next_sentence_label=True,
use_position_id=False):
"""Creates input dataset from (tf)records files for pretraining.""" """Creates input dataset from (tf)records files for pretraining."""
name_to_features = { name_to_features = {
'input_ids': 'input_ids':
...@@ -79,7 +80,9 @@ def create_pretrain_dataset(input_patterns, ...@@ -79,7 +80,9 @@ def create_pretrain_dataset(input_patterns,
if use_next_sentence_label: if use_next_sentence_label:
name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1], name_to_features['next_sentence_labels'] = tf.io.FixedLenFeature([1],
tf.int64) tf.int64)
if use_position_id:
name_to_features['position_ids'] = tf.io.FixedLenFeature([seq_length],
tf.int64)
for input_pattern in input_patterns: for input_pattern in input_patterns:
if not tf.io.gfile.glob(input_pattern): if not tf.io.gfile.glob(input_pattern):
raise ValueError('%s does not match any files.' % input_pattern) raise ValueError('%s does not match any files.' % input_pattern)
...@@ -123,6 +126,8 @@ def create_pretrain_dataset(input_patterns, ...@@ -123,6 +126,8 @@ def create_pretrain_dataset(input_patterns,
} }
if use_next_sentence_label: if use_next_sentence_label:
x['next_sentence_labels'] = record['next_sentence_labels'] x['next_sentence_labels'] = record['next_sentence_labels']
if use_position_id:
x['position_ids'] = record['position_ids']
y = record['masked_lm_weights'] y = record['masked_lm_weights']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment