Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
992a864b
Commit
992a864b
authored
Aug 01, 2021
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Aug 01, 2021
Browse files
Internal change
PiperOrigin-RevId: 388106758
parent
e0ad9ff2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
1 deletion
+12
-1
official/modeling/multitask/interleaving_trainer.py
official/modeling/multitask/interleaving_trainer.py
+11
-1
official/modeling/multitask/interleaving_trainer_test.py
official/modeling/multitask/interleaving_trainer_test.py
+1
-0
No files found.
official/modeling/multitask/interleaving_trainer.py
View file @
992a864b
...
@@ -34,7 +34,7 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
...
@@ -34,7 +34,7 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
optimizer
:
tf
.
optimizers
.
Optimizer
,
optimizer
:
tf
.
optimizers
.
Optimizer
,
task_sampler
:
sampler
.
TaskSampler
,
task_sampler
:
sampler
.
TaskSampler
,
trainer_options
=
None
):
trainer_options
=
None
):
super
(
MultiTaskInterleavingTrainer
,
self
).
__init__
(
super
().
__init__
(
multi_task
=
multi_task
,
multi_task
=
multi_task
,
multi_task_model
=
multi_task_model
,
multi_task_model
=
multi_task_model
,
optimizer
=
optimizer
,
optimizer
=
optimizer
,
...
@@ -90,3 +90,13 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
...
@@ -90,3 +90,13 @@ class MultiTaskInterleavingTrainer(base_trainer.MultiTaskBaseTrainer):
self
.
_task_train_step_map
[
name
],
args
=
(
next
(
iterator_map
[
name
]),))
self
.
_task_train_step_map
[
name
],
args
=
(
next
(
iterator_map
[
name
]),))
self
.
global_step
.
assign_add
(
1
)
self
.
global_step
.
assign_add
(
1
)
self
.
task_step_counter
(
name
).
assign_add
(
1
)
self
.
task_step_counter
(
name
).
assign_add
(
1
)
def
train_loop_end
(
self
):
"""Record loss and metric values per task."""
result
=
super
().
train_loop_end
()
# Interleaving training does not have a good semantic for `total_loss`. In
# fact, it is always zero. To avoid confusion, we filter the `total_loss`
# from the result logs.
if
'total_loss'
in
result
:
result
.
pop
(
'total_loss'
)
return
result
official/modeling/multitask/interleaving_trainer_test.py
View file @
992a864b
...
@@ -60,6 +60,7 @@ class InterleavingTrainerTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -60,6 +60,7 @@ class InterleavingTrainerTest(tf.test.TestCase, parameterized.TestCase):
results
[
"bar"
].
keys
())
results
[
"bar"
].
keys
())
self
.
assertContainsSubset
([
"training_loss"
,
"foo_acc"
],
self
.
assertContainsSubset
([
"training_loss"
,
"foo_acc"
],
results
[
"foo"
].
keys
())
results
[
"foo"
].
keys
())
self
.
assertNotIn
(
"total_loss"
,
results
)
@
combinations
.
generate
(
all_strategy_combinations
())
@
combinations
.
generate
(
all_strategy_combinations
())
def
test_trainer_with_configs
(
self
,
distribution
):
def
test_trainer_with_configs
(
self
,
distribution
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment