"git@developer.sourcefind.cn:norm/vllm.git" did not exist on "cb7a1c1cbf7c07e072df29844fb7a51a01344392"
Commit 26bbda73 authored by David Chen's avatar David Chen Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 284874717
parent 38e48f91
...@@ -35,19 +35,28 @@ class PerfZeroBenchmark(tf.test.Benchmark): ...@@ -35,19 +35,28 @@ class PerfZeroBenchmark(tf.test.Benchmark):
""" """
local_flags = None local_flags = None
def __init__(self, output_dir=None, default_flags=None, flag_methods=None): def __init__(self,
output_dir=None,
default_flags=None,
flag_methods=None,
tpu=None):
"""Initialize class. """Initialize class.
Args: Args:
output_dir: Base directory to store all output for the test. output_dir: Base directory to store all output for the test.
default_flags: default_flags: Set of flags to pass to model.
flag_methods: flag_methods: Set of flag methods to run during setup.
tpu: (optional) TPU name to use in a TPU benchmark.
""" """
if not output_dir: if not output_dir:
output_dir = '/tmp' output_dir = '/tmp'
self.output_dir = output_dir self.output_dir = output_dir
self.default_flags = default_flags or {} self.default_flags = default_flags or {}
self.flag_methods = flag_methods or {} self.flag_methods = flag_methods or {}
if tpu:
# TPU models are expected to accept a --tpu=name flag. PerfZero creates
# the TPU at runtime and passes the TPU's name to this flag.
self.default_flags['tpu'] = tpu
def _get_model_dir(self, folder_name): def _get_model_dir(self, folder_name):
"""Returns directory to store info, e.g. saved model and event log.""" """Returns directory to store info, e.g. saved model and event log."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment