Commit 1ec978a2 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Add FLOPs computation into run_experiment.

PiperOrigin-RevId: 385674527
parent 70b343d0
......@@ -138,6 +138,11 @@ def run_experiment(
logging.info('Number of trainable params in model: %f Millions.',
num_params / 10.**6)
flops = train_utils.try_count_flops(trainer.model)
if flops is not None:
logging.info('FLOPs (multi-adds) in model: %f Billions.',
flops / 10.**9 / 2)
if run_post_eval:
with distribution_strategy.scope():
return trainer.model, trainer.evaluate(
......
......@@ -17,7 +17,7 @@ import copy
import json
import os
import pprint
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Union
from absl import logging
import dataclasses
......@@ -25,6 +25,9 @@ import gin
import orbit
import tensorflow as tf
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
# pylint: enable=g-direct-tensorflow-import
from official.core import base_task
from official.core import base_trainer
from official.core import config_definitions
......@@ -393,3 +396,48 @@ def try_count_params(model: tf.keras.Model):
'train step already reached before this run.')
return None
return None
def try_count_flops(model: Union[tf.Module, tf.keras.Model],
inputs_kwargs: Optional[Dict[str, Any]] = None):
"""Counts and returns model FLOPs.
Args:
model: A model instance.
inputs_kwargs: An optional dictionary of argument pairs specifying inputs'
shape specifications to getting corresponding concrete function.
Returns:
The model's FLOPs.
"""
if hasattr(model, 'inputs'):
try:
# Get input shape and set batch size to 1.
if model.inputs:
inputs = [
tf.TensorSpec([1] + input.shape[1:], input.dtype)
for input in model.inputs
]
concrete_func = tf.function(model).get_concrete_function(inputs)
# If model.inputs is invalid, try to use the input to get concrete
# function for model.call (subclass model).
else:
concrete_func = tf.function(model.call).get_concrete_function(
**inputs_kwargs)
frozen_func, _ = convert_variables_to_constants_v2_as_graph(concrete_func)
# Calculate FLOPs.
run_meta = tf.compat.v1.RunMetadata()
opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
opts['output'] = 'none'
flops = tf.compat.v1.profiler.profile(
graph=frozen_func.graph, run_meta=run_meta, options=opts)
return flops.total_float_ops
except Exception as e: # pylint: disable=broad-except
logging.info(
'Failed to count model FLOPs with error %s, because the build() '
'methods in keras layers were not called. This is probably because '
'the model was not feed any input, e.g., the max train step already '
'reached before this run.', e)
return None
return None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment