"git@developer.sourcefind.cn:change/sglang.git" did not exist on "39efad4fbc4a708deb87fd421d91a8eb37c59246"
Commit b53e5dc0 authored by Yu-hui Chen's avatar Yu-hui Chen Committed by TF Object Detection Team
Browse files

Updated the config_util to support the hyperparameter search

over new parameters in the CenterNet model.

PiperOrigin-RevId: 381331623
parent 4e886c07
...@@ -639,6 +639,12 @@ def _maybe_update_config_with_key_value(configs, key, value): ...@@ -639,6 +639,12 @@ def _maybe_update_config_with_key_value(configs, key, value):
_update_rescore_instances(configs["model"], value) _update_rescore_instances(configs["model"], value)
elif field_name == "unmatched_keypoint_score": elif field_name == "unmatched_keypoint_score":
_update_unmatched_keypoint_score(configs["model"], value) _update_unmatched_keypoint_score(configs["model"], value)
elif field_name == "score_distance_multiplier":
_update_score_distance_multiplier(configs["model"], value)
elif field_name == "std_dev_multiplier":
_update_std_dev_multiplier(configs["model"], value)
elif field_name == "rescoring_threshold":
_update_rescoring_threshold(configs["model"], value)
else: else:
return False return False
return True return True
...@@ -1135,10 +1141,12 @@ def _update_candidate_search_scale(model_config, search_scale): ...@@ -1135,10 +1141,12 @@ def _update_candidate_search_scale(model_config, search_scale):
def _update_candidate_ranking_mode(model_config, mode): def _update_candidate_ranking_mode(model_config, mode):
"""Updates how keypoints are snapped to candidates in CenterNet.""" """Updates how keypoints are snapped to candidates in CenterNet."""
if mode not in ("min_distance", "score_distance_ratio"): if mode not in ("min_distance", "score_distance_ratio",
"score_scaled_distance_ratio", "gaussian_weighted"):
raise ValueError("Attempting to set the keypoint candidate ranking mode " raise ValueError("Attempting to set the keypoint candidate ranking mode "
"to {}, but the only options are 'min_distance' and " "to {}, but the only options are 'min_distance', "
"'score_distance_ratio'.".format(mode)) "'score_distance_ratio', 'score_scaled_distance_ratio', "
"'gaussian_weighted'.".format(mode))
meta_architecture = model_config.WhichOneof("model") meta_architecture = model_config.WhichOneof("model")
if meta_architecture == "center_net": if meta_architecture == "center_net":
if len(model_config.center_net.keypoint_estimation_task) == 1: if len(model_config.center_net.keypoint_estimation_task) == 1:
...@@ -1214,3 +1222,50 @@ def _update_unmatched_keypoint_score(model_config, score): ...@@ -1214,3 +1222,50 @@ def _update_unmatched_keypoint_score(model_config, score):
"unmatched_keypoint_score since there are multiple " "unmatched_keypoint_score since there are multiple "
"keypoint estimation tasks") "keypoint estimation tasks")
def _update_score_distance_multiplier(model_config, score_distance_multiplier):
"""Updates the keypoint candidate selection metric. See CenterNet proto."""
meta_architecture = model_config.WhichOneof("model")
if meta_architecture == "center_net":
if len(model_config.center_net.keypoint_estimation_task) == 1:
kpt_estimation_task = model_config.center_net.keypoint_estimation_task[0]
kpt_estimation_task.score_distance_multiplier = score_distance_multiplier
else:
tf.logging.warning("Ignoring config override key for "
"score_distance_multiplier since there are multiple "
"keypoint estimation tasks")
else:
raise ValueError(
"Unsupported meta_architecture type: %s" % meta_architecture)
def _update_std_dev_multiplier(model_config, std_dev_multiplier):
"""Updates the keypoint candidate selection metric. See CenterNet proto."""
meta_architecture = model_config.WhichOneof("model")
if meta_architecture == "center_net":
if len(model_config.center_net.keypoint_estimation_task) == 1:
kpt_estimation_task = model_config.center_net.keypoint_estimation_task[0]
kpt_estimation_task.std_dev_multiplier = std_dev_multiplier
else:
tf.logging.warning("Ignoring config override key for "
"std_dev_multiplier since there are multiple "
"keypoint estimation tasks")
else:
raise ValueError(
"Unsupported meta_architecture type: %s" % meta_architecture)
def _update_rescoring_threshold(model_config, rescoring_threshold):
"""Updates the keypoint candidate selection metric. See CenterNet proto."""
meta_architecture = model_config.WhichOneof("model")
if meta_architecture == "center_net":
if len(model_config.center_net.keypoint_estimation_task) == 1:
kpt_estimation_task = model_config.center_net.keypoint_estimation_task[0]
kpt_estimation_task.rescoring_threshold = rescoring_threshold
else:
tf.logging.warning("Ignoring config override key for "
"rescoring_threshold since there are multiple "
"keypoint estimation tasks")
else:
raise ValueError(
"Unsupported meta_architecture type: %s" % meta_architecture)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment