Commit 82a26070 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Allow dynamic batch size for batched NMS only in detection model export.

PiperOrigin-RevId: 464868859
parent 50e86708
......@@ -15,6 +15,8 @@
"""Detection input and model functions for serving/inference."""
from typing import Mapping, Text
from absl import logging
import tensorflow as tf
from official.vision import configs
......@@ -35,7 +37,12 @@ class DetectionModule(export_base.ExportModule):
def _build_model(self):
if self._batch_size is None:
raise ValueError('batch_size cannot be None for detection models.')
# Only batched NMS is supported with dynamic batch size.
self.params.task.model.detection_generator.nms_version = 'batched'
logging.info(
'nms_version is set to `batched` because only batched NMS is '
'supported with dynamic batch size.')
input_specs = tf.keras.layers.InputSpec(shape=[self._batch_size] +
self._input_image_size + [3])
......
......@@ -124,10 +124,8 @@ class DetectionExportTest(tf.test.TestCase, parameterized.TestCase):
def test_build_model_fail_with_none_batch_size(self):
params = exp_factory.get_exp_config('retinanet_resnetfpn_coco')
with self.assertRaisesRegex(
ValueError, 'batch_size cannot be None for detection models.'):
detection.DetectionModule(
params, batch_size=None, input_image_size=[640, 640])
detection.DetectionModule(
params, batch_size=None, input_image_size=[640, 640])
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment