Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ee3cc115
Commit
ee3cc115
authored
Jun 15, 2020
by
A. Unique TensorFlower
Browse files
Internal change.
PiperOrigin-RevId: 316409253
parent
b3ef7ae9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
38 additions
and
5 deletions
+38
-5
official/core/base_task.py
official/core/base_task.py
+38
-5
No files found.
official/core/base_task.py
View file @
ee3cc115
...
@@ -250,11 +250,44 @@ _REGISTERED_TASK_CLS = {}
...
@@ -250,11 +250,44 @@ _REGISTERED_TASK_CLS = {}
# TODO(b/158268740): Move these outside the base class file.
# TODO(b/158268740): Move these outside the base class file.
def
register_task_cls
(
task_config
:
cfg
.
TaskConfig
)
->
Task
:
# TODO(b/158741360): Add type annotations once pytype checks across modules.
"""R
egister
ExperimentConfig factory method."""
def
r
egister
_task_cls
(
task_config_cls
):
return
registry
.
register
(
_REGISTERED_TASK_CLS
,
t
ask
_c
onfig
)
"""Decorates a factory of Tasks for lookup by a subclass of T
ask
C
onfig
.
This decorator supports registration of tasks as follows:
def
get_task_cls
(
task_config
:
cfg
.
TaskConfig
)
->
Task
:
```
task_cls
=
registry
.
lookup
(
_REGISTERED_TASK_CLS
,
task_config
)
@dataclasses.dataclass
class MyTaskConfig(TaskConfig):
# Add fields here.
pass
@register_task_cls(MyTaskConfig)
class MyTask(Task):
# Inherits def __init__(self, task_config).
pass
my_task_config = MyTaskConfig()
my_task = get_task(my_task_config) # Returns MyTask(my_task_config).
```
Besisdes a class itself, other callables that create a Task from a TaskConfig
can be decorated by the result of this function, as long as there is at most
one registration for each config class.
Args:
task_config_cls: a subclass of TaskConfig (*not* an instance of TaskConfig).
Each task_config_cls can only be used for a single registration.
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of task_config_cls.
"""
return
registry
.
register
(
_REGISTERED_TASK_CLS
,
task_config_cls
)
# The user-visible get_task() is defined after classes have been registered.
# TODO(b/158741360): Add type annotations once pytype checks across modules.
def
get_task_cls
(
task_config_cls
):
task_cls
=
registry
.
lookup
(
_REGISTERED_TASK_CLS
,
task_config_cls
)
return
task_cls
return
task_cls
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment