Commit 92700f62 authored by syiming's avatar syiming
Browse files

Adding kwargs for rpn feature extractor. Temp: force running keras model using tf1

parent 90106a40
...@@ -72,6 +72,7 @@ if tf_version.is_tf1(): ...@@ -72,6 +72,7 @@ if tf_version.is_tf1():
from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1 from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
from object_detection.models import ssd_resnet_v1_fpn_feature_extractor as ssd_resnet_v1_fpn from object_detection.models import ssd_resnet_v1_fpn_feature_extractor as ssd_resnet_v1_fpn
from object_detection.models import ssd_resnet_v1_ppn_feature_extractor as ssd_resnet_v1_ppn from object_detection.models import ssd_resnet_v1_ppn_feature_extractor as ssd_resnet_v1_ppn
from object_detection.models import faster_rcnn_resnet_v1_fpn_keras_feature_extractor as frcnn_resnet_fpn_keras
from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor
from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
from object_detection.models.ssd_mobilenet_v2_fpn_feature_extractor import SSDMobileNetV2FpnFeatureExtractor from object_detection.models.ssd_mobilenet_v2_fpn_feature_extractor import SSDMobileNetV2FpnFeatureExtractor
...@@ -229,9 +230,19 @@ if tf_version.is_tf1(): ...@@ -229,9 +230,19 @@ if tf_version.is_tf1():
frcnn_resnet_v1.FasterRCNNResnet152FeatureExtractor, frcnn_resnet_v1.FasterRCNNResnet152FeatureExtractor,
} }
FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP = {
'faster_rcnn_resnet50_fpn_keras':
frcnn_resnet_fpn_keras.FasterRCNNResnet50FpnKerasFeatureExtractor,
'faster_rcnn_resnet101_fpn_keras':
frcnn_resnet_fpn_keras.FasterRCNNResnet101FpnKerasFeatureExtractor,
'faster_rcnn_resnet152_fpn_keras':
frcnn_resnet_fpn_keras.FasterRCNNResnet152FpnKerasFeatureExtractor,
}
FEATURE_EXTRACTOR_MAPS = [ FEATURE_EXTRACTOR_MAPS = [
SSD_FEATURE_EXTRACTOR_CLASS_MAP, SSD_FEATURE_EXTRACTOR_CLASS_MAP,
FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP,
FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP
] ]
...@@ -480,6 +491,7 @@ def _build_faster_rcnn_feature_extractor( ...@@ -480,6 +491,7 @@ def _build_faster_rcnn_feature_extractor(
first_stage_features_stride = ( first_stage_features_stride = (
feature_extractor_config.first_stage_features_stride) feature_extractor_config.first_stage_features_stride)
batch_norm_trainable = feature_extractor_config.batch_norm_trainable batch_norm_trainable = feature_extractor_config.batch_norm_trainable
print(feature_type)
if feature_type not in FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP: if feature_type not in FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP:
raise ValueError('Unknown Faster R-CNN feature_extractor: {}'.format( raise ValueError('Unknown Faster R-CNN feature_extractor: {}'.format(
...@@ -524,9 +536,31 @@ def _build_faster_rcnn_keras_feature_extractor( ...@@ -524,9 +536,31 @@ def _build_faster_rcnn_keras_feature_extractor(
feature_type)) feature_type))
feature_extractor_class = FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP[ feature_extractor_class = FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP[
feature_type] feature_type]
kwargs = {}
if feature_extractor_config.HasField('conv_hyperparams'):
kwargs.update({
'conv_hyperparams':
hyperparams_builder.KerasLayerHyperparams(
feature_extractor_config.conv_hyperparams),
'override_base_feature_extractor_hyperparams':
feature_extractor_config.override_base_feature_extractor_hyperparams
})
if feature_extractor_config.HasField('fpn'):
kwargs.update({
'fpn_min_level':
feature_extractor_config.fpn.min_level,
'fpn_max_level':
feature_extractor_config.fpn.max_level,
'additional_layer_depth':
feature_extractor_config.fpn.additional_layer_depth,
})
return feature_extractor_class( return feature_extractor_class(
is_training, first_stage_features_stride, is_training, first_stage_features_stride,
batch_norm_trainable) batch_norm_trainable, **kwargs)
def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries): def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
...@@ -553,7 +587,7 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries): ...@@ -553,7 +587,7 @@ def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
_check_feature_extractor_exists(frcnn_config.feature_extractor.type) _check_feature_extractor_exists(frcnn_config.feature_extractor.type)
is_keras = tf_version.is_tf2() is_keras = tf_version.is_tf2()
if is_keras: if is_keras or True:
feature_extractor = _build_faster_rcnn_keras_feature_extractor( feature_extractor = _build_faster_rcnn_keras_feature_extractor(
frcnn_config.feature_extractor, is_training, frcnn_config.feature_extractor, is_training,
inplace_batchnorm_update=frcnn_config.inplace_batchnorm_update) inplace_batchnorm_update=frcnn_config.inplace_batchnorm_update)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment