Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
2b68aa95
Commit
2b68aa95
authored
Apr 14, 2021
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Apr 14, 2021
Browse files
Internal change
PiperOrigin-RevId: 368535232
parent
8a16208b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
6 deletions
+7
-6
official/modeling/multitask/multitask.py
official/modeling/multitask/multitask.py
+7
-6
No files found.
official/modeling/multitask/multitask.py
View file @
2b68aa95
...
@@ -21,6 +21,7 @@ from official.core import base_task
...
@@ -21,6 +21,7 @@ from official.core import base_task
from
official.core
import
config_definitions
from
official.core
import
config_definitions
from
official.core
import
task_factory
from
official.core
import
task_factory
from
official.modeling
import
optimization
from
official.modeling
import
optimization
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
configs
OptimizationConfig
=
optimization
.
OptimizationConfig
OptimizationConfig
=
optimization
.
OptimizationConfig
...
@@ -79,9 +80,7 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
...
@@ -79,9 +80,7 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
task_eval_steps
[
task_name
]
=
task_routine
.
eval_steps
task_eval_steps
[
task_name
]
=
task_routine
.
eval_steps
task_weights
[
task_name
]
=
task_routine
.
task_weight
task_weights
[
task_name
]
=
task_routine
.
task_weight
return
cls
(
return
cls
(
tasks
,
tasks
,
task_eval_steps
=
task_eval_steps
,
task_weights
=
task_weights
)
task_eval_steps
=
task_eval_steps
,
task_weights
=
task_weights
)
@
property
@
property
def
tasks
(
self
):
def
tasks
(
self
):
...
@@ -104,15 +103,17 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
...
@@ -104,15 +103,17 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
return
base_task
.
Task
.
create_optimizer
(
return
base_task
.
Task
.
create_optimizer
(
optimizer_config
=
optimizer_config
,
runtime_config
=
runtime_config
)
optimizer_config
=
optimizer_config
,
runtime_config
=
runtime_config
)
def
joint_train_step
(
self
,
task_inputs
,
multi_task_model
,
optimizer
,
def
joint_train_step
(
self
,
task_inputs
,
task_metrics
):
multi_task_model
:
base_model
.
MultiTaskBaseModel
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
task_metrics
):
"""The joint train step.
"""The joint train step.
Args:
Args:
task_inputs: a dictionary of task names and per-task features.
task_inputs: a dictionary of task names and per-task features.
multi_task_model: a MultiTaskModel instance.
multi_task_model: a MultiTask
Base
Model instance.
optimizer: a tf.optimizers.Optimizer.
optimizer: a tf.optimizers.Optimizer.
task_metrics: a dictionary of task names and per-task metrics.
task_metrics: a dictionary of task names and per-task metrics.
Returns:
Returns:
A dictionary of losses, inculding per-task losses and their weighted sum.
A dictionary of losses, inculding per-task losses and their weighted sum.
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment