Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
980b27d5
Commit
980b27d5
authored
May 28, 2020
by
Abdullah Rashwan
Committed by
A. Unique TensorFlower
May 28, 2020
Browse files
Internal change
PiperOrigin-RevId: 313662797
parent
abf60128
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
129 additions
and
0 deletions
+129
-0
official/modeling/hyperparams/oneof.py
official/modeling/hyperparams/oneof.py
+62
-0
official/modeling/hyperparams/oneof_test.py
official/modeling/hyperparams/oneof_test.py
+67
-0
No files found.
official/modeling/hyperparams/oneof.py
0 → 100644
View file @
980b27d5
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Config class that supports oneof functionality."""
from
typing
import
Optional
import
dataclasses
from
official.modeling.hyperparams
import
base_config
@
dataclasses
.
dataclass
class
OneOfConfig
(
base_config
.
Config
):
"""Configuration for configs with one of feature.
Attributes:
type: 'str', name of the field to select.
"""
type
:
Optional
[
str
]
=
None
def
as_dict
(
self
):
"""Returns a dict representation of OneOfConfig.
For the nested base_config.Config, a nested dict will be returned.
"""
if
self
.
type
is
None
:
return
{
'type'
:
None
}
elif
self
.
__dict__
[
'type'
]
not
in
self
.
__dict__
:
raise
ValueError
(
'type: {!r} is not a valid key!'
.
format
(
self
.
__dict__
[
'type'
]))
else
:
chosen_type
=
self
.
type
chosen_value
=
self
.
__dict__
[
chosen_type
]
return
{
'type'
:
self
.
type
,
chosen_type
:
self
.
_export_config
(
chosen_value
)
}
def
get
(
self
):
"""Returns selected config based on the value of type.
If type is not set (None), None is returned.
"""
chosen_type
=
self
.
type
if
chosen_type
is
None
:
return
None
if
chosen_type
not
in
self
.
__dict__
:
raise
ValueError
(
'type: {!r} is not a valid key!'
.
format
(
self
.
type
))
return
self
.
__dict__
[
chosen_type
]
official/modeling/hyperparams/oneof_test.py
0 → 100644
View file @
980b27d5
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import
dataclasses
import
tensorflow
as
tf
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
oneof
@
dataclasses
.
dataclass
class
ResNet
(
base_config
.
Config
):
model_depth
:
int
=
50
@
dataclasses
.
dataclass
class
Backbone
(
oneof
.
OneOfConfig
):
type
:
str
=
'resnet'
resnet
:
ResNet
=
ResNet
()
not_resnet
:
int
=
2
@
dataclasses
.
dataclass
class
OutputLayer
(
oneof
.
OneOfConfig
):
type
:
str
=
'single'
single
:
int
=
1
multi_head
:
int
=
2
@
dataclasses
.
dataclass
class
Network
(
base_config
.
Config
):
backbone
:
Backbone
=
Backbone
()
output_layer
:
OutputLayer
=
OutputLayer
()
class
OneOfTest
(
tf
.
test
.
TestCase
):
def
test_to_dict
(
self
):
network_params
=
{
'backbone'
:
{
'type'
:
'resnet'
,
'resnet'
:
{
'model_depth'
:
50
}
},
'output_layer'
:
{
'type'
:
'single'
,
'single'
:
1000
}
}
network_config
=
Network
(
network_params
)
self
.
assertEqual
(
network_config
.
as_dict
(),
network_params
)
def
test_get_oneof
(
self
):
backbone
=
Backbone
()
self
.
assertIsInstance
(
backbone
.
get
(),
ResNet
)
self
.
assertEqual
(
backbone
.
get
().
as_dict
(),
{
'model_depth'
:
50
})
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment