Commit 9ff8aa21 authored by Sai Ganesh Bandiatmakuri's avatar Sai Ganesh Bandiatmakuri Committed by A. Unique TensorFlower
Browse files

Add ability to take TPU address and output dir from environment variables.

PiperOrigin-RevId: 300858086
parent 225eda71
...@@ -48,15 +48,26 @@ class PerfZeroBenchmark(tf.test.Benchmark): ...@@ -48,15 +48,26 @@ class PerfZeroBenchmark(tf.test.Benchmark):
flag_methods: Set of flag methods to run during setup. flag_methods: Set of flag methods to run during setup.
tpu: (optional) TPU name to use in a TPU benchmark. tpu: (optional) TPU name to use in a TPU benchmark.
""" """
if not output_dir: if os.getenv('BENCHMARK_OUTPUT_DIR'):
output_dir = '/tmp' self.output_dir = os.getenv('BENCHMARK_OUTPUT_DIR')
elif output_dir:
self.output_dir = output_dir self.output_dir = output_dir
else:
self.output_dir = '/tmp'
self.default_flags = default_flags or {} self.default_flags = default_flags or {}
self.flag_methods = flag_methods or {} self.flag_methods = flag_methods or {}
if tpu:
if os.getenv('BENCHMARK_TPU'):
resolved_tpu = os.getenv('BENCHMARK_TPU')
elif tpu:
resolved_tpu = tpu
else:
resolved_tpu = None
if resolved_tpu:
# TPU models are expected to accept a --tpu=name flag. PerfZero creates # TPU models are expected to accept a --tpu=name flag. PerfZero creates
# the TPU at runtime and passes the TPU's name to this flag. # the TPU at runtime and passes the TPU's name to this flag.
self.default_flags['tpu'] = tpu self.default_flags['tpu'] = resolved_tpu
def _get_model_dir(self, folder_name): def _get_model_dir(self, folder_name):
"""Returns directory to store info, e.g. saved model and event log.""" """Returns directory to store info, e.g. saved model and event log."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment