Unverified Commit c5a4978d authored by Jing Li's avatar Jing Li Committed by GitHub
Browse files

Merged commit includes the following changes: (#7263)

* Merged commit includes the following changes:
258867180  by jingli<jingli@google.com>:

    Add new folders for upcoming reorg in model garden.

--
258893811  by hongkuny<hongkuny@google.com>:

    Adds summaries for metrics, allowing metrics inside keras.model.

--
258893048  by isaprykin<isaprykin@google.com>:

    Remove the `cloning` argument to `compile()`.

    Keras models are distributed by cloning in graph mode and without cloning in eager mode as of the change # 258652546.

--
258881002  by hongkuny<hongkuny@google.com>:

    Fix lint.

--
258874998  by hongkuny<hongkuny@google.com>:

    Internal

--
258872662  by hongkuny<hongkuny@google.com>:

    Fix doc

--

PiperOrigin-RevId: 258867180

* Create __init__.py

* Update __init__.py

* Update __init__.py

* Update __init__.py
parent 2569fa9a
......@@ -215,12 +215,13 @@ def run_customized_training_loop(
train_loss_metric = tf.keras.metrics.Mean(
'training_loss', dtype=tf.float32)
eval_metric = metric_fn() if metric_fn else None
eval_metrics = [metric_fn()] if metric_fn else []
# If evaluation is required, make a copy of metric as it will be used by
# both train and evaluation.
train_metric = (
eval_metric.__class__.from_config(eval_metric.get_config())
if eval_metric else None)
train_metrics = [
metric.__class__.from_config(metric.get_config())
for metric in eval_metrics
]
# Create summary writers
eval_summary_writer = tf.summary.create_file_writer(
......@@ -246,8 +247,8 @@ def run_customized_training_loop(
optimizer.apply_gradients(zip(grads, tvars))
# For reporting, the metric takes the mean of losses.
train_loss_metric.update_state(loss)
if train_metric:
train_metric.update_state(labels, model_outputs)
for metric in train_metrics:
metric.update_state(labels, model_outputs)
@tf.function
def train_steps(iterator, steps):
......@@ -257,6 +258,7 @@ def run_customized_training_loop(
iterator: the distributed iterator of training datasets.
steps: an tf.int32 integer tensor to specify number of steps to run
inside host training loop.
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
"""
......@@ -272,6 +274,7 @@ def run_customized_training_loop(
Args:
iterator: the distributed iterator of training datasets.
Raises:
ValueError: Any of the arguments or tensor shapes are invalid.
"""
......@@ -285,7 +288,8 @@ def run_customized_training_loop(
inputs, labels = inputs
model_outputs = model(inputs, training=False)
eval_metric.update_state(labels, model_outputs)
for metric in eval_metrics:
metric.update_state(labels, model_outputs)
strategy.experimental_run_v2(_test_step_fn, args=(next(iterator),))
......@@ -297,12 +301,14 @@ def run_customized_training_loop(
"""Runs validation steps and aggregate metrics."""
for _ in range(eval_steps):
test_step(test_iterator)
eval_metric_value = _float_metric_value(eval_metric)
logging.info('Step: [%d] Validation metric = %f', current_training_step,
eval_metric_value)
with eval_summary_writer.as_default():
for metric in eval_metrics + model.metrics:
metric_value = _float_metric_value(metric)
logging.info('Step: [%d] Validation %s = %f', current_training_step,
metric.name, metric_value)
tf.summary.scalar(
eval_metric.name, eval_metric_value, step=current_training_step)
metric.name, metric_value, step=current_training_step)
eval_summary_writer.flush()
def _run_callbacks_on_batch_begin(batch):
......@@ -336,8 +342,8 @@ def run_customized_training_loop(
# Training loss/metric are taking average over steps inside micro
# training loop. We reset the their values before each round.
train_loss_metric.reset_states()
if train_metric:
train_metric.reset_states()
for metric in train_metrics + model.metrics:
metric.reset_states()
_run_callbacks_on_batch_begin(current_step)
# Runs several steps in the host while loop.
......@@ -358,21 +364,17 @@ def run_customized_training_loop(
# Updates training logging.
training_status = 'Train Step: %d/%d / loss = %s' % (
current_step, total_training_steps, train_loss)
if train_metric:
train_metric_value = _float_metric_value(train_metric)
training_status += ' training metric = %f' % train_metric_value
else:
train_metric_value = None
logging.info(training_status)
if train_summary_writer:
with train_summary_writer.as_default():
tf.summary.scalar(
train_loss_metric.name, train_loss, step=current_step)
if train_metric_value:
tf.summary.scalar(
train_metric.name, train_metric_value, step=current_step)
for metric in train_metrics + model.metrics:
metric_value = _float_metric_value(metric)
training_status += ' %s = %f' % (metric.name, metric_value)
tf.summary.scalar(metric.name, metric_value, step=current_step)
train_summary_writer.flush()
logging.info(training_status)
# Saves model checkpoints and run validation steps at every epoch end.
if current_step % steps_per_epoch == 0:
......@@ -387,7 +389,8 @@ def run_customized_training_loop(
_run_evaluation(current_step,
_get_input_iterator(eval_input_fn, strategy))
# Re-initialize evaluation metric.
eval_metric.reset_states()
for metric in eval_metrics + model.metrics:
metric.reset_states()
_save_checkpoint(checkpoint, model_dir,
checkpoint_name.format(step=current_step))
......@@ -401,10 +404,11 @@ def run_customized_training_loop(
'total_training_steps': total_training_steps,
'train_loss': _float_metric_value(train_loss_metric),
}
if eval_metric:
if eval_metrics:
# TODO(hongkuny): Cleans up summary reporting in text.
training_summary['last_train_metrics'] = _float_metric_value(
train_metric)
training_summary['eval_metrics'] = _float_metric_value(eval_metric)
train_metrics[0])
training_summary['eval_metricss'] = _float_metric_value(eval_metrics[0])
_write_txt_summary(training_summary, model_dir)
......
......@@ -205,8 +205,7 @@ def run(flags_obj):
optimizer=optimizer,
metrics=(['sparse_categorical_accuracy']
if flags_obj.report_accuracy_metrics else None),
run_eagerly=flags_obj.run_eagerly,
cloning=flags_obj.clone_model_in_keras_dist_strat)
run_eagerly=flags_obj.run_eagerly)
callbacks = keras_common.get_callbacks(
learning_rate_schedule, imagenet_main.NUM_IMAGES['train'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment