"vscode:/vscode.git/clone" did not exist on "9e18e262057cda27f2339403475190a12a5d91cb"
Commit d5cd9b0a authored by Brandon Jiang's avatar Brandon Jiang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 475915184
parent d2902b5d
......@@ -16,7 +16,7 @@
import os
import pprint
from typing import Optional
# Import libraries
from absl import logging
......@@ -79,12 +79,15 @@ class BaseBenchmark( # pylint: disable=undefined-variable
def __init__(self,
output_dir=None,
tpu=None):
tpu=None,
tensorflow_models_path: Optional[str] = None):
"""Initialize class.
Args:
output_dir: Base directory to store all output for the test.
tpu: (optional) TPU name to use in a TPU benchmark.
tensorflow_models_path: Full path to tensorflow models directory. Needed
to locate config files.
"""
if os.getenv('BENCHMARK_OUTPUT_DIR'):
......@@ -101,6 +104,13 @@ class BaseBenchmark( # pylint: disable=undefined-variable
else:
self._resolved_tpu = None
if os.getenv('TENSORFLOW_MODELS_PATH'):
self._tensorflow_models_path = os.getenv('TENSORFLOW_MODELS_PATH')
elif tensorflow_models_path:
self._tensorflow_models_path = tensorflow_models_path
else:
self._tensorflow_models_path = None
def _get_model_dir(self, folder_name):
"""Returns directory to store info, e.g. saved model and event log."""
return os.path.join(self.output_dir, folder_name)
......@@ -118,16 +128,18 @@ class BaseBenchmark( # pylint: disable=undefined-variable
gin_file):
with gin.unlock_config():
gin.parse_config_files_and_bindings(
[config_utils.get_config_path(g) for g in gin_file], None)
gin.parse_config_files_and_bindings([
config_utils.get_config_path(
g, base_dir=self._tensorflow_models_path) for g in gin_file
], None)
params = exp_factory.get_exp_config(experiment_type)
for config_file in config_files:
file_path = config_utils.get_config_path(config_file)
file_path = config_utils.get_config_path(
config_file, base_dir=self._tensorflow_models_path)
params = hyperparams.override_params_dict(
params, file_path, is_strict=True)
if params_override:
params = hyperparams.override_params_dict(
params, params_override, is_strict=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment