"git@developer.sourcefind.cn:norm/vllm.git" did not exist on "cfc15a1031ef0197a1b291d2ed93717a9bdad268"
Commit d5cd9b0a authored by Brandon Jiang's avatar Brandon Jiang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 475915184
parent d2902b5d
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import os import os
import pprint import pprint
from typing import Optional
# Import libraries # Import libraries
from absl import logging from absl import logging
...@@ -79,12 +79,15 @@ class BaseBenchmark( # pylint: disable=undefined-variable ...@@ -79,12 +79,15 @@ class BaseBenchmark( # pylint: disable=undefined-variable
def __init__(self, def __init__(self,
output_dir=None, output_dir=None,
tpu=None): tpu=None,
tensorflow_models_path: Optional[str] = None):
"""Initialize class. """Initialize class.
Args: Args:
output_dir: Base directory to store all output for the test. output_dir: Base directory to store all output for the test.
tpu: (optional) TPU name to use in a TPU benchmark. tpu: (optional) TPU name to use in a TPU benchmark.
tensorflow_models_path: Full path to tensorflow models directory. Needed
to locate config files.
""" """
if os.getenv('BENCHMARK_OUTPUT_DIR'): if os.getenv('BENCHMARK_OUTPUT_DIR'):
...@@ -101,6 +104,13 @@ class BaseBenchmark( # pylint: disable=undefined-variable ...@@ -101,6 +104,13 @@ class BaseBenchmark( # pylint: disable=undefined-variable
else: else:
self._resolved_tpu = None self._resolved_tpu = None
if os.getenv('TENSORFLOW_MODELS_PATH'):
self._tensorflow_models_path = os.getenv('TENSORFLOW_MODELS_PATH')
elif tensorflow_models_path:
self._tensorflow_models_path = tensorflow_models_path
else:
self._tensorflow_models_path = None
def _get_model_dir(self, folder_name): def _get_model_dir(self, folder_name):
"""Returns directory to store info, e.g. saved model and event log.""" """Returns directory to store info, e.g. saved model and event log."""
return os.path.join(self.output_dir, folder_name) return os.path.join(self.output_dir, folder_name)
...@@ -118,16 +128,18 @@ class BaseBenchmark( # pylint: disable=undefined-variable ...@@ -118,16 +128,18 @@ class BaseBenchmark( # pylint: disable=undefined-variable
gin_file): gin_file):
with gin.unlock_config(): with gin.unlock_config():
gin.parse_config_files_and_bindings( gin.parse_config_files_and_bindings([
[config_utils.get_config_path(g) for g in gin_file], None) config_utils.get_config_path(
g, base_dir=self._tensorflow_models_path) for g in gin_file
], None)
params = exp_factory.get_exp_config(experiment_type) params = exp_factory.get_exp_config(experiment_type)
for config_file in config_files: for config_file in config_files:
file_path = config_utils.get_config_path(config_file) file_path = config_utils.get_config_path(
config_file, base_dir=self._tensorflow_models_path)
params = hyperparams.override_params_dict( params = hyperparams.override_params_dict(
params, file_path, is_strict=True) params, file_path, is_strict=True)
if params_override: if params_override:
params = hyperparams.override_params_dict( params = hyperparams.override_params_dict(
params, params_override, is_strict=True) params, params_override, is_strict=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment