Commit 82a26070 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Allow dynamic batch size for batched NMS only in detection model export.

PiperOrigin-RevId: 464868859
parent 50e86708
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
"""Detection input and model functions for serving/inference.""" """Detection input and model functions for serving/inference."""
from typing import Mapping, Text from typing import Mapping, Text
from absl import logging
import tensorflow as tf import tensorflow as tf
from official.vision import configs from official.vision import configs
...@@ -35,7 +37,12 @@ class DetectionModule(export_base.ExportModule): ...@@ -35,7 +37,12 @@ class DetectionModule(export_base.ExportModule):
def _build_model(self): def _build_model(self):
if self._batch_size is None: if self._batch_size is None:
raise ValueError('batch_size cannot be None for detection models.') # Only batched NMS is supported with dynamic batch size.
self.params.task.model.detection_generator.nms_version = 'batched'
logging.info(
'nms_version is set to `batched` because only batched NMS is '
'supported with dynamic batch size.')
input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] + input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
self._input_image_size + [3]) self._input_image_size + [3])
......
...@@ -124,10 +124,8 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase): ...@@ -124,10 +124,8 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
def test_build_model_fail_with_none_batch_size(self): def test_build_model_fail_with_none_batch_size(self):
params = exp_factory.get_exp_config('retinanet_resnetfpn_coco') params = exp_factory.get_exp_config('retinanet_resnetfpn_coco')
with self.assertRaisesRegex( detection.DetectionModule(
ValueError, 'batch_size cannot be None for detection models.'): params, batch_size=None, input_image_size=[640, 640])
detection.DetectionModule(
params, batch_size=None, input_image_size=[640, 640])
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment