"test/vscode:/vscode.git/clone" did not exist on "33f762f64441c1d63e3045fa89716786c8dad432"
Commit 39abb95d authored by Le Hou's avatar Le Hou Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 353327814
parent cc9357f9
...@@ -217,6 +217,16 @@ def serialize_config(params: config_definitions.ExperimentConfig, ...@@ -217,6 +217,16 @@ def serialize_config(params: config_definitions.ExperimentConfig,
hyperparams.save_params_dict_to_yaml(params, params_save_path) hyperparams.save_params_dict_to_yaml(params, params_save_path)
def save_gin_config(filename_surfix: str, model_dir: str):
"""Serializes and saves the experiment config."""
gin_save_path = os.path.join(
model_dir, 'operative_config.{}.gin'.format(filename_surfix))
logging.info('Saving gin configurations to %s', gin_save_path)
tf.io.gfile.makedirs(model_dir)
with tf.io.gfile.GFile(gin_save_path, 'w') as f:
f.write(gin.operative_config_str())
def read_global_step_from_checkpoint(ckpt_file_path): def read_global_step_from_checkpoint(ckpt_file_path):
"""Read global step from checkpoint, or get global step from its filename.""" """Read global step from checkpoint, or get global step from its filename."""
global_step = tf.Variable(-1, dtype=tf.int64) global_step = tf.Variable(-1, dtype=tf.int64)
......
...@@ -63,6 +63,8 @@ def main(_): ...@@ -63,6 +63,8 @@ def main(_):
params=params, params=params,
model_dir=model_dir) model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__': if __name__ == '__main__':
tfm_flags.define_flags() tfm_flags.define_flags()
app.run(main) app.run(main)
...@@ -64,6 +64,8 @@ def main(_): ...@@ -64,6 +64,8 @@ def main(_):
params=params, params=params,
model_dir=model_dir) model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__': if __name__ == '__main__':
tfm_flags.define_flags() tfm_flags.define_flags()
app.run(main) app.run(main)
...@@ -41,6 +41,7 @@ def main(_): ...@@ -41,6 +41,7 @@ def main(_):
train_utils.serialize_config(params, model_dir) train_utils.serialize_config(params, model_dir)
continuous_finetune_lib.run_continuous_finetune(FLAGS.mode, params, model_dir, continuous_finetune_lib.run_continuous_finetune(FLAGS.mode, params, model_dir,
FLAGS.pretrain_steps) FLAGS.pretrain_steps)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -67,6 +67,8 @@ def main(_): ...@@ -67,6 +67,8 @@ def main(_):
params=params, params=params,
model_dir=model_dir) model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__': if __name__ == '__main__':
tfm_flags.define_flags() tfm_flags.define_flags()
app.run(main) app.run(main)
...@@ -63,6 +63,8 @@ def main(_): ...@@ -63,6 +63,8 @@ def main(_):
params=params, params=params,
model_dir=model_dir) model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__': if __name__ == '__main__':
tfm_flags.define_flags() tfm_flags.define_flags()
app.run(main) app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment