Commit 8a835dbb authored by A. Unique TensorFlower's avatar A. Unique TensorFlower Committed by TF Object Detection Team
Browse files

Modified the validate_tf_v2_checkpoint_restore_map function when the input is...

Modified the validate_tf_v2_checkpoint_restore_map function when the input is a recurred dictionary.

PiperOrigin-RevId: 327564315
parent ef76912d
...@@ -282,7 +282,8 @@ def validate_tf_v2_checkpoint_restore_map(checkpoint_restore_map): ...@@ -282,7 +282,8 @@ def validate_tf_v2_checkpoint_restore_map(checkpoint_restore_map):
"""Ensure that given dict is a valid TF v2 style restore map. """Ensure that given dict is a valid TF v2 style restore map.
Args: Args:
checkpoint_restore_map: A dict mapping strings to tf.keras.Model objects. checkpoint_restore_map: A nested dict mapping strings to
tf.keras.Model objects.
Raises: Raises:
ValueError: If they keys in checkpoint_restore_map are not strings or if ValueError: If they keys in checkpoint_restore_map are not strings or if
...@@ -294,8 +295,12 @@ def validate_tf_v2_checkpoint_restore_map(checkpoint_restore_map): ...@@ -294,8 +295,12 @@ def validate_tf_v2_checkpoint_restore_map(checkpoint_restore_map):
if not (isinstance(key, str) and if not (isinstance(key, str) and
(isinstance(value, tf.Module) (isinstance(value, tf.Module)
or isinstance(value, tf.train.Checkpoint))): or isinstance(value, tf.train.Checkpoint))):
raise TypeError(RESTORE_MAP_ERROR_TEMPLATE.format( if isinstance(key, str) and isinstance(value, dict):
key.__class__.__name__, value.__class__.__name__)) validate_tf_v2_checkpoint_restore_map(value)
else:
raise TypeError(
RESTORE_MAP_ERROR_TEMPLATE.format(key.__class__.__name__,
value.__class__.__name__))
def is_object_based_checkpoint(checkpoint_path): def is_object_based_checkpoint(checkpoint_path):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment