"vscode:/vscode.git/clone" did not exist on "8a18e73e18cb6bb846b5d5a11e9a7ff91caedda8"
Commit 9b219f04 authored by Tianqi Liu's avatar Tianqi Liu Committed by A. Unique TensorFlower
Browse files

Internal cleanup.

PiperOrigin-RevId: 314822016
parent 22f76623
...@@ -54,6 +54,12 @@ flags.DEFINE_string( ...@@ -54,6 +54,12 @@ flags.DEFINE_string(
'to be used for training and evaluation.') 'to be used for training and evaluation.')
flags.DEFINE_string('predict_checkpoint_path', None, flags.DEFINE_string('predict_checkpoint_path', None,
'Path to the checkpoint for predictions.') 'Path to the checkpoint for predictions.')
flags.DEFINE_integer(
'num_eval_per_epoch', 1,
'Number of evaluations per epoch. The purpose of this flag is to provide '
'more granular evaluation scores and checkpoints. For example, if original '
'data has N samples and num_eval_per_epoch is n, then each epoch will be '
'evaluated every N/n samples.')
flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.') flags.DEFINE_integer('train_batch_size', 32, 'Batch size for training.')
flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.') flags.DEFINE_integer('eval_batch_size', 32, 'Batch size for evaluation.')
...@@ -92,8 +98,11 @@ def get_regression_loss_fn(): ...@@ -92,8 +98,11 @@ def get_regression_loss_fn():
return regression_loss_fn return regression_loss_fn
def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size, def get_dataset_fn(input_file_pattern,
is_training, label_type=tf.int64): max_seq_length,
global_batch_size,
is_training,
label_type=tf.int64):
"""Gets a closure to create a dataset.""" """Gets a closure to create a dataset."""
def _dataset_fn(ctx=None): def _dataset_fn(ctx=None):
...@@ -151,17 +160,21 @@ def run_bert_classifier(strategy, ...@@ -151,17 +160,21 @@ def run_bert_classifier(strategy,
use_graph_rewrite=common_flags.use_graph_rewrite()) use_graph_rewrite=common_flags.use_graph_rewrite())
return classifier_model, core_model return classifier_model, core_model
loss_fn = (get_regression_loss_fn() if is_regression loss_fn = (
else get_loss_fn(num_classes)) get_regression_loss_fn() if is_regression else get_loss_fn(num_classes))
# Defines evaluation metrics function, which will create metrics in the # Defines evaluation metrics function, which will create metrics in the
# correct device and strategy scope. # correct device and strategy scope.
if is_regression: if is_regression:
metric_fn = functools.partial(tf.keras.metrics.MeanSquaredError, metric_fn = functools.partial(
'mean_squared_error', dtype=tf.float32) tf.keras.metrics.MeanSquaredError,
'mean_squared_error',
dtype=tf.float32)
else: else:
metric_fn = functools.partial(tf.keras.metrics.SparseCategoricalAccuracy, metric_fn = functools.partial(
'accuracy', dtype=tf.float32) tf.keras.metrics.SparseCategoricalAccuracy,
'accuracy',
dtype=tf.float32)
# Start training using Keras compile/fit API. # Start training using Keras compile/fit API.
logging.info('Training using TF 2.x Keras compile/fit API with ' logging.info('Training using TF 2.x Keras compile/fit API with '
...@@ -349,8 +362,9 @@ def run_bert(strategy, ...@@ -349,8 +362,9 @@ def run_bert(strategy,
keras_utils.set_session_config(FLAGS.enable_xla) keras_utils.set_session_config(FLAGS.enable_xla)
performance.set_mixed_precision_policy(common_flags.dtype()) performance.set_mixed_precision_policy(common_flags.dtype())
epochs = FLAGS.num_train_epochs epochs = FLAGS.num_train_epochs * FLAGS.num_eval_per_epoch
train_data_size = input_meta_data['train_data_size'] train_data_size = (
input_meta_data['train_data_size'] // FLAGS.num_eval_per_epoch)
steps_per_epoch = int(train_data_size / FLAGS.train_batch_size) steps_per_epoch = int(train_data_size / FLAGS.train_batch_size)
warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size) warmup_steps = int(epochs * train_data_size * 0.1 / FLAGS.train_batch_size)
eval_steps = int( eval_steps = int(
......
...@@ -127,15 +127,14 @@ class XnliProcessor(DataProcessor): ...@@ -127,15 +127,14 @@ class XnliProcessor(DataProcessor):
"""See base class.""" """See base class."""
lines = [] lines = []
for language in self.languages: for language in self.languages:
# Skips the header.
lines.extend( lines.extend(
self._read_tsv( self._read_tsv(
os.path.join(data_dir, "multinli", os.path.join(data_dir, "multinli",
"multinli.train.%s.tsv" % language))) "multinli.train.%s.tsv" % language))[1:])
examples = [] examples = []
for (i, line) in enumerate(lines): for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "train-%d" % i guid = "train-%d" % i
text_a = self.process_text_fn(line[0]) text_a = self.process_text_fn(line[0])
text_b = self.process_text_fn(line[1]) text_b = self.process_text_fn(line[1])
...@@ -825,7 +824,8 @@ def generate_tf_record_from_data_file(processor, ...@@ -825,7 +824,8 @@ def generate_tf_record_from_data_file(processor,
eval_data_output_path: Output to which processed tf record for evaluation eval_data_output_path: Output to which processed tf record for evaluation
will be saved. will be saved.
test_data_output_path: Output to which processed tf record for testing test_data_output_path: Output to which processed tf record for testing
will be saved. Must be a pattern template with {} if processor is XNLI. will be saved. Must be a pattern template with {} if processor has
language specific test data.
max_seq_length: Maximum sequence length of the to be generated max_seq_length: Maximum sequence length of the to be generated
training/eval data. training/eval data.
......
...@@ -99,7 +99,8 @@ flags.DEFINE_string( ...@@ -99,7 +99,8 @@ flags.DEFINE_string(
flags.DEFINE_string( flags.DEFINE_string(
"test_data_output_path", None, "test_data_output_path", None,
"The path in which generated test input data will be written as tf" "The path in which generated test input data will be written as tf"
" records. If None, do not generate test data.") " records. If None, do not generate test data. Must be a pattern template"
" as test_{}.tfrecords if processor has language specific test data.")
flags.DEFINE_string("meta_data_file_path", None, flags.DEFINE_string("meta_data_file_path", None,
"The path in which input meta data will be written.") "The path in which input meta data will be written.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment