"tutorials/vscode:/vscode.git/clone" did not exist on "59ec7a9c938ce59a36877a8396851b65d510fa4c"
Commit 198af9e8 authored by Menglong Zhu's avatar Menglong Zhu Committed by dreamdragon
Browse files

Refactor ssd_meta_arch so that the target assigner instance is passed into the...

Refactor ssd_meta_arch so that the target assigner instance is passed into the SSDMetaArch constructor rather than constructed inside.

PiperOrigin-RevId: 207787053
parent 465ebbd4
...@@ -41,10 +41,7 @@ class LSTMMetaArch(ssd_meta_arch.SSDMetaArch): ...@@ -41,10 +41,7 @@ class LSTMMetaArch(ssd_meta_arch.SSDMetaArch):
box_predictor, box_predictor,
box_coder, box_coder,
feature_extractor, feature_extractor,
matcher,
region_similarity_calculator,
encode_background_as_zeros, encode_background_as_zeros,
negative_class_weight,
image_resizer_fn, image_resizer_fn,
non_max_suppression_fn, non_max_suppression_fn,
score_conversion_fn, score_conversion_fn,
...@@ -55,14 +52,26 @@ class LSTMMetaArch(ssd_meta_arch.SSDMetaArch): ...@@ -55,14 +52,26 @@ class LSTMMetaArch(ssd_meta_arch.SSDMetaArch):
normalize_loss_by_num_matches, normalize_loss_by_num_matches,
hard_example_miner, hard_example_miner,
unroll_length, unroll_length,
target_assigner_instance,
add_summaries=True): add_summaries=True):
super(LSTMMetaArch, self).__init__( super(LSTMMetaArch, self).__init__(
is_training, anchor_generator, box_predictor, box_coder, is_training=is_training,
feature_extractor, matcher, region_similarity_calculator, anchor_generator=anchor_generator,
encode_background_as_zeros, negative_class_weight, image_resizer_fn, box_predictor=box_predictor,
non_max_suppression_fn, score_conversion_fn, classification_loss, box_coder=box_coder,
localization_loss, classification_loss_weight, localization_loss_weight, feature_extractor=feature_extractor,
normalize_loss_by_num_matches, hard_example_miner, add_summaries) encode_background_as_zeros=encode_background_as_zeros,
image_resizer_fn=image_resizer_fn,
non_max_suppression_fn=non_max_suppression_fn,
score_conversion_fn=score_conversion_fn,
classification_loss=classification_loss,
localization_loss=localization_loss,
classification_loss_weight=classification_loss_weight,
localization_loss_weight=localization_loss_weight,
normalize_loss_by_num_matches=normalize_loss_by_num_matches,
hard_example_miner=hard_example_miner,
target_assigner_instance=target_assigner_instance,
add_summaries=add_summaries)
self._unroll_length = unroll_length self._unroll_length = unroll_length
@property @property
......
...@@ -26,6 +26,7 @@ from google3.third_party.tensorflow_models.object_detection.builders import matc ...@@ -26,6 +26,7 @@ from google3.third_party.tensorflow_models.object_detection.builders import matc
from google3.third_party.tensorflow_models.object_detection.builders import model_builder from google3.third_party.tensorflow_models.object_detection.builders import model_builder
from google3.third_party.tensorflow_models.object_detection.builders import post_processing_builder from google3.third_party.tensorflow_models.object_detection.builders import post_processing_builder
from google3.third_party.tensorflow_models.object_detection.builders import region_similarity_calculator_builder as sim_calc from google3.third_party.tensorflow_models.object_detection.builders import region_similarity_calculator_builder as sim_calc
from google3.third_party.tensorflow_models.object_detection.core import target_assigner
model_builder.SSD_FEATURE_EXTRACTOR_CLASS_MAP.update({ model_builder.SSD_FEATURE_EXTRACTOR_CLASS_MAP.update({
'lstm_mobilenet_v1': LSTMMobileNetV1FeatureExtractor, 'lstm_mobilenet_v1': LSTMMobileNetV1FeatureExtractor,
...@@ -140,12 +141,29 @@ def _build_lstm_model(ssd_config, lstm_config, is_training): ...@@ -140,12 +141,29 @@ def _build_lstm_model(ssd_config, lstm_config, is_training):
if unroll_length is None: if unroll_length is None:
raise ValueError('No unroll length found in the config file') raise ValueError('No unroll length found in the config file')
target_assigner_instance = target_assigner.TargetAssigner(
region_similarity_calculator,
matcher,
box_coder,
negative_class_weight=negative_class_weight)
lstm_model = lstm_meta_arch.LSTMMetaArch( lstm_model = lstm_meta_arch.LSTMMetaArch(
is_training, anchor_generator, ssd_box_predictor, box_coder, is_training=is_training,
feature_extractor, matcher, region_similarity_calculator, anchor_generator=anchor_generator,
encode_background_as_zeros, negative_class_weight, image_resizer_fn, box_predictor=ssd_box_predictor,
non_max_suppression_fn, score_conversion_fn, classification_loss, box_coder=box_coder,
localization_loss, classification_weight, localization_weight, feature_extractor=feature_extractor,
normalize_loss_by_num_matches, miner, unroll_length) encode_background_as_zeros=encode_background_as_zeros,
image_resizer_fn=image_resizer_fn,
non_max_suppression_fn=non_max_suppression_fn,
score_conversion_fn=score_conversion_fn,
classification_loss=classification_loss,
localization_loss=localization_loss,
classification_loss_weight=classification_weight,
localization_loss_weight=localization_weight,
normalize_loss_by_num_matches=normalize_loss_by_num_matches,
hard_example_miner=miner,
unroll_length=unroll_length,
target_assigner_instance=target_assigner_instance)
return lstm_model return lstm_model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment