"docs/en/advanced_guides/customize_runtime.md" did not exist on "583c4accbbf8e37b15638820b7b781f4475c6bde"
Commit c5ae4110 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by saberkun
Browse files

Internal change

PiperOrigin-RevId: 398593113
parent 6ca5ac92
......@@ -22,8 +22,10 @@ from tensorflow.python.distribute import combinations
from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory
from official.vision.beta.dataloaders import tfexample_utils
from official.vision.beta.serving import detection as detection_serving
from official.vision.beta.serving import export_tflite_lib
from official.vision.beta.serving import image_classification as image_classification_serving
from official.vision.beta.serving import semantic_segmentation as semantic_segmentation_serving
class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
......@@ -51,7 +53,8 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
experiment=['mobilenet_imagenet'],
quant_type=[None, 'default', 'fp16', 'int8'],
input_image_size=[[224, 224]]))
def test_export_tflite(self, experiment, quant_type, input_image_size):
def test_export_tflite_image_classification(self, experiment, quant_type,
input_image_size):
params = exp_factory.get_exp_config(experiment)
params.task.validation_data.input_path = self._test_tfrecord_file
params.task.train_data.input_path = self._test_tfrecord_file
......@@ -71,6 +74,53 @@ class ExportTfliteLibTest(tf.test.TestCase, parameterized.TestCase):
self.assertIsInstance(tflite_model, bytes)
@combinations.generate(
combinations.combine(
experiment=['retinanet_mobile_coco'],
quant_type=[None, 'default', 'fp16'],
input_image_size=[[256, 256]]))
def test_export_tflite_detection(self, experiment, quant_type,
input_image_size):
params = exp_factory.get_exp_config(experiment)
temp_dir = self.get_temp_dir()
module = detection_serving.DetectionModule(
params=params, batch_size=1, input_image_size=input_image_size)
self._export_from_module(
module=module,
input_type='tflite',
saved_model_dir=os.path.join(temp_dir, 'saved_model'))
tflite_model = export_tflite_lib.convert_tflite_model(
saved_model_dir=os.path.join(temp_dir, 'saved_model'),
quant_type=quant_type,
params=params,
calibration_steps=5)
self.assertIsInstance(tflite_model, bytes)
@combinations.generate(
combinations.combine(
experiment=['seg_deeplabv3_pascal'],
quant_type=[None, 'default', 'fp16'],
input_image_size=[[512, 512]]))
def test_export_tflite_semantic_segmentation(self, experiment, quant_type,
input_image_size):
params = exp_factory.get_exp_config(experiment)
temp_dir = self.get_temp_dir()
module = semantic_segmentation_serving.SegmentationModule(
params=params, batch_size=1, input_image_size=input_image_size)
self._export_from_module(
module=module,
input_type='tflite',
saved_model_dir=os.path.join(temp_dir, 'saved_model'))
tflite_model = export_tflite_lib.convert_tflite_model(
saved_model_dir=os.path.join(temp_dir, 'saved_model'),
quant_type=quant_type,
params=params,
calibration_steps=5)
self.assertIsInstance(tflite_model, bytes)
if __name__ == '__main__':
tf.test.main()
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""RetinaNet task definition."""
"""MaskRCNN task definition."""
import os
from typing import Any, Optional, List, Tuple, Mapping
......@@ -28,6 +28,7 @@ from official.vision.beta.dataloaders import tf_example_decoder
from official.vision.beta.dataloaders import tf_example_label_map_decoder
from official.vision.beta.evaluation import coco_evaluator
from official.vision.beta.evaluation import coco_utils
from official.vision.beta.evaluation import wod_detection_evaluator
from official.vision.beta.losses import maskrcnn_losses
from official.vision.beta.modeling import factory
......@@ -247,23 +248,8 @@ class MaskRCNNTask(base_task.Task):
}
return losses
def build_metrics(self, training: bool = True):
"""Build detection metrics."""
metrics = []
if training:
metric_names = [
'total_loss',
'rpn_score_loss',
'rpn_box_loss',
'frcnn_cls_loss',
'frcnn_box_loss',
'mask_loss',
'model_loss'
]
for name in metric_names:
metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32))
else:
def _build_coco_metrics(self):
"""Build COCO metrics evaluator."""
if (not self._task_config.model.include_mask
) or self._task_config.annotation_file:
self.coco_metric = coco_evaluator.COCOEvaluator(
......@@ -295,6 +281,28 @@ class MaskRCNNTask(base_task.Task):
include_mask=self._task_config.model.include_mask,
per_category_metrics=self._task_config.per_category_metrics)
def build_metrics(self, training: bool = True):
"""Build detection metrics."""
metrics = []
if training:
metric_names = [
'total_loss',
'rpn_score_loss',
'rpn_box_loss',
'frcnn_cls_loss',
'frcnn_box_loss',
'mask_loss',
'model_loss'
]
for name in metric_names:
metrics.append(tf.keras.metrics.Mean(name, dtype=tf.float32))
else:
if self._task_config.use_coco_metrics:
self._build_coco_metrics()
if self._task_config.use_wod_metrics:
self.wod_metric = wod_detection_evaluator.WOD2dDetectionEvaluator()
return metrics
def train_step(self,
......@@ -376,6 +384,7 @@ class MaskRCNNTask(base_task.Task):
training=False)
logs = {self.loss: 0}
if self._task_config.use_coco_metrics:
coco_model_outputs = {
'detection_boxes': outputs['detection_boxes'],
'detection_scores': outputs['detection_scores'],
......@@ -388,19 +397,45 @@ class MaskRCNNTask(base_task.Task):
coco_model_outputs.update({
'detection_masks': outputs['detection_masks'],
})
logs.update({
self.coco_metric.name: (labels['groundtruths'], coco_model_outputs)
})
logs.update(
{self.coco_metric.name: (labels['groundtruths'], coco_model_outputs)})
if self.task_config.use_wod_metrics:
wod_model_outputs = {
'detection_boxes': outputs['detection_boxes'],
'detection_scores': outputs['detection_scores'],
'detection_classes': outputs['detection_classes'],
'num_detections': outputs['num_detections'],
'source_id': labels['groundtruths']['source_id'],
'image_info': labels['image_info']
}
logs.update(
{self.wod_metric.name: (labels['groundtruths'], wod_model_outputs)})
return logs
def aggregate_logs(self, state=None, step_outputs=None):
if self._task_config.use_coco_metrics:
if state is None:
self.coco_metric.reset_states()
state = self.coco_metric
self.coco_metric.update_state(
step_outputs[self.coco_metric.name][0],
step_outputs[self.coco_metric.name][1])
if self._task_config.use_wod_metrics:
if state is None:
self.wod_metric.reset_states()
self.wod_metric.update_state(
step_outputs[self.wod_metric.name][0],
step_outputs[self.wod_metric.name][1])
if state is None:
# Create an arbitrary state to indicate it's not the first step in the
# following calls to this function.
state = True
return state
def reduce_aggregated_logs(self, aggregated_logs, global_step=None):
return self.coco_metric.result()
logs = {}
if self._task_config.use_coco_metrics:
logs.update(self.coco_metric.result())
if self._task_config.use_wod_metrics:
logs.update(self.wod_metric.result())
return logs
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment