Commit c1cce439 authored by Mik Vyatskov's avatar Mik Vyatskov Committed by Facebook GitHub Bot
Browse files

Implement helper to read model configs

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/289

Right now configs are written to a dedicated folder after training, one file per model config. This PR introduces a new function that allows to read model configs in the same format, for the situations where the configs cannot be passed back directly, e.g. when running through torchx.

Reviewed By: wat3rBro

Differential Revision: D37086940

fbshipit-source-id: 3938381bcf48a8069fb4b840fd2c2d052e983c6c
parent 7a47bb3d
......@@ -19,6 +19,9 @@ from .tensorboard_log_util import get_tensorboard_log_dir # noqa: forwarding
logger = logging.getLogger(__name__)
# Subdirectory with model configurations dumped by the training binary.
TRAINED_MODEL_CONFIGS_DIR: str = "trained_model_configs"
def check_version(library, min_version, warning_only=False):
"""Check the version of the library satisfies the provided minimum version.
......@@ -83,7 +86,7 @@ def dump_trained_model_configs(
A map of model name to model config path.
"""
trained_model_configs = {}
trained_model_config_dir = os.path.join(output_dir, "trained_model_configs")
trained_model_config_dir = os.path.join(output_dir, TRAINED_MODEL_CONFIGS_DIR)
PathManager.mkdirs(trained_model_config_dir)
for name, trained_cfg in trained_cfgs.items():
config_file = os.path.join(trained_model_config_dir, "{}.yaml".format(name))
......@@ -95,6 +98,26 @@ def dump_trained_model_configs(
return trained_model_configs
# TODO: Remove once the interface for passing the result of training is figured out.
def read_trained_model_configs(output_dir: str) -> Dict[str, str]:
"""Reads trained model config files from output_dir.
Args:
output_dir: output directory.
Returns:
A map of model name to model config path.
"""
trained_model_config_dir = os.path.join(output_dir, TRAINED_MODEL_CONFIGS_DIR)
if not PathManager.exists(trained_model_config_dir):
return {}
return {
# model_name.yaml -> model_name
os.path.splitext(filename)[0]: os.path.join(trained_model_config_dir, filename)
for filename in PathManager.ls(trained_model_config_dir)
}
@contextmanager
def mode(net: torch.nn.Module, training: bool) -> Iterator[torch.nn.Module]:
"""Temporarily switch to training/evaluation mode."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment