Commit 1c89b792 authored by Maxim Neumann's avatar Maxim Neumann Committed by A. Unique TensorFlower
Browse files

Add a flag to control the number of train examples.

PiperOrigin-RevId: 327838493
parent e0b6ce02
......@@ -36,11 +36,13 @@ def decode_record(record, name_to_features):
return example
def single_file_dataset(input_file, name_to_features):
def single_file_dataset(input_file, name_to_features, num_samples=None):
"""Creates a single-file dataset to be passed for BERT custom training."""
# For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file)
if num_samples:
d = d.take(num_samples)
d = d.map(
lambda record: decode_record(record, name_to_features),
num_parallel_calls=tf.data.experimental.AUTOTUNE)
......@@ -156,7 +158,8 @@ def create_classifier_dataset(file_path,
is_training=True,
input_pipeline_context=None,
label_type=tf.int64,
include_sample_weights=False):
include_sample_weights=False,
num_samples=None):
"""Creates input dataset from (tf)records files for train/eval."""
name_to_features = {
'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
......@@ -166,7 +169,8 @@ def create_classifier_dataset(file_path,
}
if include_sample_weights:
name_to_features['weight'] = tf.io.FixedLenFeature([], tf.float32)
dataset = single_file_dataset(file_path, name_to_features)
dataset = single_file_dataset(file_path, name_to_features,
num_samples=num_samples)
# The dataset is always sharded by number of hosts.
# num_input_pipelines is the number of hosts rather than number of cores.
......
......@@ -53,6 +53,9 @@ flags.DEFINE_string(
'input_meta_data_path', None,
'Path to file that contains meta data about input '
'to be used for training and evaluation.')
flags.DEFINE_integer('train_data_size', None, 'Number of training samples '
'to use. If None, uses the full train data. '
'(default: None).')
flags.DEFINE_string('predict_checkpoint_path', None,
'Path to the checkpoint for predictions.')
flags.DEFINE_integer(
......@@ -92,7 +95,8 @@ def get_dataset_fn(input_file_pattern,
global_batch_size,
is_training,
label_type=tf.int64,
include_sample_weights=False):
include_sample_weights=False,
num_samples=None):
"""Gets a closure to create a dataset."""
def _dataset_fn(ctx=None):
......@@ -106,7 +110,8 @@ def get_dataset_fn(input_file_pattern,
is_training=is_training,
input_pipeline_context=ctx,
label_type=label_type,
include_sample_weights=include_sample_weights)
include_sample_weights=include_sample_weights,
num_samples=num_samples)
return dataset
return _dataset_fn
......@@ -374,6 +379,9 @@ def run_bert(strategy,
epochs = FLAGS.num_train_epochs * FLAGS.num_eval_per_epoch
train_data_size = (
input_meta_data['train_data_size'] // FLAGS.num_eval_per_epoch)
if FLAGS.train_data_size:
train_data_size = min(train_data_size, FLAGS.train_data_size)
logging.info('Updated train_data_size: %s', train_data_size)
steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size)
eval_steps = int(
......@@ -489,7 +497,8 @@ def custom_main(custom_callbacks=None, custom_metrics=None):
FLAGS.train_batch_size,
is_training=True,
label_type=label_type,
include_sample_weights=include_sample_weights)
include_sample_weights=include_sample_weights,
num_samples=FLAGS.train_data_size)
run_bert(
strategy,
input_meta_data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment