Commit 1c89b792 authored by Maxim Neumann's avatar Maxim Neumann Committed by A. Unique TensorFlower
Browse files

Add a flag to control the number of train examples.

PiperOrigin-RevId: 327838493
parent e0b6ce02
...@@ -36,11 +36,13 @@ def decode_record(record, name_to_features): ...@@ -36,11 +36,13 @@ def decode_record(record, name_to_features):
return example return example
def single_file_dataset(input_file, name_to_features): def single_file_dataset(input_file, name_to_features, num_samples=None):
"""Creates a single-file dataset to be passed for BERT custom training.""" """Creates a single-file dataset to be passed for BERT custom training."""
# For training, we want a lot of parallel reading and shuffling. # For training, we want a lot of parallel reading and shuffling.
# For eval, we want no shuffling and parallel reading doesn't matter. # For eval, we want no shuffling and parallel reading doesn't matter.
d = tf.data.TFRecordDataset(input_file) d = tf.data.TFRecordDataset(input_file)
if num_samples:
d = d.take(num_samples)
d = d.map( d = d.map(
lambda record: decode_record(record, name_to_features), lambda record: decode_record(record, name_to_features),
num_parallel_calls=tf.data.experimental.AUTOTUNE) num_parallel_calls=tf.data.experimental.AUTOTUNE)
...@@ -156,7 +158,8 @@ def create_classifier_dataset(file_path, ...@@ -156,7 +158,8 @@ def create_classifier_dataset(file_path,
is_training=True, is_training=True,
input_pipeline_context=None, input_pipeline_context=None,
label_type=tf.int64, label_type=tf.int64,
include_sample_weights=False): include_sample_weights=False,
num_samples=None):
"""Creates input dataset from (tf)records files for train/eval.""" """Creates input dataset from (tf)records files for train/eval."""
name_to_features = { name_to_features = {
'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64), 'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),
...@@ -166,7 +169,8 @@ def create_classifier_dataset(file_path, ...@@ -166,7 +169,8 @@ def create_classifier_dataset(file_path,
} }
if include_sample_weights: if include_sample_weights:
name_to_features['weight'] = tf.io.FixedLenFeature([], tf.float32) name_to_features['weight'] = tf.io.FixedLenFeature([], tf.float32)
dataset = single_file_dataset(file_path, name_to_features) dataset = single_file_dataset(file_path, name_to_features,
num_samples=num_samples)
# The dataset is always sharded by number of hosts. # The dataset is always sharded by number of hosts.
# num_input_pipelines is the number of hosts rather than number of cores. # num_input_pipelines is the number of hosts rather than number of cores.
......
...@@ -53,6 +53,9 @@ flags.DEFINE_string( ...@@ -53,6 +53,9 @@ flags.DEFINE_string(
'input_meta_data_path', None, 'input_meta_data_path', None,
'Path to file that contains meta data about input ' 'Path to file that contains meta data about input '
'to be used for training and evaluation.') 'to be used for training and evaluation.')
flags.DEFINE_integer('train_data_size', None, 'Number of training samples '
'to use. If None, uses the full train data. '
'(default: None).')
flags.DEFINE_string('predict_checkpoint_path', None, flags.DEFINE_string('predict_checkpoint_path', None,
'Path to the checkpoint for predictions.') 'Path to the checkpoint for predictions.')
flags.DEFINE_integer( flags.DEFINE_integer(
...@@ -92,7 +95,8 @@ def get_dataset_fn(input_file_pattern, ...@@ -92,7 +95,8 @@ def get_dataset_fn(input_file_pattern,
global_batch_size, global_batch_size,
is_training, is_training,
label_type=tf.int64, label_type=tf.int64,
include_sample_weights=False): include_sample_weights=False,
num_samples=None):
"""Gets a closure to create a dataset.""" """Gets a closure to create a dataset."""
def _dataset_fn(ctx=None): def _dataset_fn(ctx=None):
...@@ -106,7 +110,8 @@ def get_dataset_fn(input_file_pattern, ...@@ -106,7 +110,8 @@ def get_dataset_fn(input_file_pattern,
is_training=is_training, is_training=is_training,
input_pipeline_context=ctx, input_pipeline_context=ctx,
label_type=label_type, label_type=label_type,
include_sample_weights=include_sample_weights) include_sample_weights=include_sample_weights,
num_samples=num_samples)
return dataset return dataset
return _dataset_fn return _dataset_fn
...@@ -374,6 +379,9 @@ def run_bert(strategy, ...@@ -374,6 +379,9 @@ def run_bert(strategy,
epochs = FLAGS.num_train_epochs * FLAGS.num_eval_per_epoch epochs = FLAGS.num_train_epochs * FLAGS.num_eval_per_epoch
train_data_size = ( train_data_size = (
input_meta_data['train_data_size'] // FLAGS.num_eval_per_epoch) input_meta_data['train_data_size'] // FLAGS.num_eval_per_epoch)
if FLAGS.train_data_size:
train_data_size = min(train_data_size, FLAGS.train_data_size)
logging.info('Updated train_data_size: %s', train_data_size)
steps_per_epoch = int(train_data_size / FLAGS.train_batch_size) steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size) warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size)
eval_steps = int( eval_steps = int(
...@@ -489,7 +497,8 @@ def custom_main(custom_callbacks=None, custom_metrics=None): ...@@ -489,7 +497,8 @@ def custom_main(custom_callbacks=None, custom_metrics=None):
FLAGS.train_batch_size, FLAGS.train_batch_size,
is_training=True, is_training=True,
label_type=label_type, label_type=label_type,
include_sample_weights=include_sample_weights) include_sample_weights=include_sample_weights,
num_samples=FLAGS.train_data_size)
run_bert( run_bert(
strategy, strategy,
input_meta_data, input_meta_data,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment