Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9c1887a8
"vscode:/vscode.git/clone" did not exist on "4b4a90f233ff807994f8de78b1f9b1687b6328a4"
Commit
9c1887a8
authored
Apr 21, 2020
by
Pengchong Jin
Committed by
A. Unique TensorFlower
Apr 21, 2020
Browse files
ParamsDict update.
PiperOrigin-RevId: 307639416
parent
5e539a3d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
3 deletions
+37
-3
official/modeling/hyperparams/params_dict.py
official/modeling/hyperparams/params_dict.py
+19
-1
official/modeling/hyperparams/params_dict_test.py
official/modeling/hyperparams/params_dict_test.py
+18
-2
No files found.
official/modeling/hyperparams/params_dict.py
View file @
9c1887a8
...
@@ -125,6 +125,25 @@ class ParamsDict(object):
...
@@ -125,6 +125,25 @@ class ParamsDict(object):
"""Accesses through built-in dictionary get method."""
"""Accesses through built-in dictionary get method."""
return
self
.
__dict__
.
get
(
key
,
value
)
return
self
.
__dict__
.
get
(
key
,
value
)
def
__delattr__
(
self
,
k
):
"""Deletes the key and removes its values.
Args:
k: the key string.
Raises:
AttributeError: if k is reserverd or not defined in the ParamsDict.
ValueError: if the ParamsDict instance has been locked.
"""
if
k
in
ParamsDict
.
RESERVED_ATTR
:
raise
AttributeError
(
'The key `{}` is reserved. No change is allowes. '
.
format
(
k
))
if
k
not
in
self
.
__dict__
.
keys
():
raise
AttributeError
(
'The key `{}` does not exist. '
.
format
(
k
))
if
self
.
_locked
:
raise
ValueError
(
'The ParamsDict has been locked. No change is allowed.'
)
del
self
.
__dict__
[
k
]
def
override
(
self
,
override_params
,
is_strict
=
True
):
def
override
(
self
,
override_params
,
is_strict
=
True
):
"""Override the ParamsDict with a set of given params.
"""Override the ParamsDict with a set of given params.
...
@@ -286,7 +305,6 @@ def read_yaml_to_params_dict(file_path):
...
@@ -286,7 +305,6 @@ def read_yaml_to_params_dict(file_path):
def
save_params_dict_to_yaml
(
params
,
file_path
):
def
save_params_dict_to_yaml
(
params
,
file_path
):
"""Saves the input ParamsDict to a YAML file."""
"""Saves the input ParamsDict to a YAML file."""
with
tf
.
io
.
gfile
.
GFile
(
file_path
,
'w'
)
as
f
:
with
tf
.
io
.
gfile
.
GFile
(
file_path
,
'w'
)
as
f
:
def
_my_list_rep
(
dumper
,
data
):
def
_my_list_rep
(
dumper
,
data
):
# u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence.
# u'tag:yaml.org,2002:seq' is the YAML internal tag for sequence.
return
dumper
.
represent_sequence
(
return
dumper
.
represent_sequence
(
...
...
official/modeling/hyperparams/params_dict_test.py
View file @
9c1887a8
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Tests for
official.modeling.hyperparams.
params_dict.py."""
"""Tests for params_dict.py."""
import
os
import
os
...
@@ -45,12 +45,14 @@ class ParamsDictTest(tf.test.TestCase):
...
@@ -45,12 +45,14 @@ class ParamsDictTest(tf.test.TestCase):
self
.
assertEqual
(
params
.
b
,
2
)
self
.
assertEqual
(
params
.
b
,
2
)
def
test_lock
(
self
):
def
test_lock
(
self
):
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
2
})
params
=
params_dict
.
ParamsDict
({
'a'
:
1
,
'b'
:
2
,
'c'
:
3
})
params
.
lock
()
params
.
lock
()
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
params
.
a
=
10
params
.
a
=
10
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
params
.
override
({
'b'
:
20
})
params
.
override
({
'b'
:
20
})
with
self
.
assertRaises
(
ValueError
):
del
params
.
c
def
test_setattr
(
self
):
def
test_setattr
(
self
):
params
=
params_dict
.
ParamsDict
()
params
=
params_dict
.
ParamsDict
()
...
@@ -69,6 +71,20 @@ class ParamsDictTest(tf.test.TestCase):
...
@@ -69,6 +71,20 @@ class ParamsDictTest(tf.test.TestCase):
self
.
assertEqual
(
params
.
b
,
2
)
self
.
assertEqual
(
params
.
b
,
2
)
self
.
assertEqual
(
params
.
c
,
None
)
self
.
assertEqual
(
params
.
c
,
None
)
def
test_delattr
(
self
):
params
=
params_dict
.
ParamsDict
()
params
.
override
(
{
'a'
:
'aa'
,
'b'
:
2
,
'c'
:
None
,
'd'
:
{
'd1'
:
1
,
'd2'
:
10
}},
is_strict
=
False
)
del
params
.
c
self
.
assertEqual
(
params
.
a
,
'aa'
)
self
.
assertEqual
(
params
.
b
,
2
)
with
self
.
assertRaises
(
AttributeError
):
_
=
params
.
c
del
params
.
d
with
self
.
assertRaises
(
AttributeError
):
_
=
params
.
d
.
d1
def
test_contains
(
self
):
def
test_contains
(
self
):
params
=
params_dict
.
ParamsDict
()
params
=
params_dict
.
ParamsDict
()
params
.
override
(
params
.
override
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment