Commit 642e1005 authored by Zhichao Lu's avatar Zhichao Lu Committed by pkulzc
Browse files

Utility to write config files to disk.

PiperOrigin-RevId: 190972737
parent e9f85e83
......@@ -14,10 +14,13 @@
# ==============================================================================
"""Functions for reading and updating configuration files."""
import os
import tensorflow as tf
from google.protobuf import text_format
from tensorflow.python.lib.io import file_io
from object_detection.protos import eval_pb2
from object_detection.protos import input_reader_pb2
from object_detection.protos import model_pb2
......@@ -119,6 +122,24 @@ def create_pipeline_proto_from_configs(configs):
return pipeline_config
def save_pipeline_config(pipeline_config, directory):
"""Saves a pipeline config text file to disk.
Args:
pipeline_config: A pipeline_pb2.TrainEvalPipelineConfig.
directory: The model directory into which the pipeline config file will be
saved.
"""
if not file_io.file_exists(directory):
file_io.recursive_create_dir(directory)
pipeline_config_path = os.path.join(directory, "pipeline.config")
config_text = text_format.MessageToString(pipeline_config)
with tf.gfile.Open(pipeline_config_path, "wb") as f:
tf.logging.info("Writing pipeline config file to %s",
pipeline_config_path)
f.write(config_text)
def get_configs_from_multiple_files(model_config_path="",
train_config_path="",
train_input_config_path="",
......
......@@ -110,6 +110,23 @@ class ConfigUtilTest(tf.test.TestCase):
config_util.create_pipeline_proto_from_configs(configs))
self.assertEqual(pipeline_config, pipeline_config_reconstructed)
def test_save_pipeline_config(self):
"""Tests that the pipeline config is properly saved to disk."""
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
pipeline_config.model.faster_rcnn.num_classes = 10
pipeline_config.train_config.batch_size = 32
pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
pipeline_config.eval_config.num_examples = 20
pipeline_config.eval_input_reader.queue_capacity = 100
config_util.save_pipeline_config(pipeline_config, self.get_temp_dir())
configs = config_util.get_configs_from_pipeline_file(
os.path.join(self.get_temp_dir(), "pipeline.config"))
pipeline_config_reconstructed = (
config_util.create_pipeline_proto_from_configs(configs))
self.assertEqual(pipeline_config, pipeline_config_reconstructed)
def test_get_configs_from_multiple_files(self):
"""Tests that proto configs can be read from multiple files."""
temp_dir = self.get_temp_dir()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment