Commit 5ceb3acf authored by David Chen's avatar David Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 268057019
parent e91c41c2
......@@ -180,13 +180,13 @@ class TransformerTask(object):
if not params["static_batch"]:
raise ValueError("TPU requires static batch for input data.")
else:
print("Running transformer with num_gpus =", num_gpus)
logging.info("Running transformer with num_gpus =", num_gpus)
if self.distribution_strategy:
print("For training, using distribution strategy: ",
logging.info("For training, using distribution strategy: ",
self.distribution_strategy)
else:
print("Not using any distribution strategy.")
logging.info("Not using any distribution strategy.")
@property
def use_tpu(self):
......@@ -289,7 +289,8 @@ class TransformerTask(object):
else flags_obj.steps_between_evals)
current_iteration = current_step // flags_obj.steps_between_evals
print("Start train iteration at global step:{}".format(current_step))
logging.info(
"Start train iteration at global step:{}".format(current_step))
history = None
if params["use_ctl"]:
if not self.use_tpu:
......@@ -324,7 +325,7 @@ class TransformerTask(object):
current_step += train_steps_per_eval
logging.info("Train history: {}".format(history.history))
print("End train iteration at global step:{}".format(current_step))
logging.info("End train iteration at global step:{}".format(current_step))
if (flags_obj.bleu_source and flags_obj.bleu_ref):
uncased_score, cased_score = self.eval()
......@@ -401,7 +402,7 @@ class TransformerTask(object):
else:
model.load_weights(init_weight_path)
else:
print("Weights not loaded from path:{}".format(init_weight_path))
logging.info("Weights not loaded from path:{}".format(init_weight_path))
def _create_optimizer(self):
"""Creates optimizer."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment