Commit 35492c3d authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 325713616
parent c3fe0550
......@@ -75,7 +75,11 @@ class Task(tf.Module):
if not ckpt_dir_or_file:
return
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
if hasattr(model, "checkpoint_items"):
checkpoint_items = model.checkpoint_items
else:
checkpoint_items = dict(model=model)
ckpt = tf.train.Checkpoint(**checkpoint_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info("Finished loading pretrained checkpoint from %s",
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -15,11 +14,6 @@
# ==============================================================================
"""Base configurations to standardize experiments."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import copy
import functools
from typing import Any, List, Mapping, Optional, Type
......@@ -220,9 +214,12 @@ class Config(params_dict.ParamsDict):
}
def replace(self, **kwargs):
"""Like `override`, but returns a copy with the current config unchanged."""
params = self.__class__(self)
params.override(kwargs, is_strict=True)
"""Overrides/returns a unlocked copy with the current config unchanged."""
# pylint: disable=protected-access
params = copy.deepcopy(self)
params._locked = False
params._override(kwargs, is_strict=True)
# pylint: enable=protected-access
return params
@classmethod
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -106,6 +105,22 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
self.assertEqual(config.g[0].a, 4)
self.assertEqual(config.g[0].b, 'new text 3')
def test_replace(self):
config = DumpConfig2()
new_config = config.replace(e={'a': 2})
self.assertEqual(new_config.e.a, 2)
self.assertIsInstance(new_config.e, DumpConfig1)
config = DumpConfig2(e=DumpConfig2())
new_config = config.replace(e={'c': 4})
self.assertEqual(new_config.e.c, 4)
self.assertIsInstance(new_config.e, DumpConfig2)
config = DumpConfig3()
new_config = config.replace(g=[{'a': 4, 'b': 'new text 3'}])
self.assertIsInstance(new_config.g[0], DumpConfig1)
self.assertEqual(new_config.g[0].a, 4)
@parameterized.parameters(
('_locked', "The key '_locked' is internally reserved."),
('_restrictions', "The key '_restrictions' is internally reserved."),
......@@ -294,6 +309,11 @@ class BaseConfigTest(parameterized.TestCase, tf.test.TestCase):
]),
"['s', 1, 1.0, True, None, {}, [], (), {8: 9, (2,): (3, [4], {6: 7})}]")
def test_with_restrictions(self):
restrictions = ['e.a<c']
config = DumpConfig2(restrictions=restrictions)
config.validate()
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment