Commit 88c864f7 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Remove unnecessary test_input_fn.

PiperOrigin-RevId: 276394582
parent a4789d12
......@@ -184,7 +184,6 @@ def main(unused_argv):
eval_fn=eval_fn,
metric_fn=get_metric_fn,
train_input_fn=train_input_fn,
test_input_fn=test_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
init_from_transformerxl=FLAGS.init_from_transformerxl,
total_training_steps=total_training_steps,
......
......@@ -135,7 +135,6 @@ def main(unused_argv):
eval_fn=None,
metric_fn=None,
train_input_fn=train_input_fn,
test_input_fn=None,
init_checkpoint=FLAGS.init_checkpoint,
init_from_transformerxl=FLAGS.init_from_transformerxl,
total_training_steps=total_training_steps,
......
......@@ -281,7 +281,6 @@ def main(unused_argv):
eval_fn=eval_fn,
metric_fn=None,
train_input_fn=train_input_fn,
test_input_fn=test_input_fn,
init_checkpoint=FLAGS.init_checkpoint,
init_from_transformerxl=FLAGS.init_from_transformerxl,
total_training_steps=total_training_steps,
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""XLNet classification finetuning runner in tf2.0."""
"""XLNet training utils."""
from __future__ import absolute_import
from __future__ import division
......@@ -61,7 +61,6 @@ def train(
eval_fn: Optional[Callable[[tf.keras.Model, int, tf.summary.SummaryWriter],
Any]] = None,
metric_fn: Optional[Callable[[], tf.keras.metrics.Metric]] = None,
test_input_fn: Optional[Callable] = None,
init_checkpoint: Optional[Text] = None,
init_from_transformerxl: Optional[bool] = False,
model_dir: Optional[Text] = None,
......@@ -86,8 +85,6 @@ def train(
metric_fn: A metrics function returns a Keras Metric object to record
evaluation result using evaluation dataset or with training dataset
after every epoch.
test_input_fn: Function returns a evaluation dataset. If none, evaluation
is skipped.
init_checkpoint: Optional checkpoint to load to `sub_model` returned by
`model_fn`.
init_from_transformerxl: Whether to load to `transformerxl_model` of
......@@ -124,7 +121,7 @@ def train(
tf.io.gfile.mkdir(summary_dir)
train_summary_writer = None
eval_summary_writer = None
if test_input_fn:
if eval_fn:
eval_summary_writer = tf.summary.create_file_writer(
os.path.join(summary_dir, "eval"))
if steps_per_loop >= _MIN_SUMMARY_STEPS:
......@@ -288,7 +285,7 @@ def train(
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if test_input_fn and current_step % save_steps == 0:
if eval_fn and current_step % save_steps == 0:
logging.info("Running evaluation after step: %s.", current_step)
......@@ -296,7 +293,7 @@ def train(
if model_dir:
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if test_input_fn:
if eval_fn:
logging.info("Running final evaluation after training is complete.")
eval_metric = eval_fn(model, current_step, eval_summary_writer)
......@@ -306,7 +303,7 @@ def train(
}
if train_metric:
training_summary["last_train_metrics"] = _float_metric_value(train_metric)
if test_input_fn:
if eval_fn:
# eval_metric is supposed to be a float.
training_summary["eval_metrics"] = eval_metric
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment