Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3b0d58e2
Commit
3b0d58e2
authored
Jan 07, 2021
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Jan 07, 2021
Browse files
Add multitask evaluation
PiperOrigin-RevId: 350705651
parent
72284a6c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
728 additions
and
0 deletions
+728
-0
official/modeling/multitask/__init__.py
official/modeling/multitask/__init__.py
+14
-0
official/modeling/multitask/base_model.py
official/modeling/multitask/base_model.py
+60
-0
official/modeling/multitask/configs.py
official/modeling/multitask/configs.py
+53
-0
official/modeling/multitask/evaluator.py
official/modeling/multitask/evaluator.py
+171
-0
official/modeling/multitask/evaluator_test.py
official/modeling/multitask/evaluator_test.py
+136
-0
official/modeling/multitask/multitask.py
official/modeling/multitask/multitask.py
+171
-0
official/modeling/multitask/train_lib.py
official/modeling/multitask/train_lib.py
+123
-0
No files found.
official/modeling/multitask/__init__.py
0 → 100644
View file @
3b0d58e2
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
official/modeling/multitask/base_model.py
0 → 100644
View file @
3b0d58e2
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Abstraction of multi-task model."""
from
typing
import
Text
,
Dict
import
tensorflow
as
tf
class
MultiTaskBaseModel
(
tf
.
Module
):
"""Base class that holds multi-task model computation."""
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_sub_tasks
=
self
.
_instantiate_sub_tasks
()
def
_instantiate_sub_tasks
(
self
)
->
Dict
[
Text
,
tf
.
keras
.
Model
]:
"""Abstract function that sets up the computation for each sub-task.
Returns:
A map from task name (as string) to a tf.keras.Model object that
represents the sub-task in the multi-task pool.
"""
raise
NotImplementedError
(
"_instantiate_sub_task_models() is not implemented."
)
@
property
def
sub_tasks
(
self
):
"""Fetch a map of task name (string) to task model (tf.keras.Model)."""
return
self
.
_sub_tasks
def
initialize
(
self
):
"""Optional function that loads a pre-train checkpoint."""
return
official/modeling/multitask/configs.py
0 → 100644
View file @
3b0d58e2
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration definitions for multi-task training."""
from
typing
import
Optional
,
Tuple
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.modeling.hyperparams
import
base_config
@
dataclasses
.
dataclass
class
TaskRoutine
(
base_config
.
Config
):
task_name
:
str
=
""
task_config
:
cfg
.
TaskConfig
=
None
mixing_steps
:
int
=
1
eval_steps
:
Optional
[
int
]
=
None
task_weight
:
Optional
[
float
]
=
None
@
dataclasses
.
dataclass
class
MultiTaskConfig
(
base_config
.
Config
):
init_checkpoint
:
str
=
""
model
:
base_config
.
Config
=
None
task_routines
:
Tuple
[
TaskRoutine
,
...]
=
()
@
dataclasses
.
dataclass
class
MultiEvalExperimentConfig
(
base_config
.
Config
):
"""An experiment config for single-task training and multi-task evaluation.
Attributes:
task: the single-stream training task.
eval_tasks: individual evaluation tasks.
trainer: the trainer configuration.
runtime: the runtime configuration.
"""
task
:
cfg
.
TaskConfig
=
cfg
.
TaskConfig
()
eval_tasks
:
MultiTaskConfig
=
MultiTaskConfig
()
trainer
:
cfg
.
TrainerConfig
=
cfg
.
TrainerConfig
()
runtime
:
cfg
.
RuntimeConfig
=
cfg
.
RuntimeConfig
()
official/modeling/multitask/evaluator.py
0 → 100644
View file @
3b0d58e2
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask Evaluator implementation.
The evaluator implements the Orbit `AbstractEvaluator` interface.
"""
from
typing
import
Optional
,
Union
import
gin
import
orbit
import
tensorflow
as
tf
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
multitask
@
gin
.
configurable
class
MultiTaskEvaluator
(
orbit
.
AbstractEvaluator
):
"""Implements the common trainer shared for TensorFlow models."""
def
__init__
(
self
,
task
:
multitask
.
MultiTask
,
model
:
Union
[
tf
.
keras
.
Model
,
base_model
.
MultiTaskBaseModel
],
global_step
:
Optional
[
tf
.
Variable
]
=
None
):
"""Initialize common trainer for TensorFlow models.
Args:
task: A multitask.MultiTask instance.
model: tf.keras.Model instance.
global_step: the global step variable.
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
self
.
_strategy
=
tf
.
distribute
.
get_strategy
()
self
.
_task
=
task
self
.
_model
=
model
self
.
_global_step
=
global_step
or
orbit
.
utils
.
create_global_step
()
# TODO(hongkuny): Define a more robust way to handle the training/eval
# checkpoint loading.
if
hasattr
(
self
.
model
,
"checkpoint_items"
):
# Each evaluation task can have different models and load a subset of
# components from the training checkpoint. This is assuming the
# checkpoint items are able to load the weights of the evaluation model.
checkpoint_items
=
self
.
model
.
checkpoint_items
else
:
# This is assuming the evaluation model is exactly the training model.
checkpoint_items
=
dict
(
model
=
self
.
model
)
self
.
_checkpoint
=
tf
.
train
.
Checkpoint
(
global_step
=
self
.
global_step
,
**
checkpoint_items
)
self
.
_validation_losses
=
None
self
.
_validation_metrics
=
None
# Builds per-task datasets.
self
.
eval_datasets
=
{}
for
name
,
task
in
self
.
task
.
tasks
.
items
():
self
.
eval_datasets
[
name
]
=
orbit
.
utils
.
make_distributed_dataset
(
self
.
strategy
,
task
.
build_inputs
,
task
.
task_config
.
validation_data
)
# Builds per-task validation loops.
def
get_function
(
task_name
,
task
):
task_metrics
=
self
.
validation_metrics
[
task_name
]
task_loss
=
self
.
validation_losses
[
task_name
]
if
isinstance
(
self
.
model
,
base_model
.
MultiTaskBaseModel
):
model
=
self
.
model
.
sub_tasks
[
task_name
]
else
:
model
=
self
.
model
def
step_fn
(
inputs
):
logs
=
task
.
validation_step
(
inputs
,
model
=
model
,
metrics
=
task_metrics
)
task_loss
.
update_state
(
logs
[
task
.
loss
])
return
logs
@
tf
.
function
def
eval_step_fn
(
iterator
):
distributed_outputs
=
self
.
strategy
.
run
(
step_fn
,
args
=
(
next
(
iterator
),))
return
tf
.
nest
.
map_structure
(
self
.
strategy
.
experimental_local_results
,
distributed_outputs
)
return
orbit
.
utils
.
create_loop_fn
(
eval_step_fn
)
self
.
task_fns
=
{
name
:
get_function
(
name
,
task
)
for
name
,
task
in
self
.
task
.
tasks
.
items
()
}
@
property
def
strategy
(
self
):
return
self
.
_strategy
@
property
def
task
(
self
):
return
self
.
_task
@
property
def
model
(
self
):
return
self
.
_model
@
property
def
global_step
(
self
):
return
self
.
_global_step
@
property
def
validation_losses
(
self
):
"""Accesses the validation loss metric object."""
if
self
.
_validation_losses
is
None
:
# Builds the per-task metrics and losses.
self
.
_validation_losses
=
{}
for
name
in
self
.
task
.
tasks
:
self
.
_validation_losses
[
name
]
=
tf
.
keras
.
metrics
.
Mean
(
"validation_loss"
,
dtype
=
tf
.
float32
)
return
self
.
_validation_losses
@
property
def
validation_metrics
(
self
):
"""Accesses all validation metric metric objects."""
if
self
.
_validation_metrics
is
None
:
# Builds the per-task metrics and losses.
self
.
_validation_metrics
=
{}
for
name
,
task
in
self
.
task
.
tasks
.
items
():
self
.
_validation_metrics
[
name
]
=
task
.
build_metrics
(
training
=
False
)
return
self
.
_validation_metrics
@
property
def
checkpoint
(
self
):
"""Accesses the training checkpoint."""
return
self
.
_checkpoint
def
evaluate
(
self
,
num_steps
:
tf
.
Tensor
):
"""Performs evaluation for each `EvalTask`."""
for
metric
in
self
.
validation_losses
.
values
():
metric
.
reset_states
()
for
metrics
in
self
.
validation_metrics
.
values
():
for
metric
in
metrics
:
metric
.
reset_states
()
results
=
{}
eval_iters
=
tf
.
nest
.
map_structure
(
iter
,
self
.
eval_datasets
)
for
name
,
task_eval_loop
in
self
.
task_fns
.
items
():
outputs
=
None
eval_iter
=
eval_iters
[
name
]
task
=
self
.
task
.
tasks
[
name
]
task_eval_steps
=
self
.
task
.
task_eval_steps
(
name
)
or
num_steps
outputs
=
task_eval_loop
(
eval_iter
,
task_eval_steps
,
state
=
outputs
,
reduce_fn
=
task
.
aggregate_logs
)
task_metrics
=
self
.
validation_metrics
[
name
]
task_loss
=
self
.
validation_losses
[
name
]
logs
=
{}
for
metric
in
task_metrics
+
[
task_loss
]:
logs
[
metric
.
name
]
=
metric
.
result
()
if
outputs
:
metrics
=
task
.
reduce_aggregated_logs
(
outputs
)
logs
.
update
(
metrics
)
results
[
name
]
=
logs
return
results
official/modeling/multitask/evaluator_test.py
0 → 100644
View file @
3b0d58e2
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.evaluator."""
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.core
import
base_task
from
official.core
import
config_definitions
as
cfg
from
official.modeling.multitask
import
evaluator
from
official.modeling.multitask
import
multitask
def
all_strategy_combinations
():
return
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
mode
=
"eager"
,
)
class
MockModel
(
tf
.
keras
.
Model
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
dense
=
tf
.
keras
.
layers
.
Dense
(
1
)
def
call
(
self
,
inputs
):
print
(
inputs
,
type
(
inputs
))
if
"y"
in
inputs
:
self
.
add_loss
(
tf
.
zeros
((
1
,),
dtype
=
tf
.
float32
))
else
:
self
.
add_loss
(
tf
.
ones
((
1
,),
dtype
=
tf
.
float32
))
return
self
.
dense
(
inputs
[
"x"
])
class
MockTask
(
base_task
.
Task
):
"""Mock task object for testing."""
def
build_metrics
(
self
,
training
:
bool
=
True
):
del
training
return
[
tf
.
keras
.
metrics
.
Accuracy
(
name
=
"acc"
)]
def
build_inputs
(
self
,
params
):
def
generate_data
(
_
):
x
=
tf
.
zeros
(
shape
=
(
2
,),
dtype
=
tf
.
float32
)
label
=
tf
.
zeros
([
1
],
dtype
=
tf
.
int32
)
if
self
.
name
==
"bar"
:
return
dict
(
x
=
x
,
y
=
x
),
label
else
:
return
dict
(
x
=
x
),
label
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
generate_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
.
prefetch
(
buffer_size
=
1
).
batch
(
2
,
drop_remainder
=
True
)
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
):
logs
=
super
().
validation_step
(
inputs
,
model
,
metrics
)
logs
[
"counter"
]
=
tf
.
ones
((
1
,),
dtype
=
tf
.
float32
)
return
logs
def
aggregate_logs
(
self
,
state
,
step_outputs
):
if
state
is
None
:
state
=
{}
for
key
,
value
in
step_outputs
.
items
():
if
key
not
in
state
:
state
[
key
]
=
[]
state
[
key
].
append
(
np
.
concatenate
([
np
.
expand_dims
(
v
.
numpy
(),
axis
=
0
)
for
v
in
value
]))
return
state
def
reduce_aggregated_logs
(
self
,
aggregated_logs
):
for
k
,
v
in
aggregated_logs
.
items
():
aggregated_logs
[
k
]
=
np
.
sum
(
np
.
stack
(
v
,
axis
=
0
))
return
aggregated_logs
class
EvaluatorTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_multitask_evaluator
(
self
,
distribution
):
with
distribution
.
scope
():
tasks
=
[
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"bar"
),
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"foo"
)
]
test_multitask
=
multitask
.
MultiTask
(
tasks
=
tasks
)
model
=
MockModel
()
test_evaluator
=
evaluator
.
MultiTaskEvaluator
(
task
=
test_multitask
,
model
=
model
)
results
=
test_evaluator
.
evaluate
(
tf
.
convert_to_tensor
(
1
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"validation_loss"
,
"acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"validation_loss"
,
"acc"
],
results
[
"foo"
].
keys
())
self
.
assertEqual
(
results
[
"bar"
][
"validation_loss"
],
0.0
)
self
.
assertEqual
(
results
[
"foo"
][
"validation_loss"
],
1.0
)
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_multitask_evaluator_numpy_metrics
(
self
,
distribution
):
with
distribution
.
scope
():
tasks
=
[
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"bar"
),
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"foo"
)
]
test_multitask
=
multitask
.
MultiTask
(
tasks
=
tasks
)
model
=
MockModel
()
test_evaluator
=
evaluator
.
MultiTaskEvaluator
(
task
=
test_multitask
,
model
=
model
)
results
=
test_evaluator
.
evaluate
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertEqual
(
results
[
"bar"
][
"counter"
],
5.
*
distribution
.
num_replicas_in_sync
)
self
.
assertEqual
(
results
[
"foo"
][
"counter"
],
5.
*
distribution
.
num_replicas_in_sync
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/modeling/multitask/multitask.py
0 → 100644
View file @
3b0d58e2
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experimental MultiTask base class for multi-task training/evaluation."""
import
abc
from
typing
import
Dict
,
List
,
Optional
,
Text
,
Union
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
config_definitions
from
official.core
import
task_factory
from
official.modeling
import
optimization
from
official.modeling
import
performance
from
official.modeling.multitask
import
configs
TrainerConfig
=
config_definitions
.
TrainerConfig
RuntimeConfig
=
config_definitions
.
RuntimeConfig
class
MultiTask
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""A multi-task class to manage multiple tasks."""
def
__init__
(
self
,
tasks
:
Union
[
Dict
[
Text
,
base_task
.
Task
],
List
[
base_task
.
Task
]],
task_mixing_steps
:
Optional
[
Dict
[
str
,
int
]]
=
None
,
task_weights
:
Optional
[
Dict
[
str
,
float
]]
=
None
,
task_eval_steps
:
Optional
[
Dict
[
str
,
int
]]
=
None
,
name
:
Optional
[
str
]
=
None
):
"""MultiTask initialization.
Args:
tasks: a list or a flat dict of Task.
task_mixing_steps: a dict of (task, mixing steps).
task_weights: a dict of (task, loss weight).
task_eval_steps: a dict of (task, eval steps).
name: the instance name of a MultiTask object.
"""
super
().
__init__
(
name
=
name
)
if
isinstance
(
tasks
,
list
):
self
.
_tasks
=
{}
for
task
in
tasks
:
if
task
.
name
in
self
.
_tasks
:
raise
ValueError
(
"Duplicated tasks found, task.name is %s"
%
task
.
name
)
self
.
_tasks
[
task
.
name
]
=
task
elif
isinstance
(
tasks
,
dict
):
self
.
_tasks
=
tasks
else
:
raise
ValueError
(
"The tasks argument has an invalid type: %s"
%
type
(
tasks
))
self
.
_task_eval_steps
=
task_eval_steps
or
{}
self
.
_task_eval_steps
=
dict
([
(
name
,
self
.
_task_eval_steps
.
get
(
name
,
None
))
for
name
in
self
.
tasks
])
self
.
_task_mixing_steps
=
task_mixing_steps
or
{}
self
.
_task_mixing_steps
=
dict
([
(
name
,
self
.
_task_mixing_steps
.
get
(
name
,
1
))
for
name
in
self
.
tasks
])
self
.
_task_weights
=
task_weights
or
{}
self
.
_task_weights
=
dict
([
(
name
,
self
.
_task_weights
.
get
(
name
,
None
))
for
name
in
self
.
tasks
])
@
classmethod
def
from_config
(
cls
,
config
:
configs
.
MultiTaskConfig
,
logging_dir
=
None
):
tasks
=
{}
task_eval_steps
=
{}
task_mixing_steps
=
{}
task_weights
=
{}
for
task_routine
in
config
.
task_routines
:
task_name
=
task_routine
.
task_name
tasks
[
task_name
]
=
task_factory
.
get_task
(
task_routine
.
task_config
,
logging_dir
=
logging_dir
)
task_eval_steps
[
task_name
]
=
task_routine
.
eval_steps
task_mixing_steps
[
task_name
]
=
task_routine
.
mixing_steps
task_weights
[
task_name
]
=
task_routine
.
task_weight
return
cls
(
tasks
,
task_mixing_steps
=
task_mixing_steps
,
task_eval_steps
=
task_eval_steps
,
task_weights
=
task_weights
)
@
property
def
tasks
(
self
):
return
self
.
_tasks
def
task_eval_steps
(
self
,
task_name
):
return
self
.
_task_eval_steps
[
task_name
]
def
task_mixing_steps
(
self
,
task_name
):
return
self
.
_task_mixing_steps
[
task_name
]
def
task_weight
(
self
,
task_name
):
return
self
.
_task_weights
[
task_name
]
@
classmethod
def
create_optimizer
(
cls
,
trainer_config
:
TrainerConfig
,
runtime_config
:
Optional
[
RuntimeConfig
]
=
None
):
"""Creates an TF optimizer from configurations.
Args:
trainer_config: the parameters of the trainer.
runtime_config: the parameters of the runtime.
Returns:
A tf.optimizers.Optimizer object.
"""
opt_factory
=
optimization
.
OptimizerFactory
(
trainer_config
.
optimizer_config
)
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
# Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations.
if
runtime_config
and
runtime_config
.
loss_scale
:
optimizer
=
performance
.
configure_optimizer
(
optimizer
,
use_float16
=
runtime_config
.
mixed_precision_dtype
==
"float16"
,
loss_scale
=
runtime_config
.
loss_scale
)
return
optimizer
def
joint_train_step
(
self
,
task_inputs
,
multi_task_model
,
optimizer
,
task_metrics
):
"""The joint train step.
Args:
task_inputs: a dictionary of task names and per-task features.
multi_task_model: a MultiTaskModel instance.
optimizer: a tf.optimizers.Optimizer.
task_metrics: a dictionary of task names and per-task metrics.
Returns:
A dictionary of losses, inculding per-task losses and their weighted sum.
"""
losses
=
{}
with
tf
.
GradientTape
()
as
tape
:
total_loss
=
0.0
for
name
,
model
in
multi_task_model
.
sub_tasks
.
items
():
inputs
=
task_inputs
[
name
]
if
isinstance
(
inputs
,
tuple
)
and
len
(
inputs
)
==
2
:
features
,
labels
=
inputs
elif
isinstance
(
inputs
,
dict
):
features
,
labels
=
inputs
,
inputs
else
:
raise
ValueError
(
"The iterator output is neither a tuple nor a "
"dictionary. It is not implemented to support "
"such outputs."
)
outputs
=
model
(
features
,
training
=
True
)
task_loss
=
self
.
tasks
[
name
].
build_losses
(
labels
,
outputs
)
task_weight
=
self
.
task_weight
(
name
)
total_loss
+=
task_weight
*
task_loss
losses
[
name
]
=
task_loss
self
.
tasks
[
name
].
process_metrics
(
task_metrics
[
name
],
labels
,
outputs
)
# Scales loss as the default gradients allreduce performs sum inside
# the optimizer.
scaled_loss
=
total_loss
/
tf
.
distribute
.
get_strategy
(
).
num_replicas_in_sync
tvars
=
multi_task_model
.
trainable_variables
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
losses
[
"total_loss"
]
=
total_loss
return
losses
official/modeling/multitask/train_lib.py
0 → 100644
View file @
3b0d58e2
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask training driver library."""
# pytype: disable=attribute-error
import
os
from
absl
import
logging
import
orbit
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
base_trainer
as
core_lib
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
evaluator
as
evaluator_lib
from
official.modeling.multitask
import
multitask
def
run_experiment_wtih_multitask_eval
(
*
,
distribution_strategy
:
tf
.
distribute
.
Strategy
,
train_task
:
base_task
.
Task
,
eval_tasks
:
multitask
.
MultiTask
,
mode
:
str
,
params
:
configs
.
MultiEvalExperimentConfig
,
model_dir
:
str
)
->
tf
.
keras
.
Model
:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
train_task: A base_task.Task instance.
eval_tasks: A multitask.MultiTask with evaluation tasks.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: MultiEvalExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
Returns:
model: `tf.keras.Model` instance.
"""
is_training
=
'train'
in
mode
is_eval
=
'eval'
in
mode
with
distribution_strategy
.
scope
():
optimizer
=
train_task
.
create_optimizer
(
params
.
trainer
,
params
.
runtime
)
model
=
train_task
.
build_model
()
if
is_training
:
trainer
=
core_lib
.
Trainer
(
config
=
params
,
task
=
train_task
,
model
=
model
,
optimizer
=
optimizer
,
train
=
True
,
evaluate
=
False
)
else
:
trainer
=
None
if
is_eval
:
evaluator
=
evaluator_lib
.
MultiTaskEvaluator
(
task
=
eval_tasks
,
model
=
model
,
global_step
=
trainer
.
global_step
if
is_training
else
None
)
else
:
evaluator
=
None
if
trainer
:
checkpoint
=
trainer
.
checkpoint
global_step
=
trainer
.
global_step
else
:
checkpoint
=
evaluator
.
checkpoint
global_step
=
evaluator
.
global_step
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
directory
=
model_dir
,
max_to_keep
=
params
.
trainer
.
max_to_keep
,
step_counter
=
global_step
,
checkpoint_interval
=
params
.
trainer
.
checkpoint_interval
,
init_fn
=
trainer
.
initialize
if
trainer
else
None
)
controller
=
orbit
.
Controller
(
strategy
=
distribution_strategy
,
trainer
=
trainer
,
evaluator
=
evaluator
,
global_step
=
global_step
,
steps_per_loop
=
params
.
trainer
.
steps_per_loop
,
checkpoint_manager
=
checkpoint_manager
,
summary_dir
=
os
.
path
.
join
(
model_dir
,
'train'
),
eval_summary_dir
=
os
.
path
.
join
(
model_dir
,
'validation'
),
summary_interval
=
params
.
trainer
.
summary_interval
)
logging
.
info
(
'Starts to execute mode: %s'
,
mode
)
with
distribution_strategy
.
scope
():
if
mode
==
'train'
:
controller
.
train
(
steps
=
params
.
trainer
.
train_steps
)
elif
mode
==
'train_and_eval'
:
controller
.
train_and_evaluate
(
train_steps
=
params
.
trainer
.
train_steps
,
eval_steps
=
params
.
trainer
.
validation_steps
,
eval_interval
=
params
.
trainer
.
validation_interval
)
elif
mode
==
'eval'
:
controller
.
evaluate
(
steps
=
params
.
trainer
.
validation_steps
)
elif
mode
==
'continuous_eval'
:
def
timeout_fn
():
if
evaluator
.
global_step
.
numpy
()
>=
params
.
trainer
.
train_steps
:
return
True
return
False
controller
.
evaluate_continuously
(
steps
=
params
.
trainer
.
validation_steps
,
timeout
=
params
.
trainer
.
continuous_eval_timeout
,
timeout_fn
=
timeout_fn
)
else
:
raise
NotImplementedError
(
'The mode is not implemented: %s'
%
mode
)
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment