"data/vscode:/vscode.git/clone" did not exist on "ceb3cee72d2b509d8bb455531fc3967a32ffa447"
Commit 51078b5d authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Add support for RetinaNet QAT to TFLite model. Also update config files to...

Add support for RetinaNet QAT to TFLite model. Also update config files to reflect the latest numbers.

PiperOrigin-RevId: 470861612
parent 4fe13a8b
...@@ -15,4 +15,5 @@ ...@@ -15,4 +15,5 @@
"""Configs package definition.""" """Configs package definition."""
from official.projects.qat.vision.configs import image_classification from official.projects.qat.vision.configs import image_classification
from official.projects.qat.vision.configs import retinanet
from official.projects.qat.vision.configs import semantic_segmentation from official.projects.qat.vision.configs import semantic_segmentation
# --experiment_type=retinanet_spinenet_mobile_coco_qat # --experiment_type=retinanet_mobile_coco_qat
# COCO mAP: 21.62 (Evaluated on the TFLite after conversion.) # COCO mAP: 23.02 from QAT training and 21.62 from the TFLite after conversion.
# QAT only supports float32 tpu due to fake-quant op. # QAT only supports float32 tpu due to fake-quant op.
runtime: runtime:
distribution_strategy: 'tpu' distribution_strategy: 'tpu'
mixed_precision_dtype: 'float32' mixed_precision_dtype: 'float32'
task: task:
losses: losses:
l2_weight_decay: 3.0e-04 l2_weight_decay: 0.0
model: model:
anchor: anchor:
anchor_size: 3 anchor_size: 3
...@@ -45,24 +45,28 @@ task: ...@@ -45,24 +45,28 @@ task:
aug_scale_min: 0.5 aug_scale_min: 0.5
validation_data: validation_data:
dtype: 'float32' dtype: 'float32'
global_batch_size: 64 global_batch_size: 256
is_training: false is_training: false
drop_remainder: false
quantization: quantization:
pretrained_original_checkpoint: 'gs://**/tf2_mobilenetv2_ssd_jul29/28129005/ckpt-277200' pretrained_original_checkpoint: 'gs://**/coco_mobilenetv2_mobile_tpu/ckpt-277200'
quantize_detection_decoder: true quantize_detection_decoder: true
quantize_detection_head: true quantize_detection_head: true
trainer: trainer:
best_checkpoint_eval_metric: AP
best_checkpoint_export_subdir: best_ckpt
best_checkpoint_metric_comp: higher
optimizer_config: optimizer_config:
learning_rate: learning_rate:
type: 'exponential' type: 'exponential'
exponential: exponential:
decay_rate: 0.96 decay_rate: 0.96
decay_steps: 231 decay_steps: 231
initial_learning_rate: 0.2 initial_learning_rate: 0.5
name: 'ExponentialDecay' name: 'ExponentialDecay'
offset: 0 offset: 0
staircase: true staircase: true
steps_per_loop: 462 steps_per_loop: 462
train_steps: 46200 train_steps: 46200
validation_interval: 462 validation_interval: 462
validation_steps: 625 validation_steps: 20
# --experiment_type=retinanet_spinenet_mobile_coco_qat # --experiment_type=retinanet_mobile_coco_qat
runtime: runtime:
distribution_strategy: 'mirrored' distribution_strategy: 'mirrored'
mixed_precision_dtype: 'float32' mixed_precision_dtype: 'float32'
......
# --experiment_type=retinanet_spinenet_mobile_coco_qat # --experiment_type=retinanet_mobile_coco_qat
# COCO mAP: 24.7 # COCO mAP: 24.7
# QAT only supports float32 tpu due to fake-quant op. # QAT only supports float32 tpu due to fake-quant op.
runtime: runtime:
......
# --experiment_type=retinanet_spinenet_mobile_coco_qat # --experiment_type=retinanet_mobile_coco_qat
# COCO mAP: 23.2 # COCO mAP: 23.2
# QAT only supports float32 tpu due to fake-quant op. # QAT only supports float32 tpu due to fake-quant op.
runtime: runtime:
......
...@@ -28,8 +28,8 @@ class RetinaNetTask(retinanet.RetinaNetTask): ...@@ -28,8 +28,8 @@ class RetinaNetTask(retinanet.RetinaNetTask):
quantization: Optional[common.Quantization] = None quantization: Optional[common.Quantization] = None
@exp_factory.register_config_factory('retinanet_spinenet_mobile_coco_qat') @exp_factory.register_config_factory('retinanet_mobile_coco_qat')
def retinanet_spinenet_mobile_coco() -> cfg.ExperimentConfig: def retinanet_mobile_coco() -> cfg.ExperimentConfig:
"""Generates a config for COCO OD RetinaNet for mobile with QAT.""" """Generates a config for COCO OD RetinaNet for mobile with QAT."""
config = retinanet.retinanet_spinenet_mobile_coco() config = retinanet.retinanet_spinenet_mobile_coco()
task = RetinaNetTask.from_args( task = RetinaNetTask.from_args(
......
...@@ -28,7 +28,7 @@ from official.vision.configs import retinanet as exp_cfg ...@@ -28,7 +28,7 @@ from official.vision.configs import retinanet as exp_cfg
class RetinaNetConfigTest(tf.test.TestCase, parameterized.TestCase): class RetinaNetConfigTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters( @parameterized.parameters(
('retinanet_spinenet_mobile_coco_qat',), ('retinanet_mobile_coco_qat',),
) )
def test_retinanet_configs(self, config_name): def test_retinanet_configs(self, config_name):
config = exp_factory.get_exp_config(config_name) config = exp_factory.get_exp_config(config_name)
......
...@@ -13,10 +13,12 @@ ...@@ -13,10 +13,12 @@
# limitations under the License. # limitations under the License.
"""Export modules for QAT model serving/inference.""" """Export modules for QAT model serving/inference."""
from absl import logging
import tensorflow as tf import tensorflow as tf
from official.projects.qat.vision.modeling import factory as qat_factory from official.projects.qat.vision.modeling import factory as qat_factory
from official.vision import configs
from official.vision.serving import detection
from official.vision.serving import image_classification from official.vision.serving import image_classification
from official.vision.serving import semantic_segmentation from official.vision.serving import semantic_segmentation
...@@ -42,3 +44,25 @@ class SegmentationModule(semantic_segmentation.SegmentationModule): ...@@ -42,3 +44,25 @@ class SegmentationModule(semantic_segmentation.SegmentationModule):
self._input_image_size + [3]) self._input_image_size + [3])
return qat_factory.build_qat_segmentation_model( return qat_factory.build_qat_segmentation_model(
model, self.params.task.quantization, input_specs) model, self.params.task.quantization, input_specs)
class DetectionModule(detection.DetectionModule):
"""Detection Module."""
def _build_model(self):
if self.params.task.model.detection_generator.nms_version != 'tflite':
self.params.task.model.detection_generator.nms_version = 'tflite'
logging.info('Set `nms_version` to `tflite` because only TFLite NMS is '
'supported for QAT detection models.')
model = super()._build_model()
if isinstance(self.params.task.model, configs.retinanet.RetinaNet):
model = qat_factory.build_qat_retinanet(model,
self.params.task.quantization,
self.params.task.model)
else:
raise ValueError('Detection module not implemented for {} model.'.format(
type(self.params.task.model)))
return model
...@@ -34,7 +34,6 @@ imported = tf.saved_model.load(export_dir_path) ...@@ -34,7 +34,6 @@ imported = tf.saved_model.load(export_dir_path)
model_fn = imported.signatures['serving_default'] model_fn = imported.signatures['serving_default']
output = model_fn(input_images) output = model_fn(input_images)
""" """
from absl import app from absl import app
from absl import flags from absl import flags
...@@ -106,6 +105,8 @@ def main(_): ...@@ -106,6 +105,8 @@ def main(_):
if isinstance(params.task, if isinstance(params.task,
configs.image_classification.ImageClassificationTask): configs.image_classification.ImageClassificationTask):
export_module_cls = export_module.ClassificationModule export_module_cls = export_module.ClassificationModule
elif isinstance(params.task, configs.retinanet.RetinaNetTask):
export_module_cls = export_module.DetectionModule
elif isinstance(params.task, elif isinstance(params.task,
configs.semantic_segmentation.SemanticSegmentationTask): configs.semantic_segmentation.SemanticSegmentationTask):
export_module_cls = export_module.SegmentationModule export_module_cls = export_module.SegmentationModule
......
...@@ -36,8 +36,8 @@ class RetinaNetTaskTest(parameterized.TestCase, tf.test.TestCase): ...@@ -36,8 +36,8 @@ class RetinaNetTaskTest(parameterized.TestCase, tf.test.TestCase):
record_file=tfrecord_file, tf_examples=examples) record_file=tfrecord_file, tf_examples=examples)
@parameterized.parameters( @parameterized.parameters(
('retinanet_spinenet_mobile_coco_qat', True), ('retinanet_mobile_coco_qat', True),
('retinanet_spinenet_mobile_coco_qat', False), ('retinanet_mobile_coco_qat', False),
) )
def test_retinanet_task(self, test_config, is_training): def test_retinanet_task(self, test_config, is_training):
"""RetinaNet task test for training and val using toy configs.""" """RetinaNet task test for training and val using toy configs."""
......
...@@ -21,6 +21,7 @@ task: ...@@ -21,6 +21,7 @@ task:
fpn: fpn:
num_filters: 128 num_filters: 128
use_separable_conv: true use_separable_conv: true
use_keras_layer: true
head: head:
num_convs: 4 num_convs: 4
num_filters: 128 num_filters: 128
...@@ -43,8 +44,9 @@ task: ...@@ -43,8 +44,9 @@ task:
aug_scale_min: 0.5 aug_scale_min: 0.5
validation_data: validation_data:
dtype: 'bfloat16' dtype: 'bfloat16'
global_batch_size: 8 global_batch_size: 256
is_training: false is_training: false
drop_remainder: false
trainer: trainer:
optimizer_config: optimizer_config:
learning_rate: learning_rate:
...@@ -59,4 +61,4 @@ trainer: ...@@ -59,4 +61,4 @@ trainer:
steps_per_loop: 462 steps_per_loop: 462
train_steps: 277200 train_steps: 277200
validation_interval: 462 validation_interval: 462
validation_steps: 625 validation_steps: 20
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment