Commit 393b32d2 authored by zhanggezhong's avatar zhanggezhong
Browse files

Update resnet_ctl_imagenet_main.py

parent 66196533
......@@ -151,14 +151,18 @@ def run(flags_obj):
checkpoint_interval = (
steps_per_loop * 5 if flags_obj.enable_checkpoint_and_export else None)
summary_interval = steps_per_loop if flags_obj.enable_tensorboard else None
#多进程断点续训时使用该句代码创建多进程模型权重的保存路径
#unique_checkpoint_dir = os.path.join(flags_obj.model_dir, f'woeker_{num_index}')
checkpoint_manager = tf.train.CheckpointManager(
runnable.checkpoint,
directory=flags_obj.model_dir,
#directory=unique_checkpoint_dir
max_to_keep=10,
step_counter=runnable.global_step,
checkpoint_interval=checkpoint_interval)
#多进程断点续训时加载模型权重的方式
# runnable.checkpoit.restore('/path/checkpoint*')
resnet_controller = orbit.Controller(
strategy=strategy,
trainer=runnable,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment