Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
29b4a322
Commit
29b4a322
authored
Jul 24, 2021
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Jul 24, 2021
Browse files
Refactor multitask evaluator: consume a list of tasks and optional dictionary of eval steps.
PiperOrigin-RevId: 386654855
parent
d3b705d2
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
60 additions
and
59 deletions
+60
-59
official/core/config_definitions.py
official/core/config_definitions.py
+2
-1
official/modeling/multitask/configs.py
official/modeling/multitask/configs.py
+2
-1
official/modeling/multitask/evaluator.py
official/modeling/multitask/evaluator.py
+22
-21
official/modeling/multitask/evaluator_test.py
official/modeling/multitask/evaluator_test.py
+3
-8
official/modeling/multitask/multitask.py
official/modeling/multitask/multitask.py
+3
-9
official/modeling/multitask/train_lib.py
official/modeling/multitask/train_lib.py
+13
-7
official/modeling/multitask/train_lib_test.py
official/modeling/multitask/train_lib_test.py
+11
-10
official/nlp/continuous_finetune_lib.py
official/nlp/continuous_finetune_lib.py
+4
-2
No files found.
official/core/config_definitions.py
View file @
29b4a322
...
@@ -239,9 +239,10 @@ class TrainerConfig(base_config.Config):
...
@@ -239,9 +239,10 @@ class TrainerConfig(base_config.Config):
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
TaskConfig
(
base_config
.
Config
):
class
TaskConfig
(
base_config
.
Config
):
init_checkpoint
:
str
=
""
init_checkpoint
:
str
=
""
model
:
base_config
.
Config
=
None
model
:
Optional
[
base_config
.
Config
]
=
None
train_data
:
DataConfig
=
DataConfig
()
train_data
:
DataConfig
=
DataConfig
()
validation_data
:
DataConfig
=
DataConfig
()
validation_data
:
DataConfig
=
DataConfig
()
name
:
Optional
[
str
]
=
None
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/modeling/multitask/configs.py
View file @
29b4a322
...
@@ -23,6 +23,7 @@ from official.modeling import hyperparams
...
@@ -23,6 +23,7 @@ from official.modeling import hyperparams
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
TaskRoutine
(
hyperparams
.
Config
):
class
TaskRoutine
(
hyperparams
.
Config
):
# TODO(hongkuny): deprecate the task_name once we migrated client code.
task_name
:
str
=
""
task_name
:
str
=
""
task_config
:
cfg
.
TaskConfig
=
None
task_config
:
cfg
.
TaskConfig
=
None
eval_steps
:
Optional
[
int
]
=
None
eval_steps
:
Optional
[
int
]
=
None
...
@@ -76,4 +77,4 @@ class MultiEvalExperimentConfig(cfg.ExperimentConfig):
...
@@ -76,4 +77,4 @@ class MultiEvalExperimentConfig(cfg.ExperimentConfig):
Attributes:
Attributes:
eval_tasks: individual evaluation tasks.
eval_tasks: individual evaluation tasks.
"""
"""
eval_tasks
:
MultiTaskConfig
=
MultiTaskConfig
()
eval_tasks
:
Tuple
[
TaskRoutine
,
...]
=
()
official/modeling/multitask/evaluator.py
View file @
29b4a322
...
@@ -16,14 +16,14 @@
...
@@ -16,14 +16,14 @@
The evaluator implements the Orbit `AbstractEvaluator` interface.
The evaluator implements the Orbit `AbstractEvaluator` interface.
"""
"""
from
typing
import
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
import
gin
import
gin
import
orbit
import
orbit
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
train_utils
from
official.core
import
train_utils
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
multitask
@
gin
.
configurable
@
gin
.
configurable
...
@@ -32,37 +32,39 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
...
@@ -32,37 +32,39 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
def
__init__
(
def
__init__
(
self
,
self
,
task
:
multitask
.
Multi
Task
,
eval_
task
s
:
List
[
base_task
.
Task
]
,
model
:
Union
[
tf
.
keras
.
Model
,
base_model
.
MultiTaskBaseModel
],
model
:
Union
[
tf
.
keras
.
Model
,
base_model
.
MultiTaskBaseModel
],
global_step
:
Optional
[
tf
.
Variable
]
=
None
,
global_step
:
Optional
[
tf
.
Variable
]
=
None
,
eval_steps
:
Optional
[
Dict
[
str
,
int
]]
=
None
,
checkpoint_exporter
:
Optional
[
train_utils
.
BestCheckpointExporter
]
=
None
):
checkpoint_exporter
:
Optional
[
train_utils
.
BestCheckpointExporter
]
=
None
):
"""Initialize common trainer for TensorFlow models.
"""Initialize common trainer for TensorFlow models.
Args:
Args:
task: A
multitask.MultiTask instanc
e.
eval_
task
s
: A
list of tasks to evaluat
e.
model: tf.keras.Model instance.
model: tf.keras.Model instance.
global_step: the global step variable.
global_step: the global step variable.
eval_steps: a dictionary of steps to run eval keyed by task names.
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
checkpoint_exporter: an object that has the `maybe_export_checkpoint`
interface.
interface.
"""
"""
# Gets the current distribution strategy. If not inside any strategy scope,
# Gets the current distribution strategy. If not inside any strategy scope,
# it gets a single-replica no-op strategy.
# it gets a single-replica no-op strategy.
self
.
_strategy
=
tf
.
distribute
.
get_strategy
()
self
.
_strategy
=
tf
.
distribute
.
get_strategy
()
self
.
_task
=
task
self
.
_task
s
=
eval_
task
s
self
.
_model
=
model
self
.
_model
=
model
self
.
_global_step
=
global_step
or
orbit
.
utils
.
create_global_step
()
self
.
_global_step
=
global_step
or
orbit
.
utils
.
create_global_step
()
self
.
_checkpoint_exporter
=
checkpoint_exporter
self
.
_checkpoint_exporter
=
checkpoint_exporter
self
.
_checkpoint
=
tf
.
train
.
Checkpoint
(
self
.
_checkpoint
=
tf
.
train
.
Checkpoint
(
global_step
=
self
.
global_step
,
global_step
=
self
.
global_step
,
model
=
self
.
model
)
model
=
self
.
model
)
self
.
_validation_losses
=
None
self
.
_validation_losses
=
None
self
.
_validation_metrics
=
None
self
.
_validation_metrics
=
None
# Builds per-task datasets.
# Builds per-task datasets.
self
.
eval_datasets
=
{}
self
.
eval_datasets
=
{}
for
name
,
task
in
self
.
task
.
tasks
.
items
():
self
.
eval_steps
=
eval_steps
or
{}
self
.
eval_datasets
[
name
]
=
orbit
.
utils
.
make_distributed_dataset
(
for
task
in
self
.
tasks
:
self
.
eval_datasets
[
task
.
name
]
=
orbit
.
utils
.
make_distributed_dataset
(
self
.
strategy
,
task
.
build_inputs
,
task
.
task_config
.
validation_data
)
self
.
strategy
,
task
.
build_inputs
,
task
.
task_config
.
validation_data
)
# Builds per-task validation loops.
# Builds per-task validation loops.
...
@@ -89,8 +91,7 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
...
@@ -89,8 +91,7 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
return
orbit
.
utils
.
create_loop_fn
(
eval_step_fn
)
return
orbit
.
utils
.
create_loop_fn
(
eval_step_fn
)
self
.
task_fns
=
{
self
.
task_fns
=
{
name
:
get_function
(
name
,
task
)
task
.
name
:
get_function
(
task
.
name
,
task
)
for
task
in
self
.
tasks
for
name
,
task
in
self
.
task
.
tasks
.
items
()
}
}
@
property
@
property
...
@@ -98,8 +99,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
...
@@ -98,8 +99,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
return
self
.
_strategy
return
self
.
_strategy
@
property
@
property
def
task
(
self
):
def
task
s
(
self
):
return
self
.
_task
return
self
.
_task
s
@
property
@
property
def
model
(
self
):
def
model
(
self
):
...
@@ -115,8 +116,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
...
@@ -115,8 +116,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
if
self
.
_validation_losses
is
None
:
if
self
.
_validation_losses
is
None
:
# Builds the per-task metrics and losses.
# Builds the per-task metrics and losses.
self
.
_validation_losses
=
{}
self
.
_validation_losses
=
{}
for
name
in
self
.
task
.
tasks
:
for
task
in
self
.
tasks
:
self
.
_validation_losses
[
name
]
=
tf
.
keras
.
metrics
.
Mean
(
self
.
_validation_losses
[
task
.
name
]
=
tf
.
keras
.
metrics
.
Mean
(
"validation_loss"
,
dtype
=
tf
.
float32
)
"validation_loss"
,
dtype
=
tf
.
float32
)
return
self
.
_validation_losses
return
self
.
_validation_losses
...
@@ -126,8 +127,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
...
@@ -126,8 +127,8 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
if
self
.
_validation_metrics
is
None
:
if
self
.
_validation_metrics
is
None
:
# Builds the per-task metrics and losses.
# Builds the per-task metrics and losses.
self
.
_validation_metrics
=
{}
self
.
_validation_metrics
=
{}
for
name
,
task
in
self
.
task
.
tasks
.
items
()
:
for
task
in
self
.
task
s
:
self
.
_validation_metrics
[
name
]
=
task
.
build_metrics
(
training
=
False
)
self
.
_validation_metrics
[
task
.
name
]
=
task
.
build_metrics
(
training
=
False
)
return
self
.
_validation_metrics
return
self
.
_validation_metrics
@
property
@
property
...
@@ -145,12 +146,12 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
...
@@ -145,12 +146,12 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
results
=
{}
results
=
{}
eval_iters
=
tf
.
nest
.
map_structure
(
iter
,
self
.
eval_datasets
)
eval_iters
=
tf
.
nest
.
map_structure
(
iter
,
self
.
eval_datasets
)
for
name
,
task_eval_loop
in
self
.
task
_fns
.
items
()
:
for
task
in
self
.
task
s
:
outputs
=
None
outputs
=
None
name
=
task
.
name
eval_iter
=
eval_iters
[
name
]
eval_iter
=
eval_iters
[
name
]
task
=
self
.
task
.
tasks
[
name
]
task_eval_steps
=
self
.
eval_steps
.
get
(
name
,
None
)
or
num_steps
task_eval_steps
=
self
.
task
.
task_eval_steps
(
name
)
or
num_steps
outputs
=
self
.
task_fns
[
name
](
outputs
=
task_eval_loop
(
eval_iter
,
eval_iter
,
task_eval_steps
,
task_eval_steps
,
state
=
outputs
,
state
=
outputs
,
...
...
official/modeling/multitask/evaluator_test.py
View file @
29b4a322
...
@@ -22,7 +22,6 @@ from tensorflow.python.distribute import strategy_combinations
...
@@ -22,7 +22,6 @@ from tensorflow.python.distribute import strategy_combinations
from
official.core
import
base_task
from
official.core
import
base_task
from
official.core
import
config_definitions
as
cfg
from
official.core
import
config_definitions
as
cfg
from
official.modeling.multitask
import
evaluator
from
official.modeling.multitask
import
evaluator
from
official.modeling.multitask
import
multitask
def
all_strategy_combinations
():
def
all_strategy_combinations
():
...
@@ -89,9 +88,7 @@ class MockTask(base_task.Task):
...
@@ -89,9 +88,7 @@ class MockTask(base_task.Task):
np
.
concatenate
([
np
.
expand_dims
(
v
.
numpy
(),
axis
=
0
)
for
v
in
value
]))
np
.
concatenate
([
np
.
expand_dims
(
v
.
numpy
(),
axis
=
0
)
for
v
in
value
]))
return
state
return
state
def
reduce_aggregated_logs
(
self
,
def
reduce_aggregated_logs
(
self
,
aggregated_logs
,
global_step
=
None
):
aggregated_logs
,
global_step
=
None
):
for
k
,
v
in
aggregated_logs
.
items
():
for
k
,
v
in
aggregated_logs
.
items
():
aggregated_logs
[
k
]
=
np
.
sum
(
np
.
stack
(
v
,
axis
=
0
))
aggregated_logs
[
k
]
=
np
.
sum
(
np
.
stack
(
v
,
axis
=
0
))
return
aggregated_logs
return
aggregated_logs
...
@@ -106,10 +103,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -106,10 +103,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"bar"
),
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"bar"
),
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"foo"
)
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"foo"
)
]
]
test_multitask
=
multitask
.
MultiTask
(
tasks
=
tasks
)
model
=
MockModel
()
model
=
MockModel
()
test_evaluator
=
evaluator
.
MultiTaskEvaluator
(
test_evaluator
=
evaluator
.
MultiTaskEvaluator
(
task
=
test_multi
task
,
model
=
model
)
eval_tasks
=
task
s
,
model
=
model
)
results
=
test_evaluator
.
evaluate
(
tf
.
convert_to_tensor
(
1
,
dtype
=
tf
.
int32
))
results
=
test_evaluator
.
evaluate
(
tf
.
convert_to_tensor
(
1
,
dtype
=
tf
.
int32
))
self
.
assertContainsSubset
([
"validation_loss"
,
"acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"validation_loss"
,
"acc"
],
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"validation_loss"
,
"acc"
],
results
[
"foo"
].
keys
())
self
.
assertContainsSubset
([
"validation_loss"
,
"acc"
],
results
[
"foo"
].
keys
())
...
@@ -123,10 +119,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -123,10 +119,9 @@ class EvaluatorTest(tf.test.TestCase, parameterized.TestCase):
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"bar"
),
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"bar"
),
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"foo"
)
MockTask
(
params
=
cfg
.
TaskConfig
(),
name
=
"foo"
)
]
]
test_multitask
=
multitask
.
MultiTask
(
tasks
=
tasks
)
model
=
MockModel
()
model
=
MockModel
()
test_evaluator
=
evaluator
.
MultiTaskEvaluator
(
test_evaluator
=
evaluator
.
MultiTaskEvaluator
(
task
=
test_multi
task
,
model
=
model
)
eval_tasks
=
task
s
,
model
=
model
)
results
=
test_evaluator
.
evaluate
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
results
=
test_evaluator
.
evaluate
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertEqual
(
results
[
"bar"
][
"counter"
],
self
.
assertEqual
(
results
[
"bar"
][
"counter"
],
5.
*
distribution
.
num_replicas_in_sync
)
5.
*
distribution
.
num_replicas_in_sync
)
...
...
official/modeling/multitask/multitask.py
View file @
29b4a322
...
@@ -59,10 +59,7 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
...
@@ -59,10 +59,7 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
else
:
else
:
raise
ValueError
(
"The tasks argument has an invalid type: %s"
%
raise
ValueError
(
"The tasks argument has an invalid type: %s"
%
type
(
tasks
))
type
(
tasks
))
self
.
_task_eval_steps
=
task_eval_steps
or
{}
self
.
task_eval_steps
=
task_eval_steps
or
{}
self
.
_task_eval_steps
=
dict
([
(
name
,
self
.
_task_eval_steps
.
get
(
name
,
None
))
for
name
in
self
.
tasks
])
self
.
_task_weights
=
task_weights
or
{}
self
.
_task_weights
=
task_weights
or
{}
self
.
_task_weights
=
dict
([
self
.
_task_weights
=
dict
([
(
name
,
self
.
_task_weights
.
get
(
name
,
1.0
))
for
name
in
self
.
tasks
(
name
,
self
.
_task_weights
.
get
(
name
,
1.0
))
for
name
in
self
.
tasks
...
@@ -74,9 +71,9 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
...
@@ -74,9 +71,9 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
task_eval_steps
=
{}
task_eval_steps
=
{}
task_weights
=
{}
task_weights
=
{}
for
task_routine
in
config
.
task_routines
:
for
task_routine
in
config
.
task_routines
:
task_name
=
task_routine
.
task_name
task_name
=
task_routine
.
task_name
or
task_routine
.
task_config
.
name
tasks
[
task_name
]
=
task_factory
.
get_task
(
tasks
[
task_name
]
=
task_factory
.
get_task
(
task_routine
.
task_config
,
logging_dir
=
logging_dir
)
task_routine
.
task_config
,
logging_dir
=
logging_dir
,
name
=
task_name
)
task_eval_steps
[
task_name
]
=
task_routine
.
eval_steps
task_eval_steps
[
task_name
]
=
task_routine
.
eval_steps
task_weights
[
task_name
]
=
task_routine
.
task_weight
task_weights
[
task_name
]
=
task_routine
.
task_weight
return
cls
(
return
cls
(
...
@@ -86,9 +83,6 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
...
@@ -86,9 +83,6 @@ class MultiTask(tf.Module, metaclass=abc.ABCMeta):
def
tasks
(
self
):
def
tasks
(
self
):
return
self
.
_tasks
return
self
.
_tasks
def
task_eval_steps
(
self
,
task_name
):
return
self
.
_task_eval_steps
[
task_name
]
def
task_weight
(
self
,
task_name
):
def
task_weight
(
self
,
task_name
):
return
self
.
_task_weights
[
task_name
]
return
self
.
_task_weights
[
task_name
]
...
...
official/modeling/multitask/train_lib.py
View file @
29b4a322
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
"""Multitask training driver library."""
"""Multitask training driver library."""
# pytype: disable=attribute-error
# pytype: disable=attribute-error
import
os
import
os
from
typing
import
Optional
from
typing
import
List
,
Optional
from
absl
import
logging
from
absl
import
logging
import
orbit
import
orbit
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -69,9 +69,11 @@ def run_experiment(*, distribution_strategy: tf.distribute.Strategy,
...
@@ -69,9 +69,11 @@ def run_experiment(*, distribution_strategy: tf.distribute.Strategy,
trainer
=
TRAINERS
[
params
.
trainer
.
trainer_type
](
trainer
=
TRAINERS
[
params
.
trainer
.
trainer_type
](
**
kwargs
)
if
is_training
else
None
**
kwargs
)
if
is_training
else
None
if
is_eval
:
if
is_eval
:
eval_steps
=
task
.
task_eval_steps
evaluator
=
evaluator_lib
.
MultiTaskEvaluator
(
evaluator
=
evaluator_lib
.
MultiTaskEvaluator
(
task
=
task
,
eval_
task
s
=
task
.
tasks
.
values
()
,
model
=
model
,
model
=
model
,
eval_steps
=
eval_steps
,
global_step
=
trainer
.
global_step
if
is_training
else
None
,
global_step
=
trainer
.
global_step
if
is_training
else
None
,
checkpoint_exporter
=
train_utils
.
maybe_create_best_ckpt_exporter
(
checkpoint_exporter
=
train_utils
.
maybe_create_best_ckpt_exporter
(
params
,
model_dir
))
params
,
model_dir
))
...
@@ -137,7 +139,7 @@ def run_experiment_with_multitask_eval(
...
@@ -137,7 +139,7 @@ def run_experiment_with_multitask_eval(
*
,
*
,
distribution_strategy
:
tf
.
distribute
.
Strategy
,
distribution_strategy
:
tf
.
distribute
.
Strategy
,
train_task
:
base_task
.
Task
,
train_task
:
base_task
.
Task
,
eval_tasks
:
multitask
.
Multi
Task
,
eval_tasks
:
List
[
base_task
.
Task
]
,
mode
:
str
,
mode
:
str
,
params
:
configs
.
MultiEvalExperimentConfig
,
params
:
configs
.
MultiEvalExperimentConfig
,
model_dir
:
str
,
model_dir
:
str
,
...
@@ -149,7 +151,7 @@ def run_experiment_with_multitask_eval(
...
@@ -149,7 +151,7 @@ def run_experiment_with_multitask_eval(
Args:
Args:
distribution_strategy: A distribution distribution_strategy.
distribution_strategy: A distribution distribution_strategy.
train_task: A base_task.Task instance.
train_task: A base_task.Task instance.
eval_tasks: A
multitask.MultiTask with
evaluation tasks.
eval_tasks: A
list of
evaluation tasks.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
or 'continuous_eval'.
params: MultiEvalExperimentConfig instance.
params: MultiEvalExperimentConfig instance.
...
@@ -173,8 +175,8 @@ def run_experiment_with_multitask_eval(
...
@@ -173,8 +175,8 @@ def run_experiment_with_multitask_eval(
config
=
params
,
config
=
params
,
task
=
train_task
,
task
=
train_task
,
model
=
train_task
.
build_model
(),
model
=
train_task
.
build_model
(),
optimizer
=
train_task
.
create_optimizer
(
optimizer
=
train_task
.
create_optimizer
(
params
.
trainer
.
optimizer_config
,
params
.
trainer
.
optimizer_config
,
params
.
runtime
),
params
.
runtime
),
train
=
True
,
train
=
True
,
evaluate
=
False
)
evaluate
=
False
)
else
:
else
:
...
@@ -182,10 +184,14 @@ def run_experiment_with_multitask_eval(
...
@@ -182,10 +184,14 @@ def run_experiment_with_multitask_eval(
model
=
trainer
.
model
if
trainer
else
train_task
.
build_model
()
model
=
trainer
.
model
if
trainer
else
train_task
.
build_model
()
if
is_eval
:
if
is_eval
:
eval_steps
=
dict
([(
task_routine
.
task_config
.
name
,
task_routine
.
eval_steps
)
for
task_routine
in
params
.
eval_tasks
])
evaluator
=
evaluator_lib
.
MultiTaskEvaluator
(
evaluator
=
evaluator_lib
.
MultiTaskEvaluator
(
task
=
eval_tasks
,
eval_
task
s
=
eval_tasks
,
model
=
model
,
model
=
model
,
global_step
=
trainer
.
global_step
if
is_training
else
None
,
global_step
=
trainer
.
global_step
if
is_training
else
None
,
eval_steps
=
eval_steps
,
checkpoint_exporter
=
train_utils
.
maybe_create_best_ckpt_exporter
(
checkpoint_exporter
=
train_utils
.
maybe_create_best_ckpt_exporter
(
params
,
model_dir
))
params
,
model_dir
))
else
:
else
:
...
...
official/modeling/multitask/train_lib_test.py
View file @
29b4a322
...
@@ -65,8 +65,7 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -65,8 +65,7 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
task
=
configs
.
MultiTaskConfig
(
task
=
configs
.
MultiTaskConfig
(
task_routines
=
(
task_routines
=
(
configs
.
TaskRoutine
(
configs
.
TaskRoutine
(
task_name
=
'foo'
,
task_name
=
'foo'
,
task_config
=
test_utils
.
FooConfig
()),
task_config
=
test_utils
.
FooConfig
()),
configs
.
TaskRoutine
(
configs
.
TaskRoutine
(
task_name
=
'bar'
,
task_config
=
test_utils
.
BarConfig
()))))
task_name
=
'bar'
,
task_config
=
test_utils
.
BarConfig
()))))
experiment_config
=
params_dict
.
override_params_dict
(
experiment_config
=
params_dict
.
override_params_dict
(
...
@@ -95,18 +94,20 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -95,18 +94,20 @@ class TrainLibTest(tf.test.TestCase, parameterized.TestCase):
model_dir
=
self
.
get_temp_dir
()
model_dir
=
self
.
get_temp_dir
()
experiment_config
=
configs
.
MultiEvalExperimentConfig
(
experiment_config
=
configs
.
MultiEvalExperimentConfig
(
task
=
test_utils
.
FooConfig
(),
task
=
test_utils
.
FooConfig
(),
eval_tasks
=
configs
.
MultiTaskConfig
(
eval_tasks
=
(
configs
.
TaskRoutine
(
task_routines
=
(
task_name
=
'foo'
,
task_config
=
test_utils
.
FooConfig
(),
eval_steps
=
2
),
configs
.
TaskRoutine
(
configs
.
TaskRoutine
(
task_name
=
'foo'
,
task_name
=
'bar'
,
task_config
=
test_utils
.
FooConfig
()),
task_config
=
test_utils
.
BarConfig
(),
configs
.
TaskRoutine
(
eval_steps
=
3
)))
task_name
=
'bar'
,
task_config
=
test_utils
.
BarConfig
()))))
experiment_config
=
params_dict
.
override_params_dict
(
experiment_config
=
params_dict
.
override_params_dict
(
experiment_config
,
self
.
_test_config
,
is_strict
=
False
)
experiment_config
,
self
.
_test_config
,
is_strict
=
False
)
with
distribution_strategy
.
scope
():
with
distribution_strategy
.
scope
():
train_task
=
task_factory
.
get_task
(
experiment_config
.
task
)
train_task
=
task_factory
.
get_task
(
experiment_config
.
task
)
eval_tasks
=
multitask
.
MultiTask
.
from_config
(
experiment_config
.
eval_tasks
)
eval_tasks
=
[
task_factory
.
get_task
(
config
.
task_config
,
name
=
config
.
task_name
)
for
config
in
experiment_config
.
eval_tasks
]
train_lib
.
run_experiment_with_multitask_eval
(
train_lib
.
run_experiment_with_multitask_eval
(
distribution_strategy
=
distribution_strategy
,
distribution_strategy
=
distribution_strategy
,
train_task
=
train_task
,
train_task
=
train_task
,
...
...
official/nlp/continuous_finetune_lib.py
View file @
29b4a322
...
@@ -28,7 +28,6 @@ from official.core import train_lib
...
@@ -28,7 +28,6 @@ from official.core import train_lib
from
official.core
import
train_utils
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.modeling
import
performance
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
configs
from
official.modeling.multitask
import
multitask
from
official.modeling.multitask
import
train_lib
as
multitask_train_lib
from
official.modeling.multitask
import
train_lib
as
multitask_train_lib
...
@@ -167,7 +166,10 @@ def run_continuous_finetune(
...
@@ -167,7 +166,10 @@ def run_continuous_finetune(
with
distribution_strategy
.
scope
():
with
distribution_strategy
.
scope
():
if
isinstance
(
params
,
configs
.
MultiEvalExperimentConfig
):
if
isinstance
(
params
,
configs
.
MultiEvalExperimentConfig
):
task
=
task_factory
.
get_task
(
params_replaced
.
task
)
task
=
task_factory
.
get_task
(
params_replaced
.
task
)
eval_tasks
=
multitask
.
MultiTask
.
from_config
(
params_replaced
.
eval_tasks
)
eval_tasks
=
[
task_factory
.
get_task
(
config
.
task_config
,
name
=
config
.
task_name
)
for
config
in
params
.
eval_tasks
]
(
_
,
(
_
,
eval_metrics
)
=
multitask_train_lib
.
run_experiment_with_multitask_eval
(
eval_metrics
)
=
multitask_train_lib
.
run_experiment_with_multitask_eval
(
distribution_strategy
=
distribution_strategy
,
distribution_strategy
=
distribution_strategy
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment