Commit 682d36b5 authored by Ran Chen's avatar Ran Chen Committed by A. Unique TensorFlower
Browse files

Save to tmp directory on non-chief workers in model_training_utils

In a multi worker set up saving is done on each worker. If they're saving to the same location, e.g. GCS, there will be conflicts. With this change we save to temporary directory on non-chief workers.

Note that, there may be synchronization in saving that needs all workers to participate, so we cannot only save on one worker.

PiperOrigin-RevId: 300141152
parent ebc28058
......@@ -20,6 +20,7 @@ from __future__ import print_function
import json
import os
import tempfile
from absl import logging
import tensorflow as tf
......@@ -30,12 +31,29 @@ _SUMMARY_TXT = 'training_summary.txt'
_MIN_SUMMARY_STEPS = 10
def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix):
def _should_export_checkpoint(strategy):
return (not strategy) or strategy.extended.should_checkpoint
def _should_export_summary(strategy):
return (not strategy) or strategy.extended.should_save_summary
def _save_checkpoint(strategy, checkpoint, model_dir, checkpoint_prefix):
"""Saves model to with provided checkpoint prefix."""
checkpoint_path = os.path.join(model_dir, checkpoint_prefix)
saved_path = checkpoint.save(checkpoint_path)
logging.info('Saving model as TF checkpoint: %s', saved_path)
if _should_export_checkpoint(strategy):
checkpoint_path = os.path.join(model_dir, checkpoint_prefix)
saved_path = checkpoint.save(checkpoint_path)
logging.info('Saving model as TF checkpoint: %s', saved_path)
else:
# In multi worker training we need every worker to save checkpoint, because
# variables can trigger synchronization on read and synchronization needs
# all workers to participate. To avoid workers overriding each other we save
# to a temporary directory on non-chief workers.
tmp_dir = tempfile.mkdtemp()
checkpoint.save(os.path.join(tmp_dir, 'ckpt'))
tf.io.gfile.rmtree(tmp_dir)
return
......@@ -242,7 +260,13 @@ def run_customized_training_loop(
]
# Create summary writers
summary_dir = os.path.join(model_dir, 'summaries')
if _should_export_summary(strategy):
summary_dir = os.path.join(model_dir, 'summaries')
else:
# In multi worker training we need every worker to write summary, because
# variables can trigger synchronization on read and synchronization needs
# all workers to participate.
summary_dir = tempfile.mkdtemp()
eval_summary_writer = tf.summary.create_file_writer(
os.path.join(summary_dir, 'eval'))
if steps_per_loop >= _MIN_SUMMARY_STEPS:
......@@ -418,11 +442,11 @@ def run_customized_training_loop(
# To avoid repeated model saving, we do not save after the last
# step of training.
if current_step < total_training_steps:
_save_checkpoint(checkpoint, model_dir,
_save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if sub_model_export_name:
_save_checkpoint(
sub_model_checkpoint, model_dir,
strategy, sub_model_checkpoint, model_dir,
'%s_step_%d.ckpt' % (sub_model_export_name, current_step))
if eval_input_fn:
logging.info('Running evaluation after step: %s.', current_step)
......@@ -432,10 +456,10 @@ def run_customized_training_loop(
for metric in eval_metrics + model.metrics:
metric.reset_states()
_save_checkpoint(checkpoint, model_dir,
_save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step))
if sub_model_export_name:
_save_checkpoint(sub_model_checkpoint, model_dir,
_save_checkpoint(strategy, sub_model_checkpoint, model_dir,
'%s.ckpt' % sub_model_export_name)
if eval_input_fn:
......@@ -455,4 +479,7 @@ def run_customized_training_loop(
write_txt_summary(training_summary, summary_dir)
if not _should_export_summary(strategy):
tf.io.gfile.rmtree(summary_dir)
return model
......@@ -21,6 +21,7 @@ from __future__ import print_function
import os
from absl.testing import parameterized
from absl.testing.absltest import mock
import numpy as np
import tensorflow as tf
......@@ -208,6 +209,27 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
check_eventfile_for_keyword('mean_input',
os.path.join(model_dir, 'summaries/eval')))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy_gpu,
],
mode='eager',
))
def test_train_check_artifacts_non_chief(self, distribution):
# We shouldn't export artifacts on non-chief workers. Since there's no easy
# way to test with real MultiWorkerMirroredStrategy, we patch the strategy
# to make it as if it's MultiWorkerMirroredStrategy on non-chief workers.
extended = distribution.extended
with mock.patch.object(extended.__class__, 'should_checkpoint',
new_callable=mock.PropertyMock, return_value=False), \
mock.patch.object(extended.__class__, 'should_save_summary',
new_callable=mock.PropertyMock, return_value=False):
model_dir = self.get_temp_dir()
self.run_training(
distribution, model_dir, steps_per_loop=10, run_eagerly=False)
self.assertEmpty(tf.io.gfile.listdir(model_dir))
if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment