Commit c5ae4110 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by saberkun
Browse files

Internal change

PiperOrigin-RevId: 398593113
parent 6ca5ac92
...@@ -22,8 +22,10 @@ from tensorflow.python.distribute import combinations ...@@ -22,8 +22,10 @@ from tensorflow.python.distribute import combinations
from official.common import registry_imports # pylint: disable=unused-import from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory from official.core import exp_factory
from official.vision.beta.dataloaders import tfexample_utils from official.vision.beta.dataloaders import tfexample_utils
from official.vision.beta.serving import detection as detection_serving
from official.vision.beta.serving import export_tflite_lib from official.vision.beta.serving import export_tflite_lib
from official.vision.beta.serving import image_classification as image_classification_serving from official.vision.beta.serving import image_classification as image_classification_serving
from official.vision.beta.serving import semantic_segmentation as semantic_segmentation_serving
class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase): class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
...@@ -51,7 +53,8 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -51,7 +53,8 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
experiment=['mobilenet_imagenet'], experiment=['mobilenet_imagenet'],
quant_type=[None, 'default', 'fp16', 'int8'], quant_type=[None, 'default', 'fp16', 'int8'],
input_image_size=[[224, 224]])) input_image_size=[[224, 224]]))
def test_export_tflite(self, experiment, quant_type, input_image_size): def test_export_tflite_image_classification(self, experiment, quant_type,
input_image_size):
params = exp_factory.get_exp_config(experiment) params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = self._test_tfrecord_file params.task.validation_data.input_path = self._test_tfrecord_file
params.task.train_data.input_path = self._test_tfrecord_file params.task.train_data.input_path = self._test_tfrecord_file
...@@ -71,6 +74,53 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase): ...@@ -71,6 +74,53 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsInstance(tflite_model, bytes) self.assertIsInstance(tflite_model, bytes)
@combinations.generate(
combinations.combine(
experiment=['retinanet_mobile_coco'],
quant_type=[None, 'default', 'fp16'],
input_image_size=[[256, 256]]))
def test_export_tflite_detection(self, experiment, quant_type,
input_image_size):
params = exp_factory.get_exp_config(experiment)
temp_dir = self.get_temp_dir()
module = detection_serving.DetectionModule(
params=params, batch_size=1, input_image_size=input_image_size)
self._export_from_module(
module=module,
input_type='tflite',
saved_model_dir=os.path.join(temp_dir, 'saved_model'))
tflite_model = export_tflite_lib.convert_tflite_model(
saved_model_dir=os.path.join(temp_dir, 'saved_model'),
quant_type=quant_type,
params=params,
calibration_steps=5)
self.assertIsInstance(tflite_model, bytes)
@combinations.generate(
combinations.combine(
experiment=['seg_deeplabv3_pascal'],
quant_type=[None, 'default', 'fp16'],
input_image_size=[[512, 512]]))
def test_export_tflite_semantic_segmentation(self, experiment, quant_type,
input_image_size):
params = exp_factory.get_exp_config(experiment)
temp_dir = self.get_temp_dir()
module = semantic_segmentation_serving.SegmentationModule(
params=params, batch_size=1, input_image_size=input_image_size)
self._export_from_module(
module=module,
input_type='tflite',
saved_model_dir=os.path.join(temp_dir, 'saved_model'))
tflite_model = export_tflite_lib.convert_tflite_model(
saved_model_dir=os.path.join(temp_dir, 'saved_model'),
quant_type=quant_type,
params=params,
calibration_steps=5)
self.assertIsInstance(tflite_model, bytes)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""RetinaNet task definition.""" """MaskRCNN task definition."""
import os import os
from typing import Any, Optional, List, Tuple, Mapping from typing import Any, Optional, List, Tuple, Mapping
...@@ -28,6 +28,7 @@ from official.vision.beta.dataloaders import tf_example_decoder ...@@ -28,6 +28,7 @@ from official.vision.beta.dataloaders import tf_example_decoder
from official.vision.beta.dataloaders import tf_example_label_map_decoder from official.vision.beta.dataloaders import tf_example_label_map_decoder
from official.vision.beta.evaluation import coco_evaluator from official.vision.beta.evaluation import coco_evaluator
from official.vision.beta.evaluation import coco_utils from official.vision.beta.evaluation import coco_utils
from official.vision.beta.evaluation import wod_detection_evaluator
from official.vision.beta.losses import maskrcnn_losses from official.vision.beta.losses import maskrcnn_losses
from official.vision.beta.modeling import factory from official.vision.beta.modeling import factory
...@@ -247,6 +248,39 @@ class MaskRCNNTask(base_task.Task): ...@@ -247,6 +248,39 @@ class MaskRCNNTask(base_task.Task):
} }
return losses return losses
def _build_coco_metrics(self):
"""Build COCO metrics evaluator."""
if (not self._task_config.model.include_mask
) or self._task_config.annotation_file:
self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file=self._task_config.annotation_file,
include_mask=self._task_config.model.include_mask,
per_category_metrics=self._task_config.per_category_metrics)
else:
# Builds COCO-style annotation file if include_mask is True, and
# annotation_file isn't provided.
annotation_path = os.path.join(self._logging_dir, 'annotation.json')
if tf.io.gfile.exists(annotation_path):
logging.info(
'annotation.json file exists, skipping creating the annotation'
' file.')
else:
if self._task_config.validation_data.num_examples <= 0:
logging.info('validation_data.num_examples needs to be > 0')
if not self._task_config.validation_data.input_path:
logging.info('Can not create annotation file for tfds.')
logging.info(
'Creating coco-style annotation file: %s', annotation_path)
coco_utils.scan_and_generator_annotation_file(
self._task_config.validation_data.input_path,
self._task_config.validation_data.file_type,
self._task_config.validation_data.num_examples,
self.task_config.model.include_mask, annotation_path)
self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file=annotation_path,
include_mask=self._task_config.model.include_mask,
per_category_metrics=self._task_config.per_category_metrics)
def build_metrics(self, training: bool = True): def build_metrics(self, training: bool = True):
"""Build detection metrics.""" """Build detection metrics."""
metrics = [] metrics = []
...@@ -264,36 +298,10 @@ class MaskRCNNTask(base_task.Task): ...@@ -264,36 +298,10 @@ class MaskRCNNTask(base_task.Task):
metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32)) metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32))
else: else:
if (not self._task_config.model.include_mask if self._task_config.use_coco_metrics:
) or self._task_config.annotation_file: self._build_coco_metrics()
self.coco_metric = coco_evaluator.COCOEvaluator( if self._task_config.use_wod_metrics:
annotation_file=self._task_config.annotation_file, self.wod_metric = wod_detection_evaluator.WOD2dDetectionEvaluator()
include_mask=self._task_config.model.include_mask,
per_category_metrics=self._task_config.per_category_metrics)
else:
# Builds COCO-style annotation file if include_mask is True, and
# annotation_file isn't provided.
annotation_path = os.path.join(self._logging_dir, 'annotation.json')
if tf.io.gfile.exists(annotation_path):
logging.info(
'annotation.json file exists, skipping creating the annotation'
' file.')
else:
if self._task_config.validation_data.num_examples <= 0:
logging.info('validation_data.num_examples needs to be > 0')
if not self._task_config.validation_data.input_path:
logging.info('Can not create annotation file for tfds.')
logging.info(
'Creating coco-style annotation file: %s', annotation_path)
coco_utils.scan_and_generator_annotation_file(
self._task_config.validation_data.input_path,
self._task_config.validation_data.file_type,
self._task_config.validation_data.num_examples,
self.task_config.model.include_mask, annotation_path)
self.coco_metric = coco_evaluator.COCOEvaluator(
annotation_file=annotation_path,
include_mask=self._task_config.model.include_mask,
per_category_metrics=self._task_config.per_category_metrics)
return metrics return metrics
...@@ -376,31 +384,58 @@ class MaskRCNNTask(base_task.Task): ...@@ -376,31 +384,58 @@ class MaskRCNNTask(base_task.Task):
training=False) training=False)
logs = {self.loss: 0} logs = {self.loss: 0}
coco_model_outputs = { if self._task_config.use_coco_metrics:
'detection_boxes': outputs['detection_boxes'], coco_model_outputs = {
'detection_scores': outputs['detection_scores'], 'detection_boxes': outputs['detection_boxes'],
'detection_classes': outputs['detection_classes'], 'detection_scores': outputs['detection_scores'],
'num_detections': outputs['num_detections'], 'detection_classes': outputs['detection_classes'],
'source_id': labels['groundtruths']['source_id'], 'num_detections': outputs['num_detections'],
'image_info': labels['image_info'] 'source_id': labels['groundtruths']['source_id'],
} 'image_info': labels['image_info']
if self.task_config.model.include_mask: }
coco_model_outputs.update({ if self.task_config.model.include_mask:
'detection_masks': outputs['detection_masks'], coco_model_outputs.update({
}) 'detection_masks': outputs['detection_masks'],
logs.update({ })
self.coco_metric.name: (labels['groundtruths'], coco_model_outputs) logs.update(
}) {self.coco_metric.name: (labels['groundtruths'], coco_model_outputs)})
if self.task_config.use_wod_metrics:
wod_model_outputs = {
'detection_boxes': outputs['detection_boxes'],
'detection_scores': outputs['detection_scores'],
'detection_classes': outputs['detection_classes'],
'num_detections': outputs['num_detections'],
'source_id': labels['groundtruths']['source_id'],
'image_info': labels['image_info']
}
logs.update(
{self.wod_metric.name: (labels['groundtruths'], wod_model_outputs)})
return logs return logs
def aggregate_logs(self, state=None, step_outputs=None): def aggregate_logs(self, state=None, step_outputs=None):
if self._task_config.use_coco_metrics:
if state is None:
self.coco_metric.reset_states()
self.coco_metric.update_state(
step_outputs[self.coco_metric.name][0],
step_outputs[self.coco_metric.name][1])
if self._task_config.use_wod_metrics:
if state is None:
self.wod_metric.reset_states()
self.wod_metric.update_state(
step_outputs[self.wod_metric.name][0],
step_outputs[self.wod_metric.name][1])
if state is None: if state is None:
self.coco_metric.reset_states() # Create an arbitrary state to indicate it's not the first step in the
state = self.coco_metric # following calls to this function.
self.coco_metric.update_state( state = True
step_outputs[self.coco_metric.name][0],
step_outputs[self.coco_metric.name][1])
return state return state
def reduce_aggregated_logs(self, aggregated_logs, global_step=None): def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
return self.coco_metric.result() logs = {}
if self._task_config.use_coco_metrics:
logs.update(self.coco_metric.result())
if self._task_config.use_wod_metrics:
logs.update(self.wod_metric.result())
return logs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment