Commit 56dfd31b authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 413710486
parent 0b6e571a
......@@ -130,7 +130,9 @@ def export_inference_graph(
if log_model_flops_and_params:
inputs_kwargs = None
if isinstance(params.task, configs.retinanet.RetinaNetTask):
if isinstance(
params.task,
(configs.retinanet.RetinaNetTask, configs.maskrcnn.MaskRCNNTask)):
# We need to create inputs_kwargs argument to specify the input shapes for
# subclass model that overrides model.call to take multiple inputs,
# e.g., RetinaNet model.
......
......@@ -26,26 +26,43 @@ from official.vision.beta.serving import export_saved_model_lib
class WriteModelFlopsAndParamsTest(tf.test.TestCase):
@mock.patch.object(export_base, 'export', autospec=True, spec_set=True)
def test_retinanet_task(self, unused_export):
tempdir = self.create_tempdir()
params = configs.retinanet.retinanet_resnetfpn_coco()
print(params.task.model.backbone)
params.task.model.backbone.resnet.model_id = 18
params.task.model.num_classes = 2
params.task.model.max_level = 6
def setUp(self):
super().setUp()
self.tempdir = self.create_tempdir()
self.enter_context(
mock.patch.object(export_base, 'export', autospec=True, spec_set=True))
def _export_model_with_log_model_flops_and_params(self, params):
export_saved_model_lib.export_inference_graph(
input_type='image_tensor',
batch_size=1,
input_image_size=[64, 64],
params=params,
checkpoint_path=os.path.join(tempdir, 'unused-ckpt'),
export_dir=tempdir,
checkpoint_path=os.path.join(self.tempdir, 'unused-ckpt'),
export_dir=self.tempdir,
log_model_flops_and_params=True)
def assertModelAnalysisFilesExist(self):
self.assertTrue(
tf.io.gfile.exists(os.path.join(tempdir, 'model_params.txt')))
tf.io.gfile.exists(os.path.join(self.tempdir, 'model_params.txt')))
self.assertTrue(
tf.io.gfile.exists(os.path.join(tempdir, 'model_flops.txt')))
tf.io.gfile.exists(os.path.join(self.tempdir, 'model_flops.txt')))
def test_retinanet_task(self):
params = configs.retinanet.retinanet_resnetfpn_coco()
params.task.model.backbone.resnet.model_id = 18
params.task.model.num_classes = 2
params.task.model.max_level = 6
self._export_model_with_log_model_flops_and_params(params)
self.assertModelAnalysisFilesExist()
def test_maskrcnn_task(self):
params = configs.maskrcnn.maskrcnn_resnetfpn_coco()
params.task.model.backbone.resnet.model_id = 18
params.task.model.num_classes = 2
params.task.model.max_level = 6
self._export_model_with_log_model_flops_and_params(params)
self.assertModelAnalysisFilesExist()
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment