"docs/source/api/vscode:/vscode.git/clone" did not exist on "85bd0cdecacb82737c5d521d8546caaa3d926658"
Commit 2db7ab2a authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Add an option to save model flops and params on export.

PiperOrigin-RevId: 407461964
parent ff93e945
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for testing."""
import tensorflow as tf
class FakeKerasModel(tf.keras.Model):
"""Fake keras model for testing."""
def __init__(self):
super().__init__()
self.dense = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
def call(self, inputs):
return self.dense2(self.dense(inputs))
class _Dense(tf.Module):
"""A dense layer."""
def __init__(self, input_dim, output_size, name=None):
super().__init__(name=name)
with self.name_scope:
self.w = tf.Variable(
tf.random.normal([input_dim, output_size]), name='w')
self.b = tf.Variable(tf.zeros([output_size]), name='b')
@tf.Module.with_name_scope
def __call__(self, x):
y = tf.matmul(x, self.w) + self.b
return tf.nn.relu(y)
class FakeModule(tf.Module):
"""Fake model using tf.Module for testing."""
def __init__(self, input_size, name=None):
super().__init__(name=name)
with self.name_scope:
self.dense = _Dense(input_size, 4, name='dense')
self.dense2 = _Dense(4, 4, name='dense_1')
@tf.Module.with_name_scope
def __call__(self, x):
return self.dense2(self.dense(x))
...@@ -380,6 +380,23 @@ def remove_ckpts(model_dir): ...@@ -380,6 +380,23 @@ def remove_ckpts(model_dir):
tf.io.gfile.remove(file_to_remove) tf.io.gfile.remove(file_to_remove)
def write_model_params(model: Union[tf.Module, tf.keras.Model],
output_path: str) -> None:
"""Writes the model parameters and shapes to a file.
Args:
model: A model instance.
output_path: Output file path.
"""
with tf.io.gfile.GFile(output_path, 'w') as f:
total_params = 0
for var in model.variables:
shape = tf.shape(var)
total_params += tf.math.reduce_prod(shape).numpy()
f.write(f'{var.name} {shape.numpy().tolist()}\n')
f.write(f'\nTotal params: {total_params}\n')
def try_count_params( def try_count_params(
model: Union[tf.Module, tf.keras.Model], model: Union[tf.Module, tf.keras.Model],
trainable_only: bool = False): trainable_only: bool = False):
...@@ -412,13 +429,15 @@ def try_count_params( ...@@ -412,13 +429,15 @@ def try_count_params(
def try_count_flops(model: Union[tf.Module, tf.keras.Model], def try_count_flops(model: Union[tf.Module, tf.keras.Model],
inputs_kwargs: Optional[Dict[str, Any]] = None): inputs_kwargs: Optional[Dict[str, Any]] = None,
output_path: Optional[str] = None):
"""Counts and returns model FLOPs. """Counts and returns model FLOPs.
Args: Args:
model: A model instance. model: A model instance.
inputs_kwargs: An optional dictionary of argument pairs specifying inputs' inputs_kwargs: An optional dictionary of argument pairs specifying inputs'
shape specifications to getting corresponding concrete function. shape specifications to getting corresponding concrete function.
output_path: A file path to write the profiling results to.
Returns: Returns:
The model's FLOPs. The model's FLOPs.
...@@ -442,7 +461,10 @@ def try_count_flops(model: Union[tf.Module, tf.keras.Model], ...@@ -442,7 +461,10 @@ def try_count_flops(model: Union[tf.Module, tf.keras.Model],
# Calculate FLOPs. # Calculate FLOPs.
run_meta = tf.compat.v1.RunMetadata() run_meta = tf.compat.v1.RunMetadata()
opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation() opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
opts['output'] = 'none' if output_path is not None:
opts['output'] = f'file:outfile={output_path}'
else:
opts['output'] = 'none'
flops = tf.compat.v1.profiler.profile( flops = tf.compat.v1.profiler.profile(
graph=frozen_func.graph, run_meta=run_meta, options=opts) graph=frozen_func.graph, run_meta=run_meta, options=opts)
return flops.total_float_ops return flops.total_float_ops
......
...@@ -14,8 +14,12 @@ ...@@ -14,8 +14,12 @@
"""Tests for official.core.train_utils.""" """Tests for official.core.train_utils."""
import os
import numpy as np
import tensorflow as tf import tensorflow as tf
from official.core import test_utils
from official.core import train_utils from official.core import train_utils
...@@ -51,6 +55,44 @@ class TrainUtilsTest(tf.test.TestCase): ...@@ -51,6 +55,44 @@ class TrainUtilsTest(tf.test.TestCase):
self.assertEqual(d['a']['i']['x'], 123) self.assertEqual(d['a']['i']['x'], 123)
self.assertEqual(d['b'], 456) self.assertEqual(d['b'], 456)
def test_write_model_params_keras_model(self):
inputs = np.zeros([2, 3])
model = test_utils.FakeKerasModel()
model(inputs) # Must do forward pass to build the model.
filepath = os.path.join(self.create_tempdir(), 'model_params.txt')
train_utils.write_model_params(model, filepath)
actual = tf.io.gfile.GFile(filepath, 'r').read().splitlines()
expected = [
'fake_keras_model/dense/kernel:0 [3, 4]',
'fake_keras_model/dense/bias:0 [4]',
'fake_keras_model/dense_1/kernel:0 [4, 4]',
'fake_keras_model/dense_1/bias:0 [4]',
'',
'Total params: 36',
]
self.assertEqual(actual, expected)
def test_write_model_params_module(self):
inputs = np.zeros([2, 3], dtype=np.float32)
model = test_utils.FakeModule(3, name='fake_module')
model(inputs) # Must do forward pass to build the model.
filepath = os.path.join(self.create_tempdir(), 'model_params.txt')
train_utils.write_model_params(model, filepath)
actual = tf.io.gfile.GFile(filepath, 'r').read().splitlines()
expected = [
'fake_module/dense/b:0 [4]',
'fake_module/dense/w:0 [3, 4]',
'fake_module/dense_1/b:0 [4]',
'fake_module/dense_1/w:0 [4, 4]',
'',
'Total params: 36',
]
self.assertEqual(actual, expected)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -18,6 +18,7 @@ r"""Vision models export utility function for serving/inference.""" ...@@ -18,6 +18,7 @@ r"""Vision models export utility function for serving/inference."""
import os import os
from typing import Optional, List from typing import Optional, List
from absl import logging
import tensorflow as tf import tensorflow as tf
from official.core import config_definitions as cfg from official.core import config_definitions as cfg
...@@ -41,7 +42,8 @@ def export_inference_graph( ...@@ -41,7 +42,8 @@ def export_inference_graph(
export_module: Optional[export_base.ExportModule] = None, export_module: Optional[export_base.ExportModule] = None,
export_checkpoint_subdir: Optional[str] = None, export_checkpoint_subdir: Optional[str] = None,
export_saved_model_subdir: Optional[str] = None, export_saved_model_subdir: Optional[str] = None,
save_options: Optional[tf.saved_model.SaveOptions] = None): save_options: Optional[tf.saved_model.SaveOptions] = None,
log_model_flops_and_params: bool = False):
"""Exports inference graph for the model specified in the exp config. """Exports inference graph for the model specified in the exp config.
Saved model is stored at export_dir/saved_model, checkpoint is saved Saved model is stored at export_dir/saved_model, checkpoint is saved
...@@ -63,6 +65,8 @@ def export_inference_graph( ...@@ -63,6 +65,8 @@ def export_inference_graph(
export_saved_model_subdir: Optional subdirectory under export_dir export_saved_model_subdir: Optional subdirectory under export_dir
to store saved model. to store saved model.
save_options: `SaveOptions` for `tf.saved_model.save`. save_options: `SaveOptions` for `tf.saved_model.save`.
log_model_flops_and_params: If True, writes model FLOPs to model_flops.txt
and model parameters to model_params.txt.
""" """
if export_checkpoint_subdir: if export_checkpoint_subdir:
...@@ -123,3 +127,32 @@ def export_inference_graph( ...@@ -123,3 +127,32 @@ def export_inference_graph(
ckpt = tf.train.Checkpoint(model=export_module.model) ckpt = tf.train.Checkpoint(model=export_module.model)
ckpt.save(os.path.join(output_checkpoint_directory, 'ckpt')) ckpt.save(os.path.join(output_checkpoint_directory, 'ckpt'))
train_utils.serialize_config(params, export_dir) train_utils.serialize_config(params, export_dir)
if log_model_flops_and_params:
inputs_kwargs = None
if isinstance(params.task, configs.retinanet.RetinaNetTask):
# We need to create inputs_kwargs argument to specify the input shapes for
# subclass model that overrides model.call to take multiple inputs,
# e.g., RetinaNet model.
inputs_kwargs = {
'images':
tf.TensorSpec([1] + input_image_size + [num_channels],
tf.float32),
'image_shape':
tf.TensorSpec([1, 2], tf.float32)
}
dummy_inputs = {
k: tf.ones(v.shape.as_list(), tf.float32)
for k, v in inputs_kwargs.items()
}
# Must do forward pass to build the model.
export_module.model(**dummy_inputs)
else:
logging.info(
'Logging model flops and params not implemented for %s task.',
type(params.task))
return
train_utils.try_count_flops(export_module.model, inputs_kwargs,
os.path.join(export_dir, 'model_flops.txt'))
train_utils.write_model_params(export_module.model,
os.path.join(export_dir, 'model_params.txt'))
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for official.core.export_saved_model_lib."""
import os
from unittest import mock
import tensorflow as tf
from official.core import export_base
from official.vision.beta import configs
from official.vision.beta.serving import export_saved_model_lib
class WriteModelFlopsAndParamsTest(tf.test.TestCase):
@mock.patch.object(export_base, 'export', autospec=True, spec_set=True)
def test_retinanet_task(self, unused_export):
tempdir = self.create_tempdir()
export_saved_model_lib.export_inference_graph(
input_type='image_tensor',
batch_size=1,
input_image_size=[128, 128],
params=configs.retinanet.retinanet_resnetfpn_coco(),
checkpoint_path=os.path.join(tempdir, 'unused-ckpt'),
export_dir=tempdir,
log_model_flops_and_params=True)
self.assertTrue(
tf.io.gfile.exists(os.path.join(tempdir, 'model_params.txt')))
self.assertTrue(
tf.io.gfile.exists(os.path.join(tempdir, 'model_flops.txt')))
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment