Commit e649274e authored by syiming's avatar syiming
Browse files

add seperated class for resnet 50 101 152

parent 33a4c064
...@@ -40,12 +40,12 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor( ...@@ -40,12 +40,12 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor(
def __init__(self, def __init__(self,
is_training, is_training,
resnet_v1_base_model,
resnet_v1_base_model_name,
first_stage_features_stride, first_stage_features_stride,
conv_hyperparams, conv_hyperparams,
min_depth, min_depth,
depth_multiplier, depth_multiplier,
resnet_v1_base_model,
resnet_v1_base_model_name,
batch_norm_trainable=False, batch_norm_trainable=False,
weight_decay=0.0, weight_decay=0.0,
fpn_min_level=3, fpn_min_level=3,
...@@ -289,3 +289,143 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor( ...@@ -289,3 +289,143 @@ class FasterRCNNResnetV1FPNKerasFeatureExtractor(
# return feature_maps # return feature_maps
def get_box_classifier_feature_extractor_model(self, name=None): def get_box_classifier_feature_extractor_model(self, name=None):
class FasterRCNNResnet50FPNKerasFeatureExtractor(
FasterRCNNResnetV1FPNKerasFeatureExtractor):
"""Faster RCNN with Resnet50 FPN feature extractor implementation."""
def __init__(self,
is_training,
first_stage_features_stride=16,
conv_hyperparams=None,
min_depth=16,
depth_multiplier=1,
batch_norm_trainable=False,
weight_decay=0.0,
fpn_min_level=3,
fpn_max_level=7,
additional_layer_depth=256,
override_base_feature_extractor_hyperparams=False):
"""Constructor.
Args:
is_training: See base class.
first_stage_features_stride: See base class.
conv_hyperparams: See base class.
min_depth: See base class.
depth_multiplier: See base class.
batch_norm_trainable: See base class.
weight_decay: See base class.
fpn_min_level: See base class.
fpn_max_level: See base class.
additional_layer_depth: See base class.
override_base_feature_extractor_hyperparams: See base class.
"""
super(FasterRCNNResnet50KerasFeatureExtractor, self).__init__(
is_training=is_training,
first_stage_features_stride=first_stage_features_stride,
conv_hyperparams=conv_hyperparameters,
min_depth=min_depth,
depth_multiplier=depth_multiplier,
resnet_v1_base_model=resnet_v1.resnet_v1_50,
resnet_v1_base_model_name='resnet_v1_50',
batch_norm_trainable=batch_norm_trainable,
weight_decay=weight_decay,
fpn_min_level=fpn_min_level,
fpn_max_level=fpn_max_level,
additional_layer_depth=additional_layer_depth,
override_base_feature_extractor_hyperparams=override_base_feature_extractor_hyperparams)
class FasterRCNNResnet101FPNKerasFeatureExtractor(
FasterRCNNResnetV1FPNKerasFeatureExtractor):
"""Faster RCNN with Resnet101 FPN feature extractor implementation."""
def __init__(self,
is_training,
first_stage_features_stride=16,
conv_hyperparams=None,
min_depth=16,
depth_multiplier=1,
batch_norm_trainable=False,
weight_decay=0.0,
fpn_min_level=3,
fpn_max_level=7,
additional_layer_depth=256,
override_base_feature_extractor_hyperparams=False):
"""Constructor.
Args:
is_training: See base class.
first_stage_features_stride: See base class.
conv_hyperparams: See base class.
min_depth: See base class.
depth_multiplier: See base class.
batch_norm_trainable: See base class.
weight_decay: See base class.
fpn_min_level: See base class.
fpn_max_level: See base class.
additional_layer_depth: See base class.
override_base_feature_extractor_hyperparams: See base class.
"""
super(FasterRCNNResnet50KerasFeatureExtractor, self).__init__(
is_training=is_training,
first_stage_features_stride=first_stage_features_stride,
conv_hyperparams=conv_hyperparameters,
min_depth=min_depth,
depth_multiplier=depth_multiplier,
resnet_v1_base_model=resnet_v1.resnet_v1_101,
resnet_v1_base_model_name='resnet_v1_101',
batch_norm_trainable=batch_norm_trainable,
weight_decay=weight_decay,
fpn_min_level=fpn_min_level,
fpn_max_level=fpn_max_level,
additional_layer_depth=additional_layer_depth,
override_base_feature_extractor_hyperparams=override_base_feature_extractor_hyperparams)
class FasterRCNNResnet152FPNKerasFeatureExtractor(
FasterRCNNResnetV1FPNKerasFeatureExtractor):
"""Faster RCNN with Resnet152 FPN feature extractor implementation."""
def __init__(self,
is_training,
first_stage_features_stride=16,
conv_hyperparams=None,
min_depth=16,
depth_multiplier=1,
batch_norm_trainable=False,
weight_decay=0.0,
fpn_min_level=3,
fpn_max_level=7,
additional_layer_depth=256,
override_base_feature_extractor_hyperparams=False):
"""Constructor.
Args:
is_training: See base class.
first_stage_features_stride: See base class.
conv_hyperparams: See base class.
min_depth: See base class.
depth_multiplier: See base class.
batch_norm_trainable: See base class.
weight_decay: See base class.
fpn_min_level: See base class.
fpn_max_level: See base class.
additional_layer_depth: See base class.
override_base_feature_extractor_hyperparams: See base class.
"""
super(FasterRCNNResnet50KerasFeatureExtractor, self).__init__(
is_training=is_training,
first_stage_features_stride=first_stage_features_stride,
conv_hyperparams=conv_hyperparameters,
min_depth=min_depth,
depth_multiplier=depth_multiplier,
resnet_v1_base_model=resnet_v1.resnet_v1_152,
resnet_v1_base_model_name='resnet_v1_152',
batch_norm_trainable=batch_norm_trainable,
weight_decay=weight_decay,
fpn_min_level=fpn_min_level,
fpn_max_level=fpn_max_level,
additional_layer_depth=additional_layer_depth,
override_base_feature_extractor_hyperparams=override_base_feature_extractor_hyperparams)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment