Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b92025a9
Commit
b92025a9
authored
Aug 18, 2021
by
anivegesana
Browse files
Merge branch 'master' of
https://github.com/tensorflow/models
into detection_generator_pr_2
parents
1b425791
37536370
Changes
108
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
532 additions
and
125 deletions
+532
-125
official/core/actions.py
official/core/actions.py
+69
-3
official/core/actions_test.py
official/core/actions_test.py
+35
-9
official/core/base_trainer.py
official/core/base_trainer.py
+1
-5
official/core/base_trainer_test.py
official/core/base_trainer_test.py
+8
-56
official/core/registry.py
official/core/registry.py
+7
-3
official/core/train_lib.py
official/core/train_lib.py
+2
-3
official/core/train_lib_test.py
official/core/train_lib_test.py
+89
-1
official/modeling/fast_training/experimental/tf2_utils_2x_wide.py
.../modeling/fast_training/experimental/tf2_utils_2x_wide.py
+186
-0
official/modeling/fast_training/experimental/tf2_utils_2x_wide_test.py
...ling/fast_training/experimental/tf2_utils_2x_wide_test.py
+101
-0
official/modeling/fast_training/progressive/policies.py
official/modeling/fast_training/progressive/policies.py
+11
-2
official/modeling/fast_training/progressive/train.py
official/modeling/fast_training/progressive/train.py
+1
-1
official/modeling/fast_training/progressive/train_lib.py
official/modeling/fast_training/progressive/train_lib.py
+1
-1
official/modeling/fast_training/progressive/train_lib_test.py
...cial/modeling/fast_training/progressive/train_lib_test.py
+3
-3
official/modeling/fast_training/progressive/trainer.py
official/modeling/fast_training/progressive/trainer.py
+4
-4
official/modeling/fast_training/progressive/trainer_test.py
official/modeling/fast_training/progressive/trainer_test.py
+2
-2
official/modeling/fast_training/progressive/utils.py
official/modeling/fast_training/progressive/utils.py
+0
-0
official/modeling/multitask/base_model.py
official/modeling/multitask/base_model.py
+0
-15
official/modeling/multitask/base_trainer.py
official/modeling/multitask/base_trainer.py
+0
-15
official/modeling/multitask/evaluator.py
official/modeling/multitask/evaluator.py
+8
-1
official/modeling/multitask/task_sampler.py
official/modeling/multitask/task_sampler.py
+4
-1
No files found.
official/core/actions.py
View file @
b92025a9
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
import
os
import
os
from
typing
import
List
from
typing
import
List
from
absl
import
logging
import
gin
import
gin
import
orbit
import
orbit
...
@@ -119,6 +120,58 @@ class EMACheckpointing:
...
@@ -119,6 +120,58 @@ class EMACheckpointing:
self
.
_optimizer
.
swap_weights
()
self
.
_optimizer
.
swap_weights
()
class
RecoveryAction
:
"""Train action to recover from loss blowup.
Checks the loss value by the given threshold. If applicable, recover the
model by reading the checkpoint on disk.
"""
def
__init__
(
self
,
checkpoint_manager
:
tf
.
train
.
CheckpointManager
):
self
.
checkpoint_manager
=
checkpoint_manager
def
__call__
(
self
,
_
):
"""Recovers the training by triggering checkpoint restoration."""
# Loads the previous good checkpoint.
checkpoint_path
=
self
.
checkpoint_manager
.
restore_or_initialize
()
logging
.
warning
(
'Recovering the model from checkpoint: %s.'
,
checkpoint_path
)
class
RecoveryCondition
:
"""Recovery Condition."""
def
__init__
(
self
,
global_step
:
tf
.
Variable
,
loss_upper_bound
:
float
,
recovery_begin_steps
:
int
=
0
,
recovery_max_trials
:
int
=
3
):
self
.
recover_counter
=
0
self
.
recovery_begin_steps
=
recovery_begin_steps
self
.
recovery_max_trials
=
recovery_max_trials
self
.
loss_upper_bound
=
loss_upper_bound
self
.
global_step
=
global_step
def
__call__
(
self
,
outputs
:
orbit
.
runner
.
Output
):
loss_value
=
outputs
[
'training_loss'
]
if
tf
.
math
.
is_nan
(
loss_value
):
self
.
recover_counter
+=
1
if
self
.
recover_counter
>
self
.
recovery_max_trials
:
raise
RuntimeError
(
'The loss value is NaN after training loop and it happens %d times.'
%
self
.
recover_counter
)
return
True
if
(
self
.
global_step
>=
self
.
recovery_begin_steps
and
loss_value
>
self
.
loss_upper_bound
):
self
.
recover_counter
+=
1
if
self
.
recover_counter
>
self
.
recovery_max_trials
:
raise
RuntimeError
(
f
'The loss value is
{
loss_value
}
, which is larger than the bound
{
self
.
loss_upper_bound
}
, happens
{
self
.
recover_counter
}
times.'
)
return
True
return
False
@
gin
.
configurable
@
gin
.
configurable
def
get_eval_actions
(
def
get_eval_actions
(
params
:
config_definitions
.
ExperimentConfig
,
params
:
config_definitions
.
ExperimentConfig
,
...
@@ -140,9 +193,10 @@ def get_eval_actions(
...
@@ -140,9 +193,10 @@ def get_eval_actions(
@
gin
.
configurable
@
gin
.
configurable
def
get_train_actions
(
params
:
config_definitions
.
ExperimentConfig
,
def
get_train_actions
(
trainer
:
base_trainer
.
Trainer
,
params
:
config_definitions
.
ExperimentConfig
,
trainer
:
base_trainer
.
Trainer
,
model_dir
:
str
)
->
List
[
orbit
.
Action
]:
model_dir
:
str
,
checkpoint_manager
:
tf
.
train
.
CheckpointManager
)
->
List
[
orbit
.
Action
]:
"""Gets train actions for TFM trainer."""
"""Gets train actions for TFM trainer."""
train_actions
=
[]
train_actions
=
[]
# Adds pruning callback actions.
# Adds pruning callback actions.
...
@@ -153,4 +207,16 @@ def get_train_actions(params: config_definitions.ExperimentConfig,
...
@@ -153,4 +207,16 @@ def get_train_actions(params: config_definitions.ExperimentConfig,
model
=
trainer
.
model
,
model
=
trainer
.
model
,
optimizer
=
trainer
.
optimizer
))
optimizer
=
trainer
.
optimizer
))
if
params
.
trainer
.
recovery_max_trials
>=
0
:
recovery_condition
=
RecoveryCondition
(
global_step
=
trainer
.
global_step
,
loss_upper_bound
=
params
.
trainer
.
loss_upper_bound
,
recovery_begin_steps
=
params
.
trainer
.
recovery_begin_steps
,
recovery_max_trials
=
params
.
trainer
.
recovery_max_trials
,
)
recover_action
=
orbit
.
actions
.
ConditionalAction
(
condition
=
recovery_condition
,
action
=
RecoveryAction
(
checkpoint_manager
),
)
train_actions
.
append
(
recover_action
)
return
train_actions
return
train_actions
official/core/actions_test.py
View file @
b92025a9
...
@@ -17,6 +17,8 @@
...
@@ -17,6 +17,8 @@
import
os
import
os
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
numpy
as
np
import
orbit
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
combinations
...
@@ -35,17 +37,14 @@ class TestModel(tf.Module):
...
@@ -35,17 +37,14 @@ class TestModel(tf.Module):
return
self
.
value
return
self
.
value
def
all_strategy_combinations
():
class
ActionsTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
return
combinations
.
combine
(
@
combinations
.
generate
(
combinations
.
combine
(
distribution
=
[
distribution
=
[
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
strategy_combinations
.
one_device_strategy_gpu
,
],)
],))
class
ActionsTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_ema_checkpointing
(
self
,
distribution
):
def
test_ema_checkpointing
(
self
,
distribution
):
with
distribution
.
scope
():
with
distribution
.
scope
():
directory
=
self
.
create_tempdir
()
directory
=
self
.
create_tempdir
()
...
@@ -76,6 +75,33 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -76,6 +75,33 @@ class ActionsTest(tf.test.TestCase, parameterized.TestCase):
# Checks model.value is 0 after swapping.
# Checks model.value is 0 after swapping.
self
.
assertEqual
(
model
(),
0
)
self
.
assertEqual
(
model
(),
0
)
@
combinations
.
generate
(
combinations
.
combine
(
distribution
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],))
def
test_recovery_condition
(
self
,
distribution
):
with
distribution
.
scope
():
global_step
=
orbit
.
utils
.
create_global_step
()
recover_condition
=
actions
.
RecoveryCondition
(
global_step
,
loss_upper_bound
=
0.5
,
recovery_max_trials
=
2
)
outputs
=
{
'training_loss'
:
0.6
}
self
.
assertTrue
(
recover_condition
(
outputs
))
self
.
assertTrue
(
recover_condition
(
outputs
))
with
self
.
assertRaises
(
RuntimeError
):
recover_condition
(
outputs
)
global_step
=
orbit
.
utils
.
create_global_step
()
recover_condition
=
actions
.
RecoveryCondition
(
global_step
,
loss_upper_bound
=
0.5
,
recovery_max_trials
=
2
)
outputs
=
{
'training_loss'
:
tf
.
constant
([
np
.
nan
],
tf
.
float32
)}
self
.
assertTrue
(
recover_condition
(
outputs
))
self
.
assertTrue
(
recover_condition
(
outputs
))
with
self
.
assertRaises
(
RuntimeError
):
recover_condition
(
outputs
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
official/core/base_trainer.py
View file @
b92025a9
...
@@ -370,6 +370,7 @@ class Trainer(_AsyncTrainer):
...
@@ -370,6 +370,7 @@ class Trainer(_AsyncTrainer):
"""Accesses the training checkpoint."""
"""Accesses the training checkpoint."""
return
self
.
_checkpoint
return
self
.
_checkpoint
# TODO(yejiayu): Remove this once all deps are fixed.
def
add_recovery
(
self
,
params
:
TrainerConfig
,
def
add_recovery
(
self
,
params
:
TrainerConfig
,
checkpoint_manager
:
tf
.
train
.
CheckpointManager
):
checkpoint_manager
:
tf
.
train
.
CheckpointManager
):
if
params
.
recovery_max_trials
>=
0
:
if
params
.
recovery_max_trials
>=
0
:
...
@@ -382,11 +383,6 @@ class Trainer(_AsyncTrainer):
...
@@ -382,11 +383,6 @@ class Trainer(_AsyncTrainer):
def
train_loop_end
(
self
):
def
train_loop_end
(
self
):
"""See base class."""
"""See base class."""
self
.
join
()
self
.
join
()
# Checks if the model numeric status is stable and conducts the checkpoint
# recovery accordingly.
if
self
.
_recovery
:
self
.
_recovery
.
maybe_recover
(
self
.
train_loss
.
result
().
numpy
(),
self
.
global_step
.
numpy
())
logs
=
{}
logs
=
{}
for
metric
in
self
.
train_metrics
+
[
self
.
train_loss
]:
for
metric
in
self
.
train_metrics
+
[
self
.
train_loss
]:
logs
[
metric
.
name
]
=
metric
.
result
()
logs
[
metric
.
name
]
=
metric
.
result
()
...
...
official/core/base_trainer_test.py
View file @
b92025a9
...
@@ -14,12 +14,12 @@
...
@@ -14,12 +14,12 @@
"""Tests for tensorflow_models.core.trainers.trainer."""
"""Tests for tensorflow_models.core.trainers.trainer."""
# pylint: disable=g-direct-tensorflow-import
# pylint: disable=g-direct-tensorflow-import
import
gc
import
multiprocessing
import
multiprocessing
import
os
import
os
import
sys
import
sys
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
numpy
as
np
import
orbit
import
orbit
import
portpicker
import
portpicker
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -165,6 +165,13 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -165,6 +165,13 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
}
}
})))
})))
def
tearDown
(
self
):
gc
.
collect
()
# This will only contain uncollectable garbage, i.e. reference cycles
# involving objects with __del__ defined.
self
.
assertEmpty
(
gc
.
garbage
)
super
().
tearDown
()
def
create_test_trainer
(
self
,
config
,
model_dir
=
None
,
task
=
None
):
def
create_test_trainer
(
self
,
config
,
model_dir
=
None
,
task
=
None
):
task
=
task
or
mock_task
.
MockTask
(
config
.
task
,
logging_dir
=
model_dir
)
task
=
task
or
mock_task
.
MockTask
(
config
.
task
,
logging_dir
=
model_dir
)
ckpt_exporter
=
train_lib
.
maybe_create_best_ckpt_exporter
(
config
,
model_dir
)
ckpt_exporter
=
train_lib
.
maybe_create_best_ckpt_exporter
(
config
,
model_dir
)
...
@@ -337,61 +344,6 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -337,61 +344,6 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertTrue
(
self
.
assertTrue
(
tf
.
io
.
gfile
.
exists
(
os
.
path
.
join
(
model_dir
,
'best_ckpt'
,
'info.json'
)))
tf
.
io
.
gfile
.
exists
(
os
.
path
.
join
(
model_dir
,
'best_ckpt'
,
'info.json'
)))
def
test_recovery
(
self
):
config
=
cfg
.
ExperimentConfig
(
trainer
=
cfg
.
TrainerConfig
(
loss_upper_bound
=
0.5
,
recovery_max_trials
=
2
,
optimizer_config
=
cfg
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'sgd'
},
'learning_rate'
:
{
'type'
:
'constant'
}
})))
model_dir
=
self
.
get_temp_dir
()
trainer
=
self
.
create_test_trainer
(
config
,
model_dir
=
model_dir
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
trainer
.
checkpoint
,
self
.
get_temp_dir
(),
max_to_keep
=
2
)
checkpoint_manager
.
save
()
trainer
.
add_recovery
(
config
.
trainer
,
checkpoint_manager
=
checkpoint_manager
)
before_weights
=
trainer
.
model
.
get_weights
()
_
=
trainer
.
train
(
tf
.
convert_to_tensor
(
1
,
dtype
=
tf
.
int32
))
# The training loss is 1.0 and upper_bound is 0.5, so the recover happens.
after_weights
=
trainer
.
model
.
get_weights
()
for
left
,
right
in
zip
(
before_weights
,
after_weights
):
self
.
assertAllEqual
(
left
,
right
)
# Let's the loss be NaN and max_trials = 0 to see RuntimeError.
config
=
cfg
.
ExperimentConfig
(
trainer
=
cfg
.
TrainerConfig
(
recovery_max_trials
=
0
,
optimizer_config
=
cfg
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'sgd'
},
'learning_rate'
:
{
'type'
:
'constant'
}
})))
task
=
mock_task
.
MockTask
(
config
.
task
,
logging_dir
=
model_dir
)
def
build_losses
(
labels
,
model_outputs
,
aux_losses
=
None
):
del
labels
,
model_outputs
return
tf
.
constant
([
np
.
nan
],
tf
.
float32
)
+
aux_losses
task
.
build_losses
=
build_losses
trainer
=
trainer_lib
.
Trainer
(
config
,
task
,
model
=
task
.
build_model
(),
optimizer
=
task
.
create_optimizer
(
config
.
trainer
.
optimizer_config
,
config
.
runtime
))
trainer
.
add_recovery
(
config
.
trainer
,
checkpoint_manager
=
checkpoint_manager
)
with
self
.
assertRaises
(
RuntimeError
):
_
=
trainer
.
train
(
tf
.
convert_to_tensor
(
2
,
dtype
=
tf
.
int32
))
def
test_model_with_compiled_loss
(
self
):
def
test_model_with_compiled_loss
(
self
):
task
=
mock_task
.
MockTask
()
task
=
mock_task
.
MockTask
()
model
=
task
.
build_model
()
model
=
task
.
build_model
()
...
...
official/core/registry.py
View file @
b92025a9
...
@@ -87,11 +87,15 @@ def lookup(registered_collection, reg_key):
...
@@ -87,11 +87,15 @@ def lookup(registered_collection, reg_key):
for
h_idx
,
entry_name
in
enumerate
(
hierarchy
):
for
h_idx
,
entry_name
in
enumerate
(
hierarchy
):
if
entry_name
not
in
collection
:
if
entry_name
not
in
collection
:
raise
LookupError
(
raise
LookupError
(
"collection path {} at position {} never registered."
.
format
(
f
"collection path
{
entry_name
}
at position
{
h_idx
}
is never "
entry_name
,
h_idx
))
f
"registered. Please make sure the
{
entry_name
}
and its library is "
"imported and linked to the trainer binary."
)
collection
=
collection
[
entry_name
]
collection
=
collection
[
entry_name
]
return
collection
return
collection
else
:
else
:
if
reg_key
not
in
registered_collection
:
if
reg_key
not
in
registered_collection
:
raise
LookupError
(
"registration key {} never registered."
.
format
(
reg_key
))
raise
LookupError
(
f
"registration key
{
reg_key
}
is never "
f
"registered. Please make sure the
{
reg_key
}
and its library is "
"imported and linked to the trainer binary."
)
return
registered_collection
[
reg_key
]
return
registered_collection
[
reg_key
]
official/core/train_lib.py
View file @
b92025a9
...
@@ -87,8 +87,6 @@ def run_experiment(
...
@@ -87,8 +87,6 @@ def run_experiment(
step_counter
=
trainer
.
global_step
,
step_counter
=
trainer
.
global_step
,
checkpoint_interval
=
params
.
trainer
.
checkpoint_interval
,
checkpoint_interval
=
params
.
trainer
.
checkpoint_interval
,
init_fn
=
trainer
.
initialize
)
init_fn
=
trainer
.
initialize
)
# Adds recovery handling.
trainer
.
add_recovery
(
params
.
trainer
,
checkpoint_manager
=
checkpoint_manager
)
else
:
else
:
checkpoint_manager
=
None
checkpoint_manager
=
None
...
@@ -105,7 +103,8 @@ def run_experiment(
...
@@ -105,7 +103,8 @@ def run_experiment(
(
save_summary
)
else
None
,
(
save_summary
)
else
None
,
summary_interval
=
params
.
trainer
.
summary_interval
if
summary_interval
=
params
.
trainer
.
summary_interval
if
(
save_summary
)
else
None
,
(
save_summary
)
else
None
,
train_actions
=
actions
.
get_train_actions
(
params
,
trainer
,
model_dir
),
train_actions
=
actions
.
get_train_actions
(
params
,
trainer
,
model_dir
,
checkpoint_manager
=
checkpoint_manager
),
eval_actions
=
actions
.
get_eval_actions
(
params
,
trainer
,
model_dir
))
eval_actions
=
actions
.
get_eval_actions
(
params
,
trainer
,
model_dir
))
logging
.
info
(
'Starts to execute mode: %s'
,
mode
)
logging
.
info
(
'Starts to execute mode: %s'
,
mode
)
...
...
official/core/train_lib_test.py
View file @
b92025a9
...
@@ -19,6 +19,7 @@ import os
...
@@ -19,6 +19,7 @@ import os
from
absl
import
flags
from
absl
import
flags
from
absl.testing
import
flagsaver
from
absl.testing
import
flagsaver
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.distribute
import
combinations
from
tensorflow.python.distribute
import
combinations
...
@@ -30,6 +31,7 @@ from official.common import registry_imports
...
@@ -30,6 +31,7 @@ from official.common import registry_imports
from
official.core
import
task_factory
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.core
import
train_utils
from
official.utils.testing
import
mock_task
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
@@ -114,7 +116,93 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -114,7 +116,93 @@ class TrainTest(tf.test.TestCase, parameterized.TestCase):
params
=
params
,
params
=
params
,
model_dir
=
model_dir
,
model_dir
=
model_dir
,
run_post_eval
=
run_post_eval
)
run_post_eval
=
run_post_eval
)
print
(
logs
)
@
combinations
.
generate
(
combinations
.
combine
(
distribution_strategy
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
flag_mode
=
[
'train'
,
'train_and_eval'
],
))
def
test_recovery_nan_error
(
self
,
distribution_strategy
,
flag_mode
):
model_dir
=
self
.
get_temp_dir
()
flags_dict
=
dict
(
experiment
=
'mock'
,
mode
=
flag_mode
,
model_dir
=
model_dir
,
params_override
=
json
.
dumps
(
self
.
_test_config
))
with
flagsaver
.
flagsaver
(
**
flags_dict
):
params
=
train_utils
.
parse_configuration
(
flags
.
FLAGS
)
train_utils
.
serialize_config
(
params
,
model_dir
)
with
distribution_strategy
.
scope
():
# task = task_factory.get_task(params.task, logging_dir=model_dir)
task
=
mock_task
.
MockTask
(
params
.
task
,
logging_dir
=
model_dir
)
# Set the loss to NaN to trigger RunTimeError.
def
build_losses
(
labels
,
model_outputs
,
aux_losses
=
None
):
del
labels
,
model_outputs
return
tf
.
constant
([
np
.
nan
],
tf
.
float32
)
+
aux_losses
task
.
build_losses
=
build_losses
with
self
.
assertRaises
(
RuntimeError
):
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
flag_mode
,
params
=
params
,
model_dir
=
model_dir
)
@
combinations
.
generate
(
combinations
.
combine
(
distribution_strategy
=
[
strategy_combinations
.
default_strategy
,
strategy_combinations
.
cloud_tpu_strategy
,
strategy_combinations
.
one_device_strategy_gpu
,
],
flag_mode
=
[
'train'
],
))
def
test_recovery
(
self
,
distribution_strategy
,
flag_mode
):
loss_threshold
=
1.0
model_dir
=
self
.
get_temp_dir
()
flags_dict
=
dict
(
experiment
=
'mock'
,
mode
=
flag_mode
,
model_dir
=
model_dir
,
params_override
=
json
.
dumps
(
self
.
_test_config
))
with
flagsaver
.
flagsaver
(
**
flags_dict
):
params
=
train_utils
.
parse_configuration
(
flags
.
FLAGS
)
params
.
trainer
.
loss_upper_bound
=
loss_threshold
params
.
trainer
.
recovery_max_trials
=
1
train_utils
.
serialize_config
(
params
,
model_dir
)
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
# Saves a checkpoint for reference.
model
=
task
.
build_model
()
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
model
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
self
.
get_temp_dir
(),
max_to_keep
=
2
)
checkpoint_manager
.
save
()
before_weights
=
model
.
get_weights
()
def
build_losses
(
labels
,
model_outputs
,
aux_losses
=
None
):
del
labels
,
model_outputs
return
tf
.
constant
([
loss_threshold
],
tf
.
float32
)
+
aux_losses
task
.
build_losses
=
build_losses
model
,
_
=
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
flag_mode
,
params
=
params
,
model_dir
=
model_dir
)
after_weights
=
model
.
get_weights
()
for
left
,
right
in
zip
(
before_weights
,
after_weights
):
self
.
assertAllEqual
(
left
,
right
)
def
test_parse_configuration
(
self
):
def
test_parse_configuration
(
self
):
model_dir
=
self
.
get_temp_dir
()
model_dir
=
self
.
get_temp_dir
()
...
...
official/modeling/fast_training/experimental/tf2_utils_2x_wide.py
0 → 100644
View file @
b92025a9
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stacking model horizontally."""
from
absl
import
logging
import
numpy
as
np
import
tensorflow
as
tf
def
expand_vector
(
v
:
np
.
ndarray
)
->
np
.
ndarray
:
"""Expands a vector with batch dimensions.
Equivalent to expand_1_axis(v, epsilon=0.0, axis=-1)
Args:
v: A vector with shape [..., a].
Returns:
A vector with shape [..., 2 * a].
"""
return
np
.
repeat
(
v
,
2
,
axis
=-
1
)
def
expand_1_axis
(
w
:
np
.
ndarray
,
epsilon
:
float
,
axis
:
int
)
->
np
.
ndarray
:
"""Expands either the first dimension or the last dimension of w.
If `axis = 0`, the following constraint will be satisfied:
matmul(x, w) ==
matmul(expand_vector(x), expand_1_axis(w, epsilon=0.1, axis=0))
If `axis = -1`, the following constraint will be satisfied if `epsilon = 0.0`:
expand_vector(matmul(x, w)) ==
2 * matmul(x, expand_1_axis(w, epsilon=0.0, axis=-1))
Args:
w: Numpy array of shape [a_0, a_1, ..., a_i-1, a_i].
epsilon: Symmetric Noise added to expanded tensor.
axis: Must be either 0 or -1.
Returns:
Expanded numpy array.
"""
assert
axis
in
(
0
,
-
1
),
(
"Only support expanding the first or the last dimension. "
"Got: {}"
.
format
(
axis
))
rank
=
len
(
w
.
shape
)
d_w
=
np
.
random
.
normal
(
np
.
zeros_like
(
w
),
np
.
fabs
(
w
)
*
epsilon
,
w
.
shape
)
d_w
=
np
.
repeat
(
d_w
,
2
,
axis
=
axis
)
sign_flip
=
np
.
array
([
1
,
-
1
])
for
_
in
range
(
rank
-
1
):
sign_flip
=
np
.
expand_dims
(
sign_flip
,
axis
=-
1
if
axis
==
0
else
0
)
sign_flip
=
np
.
tile
(
sign_flip
,
[
w
.
shape
[
0
]]
+
[
1
]
*
(
rank
-
2
)
+
[
w
.
shape
[
-
1
]])
d_w
*=
sign_flip
w_expand
=
(
np
.
repeat
(
w
,
2
,
axis
=
axis
)
+
d_w
)
/
2
return
w_expand
def
expand_2_axes
(
w
:
np
.
ndarray
,
epsilon
:
float
)
->
np
.
ndarray
:
"""Expands the first dimension and the last dimension of w.
The following constraint will be satisfied:
expand_vector(matmul(x, w)) == matmul(expand_vector(x), expand_2_axes(w))
Args:
w: Numpy array of shape [a_0, a_1, ..., a_i-1, a_i].
epsilon: Symmetric Noise added to expanded tensor.
Returns:
Expanded numpy array.
"""
rank
=
len
(
w
.
shape
)
d_w
=
np
.
random
.
normal
(
np
.
zeros_like
(
w
),
np
.
fabs
(
w
)
*
epsilon
,
w
.
shape
)
d_w
=
np
.
repeat
(
np
.
repeat
(
d_w
,
2
,
axis
=
0
),
2
,
axis
=-
1
)
sign_flip
=
np
.
array
([
1
,
-
1
])
for
_
in
range
(
rank
-
1
):
sign_flip
=
np
.
expand_dims
(
sign_flip
,
axis
=-
1
)
sign_flip
=
np
.
tile
(
sign_flip
,
[
w
.
shape
[
0
]]
+
[
1
]
*
(
rank
-
2
)
+
[
w
.
shape
[
-
1
]
*
2
])
d_w
*=
sign_flip
w_expand
=
(
np
.
repeat
(
np
.
repeat
(
w
,
2
,
axis
=
0
),
2
,
axis
=-
1
)
+
d_w
)
/
2
return
w_expand
def
var_to_var
(
var_from
:
tf
.
Variable
,
var_to
:
tf
.
Variable
,
epsilon
:
float
):
"""Expands a variable to another variable.
Assume the shape of `var_from` is (a, b, ..., y, z), the shape of `var_to`
can be (a, ..., z * 2), (a * 2, ..., z * 2), (a * 2, ..., z)
If the shape of `var_to` is (a, ..., 2 * z):
For any x, tf.matmul(x, var_to) ~= expand_vector(tf.matmul(x, var_from)) / 2
Not that there will be noise added to the left hand side, if epsilon != 0.
If the shape of `var_to` is (2 * a, ..., z):
For any x, tf.matmul(expand_vector(x), var_to) == tf.matmul(x, var_from)
If the shape of `var_to` is (2 * a, ..., 2 * z):
For any x, tf.matmul(expand_vector(x), var_to) ==
expand_vector(tf.matmul(expand_vector(x), var_from))
Args:
var_from: input variable to expand.
var_to: output variable.
epsilon: the noise ratio that will be added, when splitting `var_from`.
"""
shape_from
=
var_from
.
shape
shape_to
=
var_to
.
shape
if
shape_from
==
shape_to
:
var_to
.
assign
(
var_from
)
elif
len
(
shape_from
)
==
1
and
len
(
shape_to
)
==
1
:
var_to
.
assign
(
expand_vector
(
var_from
.
numpy
()))
elif
shape_from
[
0
]
*
2
==
shape_to
[
0
]
and
shape_from
[
-
1
]
==
shape_to
[
-
1
]:
var_to
.
assign
(
expand_1_axis
(
var_from
.
numpy
(),
epsilon
=
epsilon
,
axis
=
0
))
elif
shape_from
[
0
]
==
shape_to
[
0
]
and
shape_from
[
-
1
]
*
2
==
shape_to
[
-
1
]:
var_to
.
assign
(
expand_1_axis
(
var_from
.
numpy
(),
epsilon
=
epsilon
,
axis
=-
1
))
elif
shape_from
[
0
]
*
2
==
shape_to
[
0
]
and
shape_from
[
-
1
]
*
2
==
shape_to
[
-
1
]:
var_to
.
assign
(
expand_2_axes
(
var_from
.
numpy
(),
epsilon
=
epsilon
))
else
:
raise
ValueError
(
"Shape not supported, {}, {}"
.
format
(
shape_from
,
shape_to
))
def
model_to_model_2x_wide
(
model_from
:
tf
.
Module
,
model_to
:
tf
.
Module
,
epsilon
:
float
=
0.1
):
"""Expands a model to a wider version.
Also makes sure that the output of the model is not changed after expanding.
For example:
```
model_narrow = tf.keras.Sequential()
model_narrow.add(tf.keras.Input(shape=(3,)))
model_narrow.add(tf.keras.layers.Dense(4))
model_narrow.add(tf.keras.layers.Dense(1))
model_wide = tf.keras.Sequential()
model_wide.add(tf.keras.Input(shape=(6,)))
model_wide.add(tf.keras.layers.Dense(8))
model_wide.add(tf.keras.layers.Dense(1))
model_to_model_2x_wide(model_narrow, model_wide)
assert model_narrow([[1, 2, 3]]) == model_wide([[1, 1, 2, 2, 3, 3]])
```
We assume that `model_from` and `model_to` has the same architecture and only
widths of them differ.
Args:
model_from: input model to expand.
model_to: output model whose variables will be assigned expanded values
according to `model_from`.
epsilon: the noise ratio that will be added, when splitting `var_from`.
"""
for
w_from
,
w_to
in
zip
(
model_from
.
trainable_variables
,
model_to
.
trainable_variables
):
logging
.
info
(
"expanding %s %s to %s %s"
,
w_from
.
name
,
w_from
.
shape
,
w_to
.
name
,
w_to
.
shape
)
var_to_var
(
w_from
,
w_to
,
epsilon
=
epsilon
)
official/modeling/fast_training/experimental/tf2_utils_2x_wide_test.py
0 → 100644
View file @
b92025a9
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tf2_utils_2x_wide."""
import
numpy
as
np
import
tensorflow
as
tf
from
official.modeling.fast_training.experimental
import
tf2_utils_2x_wide
class
Tf2Utils2XWideTest
(
tf
.
test
.
TestCase
):
def
test_expand_vector
(
self
):
x
=
np
.
array
([
1
,
2
])
self
.
assertAllClose
(
tf2_utils_2x_wide
.
expand_vector
(
x
),
np
.
array
([
1
,
1
,
2
,
2
]))
def
test_expand_matrix
(
self
):
x
=
np
.
array
([[
1
,
2
],
[
3
,
4
]])
x
=
tf2_utils_2x_wide
.
expand_2_axes
(
x
,
epsilon
=
0.1
)
self
.
assertAllClose
(
x
[
0
,
:]
+
x
[
1
,
:],
np
.
array
([
1
,
1
,
2
,
2
]))
self
.
assertAllClose
(
x
[
2
,
:]
+
x
[
3
,
:],
np
.
array
([
3
,
3
,
4
,
4
]))
def
test_expand_matrix_axis_0
(
self
):
x
=
np
.
array
([[
1
,
2
],
[
3
,
4
]])
x
=
tf2_utils_2x_wide
.
expand_1_axis
(
x
,
axis
=
0
,
epsilon
=
0.1
)
self
.
assertAllClose
(
x
[
0
,
:]
+
x
[
1
,
:],
np
.
array
([
1
,
2
]))
self
.
assertAllClose
(
x
[
2
,
:]
+
x
[
3
,
:],
np
.
array
([
3
,
4
]))
def
test_expand_matrix_axis_1
(
self
):
x
=
np
.
array
([[
1
,
2
],
[
3
,
4
]])
x
=
tf2_utils_2x_wide
.
expand_1_axis
(
x
,
axis
=-
1
,
epsilon
=
0.1
)
self
.
assertAllClose
(
x
[:,
0
]
+
x
[:,
1
],
np
.
array
([
1
,
3
]))
self
.
assertAllClose
(
x
[:,
2
]
+
x
[:,
3
],
np
.
array
([
2
,
4
]))
def
test_expand_3d_tensor
(
self
):
x0
=
np
.
array
([
10
,
11
])
x1
=
np
.
array
([
10
,
10
,
11
,
11
])
w0
=
np
.
random
.
rand
(
2
,
2
)
w1
=
tf2_utils_2x_wide
.
expand_2_axes
(
w0
,
epsilon
=
0.1
)
o0
=
np
.
matmul
(
x0
,
w0
)
o1
=
np
.
matmul
(
x1
,
w1
)
self
.
assertAllClose
(
np
.
repeat
(
o0
,
2
,
axis
=-
1
),
o1
)
def
test_expand_3d_tensor_axis_0
(
self
):
x0
=
np
.
array
([
10
,
11
])
x1
=
np
.
array
([
10
,
10
,
11
,
11
])
w0
=
np
.
random
.
rand
(
2
,
2
)
w1
=
tf2_utils_2x_wide
.
expand_1_axis
(
w0
,
axis
=
0
,
epsilon
=
0.1
)
o0
=
np
.
matmul
(
x0
,
w0
)
o1
=
np
.
matmul
(
x1
,
w1
)
self
.
assertAllClose
(
o0
,
o1
)
def
test_expand_3d_tensor_axis_2
(
self
):
x
=
np
.
array
([
10
,
11
])
w0
=
np
.
random
.
rand
(
2
,
2
)
w1
=
tf2_utils_2x_wide
.
expand_1_axis
(
w0
,
axis
=-
1
,
epsilon
=
0.1
)
o0
=
np
.
matmul
(
x
,
w0
)
o1
=
np
.
matmul
(
x
,
w1
)
self
.
assertAllClose
(
o0
,
np
.
sum
(
o1
.
reshape
(
2
,
2
),
axis
=-
1
))
def
test_end_to_end
(
self
):
"""Covers expand_vector, expand_2_axes, and expand_1_axis."""
model_narrow
=
tf
.
keras
.
Sequential
()
model_narrow
.
add
(
tf
.
keras
.
Input
(
shape
=
(
3
,)))
model_narrow
.
add
(
tf
.
keras
.
layers
.
Dense
(
4
))
model_narrow
.
add
(
tf
.
keras
.
layers
.
Dense
(
4
))
model_narrow
.
add
(
tf
.
keras
.
layers
.
Dense
(
1
))
model_wide
=
tf
.
keras
.
Sequential
()
model_wide
.
add
(
tf
.
keras
.
Input
(
shape
=
(
6
,)))
model_wide
.
add
(
tf
.
keras
.
layers
.
Dense
(
8
))
model_wide
.
add
(
tf
.
keras
.
layers
.
Dense
(
8
))
model_wide
.
add
(
tf
.
keras
.
layers
.
Dense
(
1
))
x0
=
np
.
array
([[
1
,
2
,
3
]])
x1
=
np
.
array
([[
1
,
1
,
2
,
2
,
3
,
3
]])
# Call model once to build variables first.
_
,
_
=
model_narrow
(
x0
),
model_wide
(
x1
)
tf2_utils_2x_wide
.
model_to_model_2x_wide
(
model_narrow
,
model_wide
,
epsilon
=
0.2
)
self
.
assertAllClose
(
model_narrow
(
x0
),
model_wide
(
x1
),
rtol
=
1e-05
,
atol
=
1e-05
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/modeling/progressive/policies.py
→
official/modeling/
fast_training/
progressive/policies.py
View file @
b92025a9
...
@@ -19,13 +19,20 @@ abstract methods to handle each training stage.
...
@@ -19,13 +19,20 @@ abstract methods to handle each training stage.
"""
"""
import
abc
import
abc
import
dataclasses
from
typing
import
Any
,
Mapping
from
typing
import
Any
,
Mapping
from
absl
import
logging
from
absl
import
logging
import
dataclasses
import
six
import
six
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.eager
import
monitoring
from
official.modeling.fast_training.progressive
import
utils
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
base_config
from
official.modeling.progressive
import
utils
_progressive_policy_creation_counter
=
monitoring
.
Counter
(
'/tensorflow/training/fast_training/progressive_policy_creation'
,
'Counter for the number of ProgressivePolicy creations.'
)
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -69,6 +76,8 @@ class ProgressivePolicy:
...
@@ -69,6 +76,8 @@ class ProgressivePolicy:
optimizer
=
self
.
get_optimizer
(
stage_id
),
optimizer
=
self
.
get_optimizer
(
stage_id
),
model
=
self
.
get_model
(
stage_id
,
old_model
=
None
))
model
=
self
.
get_model
(
stage_id
,
old_model
=
None
))
_progressive_policy_creation_counter
.
get_cell
().
increase_by
(
1
)
def
compute_stage_id
(
self
,
global_step
:
int
)
->
int
:
def
compute_stage_id
(
self
,
global_step
:
int
)
->
int
:
for
stage_id
in
range
(
self
.
num_stages
()):
for
stage_id
in
range
(
self
.
num_stages
()):
global_step
-=
self
.
num_steps
(
stage_id
)
global_step
-=
self
.
num_steps
(
stage_id
)
...
...
official/modeling/progressive/train.py
→
official/modeling/
fast_training/
progressive/train.py
View file @
b92025a9
...
@@ -26,7 +26,7 @@ from official.common import flags as tfm_flags
...
@@ -26,7 +26,7 @@ from official.common import flags as tfm_flags
from
official.core
import
task_factory
from
official.core
import
task_factory
from
official.core
import
train_utils
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.modeling
import
performance
from
official.modeling.progressive
import
train_lib
from
official.modeling.
fast_training.
progressive
import
train_lib
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
...
official/modeling/progressive/train_lib.py
→
official/modeling/
fast_training/
progressive/train_lib.py
View file @
b92025a9
...
@@ -29,7 +29,7 @@ import tensorflow as tf
...
@@ -29,7 +29,7 @@ import tensorflow as tf
from
official.core
import
base_task
from
official.core
import
base_task
from
official.core
import
config_definitions
from
official.core
import
config_definitions
from
official.core
import
train_lib
as
base_train_lib
from
official.core
import
train_lib
as
base_train_lib
from
official.modeling.progressive
import
trainer
as
prog_trainer_lib
from
official.modeling.
fast_training.
progressive
import
trainer
as
prog_trainer_lib
def
run_experiment
(
distribution_strategy
:
tf
.
distribute
.
Strategy
,
def
run_experiment
(
distribution_strategy
:
tf
.
distribute
.
Strategy
,
...
...
official/modeling/progressive/train_lib_test.py
→
official/modeling/
fast_training/
progressive/train_lib_test.py
View file @
b92025a9
...
@@ -31,9 +31,9 @@ from official.core import config_definitions as cfg
...
@@ -31,9 +31,9 @@ from official.core import config_definitions as cfg
from
official.core
import
task_factory
from
official.core
import
task_factory
from
official.modeling
import
optimization
from
official.modeling
import
optimization
from
official.modeling.hyperparams
import
params_dict
from
official.modeling.hyperparams
import
params_dict
from
official.modeling.progressive
import
policies
from
official.modeling.
fast_training.
progressive
import
policies
from
official.modeling.progressive
import
train_lib
from
official.modeling.
fast_training.
progressive
import
train_lib
from
official.modeling.progressive
import
trainer
as
prog_trainer_lib
from
official.modeling.
fast_training.
progressive
import
trainer
as
prog_trainer_lib
from
official.utils.testing
import
mock_task
from
official.utils.testing
import
mock_task
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
...
official/modeling/progressive/trainer.py
→
official/modeling/
fast_training/
progressive/trainer.py
View file @
b92025a9
...
@@ -18,21 +18,21 @@ The trainer implements the Orbit `StandardTrainable` and
...
@@ -18,21 +18,21 @@ The trainer implements the Orbit `StandardTrainable` and
`StandardEvaluable` interfaces. Trainers inside this project should be
`StandardEvaluable` interfaces. Trainers inside this project should be
interchangable and independent on model architectures and tasks.
interchangable and independent on model architectures and tasks.
"""
"""
import
dataclasses
import
os
import
os
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
# Import libraries
# Import libraries
from
absl
import
logging
from
absl
import
logging
import
dataclasses
import
gin
import
gin
import
orbit
import
orbit
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
base_task
from
official.core
import
base_trainer
as
trainer_lib
from
official.core
import
base_trainer
as
trainer_lib
from
official.core
import
config_definitions
from
official.core
import
config_definitions
from
official.modeling.progressive
import
policies
from
official.modeling.
fast_training.
progressive
import
policies
from
official.modeling.progressive
import
utils
from
official.modeling.
fast_training.
progressive
import
utils
ExperimentConfig
=
config_definitions
.
ExperimentConfig
ExperimentConfig
=
config_definitions
.
ExperimentConfig
...
...
official/modeling/progressive/trainer_test.py
→
official/modeling/
fast_training/
progressive/trainer_test.py
View file @
b92025a9
...
@@ -24,8 +24,8 @@ from tensorflow.python.distribute import combinations
...
@@ -24,8 +24,8 @@ from tensorflow.python.distribute import combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.core
import
config_definitions
as
cfg
from
official.core
import
config_definitions
as
cfg
from
official.modeling
import
optimization
from
official.modeling
import
optimization
from
official.modeling.progressive
import
policies
from
official.modeling.
fast_training.
progressive
import
policies
from
official.modeling.progressive
import
trainer
as
trainer_lib
from
official.modeling.
fast_training.
progressive
import
trainer
as
trainer_lib
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
bert
from
official.utils.testing
import
mock_task
from
official.utils.testing
import
mock_task
...
...
official/modeling/progressive/utils.py
→
official/modeling/
fast_training/
progressive/utils.py
View file @
b92025a9
File moved
official/modeling/multitask/base_model.py
View file @
b92025a9
...
@@ -12,21 +12,6 @@
...
@@ -12,21 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Abstraction of multi-task model."""
"""Abstraction of multi-task model."""
from
typing
import
Text
,
Dict
from
typing
import
Text
,
Dict
...
...
official/modeling/multitask/base_trainer.py
View file @
b92025a9
...
@@ -12,21 +12,6 @@
...
@@ -12,21 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Multitask base trainer implementation.
"""Multitask base trainer implementation.
The trainer derives from the Orbit `StandardTrainer` class.
The trainer derives from the Orbit `StandardTrainer` class.
...
...
official/modeling/multitask/evaluator.py
View file @
b92025a9
...
@@ -54,8 +54,15 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
...
@@ -54,8 +54,15 @@ class MultiTaskEvaluator(orbit.AbstractEvaluator):
self
.
_model
=
model
self
.
_model
=
model
self
.
_global_step
=
global_step
or
orbit
.
utils
.
create_global_step
()
self
.
_global_step
=
global_step
or
orbit
.
utils
.
create_global_step
()
self
.
_checkpoint_exporter
=
checkpoint_exporter
self
.
_checkpoint_exporter
=
checkpoint_exporter
if
hasattr
(
self
.
model
,
"checkpoint_items"
):
checkpoint_items
=
self
.
model
.
checkpoint_items
else
:
checkpoint_items
=
{}
self
.
_checkpoint
=
tf
.
train
.
Checkpoint
(
self
.
_checkpoint
=
tf
.
train
.
Checkpoint
(
global_step
=
self
.
global_step
,
model
=
self
.
model
)
model
=
self
.
model
,
global_step
=
self
.
global_step
,
**
checkpoint_items
)
self
.
_validation_losses
=
None
self
.
_validation_losses
=
None
self
.
_validation_metrics
=
None
self
.
_validation_metrics
=
None
...
...
official/modeling/multitask/task_sampler.py
View file @
b92025a9
...
@@ -78,7 +78,10 @@ class ProportionalTaskSampler(TaskSampler):
...
@@ -78,7 +78,10 @@ class ProportionalTaskSampler(TaskSampler):
class
AnnealingTaskSampler
(
TaskSampler
):
class
AnnealingTaskSampler
(
TaskSampler
):
"""Sample tasks according to task weights as well as training progress."""
"""Sample tasks according to task weights as well as training progress.
See http://proceedings.mlr.press/v97/stickland19a/stickland19a.pdf
"""
def
__init__
(
self
,
def
__init__
(
self
,
task_weights
:
Dict
[
Text
,
Union
[
float
,
int
]],
task_weights
:
Dict
[
Text
,
Union
[
float
,
int
]],
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment