"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "d84fee6d2a64b1a253de89f40e0d3995e6b83578"
Commit 682d36b5 authored by Ran Chen's avatar Ran Chen Committed by A. Unique TensorFlower
Browse files

Save to tmp directory on non-chief workers in model_training_utils

In a multi worker set up saving is done on each worker. If they're saving to the same location, e.g. GCS, there will be conflicts. With this change we save to temporary directory on non-chief workers.

Note that, there may be synchronization in saving that needs all workers to participate, so we cannot only save on one worker.

PiperOrigin-RevId: 300141152
parent ebc28058
...@@ -20,6 +20,7 @@ from __future__ import print_function ...@@ -20,6 +20,7 @@ from __future__ import print_function
import json import json
import os import os
import tempfile
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -30,12 +31,29 @@ _SUMMARY_TXT = 'training_summary.txt' ...@@ -30,12 +31,29 @@ _SUMMARY_TXT = 'training_summary.txt'
_MIN_SUMMARY_STEPS = 10 _MIN_SUMMARY_STEPS = 10
def _save_checkpoint(checkpoint, model_dir, checkpoint_prefix): def _should_export_checkpoint(strategy):
return (not strategy) or strategy.extended.should_checkpoint
def _should_export_summary(strategy):
return (not strategy) or strategy.extended.should_save_summary
def _save_checkpoint(strategy, checkpoint, model_dir, checkpoint_prefix):
"""Saves model to with provided checkpoint prefix.""" """Saves model to with provided checkpoint prefix."""
checkpoint_path = os.path.join(model_dir, checkpoint_prefix) if _should_export_checkpoint(strategy):
saved_path = checkpoint.save(checkpoint_path) checkpoint_path = os.path.join(model_dir, checkpoint_prefix)
logging.info('Saving model as TF checkpoint: %s', saved_path) saved_path = checkpoint.save(checkpoint_path)
logging.info('Saving model as TF checkpoint: %s', saved_path)
else:
# In multi worker training we need every worker to save checkpoint, because
# variables can trigger synchronization on read and synchronization needs
# all workers to participate. To avoid workers overriding each other we save
# to a temporary directory on non-chief workers.
tmp_dir = tempfile.mkdtemp()
checkpoint.save(os.path.join(tmp_dir, 'ckpt'))
tf.io.gfile.rmtree(tmp_dir)
return return
...@@ -242,7 +260,13 @@ def run_customized_training_loop( ...@@ -242,7 +260,13 @@ def run_customized_training_loop(
] ]
# Create summary writers # Create summary writers
summary_dir = os.path.join(model_dir, 'summaries') if _should_export_summary(strategy):
summary_dir = os.path.join(model_dir, 'summaries')
else:
# In multi worker training we need every worker to write summary, because
# variables can trigger synchronization on read and synchronization needs
# all workers to participate.
summary_dir = tempfile.mkdtemp()
eval_summary_writer = tf.summary.create_file_writer( eval_summary_writer = tf.summary.create_file_writer(
os.path.join(summary_dir, 'eval')) os.path.join(summary_dir, 'eval'))
if steps_per_loop >= _MIN_SUMMARY_STEPS: if steps_per_loop >= _MIN_SUMMARY_STEPS:
...@@ -418,11 +442,11 @@ def run_customized_training_loop( ...@@ -418,11 +442,11 @@ def run_customized_training_loop(
# To avoid repeated model saving, we do not save after the last # To avoid repeated model saving, we do not save after the last
# step of training. # step of training.
if current_step < total_training_steps: if current_step < total_training_steps:
_save_checkpoint(checkpoint, model_dir, _save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step)) checkpoint_name.format(step=current_step))
if sub_model_export_name: if sub_model_export_name:
_save_checkpoint( _save_checkpoint(
sub_model_checkpoint, model_dir, strategy, sub_model_checkpoint, model_dir,
'%s_step_%d.ckpt' % (sub_model_export_name, current_step)) '%s_step_%d.ckpt' % (sub_model_export_name, current_step))
if eval_input_fn: if eval_input_fn:
logging.info('Running evaluation after step: %s.', current_step) logging.info('Running evaluation after step: %s.', current_step)
...@@ -432,10 +456,10 @@ def run_customized_training_loop( ...@@ -432,10 +456,10 @@ def run_customized_training_loop(
for metric in eval_metrics + model.metrics: for metric in eval_metrics + model.metrics:
metric.reset_states() metric.reset_states()
_save_checkpoint(checkpoint, model_dir, _save_checkpoint(strategy, checkpoint, model_dir,
checkpoint_name.format(step=current_step)) checkpoint_name.format(step=current_step))
if sub_model_export_name: if sub_model_export_name:
_save_checkpoint(sub_model_checkpoint, model_dir, _save_checkpoint(strategy, sub_model_checkpoint, model_dir,
'%s.ckpt' % sub_model_export_name) '%s.ckpt' % sub_model_export_name)
if eval_input_fn: if eval_input_fn:
...@@ -455,4 +479,7 @@ def run_customized_training_loop( ...@@ -455,4 +479,7 @@ def run_customized_training_loop(
write_txt_summary(training_summary, summary_dir) write_txt_summary(training_summary, summary_dir)
if not _should_export_summary(strategy):
tf.io.gfile.rmtree(summary_dir)
return model return model
...@@ -21,6 +21,7 @@ from __future__ import print_function ...@@ -21,6 +21,7 @@ from __future__ import print_function
import os import os
from absl.testing import parameterized from absl.testing import parameterized
from absl.testing.absltest import mock
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -208,6 +209,27 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase): ...@@ -208,6 +209,27 @@ class ModelTrainingUtilsTest(tf.test.TestCase, parameterized.TestCase):
check_eventfile_for_keyword('mean_input', check_eventfile_for_keyword('mean_input',
os.path.join(model_dir, 'summaries/eval'))) os.path.join(model_dir, 'summaries/eval')))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.one_device_strategy_gpu,
],
mode='eager',
))
def test_train_check_artifacts_non_chief(self, distribution):
# We shouldn't export artifacts on non-chief workers. Since there's no easy
# way to test with real MultiWorkerMirroredStrategy, we patch the strategy
# to make it as if it's MultiWorkerMirroredStrategy on non-chief workers.
extended = distribution.extended
with mock.patch.object(extended.__class__, 'should_checkpoint',
new_callable=mock.PropertyMock, return_value=False), \
mock.patch.object(extended.__class__, 'should_save_summary',
new_callable=mock.PropertyMock, return_value=False):
model_dir = self.get_temp_dir()
self.run_training(
distribution, model_dir, steps_per_loop=10, run_eagerly=False)
self.assertEmpty(tf.io.gfile.listdir(model_dir))
if __name__ == '__main__': if __name__ == '__main__':
assert tf.version.VERSION.startswith('2.') assert tf.version.VERSION.startswith('2.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment