Commit 9b219f04 authored by Tianqi Liu's avatar Tianqi Liu Committed by A. Unique TensorFlower
Browse files

Internal cleanup.

PiperOrigin-RevId: 314822016
parent 22f76623
......@@ -54,6 +54,12 @@ flags.DEFINE_string(
'to be used for training and evaluation.')
flags.DEFINE_string('predict_checkpoint_path', None,
'Path to the checkpoint for predictions.')
flags.DEFINE_integer(
'num_eval_per_epoch', 1,
'Number of evaluations per epoch. The purpose of this flag is to provide '
'more granular evaluation scores and checkpoints. For example, if original '
'data has N samples and num_eval_per_epoch is n, then each epoch will be '
'evaluated every N/n samples.')
flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.')
......@@ -92,8 +98,11 @@ def get_regression_loss_fn():
return regression_loss_fn
def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,
is_training, label_type=tf.int64):
def get_dataset_fn(input_file_pattern,
max_seq_length,
global_batch_size,
is_training,
label_type=tf.int64):
"""Gets a closure to create a dataset."""
def _dataset_fn(ctx=None):
......@@ -151,17 +160,21 @@ def run_bert_classifier(strategy,
use_graph_rewrite=common_flags.use_graph_rewrite())
return classifier_model, core_model
loss_fn = (get_regression_loss_fn() if is_regression
else get_loss_fn(num_classes))
loss_fn = (
get_regression_loss_fn() if is_regression else get_loss_fn(num_classes))
# Defines evaluation metrics function, which will create metrics in the
# correct device and strategy scope.
if is_regression:
metric_fn = functools.partial(tf.keras.metrics.MeanSquaredError,
'mean_squared_error', dtype=tf.float32)
metric_fn = functools.partial(
tf.keras.metrics.MeanSquaredError,
'mean_squared_error',
dtype=tf.float32)
else:
metric_fn = functools.partial(tf.keras.metrics.SparseCategoricalAccuracy,
'accuracy', dtype=tf.float32)
metric_fn = functools.partial(
tf.keras.metrics.SparseCategoricalAccuracy,
'accuracy',
dtype=tf.float32)
# Start training using Keras compile/fit API.
logging.info('Training using TF 2.x Keras compile/fit API with '
......@@ -349,8 +362,9 @@ def run_bert(strategy,
keras_utils.set_session_config(FLAGS.enable_xla)
performance.set_mixed_precision_policy(common_flags.dtype())
epochs = FLAGS.num_train_epochs
train_data_size = input_meta_data['train_data_size']
epochs = FLAGS.num_train_epochs * FLAGS.num_eval_per_epoch
train_data_size = (
input_meta_data['train_data_size'] // FLAGS.num_eval_per_epoch)
steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size)
eval_steps = int(
......
......@@ -127,15 +127,14 @@ class XnliProcessor(DataProcessor):
"""See base class."""
lines = []
for language in self.languages:
# Skips the header.
lines.extend(
self._read_tsv(
os.path.join(data_dir, "multinli",
"multinli.train.%s.tsv" % language)))
"multinli.train.%s.tsv" % language))[1:])
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "train-%d" % i
text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1])
......@@ -825,7 +824,8 @@ def generate_tf_record_from_data_file(processor,
eval_data_output_path: Output to which processed tf record for evaluation
will be saved.
test_data_output_path: Output to which processed tf record for testing
will be saved. Must be a pattern template with {} if processor is XNLI.
will be saved. Must be a pattern template with {} if processor has
language specific test data.
max_seq_length: Maximum sequence length of the to be generated
training/eval data.
......
......@@ -99,7 +99,8 @@ flags.DEFINE_string(
flags.DEFINE_string(
"test_data_output_path", None,
"The path in which generated test input data will be written as tf"
" records. If None, do not generate test data.")
" records. If None, do not generate test data. Must be a pattern template"
" as test_{}.tfrecords if processor has language specific test data.")
flags.DEFINE_string("meta_data_file_path", None,
"The path in which input meta data will be written.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment