Commit b7f0d1ae authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 413710486
parent 79923554
...@@ -130,7 +130,9 @@ def export_inference_graph( ...@@ -130,7 +130,9 @@ def export_inference_graph(
if log_model_flops_and_params: if log_model_flops_and_params:
inputs_kwargs = None inputs_kwargs = None
if isinstance(params.task, configs.retinanet.RetinaNetTask): if isinstance(
params.task,
(configs.retinanet.RetinaNetTask, configs.maskrcnn.MaskRCNNTask)):
# We need to create inputs_kwargs argument to specify the input shapes for # We need to create inputs_kwargs argument to specify the input shapes for
# subclass model that overrides model.call to take multiple inputs, # subclass model that overrides model.call to take multiple inputs,
# e.g., RetinaNet model. # e.g., RetinaNet model.
......
...@@ -26,26 +26,43 @@ from official.vision.beta.serving import export_saved_model_lib ...@@ -26,26 +26,43 @@ from official.vision.beta.serving import export_saved_model_lib
class WriteModelFlopsAndParamsTest(tf.test.TestCase): class WriteModelFlopsAndParamsTest(tf.test.TestCase):
@mock.patch.object(export_base, 'export', autospec=True, spec_set=True) def setUp(self):
def test_retinanet_task(self, unused_export): super().setUp()
tempdir = self.create_tempdir() self.tempdir = self.create_tempdir()
params = configs.retinanet.retinanet_resnetfpn_coco() self.enter_context(
print(params.task.model.backbone) mock.patch.object(export_base, 'export', autospec=True, spec_set=True))
params.task.model.backbone.resnet.model_id = 18
params.task.model.num_classes = 2 def _export_model_with_log_model_flops_and_params(self, params):
params.task.model.max_level = 6
export_saved_model_lib.export_inference_graph( export_saved_model_lib.export_inference_graph(
input_type='image_tensor', input_type='image_tensor',
batch_size=1, batch_size=1,
input_image_size=[64, 64], input_image_size=[64, 64],
params=params, params=params,
checkpoint_path=os.path.join(tempdir, 'unused-ckpt'), checkpoint_path=os.path.join(self.tempdir, 'unused-ckpt'),
export_dir=tempdir, export_dir=self.tempdir,
log_model_flops_and_params=True) log_model_flops_and_params=True)
def assertModelAnalysisFilesExist(self):
self.assertTrue( self.assertTrue(
tf.io.gfile.exists(os.path.join(tempdir, 'model_params.txt'))) tf.io.gfile.exists(os.path.join(self.tempdir, 'model_params.txt')))
self.assertTrue( self.assertTrue(
tf.io.gfile.exists(os.path.join(tempdir, 'model_flops.txt'))) tf.io.gfile.exists(os.path.join(self.tempdir, 'model_flops.txt')))
def test_retinanet_task(self):
params = configs.retinanet.retinanet_resnetfpn_coco()
params.task.model.backbone.resnet.model_id = 18
params.task.model.num_classes = 2
params.task.model.max_level = 6
self._export_model_with_log_model_flops_and_params(params)
self.assertModelAnalysisFilesExist()
def test_maskrcnn_task(self):
params = configs.maskrcnn.maskrcnn_resnetfpn_coco()
params.task.model.backbone.resnet.model_id = 18
params.task.model.num_classes = 2
params.task.model.max_level = 6
self._export_model_with_log_model_flops_and_params(params)
self.assertModelAnalysisFilesExist()
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment