"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "d1b9ae40f336caa7c1fc4fbd866d734765a6b674"
Commit 2a3971dd authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 312209760
parent cc1e2718
......@@ -238,6 +238,8 @@ def run_customized_training_loop(
total_training_steps = steps_per_epoch * epochs
train_iterator = _get_input_iterator(train_input_fn, strategy)
eval_loss_metric = tf.keras.metrics.Mean(
'training_loss', dtype=tf.float32)
with distribution_utils.get_strategy_scope(strategy):
# To correctly place the model weights on accelerators,
......@@ -365,8 +367,14 @@ def run_customized_training_loop(
model_outputs = model(inputs, training=False)
for metric in eval_metrics:
metric.update_state(labels, model_outputs)
return model_outputs, labels
strategy.run(_test_step_fn, args=(next(iterator),))
outputs, labels = strategy.run(_test_step_fn, args=(next(iterator),))
outputs = tf.nest.map_structure(strategy.experimental_local_results,
outputs)
labels = tf.nest.map_structure(strategy.experimental_local_results,
labels)
return outputs, labels
if not run_eagerly:
train_single_step = tf.function(train_single_step)
......@@ -382,12 +390,29 @@ def run_customized_training_loop(
Returns:
A dict of metic names and values.
"""
# The last batch of the evaluation is often smaller than previous ones.
# Moreover, in some distributed pieces it might even be empty. Therefore,
# different from the way training_loss is calculated, it is needed to
# gather all the logits and labels here to calculate the evaluation loss
# outside.
loss_list, loss_weights = list(), list()
for _ in range(eval_steps):
test_step(test_iterator)
outputs, labels = test_step(test_iterator)
for cur_logits, cur_labels in zip(outputs, labels):
# This is to handle cases when cur_labels is not a single tensor,
# but a dict of tensors.
cur_weight = tf.shape(tf.nest.flatten(cur_labels)[0])[0]
if cur_weight != 0:
loss_list.append(loss_fn(cur_labels, cur_logits).numpy())
loss_weights.append(cur_weight)
# The sample_weights are the actual number of examples in each batch,
# a summation of numbers of examples in each replica if using
# distributed training.
eval_loss_metric.update_state(loss_list, sample_weight=loss_weights)
logs = {}
with eval_summary_writer.as_default():
for metric in eval_metrics + model.metrics:
for metric in [eval_loss_metric] + eval_metrics + model.metrics:
metric_value = _float_metric_value(metric)
logs[metric.name] = metric_value
logging.info('Step: [%d] Validation %s = %f', current_training_step,
......@@ -482,6 +507,7 @@ def run_customized_training_loop(
logs = _run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy))
# Re-initialize evaluation metric.
eval_loss_metric.reset_states()
for metric in eval_metrics + model.metrics:
metric.reset_states()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment