Unverified Commit 154d3ffa authored by Haoyu Zhang's avatar Haoyu Zhang Committed by GitHub
Browse files

Add flag to enable Xprof (#6352)

parent 6dea4846
...@@ -215,6 +215,11 @@ def define_keras_flags(): ...@@ -215,6 +215,11 @@ def define_keras_flags():
help='The number of steps to run for training. If it is larger than ' help='The number of steps to run for training. If it is larger than '
'# batches per epoch, then use # batches per epoch. When this flag is ' '# batches per epoch, then use # batches per epoch. When this flag is '
'set, only one epoch is going to run for training.') 'set, only one epoch is going to run for training.')
flags.DEFINE_boolean(
name='enable_e2e_xprof', default=False,
help='Save end-to-end profiling data to model dir using Xprof. Profiling '
'has an overhead on both computation and memory usage, and can generate '
'gigantic files when profiling a lot of steps.')
def get_synth_input_fn(height, width, num_channels, num_classes, def get_synth_input_fn(height, width, num_channels, num_classes,
......
...@@ -22,6 +22,7 @@ from absl import app as absl_app ...@@ -22,6 +22,7 @@ from absl import app as absl_app
from absl import flags from absl import flags
import tensorflow as tf # pylint: disable=g-bad-import-order import tensorflow as tf # pylint: disable=g-bad-import-order
from tensorflow.python.eager import profiler
from official.resnet import imagenet_main from official.resnet import imagenet_main
from official.resnet.keras import keras_common from official.resnet.keras import keras_common
from official.resnet.keras import resnet_model from official.resnet.keras import resnet_model
...@@ -220,6 +221,8 @@ def run(flags_obj): ...@@ -220,6 +221,8 @@ def run(flags_obj):
callbacks = [time_callback, lr_callback] callbacks = [time_callback, lr_callback]
if flags_obj.enable_tensorboard: if flags_obj.enable_tensorboard:
callbacks.append(tensorboard_callback) callbacks.append(tensorboard_callback)
if flags_obj.enable_e2e_xprof:
profiler.start()
history = model.fit(train_input_dataset, history = model.fit(train_input_dataset,
epochs=train_epochs, epochs=train_epochs,
...@@ -230,6 +233,10 @@ def run(flags_obj): ...@@ -230,6 +233,10 @@ def run(flags_obj):
validation_freq=flags_obj.epochs_between_evals, validation_freq=flags_obj.epochs_between_evals,
verbose=2) verbose=2)
if flags_obj.enable_e2e_xprof:
results = profiler.stop()
profiler.save(flags_obj.model_dir, results)
eval_output = None eval_output = None
if not flags_obj.skip_eval: if not flags_obj.skip_eval:
eval_output = model.evaluate(eval_input_dataset, eval_output = model.evaluate(eval_input_dataset,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment