Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e773b9b3
Commit
e773b9b3
authored
Mar 25, 2021
by
A. Unique TensorFlower
Browse files
Open source the second half of multi-task library
PiperOrigin-RevId: 365085378
parent
983837ff
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1037 additions
and
0 deletions
+1037
-0
official/modeling/multitask/base_trainer.py
official/modeling/multitask/base_trainer.py
+176
-0
official/modeling/multitask/base_trainer_test.py
official/modeling/multitask/base_trainer_test.py
+90
-0
official/modeling/multitask/configs.py
official/modeling/multitask/configs.py
+33
-0
official/modeling/multitask/interleaving_trainer.py
official/modeling/multitask/interleaving_trainer.py
+92
-0
official/modeling/multitask/interleaving_trainer_test.py
official/modeling/multitask/interleaving_trainer_test.py
+101
-0
official/modeling/multitask/task_sampler.py
official/modeling/multitask/task_sampler.py
+121
-0
official/modeling/multitask/task_sampler_test.py
official/modeling/multitask/task_sampler_test.py
+75
-0
official/modeling/multitask/test_utils.py
official/modeling/multitask/test_utils.py
+125
-0
official/modeling/multitask/train_lib.py
official/modeling/multitask/train_lib.py
+104
-0
official/modeling/multitask/train_lib_test.py
official/modeling/multitask/train_lib_test.py
+120
-0
No files found.
official/modeling/multitask/base_trainer.py
0 → 100644
View file @
e773b9b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Multitask base trainer implementation.
The trainer derives from the Orbit `StandardTrainer` class.
"""
from
typing
import
Union
import
gin
import
orbit
import
tensorflow
as
tf
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
multitask
@
gin
.
configurable
class
MultiTaskBaseTrainer
(
orbit
.
StandardTrainer
):
"""Multitask base trainer."""
def
__init__
(
self
,
multi_task
:
multitask
.
MultiTask
,
multi_task_model
:
Union
[
tf
.
keras
.
Model
,
base_model
.
MultiTaskBaseModel
],
optimizer
:
tf
.
optimizers
.
Optimizer
,
trainer_options
=
None
):
self
.
_strategy
=
tf
.
distribute
.
get_strategy
()
self
.
_multi_task
=
multi_task
self
.
_multi_task_model
=
multi_task_model
self
.
_optimizer
=
optimizer
self
.
_training_losses
=
None
self
.
_training_metrics
=
None
self
.
_global_step
=
orbit
.
utils
.
create_global_step
()
if
hasattr
(
self
.
multi_task_model
,
"checkpoint_items"
):
checkpoint_items
=
self
.
multi_task_model
.
checkpoint_items
else
:
checkpoint_items
=
{}
self
.
_checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
self
.
multi_task_model
,
optimizer
=
self
.
optimizer
,
global_step
=
self
.
global_step
,
**
checkpoint_items
)
train_datasets
=
{}
for
name
,
task
in
self
.
multi_task
.
tasks
.
items
():
train_datasets
[
name
]
=
orbit
.
utils
.
make_distributed_dataset
(
self
.
strategy
,
task
.
build_inputs
,
task
.
task_config
.
train_data
)
super
().
__init__
(
train_dataset
=
train_datasets
,
options
=
trainer_options
or
orbit
.
StandardTrainerOptions
())
def
train_loop_begin
(
self
):
"""Clean up states that hold losses and metrics."""
for
_
,
train_loss_metric
in
self
.
training_losses
.
items
():
train_loss_metric
.
reset_states
()
for
_
,
metrics
in
self
.
training_metrics
.
items
():
for
metric
in
metrics
:
metric
.
reset_states
()
def
train_loop_end
(
self
):
"""Record loss and metric values per task."""
result
=
{}
for
task_name
,
loss
in
self
.
training_losses
.
items
():
result
[
task_name
]
=
{
loss
.
name
:
loss
.
result
()}
for
task_name
,
task_metrics
in
self
.
training_metrics
.
items
():
result
[
task_name
].
update
(
{
metric
.
name
:
metric
.
result
()
for
metric
in
task_metrics
})
# Note that, the learning rate schedule is managed by the keras optimizer
# internally, which respects the number of backward pass as `iterations`.
# The learning rate schedule does not follow the trainer logical global
# step of multiple tasks.
if
callable
(
self
.
optimizer
.
learning_rate
):
result
[
"learning_rate"
]
=
self
.
optimizer
.
learning_rate
(
self
.
optimizer
.
iterations
)
else
:
result
[
"learning_rate"
]
=
self
.
optimizer
.
learning_rate
return
result
@
property
def
checkpoint
(
self
):
"""Accesses the training checkpoint."""
return
self
.
_checkpoint
@
property
def
training_losses
(
self
):
"""Access training loss metric objects for all tasks."""
if
self
.
_training_losses
is
None
:
# Builds the per-task metrics and losses.
# This the total summed training loss of tasks in the joint training.
self
.
_training_losses
=
dict
(
total_loss
=
tf
.
keras
.
metrics
.
Mean
(
"training_loss"
,
dtype
=
tf
.
float32
))
for
name
in
self
.
multi_task
.
tasks
:
self
.
_training_losses
[
name
]
=
tf
.
keras
.
metrics
.
Mean
(
"training_loss"
,
dtype
=
tf
.
float32
)
return
self
.
_training_losses
@
property
def
training_metrics
(
self
):
"""Access training metric metric objects for all tasks."""
if
self
.
_training_metrics
is
None
:
# Builds the per-task metrics and losses.
self
.
_training_metrics
=
{}
for
name
,
task
in
self
.
multi_task
.
tasks
.
items
():
self
.
_training_metrics
[
name
]
=
task
.
build_metrics
(
training
=
True
)
return
self
.
_training_metrics
@
property
def
strategy
(
self
):
return
self
.
_strategy
@
property
def
multi_task
(
self
):
return
self
.
_multi_task
@
property
def
multi_task_model
(
self
):
return
self
.
_multi_task_model
@
property
def
optimizer
(
self
):
return
self
.
_optimizer
@
property
def
global_step
(
self
):
return
self
.
_global_step
def
train_step
(
self
,
iterator_map
):
"""The default train step calling the multi-task train step.
Args:
iterator_map: a dictionary of task names and per-task dataset iterators.
"""
def
step_fn
(
inputs
):
losses
=
self
.
multi_task
.
joint_train_step
(
inputs
,
multi_task_model
=
self
.
multi_task_model
,
optimizer
=
self
.
optimizer
,
task_metrics
=
self
.
training_metrics
)
for
key
,
loss
in
losses
.
items
():
self
.
training_losses
[
key
].
update_state
(
loss
)
self
.
strategy
.
run
(
step_fn
,
args
=
(
tf
.
nest
.
map_structure
(
next
,
iterator_map
),))
self
.
global_step
.
assign_add
(
1
)
official/modeling/multitask/base_trainer_test.py
0 → 100644
View file @
e773b9b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.base_trainer."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.modeling.multitask
import
base_trainer
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
multitask
from
official.modeling.multitask
import
test_utils
def
all_strategy_combinations
():
return
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
mode
=
"eager"
,
)
class
BaseTrainerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_multitask_joint_trainer
(
self
,
distribution
):
with
distribution
.
scope
():
tasks
=
[
test_utils
.
MockFooTask
(
params
=
test_utils
.
FooConfig
(),
name
=
"foo"
),
test_utils
.
MockBarTask
(
params
=
test_utils
.
BarConfig
(),
name
=
"bar"
)
]
task_weights
=
{
"foo"
:
1.0
,
"bar"
:
1.0
}
test_multitask
=
multitask
.
MultiTask
(
tasks
=
tasks
,
task_weights
=
task_weights
)
test_optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
0.1
)
model
=
test_utils
.
MockMultiTaskModel
()
test_trainer
=
base_trainer
.
MultiTaskBaseTrainer
(
multi_task
=
test_multitask
,
multi_task_model
=
model
,
optimizer
=
test_optimizer
)
results
=
test_trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"training_loss"
,
"bar_acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"training_loss"
,
"foo_acc"
],
results
[
"foo"
].
keys
())
def
test_trainer_with_configs
(
self
):
config
=
configs
.
MultiTaskConfig
(
task_routines
=
(
configs
.
TaskRoutine
(
task_name
=
"foo"
,
task_config
=
test_utils
.
FooConfig
(),
task_weight
=
0.5
),
configs
.
TaskRoutine
(
task_name
=
"bar"
,
task_config
=
test_utils
.
BarConfig
(),
task_weight
=
0.5
)))
test_multitask
=
multitask
.
MultiTask
.
from_config
(
config
)
test_optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
0.1
)
model
=
test_utils
.
MockMultiTaskModel
()
test_trainer
=
base_trainer
.
MultiTaskBaseTrainer
(
multi_task
=
test_multitask
,
multi_task_model
=
model
,
optimizer
=
test_optimizer
)
results
=
test_trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"training_loss"
,
"bar_acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"training_loss"
,
"foo_acc"
],
results
[
"foo"
].
keys
())
self
.
assertEqual
(
test_multitask
.
task_weight
(
"foo"
),
0.5
)
self
.
assertEqual
(
test_trainer
.
global_step
.
numpy
(),
5
)
self
.
assertIn
(
"learning_rate"
,
results
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/modeling/multitask/configs.py
View file @
e773b9b3
...
@@ -36,6 +36,39 @@ class MultiTaskConfig(hyperparams.Config):
...
@@ -36,6 +36,39 @@ class MultiTaskConfig(hyperparams.Config):
task_routines
:
Tuple
[
TaskRoutine
,
...]
=
()
task_routines
:
Tuple
[
TaskRoutine
,
...]
=
()
@
dataclasses
.
dataclass
class
ProportionalSampleConfig
(
hyperparams
.
Config
):
alpha
:
float
=
1.0
@
dataclasses
.
dataclass
class
AnnealingSampleConfig
(
hyperparams
.
Config
):
steps_per_epoch
:
int
=
5
total_steps
:
int
=
20
@
dataclasses
.
dataclass
class
TaskSamplingConfig
(
hyperparams
.
OneOfConfig
):
type
:
str
=
""
uniform
:
hyperparams
.
Config
=
hyperparams
.
Config
()
proportional
:
ProportionalSampleConfig
=
ProportionalSampleConfig
()
annealing
:
AnnealingSampleConfig
=
AnnealingSampleConfig
()
@
dataclasses
.
dataclass
class
MultiTaskTrainerConfig
(
cfg
.
TrainerConfig
):
trainer_type
:
str
=
"interleaving"
task_sampler
:
TaskSamplingConfig
=
TaskSamplingConfig
(
type
=
"proportional"
)
@
dataclasses
.
dataclass
class
MultiTaskExperimentConfig
(
hyperparams
.
Config
):
"""An experiment config for multi-task training and multi-task evaluation."""
task
:
MultiTaskConfig
=
MultiTaskConfig
()
trainer
:
MultiTaskTrainerConfig
=
MultiTaskTrainerConfig
()
runtime
:
cfg
.
RuntimeConfig
=
cfg
.
RuntimeConfig
()
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
MultiEvalExperimentConfig
(
cfg
.
ExperimentConfig
):
class
MultiEvalExperimentConfig
(
cfg
.
ExperimentConfig
):
"""An experiment config for single-task training and multi-task evaluation.
"""An experiment config for single-task training and multi-task evaluation.
...
...
official/modeling/multitask/interleaving_trainer.py
0 → 100644
View file @
e773b9b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask trainer that interleaves each task's train step."""
from
typing
import
Union
import
gin
import
orbit
import
tensorflow
as
tf
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
base_trainer
from
official.modeling.multitask
import
multitask
from
official.modeling.multitask
import
task_sampler
as
sampler
@
gin
.
configurable
class
MultiTaskInterleavingTrainer
(
base_trainer
.
MultiTaskBaseTrainer
):
"""MultiTask trainer that interleaves task update."""
def
__init__
(
self
,
multi_task
:
multitask
.
MultiTask
,
multi_task_model
:
Union
[
tf
.
keras
.
Model
,
base_model
.
MultiTaskBaseModel
],
optimizer
:
tf
.
optimizers
.
Optimizer
,
task_sampler
:
sampler
.
TaskSampler
,
trainer_options
=
None
):
super
(
MultiTaskInterleavingTrainer
,
self
).
__init__
(
multi_task
=
multi_task
,
multi_task_model
=
multi_task_model
,
optimizer
=
optimizer
,
trainer_options
=
trainer_options
)
self
.
_task_sampler
=
task_sampler
# Build per task train step.
def
_get_task_step
(
task_name
,
task
):
def
step_fn
(
inputs
):
if
isinstance
(
self
.
multi_task_model
,
base_model
.
MultiTaskBaseModel
):
task_model
=
self
.
multi_task_model
.
sub_tasks
[
task_name
]
else
:
task_model
=
self
.
multi_task_model
task_logs
=
task
.
train_step
(
inputs
,
model
=
task_model
,
optimizer
=
self
.
optimizer
,
metrics
=
self
.
training_metrics
[
task_name
])
self
.
training_losses
[
task_name
].
update_state
(
task_logs
[
task
.
loss
])
return
step_fn
self
.
_task_train_step_map
=
{
name
:
_get_task_step
(
name
,
task
)
for
name
,
task
in
self
.
multi_task
.
tasks
.
items
()
}
# TODO(haozhangthu): Add taskwise step counter to train_loop_end for logging
# on TensorBoard.
self
.
_task_step_counters
=
{
name
:
orbit
.
utils
.
create_global_step
()
for
name
in
self
.
multi_task
.
tasks
}
def
task_step_counter
(
self
,
name
):
return
self
.
_task_step_counters
[
name
]
def
train_step
(
self
,
iterator_map
):
# Sample one task to train according to a multinomial distribution
rn
=
tf
.
random
.
stateless_uniform
(
shape
=
[],
seed
=
(
0
,
self
.
global_step
))
cumulative_sample_distribution
=
self
.
_task_sampler
.
task_cumulative_distribution
(
self
.
global_step
)
# Prepend a [0.0] for indexing convenience.
cumulative_sample_distribution
=
tf
.
concat
(
[
tf
.
constant
([
0.0
],
dtype
=
tf
.
float32
),
cumulative_sample_distribution
],
axis
=
0
)
for
idx
,
(
name
,
_
)
in
enumerate
(
self
.
multi_task
.
tasks
.
items
()):
begin
=
cumulative_sample_distribution
[
idx
]
end
=
cumulative_sample_distribution
[
idx
+
1
]
if
rn
>=
begin
and
rn
<
end
:
self
.
_strategy
.
run
(
self
.
_task_train_step_map
[
name
],
args
=
(
next
(
iterator_map
[
name
]),))
self
.
global_step
.
assign_add
(
1
)
self
.
task_step_counter
(
name
).
assign_add
(
1
)
official/modeling/multitask/interleaving_trainer_test.py
0 → 100644
View file @
e773b9b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.interleaving_trainer."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
interleaving_trainer
from
official.modeling.multitask
import
multitask
from
official.modeling.multitask
import
task_sampler
from
official.modeling.multitask
import
test_utils
def
all_strategy_combinations
():
return
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
mode
=
"eager"
,
)
class
InterleavingTrainerTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_multitask_interleaving_trainer
(
self
,
distribution
):
with
distribution
.
scope
():
tasks
=
[
test_utils
.
MockFooTask
(
params
=
test_utils
.
FooConfig
(),
name
=
"foo"
),
test_utils
.
MockBarTask
(
params
=
test_utils
.
BarConfig
(),
name
=
"bar"
)
]
test_multitask
=
multitask
.
MultiTask
(
tasks
=
tasks
)
test_optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
0.1
)
model
=
test_utils
.
MockMultiTaskModel
()
sampler
=
task_sampler
.
UniformTaskSampler
(
task_weights
=
test_multitask
.
task_weights
)
test_trainer
=
interleaving_trainer
.
MultiTaskInterleavingTrainer
(
multi_task
=
test_multitask
,
multi_task_model
=
model
,
optimizer
=
test_optimizer
,
task_sampler
=
sampler
)
results
=
test_trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"training_loss"
,
"bar_acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"training_loss"
,
"foo_acc"
],
results
[
"foo"
].
keys
())
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_trainer_with_configs
(
self
,
distribution
):
config
=
configs
.
MultiTaskConfig
(
task_routines
=
(
configs
.
TaskRoutine
(
task_name
=
"foo"
,
task_config
=
test_utils
.
FooConfig
(),
task_weight
=
3.0
),
configs
.
TaskRoutine
(
task_name
=
"bar"
,
task_config
=
test_utils
.
BarConfig
(),
task_weight
=
1.0
)))
with
distribution
.
scope
():
test_multitask
=
multitask
.
MultiTask
.
from_config
(
config
)
test_optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
0.1
)
model
=
test_utils
.
MockMultiTaskModel
()
num_step
=
1000
sampler
=
task_sampler
.
AnnealingTaskSampler
(
task_weights
=
test_multitask
.
task_weights
,
steps_per_epoch
=
num_step
/
5
,
total_steps
=
num_step
)
test_trainer
=
interleaving_trainer
.
MultiTaskInterleavingTrainer
(
multi_task
=
test_multitask
,
multi_task_model
=
model
,
optimizer
=
test_optimizer
,
task_sampler
=
sampler
)
results
=
test_trainer
.
train
(
tf
.
convert_to_tensor
(
num_step
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"training_loss"
,
"bar_acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"training_loss"
,
"foo_acc"
],
results
[
"foo"
].
keys
())
self
.
assertEqual
(
test_trainer
.
global_step
.
numpy
(),
num_step
)
bar_sampled_step
=
test_trainer
.
task_step_counter
(
"bar"
).
numpy
()
foo_sampled_step
=
test_trainer
.
task_step_counter
(
"foo"
).
numpy
()
self
.
assertEqual
(
bar_sampled_step
+
foo_sampled_step
,
num_step
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/modeling/multitask/task_sampler.py
0 → 100644
View file @
e773b9b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils to sample tasks for interleaved optimization."""
import
abc
from
typing
import
Union
,
Dict
,
Text
import
tensorflow
as
tf
from
official.modeling.multitask
import
configs
class
TaskSampler
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""An abstract class defining task sampling API for interleaving trainer."""
def
__init__
(
self
,
task_weights
:
Dict
[
Text
,
Union
[
float
,
int
]]):
self
.
_task_weights
=
task_weights
@
abc
.
abstractmethod
def
task_cumulative_distribution
(
self
,
global_step
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Compute cumulative distribution to sample tasks.
It calculates the cumulative distribution of the multinomial task
distribution with respect to which to be sampled against.
Args:
global_step: A tensor indicating current progess of training.
Returns:
A float tensor with shape (#(task), 1) that represents the cumulative
sampling distribution.
"""
pass
class
UniformTaskSampler
(
TaskSampler
):
"""Sample all tasks uniformly."""
def
__init__
(
self
,
task_weights
:
Dict
[
Text
,
Union
[
float
,
int
]]):
super
(
UniformTaskSampler
,
self
).
__init__
(
task_weights
=
task_weights
)
self
.
_uniform_cumulative
=
tf
.
math
.
cumsum
(
tf
.
constant
(
[
1.0
/
len
(
self
.
_task_weights
)]
*
len
(
self
.
_task_weights
),
dtype
=
tf
.
float32
))
def
task_cumulative_distribution
(
self
,
global_step
:
tf
.
Tensor
)
->
tf
.
Tensor
:
del
global_step
return
self
.
_uniform_cumulative
class
ProportionalTaskSampler
(
TaskSampler
):
"""Sample tasks proportional to task weights."""
def
__init__
(
self
,
task_weights
:
Dict
[
Text
,
Union
[
float
,
int
]],
alpha
:
float
=
1.0
):
super
(
ProportionalTaskSampler
,
self
).
__init__
(
task_weights
=
task_weights
)
self
.
_alpha
=
tf
.
cast
(
alpha
,
dtype
=
tf
.
float32
)
task_weight_dict_ordered_list
=
tf
.
constant
(
[
weight
for
_
,
weight
in
self
.
_task_weights
.
items
()],
dtype
=
tf
.
float32
)
task_sizes
=
tf
.
math
.
pow
(
task_weight_dict_ordered_list
,
self
.
_alpha
)
task_distribution
=
task_sizes
/
tf
.
reduce_sum
(
task_sizes
)
self
.
_porportional_cumulative
=
tf
.
math
.
cumsum
(
task_distribution
)
def
task_cumulative_distribution
(
self
,
global_step
:
tf
.
Tensor
)
->
tf
.
Tensor
:
del
global_step
return
self
.
_porportional_cumulative
class
AnnealingTaskSampler
(
TaskSampler
):
"""Sample tasks according to task weights as well as training progress."""
def
__init__
(
self
,
task_weights
:
Dict
[
Text
,
Union
[
float
,
int
]],
steps_per_epoch
:
int
,
total_steps
:
int
):
super
(
AnnealingTaskSampler
,
self
).
__init__
(
task_weights
=
task_weights
)
self
.
_steps_per_epoch
=
tf
.
cast
(
steps_per_epoch
,
dtype
=
tf
.
float32
)
self
.
_total_epochs
=
tf
.
cast
(
total_steps
/
self
.
_steps_per_epoch
,
dtype
=
tf
.
float32
)
def
task_cumulative_distribution
(
self
,
global_step
:
tf
.
Tensor
)
->
tf
.
Tensor
:
cur_epoch
=
tf
.
math
.
floor
(
tf
.
cast
(
global_step
,
dtype
=
tf
.
float32
)
/
self
.
_steps_per_epoch
)
alpha
=
1.0
-
0.8
*
(
cur_epoch
-
1
)
/
(
self
.
_total_epochs
-
1
+
1e-10
)
task_weight_dict_ordered_list
=
[
weight
for
_
,
weight
in
self
.
_task_weights
.
items
()
]
task_sizes
=
tf
.
math
.
pow
(
tf
.
constant
(
task_weight_dict_ordered_list
,
dtype
=
tf
.
float32
),
tf
.
cast
(
alpha
,
dtype
=
tf
.
float32
))
dynamic_task_distribution
=
task_sizes
/
tf
.
reduce_sum
(
task_sizes
)
return
tf
.
math
.
cumsum
(
dynamic_task_distribution
)
def
get_task_sampler
(
config
:
configs
.
TaskSamplingConfig
,
task_weights
:
Dict
[
Text
,
float
])
->
TaskSampler
:
"""Utils to create task sampler with configuration and task weights."""
oneof_config
=
config
.
get
()
if
config
.
type
==
'uniform'
:
return
UniformTaskSampler
(
task_weights
=
task_weights
)
elif
config
.
type
==
'proportional'
:
return
ProportionalTaskSampler
(
task_weights
=
task_weights
,
alpha
=
oneof_config
.
alpha
)
elif
config
.
type
==
'annealing'
:
return
AnnealingTaskSampler
(
task_weights
=
task_weights
,
steps_per_epoch
=
oneof_config
.
steps_per_epoch
,
total_steps
=
oneof_config
.
total_steps
)
else
:
raise
RuntimeError
(
'Task sampler type not supported'
)
official/modeling/multitask/task_sampler_test.py
0 → 100644
View file @
e773b9b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.task_sampler."""
import
tensorflow
as
tf
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
task_sampler
as
sampler
class
TaskSamplerTest
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
(
TaskSamplerTest
,
self
).
setUp
()
self
.
_task_weights
=
{
'A'
:
1.0
,
'B'
:
2.0
,
'C'
:
3.0
}
def
test_uniform_sample_distribution
(
self
):
uniform_sampler
=
sampler
.
get_task_sampler
(
configs
.
TaskSamplingConfig
(
type
=
'uniform'
),
self
.
_task_weights
)
for
step
in
range
(
5
):
cumulative_distribution
=
uniform_sampler
.
task_cumulative_distribution
(
tf
.
constant
(
step
,
dtype
=
tf
.
int64
))
self
.
assertAllClose
([
0.333333
,
0.666666
,
1.0
],
cumulative_distribution
.
numpy
())
def
test_proportional_sample_distribution
(
self
):
prop_sampler
=
sampler
.
get_task_sampler
(
configs
.
TaskSamplingConfig
(
type
=
'proportional'
,
proportional
=
configs
.
ProportionalSampleConfig
(
alpha
=
2.0
)),
self
.
_task_weights
)
# CucmulativeOf(Normalize([1.0^2, 2.0^2, 3.0^2]))
for
step
in
range
(
5
):
cumulative_distribution
=
prop_sampler
.
task_cumulative_distribution
(
tf
.
constant
(
step
,
dtype
=
tf
.
int64
))
self
.
assertAllClose
([
0.07142857
,
0.35714286
,
1.0
],
cumulative_distribution
.
numpy
())
def
test_annealing_sample_distribution
(
self
):
num_epoch
=
3
step_per_epoch
=
6
annel_sampler
=
sampler
.
get_task_sampler
(
configs
.
TaskSamplingConfig
(
type
=
'annealing'
,
annealing
=
configs
.
AnnealingSampleConfig
(
steps_per_epoch
=
step_per_epoch
,
total_steps
=
step_per_epoch
*
num_epoch
)),
self
.
_task_weights
)
global_step
=
tf
.
Variable
(
0
,
dtype
=
tf
.
int64
,
name
=
'global_step'
,
trainable
=
False
)
expected_cumulative_epochs
=
[[
0.12056106
,
0.4387236
,
1.0
],
[
0.16666667
,
0.5
,
1.0
],
[
0.22477472
,
0.5654695
,
1.0
]]
for
epoch
in
range
(
num_epoch
):
for
_
in
range
(
step_per_epoch
):
cumulative_distribution
=
annel_sampler
.
task_cumulative_distribution
(
tf
.
constant
(
global_step
,
dtype
=
tf
.
int64
))
global_step
.
assign_add
(
1
)
self
.
assertAllClose
(
expected_cumulative_epochs
[
epoch
],
cumulative_distribution
.
numpy
())
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/modeling/multitask/test_utils.py
0 → 100644
View file @
e773b9b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Testing utils for mock models and tasks."""
from
typing
import
Dict
,
Text
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
config_definitions
as
cfg
from
official.core
import
task_factory
from
official.modeling.multitask
import
base_model
class
MockFooModel
(
tf
.
keras
.
Model
):
"""A mock model can consume 'foo' and 'bar' inputs."""
def
__init__
(
self
,
shared_layer
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_share_layer
=
shared_layer
self
.
_foo_specific_layer
=
tf
.
keras
.
layers
.
Dense
(
1
)
def
call
(
self
,
inputs
):
self
.
add_loss
(
tf
.
zeros
((
1
,),
dtype
=
tf
.
float32
))
if
"foo"
in
inputs
:
input_tensor
=
inputs
[
"foo"
]
else
:
input_tensor
=
inputs
[
"bar"
]
return
self
.
_foo_specific_layer
(
self
.
_share_layer
(
input_tensor
))
class
MockBarModel
(
tf
.
keras
.
Model
):
def
__init__
(
self
,
shared_layer
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_share_layer
=
shared_layer
self
.
_bar_specific_layer
=
tf
.
keras
.
layers
.
Dense
(
1
)
def
call
(
self
,
inputs
):
self
.
add_loss
(
tf
.
zeros
((
2
,),
dtype
=
tf
.
float32
))
return
self
.
_bar_specific_layer
(
self
.
_share_layer
(
inputs
[
"bar"
]))
class
MockMultiTaskModel
(
base_model
.
MultiTaskBaseModel
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
_shared_dense
=
tf
.
keras
.
layers
.
Dense
(
1
)
super
().
__init__
(
*
args
,
**
kwargs
)
def
_instantiate_sub_tasks
(
self
)
->
Dict
[
Text
,
tf
.
keras
.
Model
]:
return
{
"foo"
:
MockFooModel
(
self
.
_shared_dense
),
"bar"
:
MockBarModel
(
self
.
_shared_dense
)
}
def
mock_data
(
feature_name
):
"""Mock dataset function."""
def
_generate_data
(
_
):
x
=
tf
.
zeros
(
shape
=
(
2
,),
dtype
=
tf
.
float32
)
label
=
tf
.
zeros
([
1
],
dtype
=
tf
.
int32
)
return
{
feature_name
:
x
},
label
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
_generate_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
.
prefetch
(
buffer_size
=
1
).
batch
(
2
,
drop_remainder
=
True
)
class
FooConfig
(
cfg
.
TaskConfig
):
pass
class
BarConfig
(
cfg
.
TaskConfig
):
pass
@
task_factory
.
register_task_cls
(
FooConfig
)
class
MockFooTask
(
base_task
.
Task
):
"""Mock foo task object for testing."""
def
build_metrics
(
self
,
training
:
bool
=
True
):
del
training
return
[
tf
.
keras
.
metrics
.
Accuracy
(
name
=
"foo_acc"
)]
def
build_inputs
(
self
,
params
):
return
mock_data
(
"foo"
)
def
build_model
(
self
)
->
tf
.
keras
.
Model
:
return
MockFooModel
(
shared_layer
=
tf
.
keras
.
layers
.
Dense
(
1
))
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
loss
=
tf
.
keras
.
losses
.
mean_squared_error
(
labels
,
model_outputs
)
if
aux_losses
:
loss
+=
tf
.
add_n
(
aux_losses
)
return
tf
.
reduce_mean
(
loss
)
@
task_factory
.
register_task_cls
(
BarConfig
)
class
MockBarTask
(
base_task
.
Task
):
"""Mock bar task object for testing."""
def
build_metrics
(
self
,
training
:
bool
=
True
):
del
training
return
[
tf
.
keras
.
metrics
.
Accuracy
(
name
=
"bar_acc"
)]
def
build_inputs
(
self
,
params
):
return
mock_data
(
"bar"
)
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
loss
=
tf
.
keras
.
losses
.
mean_squared_error
(
labels
,
model_outputs
)
if
aux_losses
:
loss
+=
tf
.
add_n
(
aux_losses
)
return
tf
.
reduce_mean
(
loss
)
official/modeling/multitask/train_lib.py
View file @
e773b9b3
...
@@ -21,9 +21,113 @@ import tensorflow as tf
...
@@ -21,9 +21,113 @@ import tensorflow as tf
from
official.core
import
base_task
from
official.core
import
base_task
from
official.core
import
base_trainer
as
core_lib
from
official.core
import
base_trainer
as
core_lib
from
official.core
import
train_utils
from
official.core
import
train_utils
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
base_trainer
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
evaluator
as
evaluator_lib
from
official.modeling.multitask
import
evaluator
as
evaluator_lib
from
official.modeling.multitask
import
interleaving_trainer
from
official.modeling.multitask
import
multitask
from
official.modeling.multitask
import
multitask
from
official.modeling.multitask
import
task_sampler
TRAINERS
=
{
'interleaving'
:
interleaving_trainer
.
MultiTaskInterleavingTrainer
,
'joint'
:
base_trainer
.
MultiTaskBaseTrainer
}
def
run_experiment
(
*
,
distribution_strategy
:
tf
.
distribute
.
Strategy
,
task
:
multitask
.
MultiTask
,
model
:
base_model
.
MultiTaskBaseModel
,
mode
:
str
,
params
:
configs
.
MultiTaskExperimentConfig
,
model_dir
:
str
)
->
base_model
.
MultiTaskBaseModel
:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
task: A MultiTaskTask instance.
model: A MultiTaskBaseModel instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
Returns:
model: `base_model.MultiTaskBaseModel` instance.
"""
is_training
=
'train'
in
mode
is_eval
=
'eval'
in
mode
with
distribution_strategy
.
scope
():
optimizer
=
task
.
create_optimizer
(
params
.
trainer
.
optimizer_config
,
params
.
runtime
)
kwargs
=
dict
(
multi_task
=
task
,
multi_task_model
=
model
,
optimizer
=
optimizer
)
if
params
.
trainer
.
trainer_type
==
'interleaving'
:
sampler
=
task_sampler
.
get_task_sampler
(
params
.
trainer
.
task_sampler
,
task
.
task_weights
)
kwargs
.
update
(
dict
(
task_sampler
=
sampler
))
trainer
=
TRAINERS
[
params
.
trainer
.
trainer_type
](
**
kwargs
)
if
is_training
else
None
if
is_eval
:
evaluator
=
evaluator_lib
.
MultiTaskEvaluator
(
task
=
task
,
model
=
model
,
global_step
=
trainer
.
global_step
if
is_training
else
None
)
else
:
evaluator
=
None
if
trainer
:
checkpoint
=
trainer
.
checkpoint
global_step
=
trainer
.
global_step
else
:
checkpoint
=
evaluator
.
checkpoint
global_step
=
evaluator
.
global_step
# TODO(hongkuny,haozhangthu): Revisit initialization method.
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
directory
=
model_dir
,
max_to_keep
=
params
.
trainer
.
max_to_keep
,
step_counter
=
global_step
,
checkpoint_interval
=
params
.
trainer
.
checkpoint_interval
,
init_fn
=
model
.
initialize
)
controller
=
orbit
.
Controller
(
strategy
=
distribution_strategy
,
trainer
=
trainer
,
evaluator
=
evaluator
,
global_step
=
global_step
,
steps_per_loop
=
params
.
trainer
.
steps_per_loop
,
checkpoint_manager
=
checkpoint_manager
,
summary_dir
=
os
.
path
.
join
(
model_dir
,
'train'
),
eval_summary_dir
=
os
.
path
.
join
(
model_dir
,
'validation'
),
summary_interval
=
params
.
trainer
.
summary_interval
)
logging
.
info
(
'Starts to execute mode: %s'
,
mode
)
with
distribution_strategy
.
scope
():
if
mode
==
'train'
:
controller
.
train
(
steps
=
params
.
trainer
.
train_steps
)
elif
mode
==
'train_and_eval'
:
controller
.
train_and_evaluate
(
train_steps
=
params
.
trainer
.
train_steps
,
eval_steps
=
params
.
trainer
.
validation_steps
,
eval_interval
=
params
.
trainer
.
validation_interval
)
elif
mode
==
'eval'
:
controller
.
evaluate
(
steps
=
params
.
trainer
.
validation_steps
)
elif
mode
==
'continuous_eval'
:
def
timeout_fn
():
if
evaluator
.
global_step
.
numpy
()
>=
params
.
trainer
.
train_steps
:
return
True
return
False
controller
.
evaluate_continuously
(
steps
=
params
.
trainer
.
validation_steps
,
timeout
=
params
.
trainer
.
continuous_eval_timeout
,
timeout_fn
=
timeout_fn
)
else
:
raise
NotImplementedError
(
'The mode is not implemented: %s'
%
mode
)
return
model
def
run_experiment_with_multitask_eval
(
def
run_experiment_with_multitask_eval
(
...
...
official/modeling/multitask/train_lib_test.py
0 → 100644
View file @
e773b9b3
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask.train_lib."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.core
import
task_factory
from
official.modeling.hyperparams
import
params_dict
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
multitask
from
official.modeling.multitask
import
test_utils
from
official.modeling.multitask
import
train_lib
class
TrainLibTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
self
.
_test_config
=
{
'trainer'
:
{
'checkpoint_interval'
:
10
,
'steps_per_loop'
:
10
,
'summary_interval'
:
10
,
'train_steps'
:
10
,
'validation_steps'
:
5
,
'validation_interval'
:
10
,
'continuous_eval_timeout'
:
1
,
'optimizer_config'
:
{
'optimizer'
:
{
'type'
:
'sgd'
,
},
'learning_rate'
:
{
'type'
:
'constant'
}
}
},
}
@
combinations
.
generate
(
combinations
.
combine
(
distribution_strategy
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
mode
=
'eager'
,
flag_mode
=
[
'train'
,
'eval'
,
'train_and_eval'
]))
def
test_end_to_end
(
self
,
distribution_strategy
,
flag_mode
):
model_dir
=
self
.
get_temp_dir
()
experiment_config
=
configs
.
MultiTaskExperimentConfig
(
task
=
configs
.
MultiTaskConfig
(
task_routines
=
(
configs
.
TaskRoutine
(
task_name
=
'foo'
,
task_config
=
test_utils
.
FooConfig
()),
configs
.
TaskRoutine
(
task_name
=
'bar'
,
task_config
=
test_utils
.
BarConfig
()))))
experiment_config
=
params_dict
.
override_params_dict
(
experiment_config
,
self
.
_test_config
,
is_strict
=
False
)
with
distribution_strategy
.
scope
():
test_multitask
=
multitask
.
MultiTask
.
from_config
(
experiment_config
.
task
)
model
=
test_utils
.
MockMultiTaskModel
()
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
test_multitask
,
model
=
model
,
mode
=
flag_mode
,
params
=
experiment_config
,
model_dir
=
model_dir
)
@
combinations
.
generate
(
combinations
.
combine
(
distribution_strategy
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
mode
=
'eager'
,
flag_mode
=
[
'train'
,
'eval'
,
'train_and_eval'
]))
def
test_end_to_end_multi_eval
(
self
,
distribution_strategy
,
flag_mode
):
model_dir
=
self
.
get_temp_dir
()
experiment_config
=
configs
.
MultiEvalExperimentConfig
(
task
=
test_utils
.
FooConfig
(),
eval_tasks
=
configs
.
MultiTaskConfig
(
task_routines
=
(
configs
.
TaskRoutine
(
task_name
=
'foo'
,
task_config
=
test_utils
.
FooConfig
()),
configs
.
TaskRoutine
(
task_name
=
'bar'
,
task_config
=
test_utils
.
BarConfig
()))))
experiment_config
=
params_dict
.
override_params_dict
(
experiment_config
,
self
.
_test_config
,
is_strict
=
False
)
with
distribution_strategy
.
scope
():
train_task
=
task_factory
.
get_task
(
experiment_config
.
task
)
eval_tasks
=
multitask
.
MultiTask
.
from_config
(
experiment_config
.
eval_tasks
)
train_lib
.
run_experiment_with_multitask_eval
(
distribution_strategy
=
distribution_strategy
,
train_task
=
train_task
,
eval_tasks
=
eval_tasks
,
mode
=
flag_mode
,
params
=
experiment_config
,
model_dir
=
model_dir
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment