Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
957f32ad
Commit
957f32ad
authored
Dec 14, 2020
by
Abdullah Rashwan
Committed by
A. Unique TensorFlower
Dec 14, 2020
Browse files
Internal change
PiperOrigin-RevId: 347439073
parent
00498609
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
34 additions
and
30 deletions
+34
-30
official/core/base_task.py
official/core/base_task.py
+31
-0
official/core/base_trainer.py
official/core/base_trainer.py
+0
-27
official/core/base_trainer_test.py
official/core/base_trainer_test.py
+2
-2
official/core/train_utils.py
official/core/train_utils.py
+1
-1
No files found.
official/core/base_task.py
View file @
957f32ad
...
...
@@ -20,6 +20,13 @@ from typing import Optional
from
absl
import
logging
import
tensorflow
as
tf
from
official.core
import
config_definitions
from
official.modeling
import
optimization
from
official.modeling
import
performance
TrainerConfig
=
config_definitions
.
TrainerConfig
RuntimeConfig
=
config_definitions
.
RuntimeConfig
class
Task
(
tf
.
Module
,
metaclass
=
abc
.
ABCMeta
):
"""A single-replica view of training procedure.
...
...
@@ -54,6 +61,30 @@ class Task(tf.Module, metaclass=abc.ABCMeta):
def
logging_dir
(
self
)
->
str
:
return
self
.
_logging_dir
@
classmethod
def
create_optimizer
(
cls
,
trainer_config
:
TrainerConfig
,
runtime_config
:
Optional
[
RuntimeConfig
]
=
None
):
"""Creates an TF optimizer from configurations.
Args:
trainer_config: the parameters of the trainer.
runtime_config: the parameters of the runtime.
Returns:
A tf.optimizers.Optimizer object.
"""
opt_factory
=
optimization
.
OptimizerFactory
(
trainer_config
.
optimizer_config
)
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
# Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations.
if
runtime_config
and
runtime_config
.
loss_scale
:
optimizer
=
performance
.
configure_optimizer
(
optimizer
,
use_float16
=
runtime_config
.
mixed_precision_dtype
==
"float16"
,
loss_scale
=
runtime_config
.
loss_scale
)
return
optimizer
def
initialize
(
self
,
model
:
tf
.
keras
.
Model
):
"""[Optional] A callback function used as CheckpointManager's init_fn.
...
...
official/core/base_trainer.py
View file @
957f32ad
...
...
@@ -19,7 +19,6 @@ The base trainer implements the Orbit `StandardTrainable` and
`StandardEvaluable` interfaces. Trainers inside this project should be
interchangable and independent on model architectures and tasks.
"""
from
typing
import
Optional
from
absl
import
logging
import
gin
...
...
@@ -28,35 +27,9 @@ import tensorflow as tf
from
official.core
import
base_task
from
official.core
import
config_definitions
from
official.modeling
import
optimization
from
official.modeling
import
performance
ExperimentConfig
=
config_definitions
.
ExperimentConfig
TrainerConfig
=
config_definitions
.
TrainerConfig
RuntimeConfig
=
config_definitions
.
RuntimeConfig
def
create_optimizer
(
trainer_config
:
TrainerConfig
,
runtime_config
:
Optional
[
RuntimeConfig
]
=
None
):
"""Creates an TF optimizer from configurations.
Args:
trainer_config: the parameters of the trainer.
runtime_config: the parameters of the runtime.
Returns:
A tf.optimizers.Optimizer object.
"""
opt_factory
=
optimization
.
OptimizerFactory
(
trainer_config
.
optimizer_config
)
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
# Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations.
if
runtime_config
and
runtime_config
.
loss_scale
:
optimizer
=
performance
.
configure_optimizer
(
optimizer
,
use_float16
=
runtime_config
.
mixed_precision_dtype
==
"float16"
,
loss_scale
=
runtime_config
.
loss_scale
)
return
optimizer
class
Recovery
:
...
...
official/core/base_trainer_test.py
View file @
957f32ad
...
...
@@ -61,7 +61,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
config
,
task
,
model
=
task
.
build_model
(),
optimizer
=
t
rainer_lib
.
create_optimizer
(
config
.
trainer
,
config
.
runtime
),
optimizer
=
t
ask
.
create_optimizer
(
config
.
trainer
,
config
.
runtime
),
checkpoint_exporter
=
ckpt_exporter
)
return
trainer
...
...
@@ -180,7 +180,7 @@ class TrainerTest(tf.test.TestCase, parameterized.TestCase):
config
,
task
,
model
=
task
.
build_model
(),
optimizer
=
t
rainer_lib
.
create_optimizer
(
config
.
trainer
,
config
.
runtime
))
optimizer
=
t
ask
.
create_optimizer
(
config
.
trainer
,
config
.
runtime
))
trainer
.
add_recovery
(
config
.
trainer
,
checkpoint_manager
=
checkpoint_manager
)
with
self
.
assertRaises
(
RuntimeError
):
_
=
trainer
.
train
(
tf
.
convert_to_tensor
(
2
,
dtype
=
tf
.
int32
))
...
...
official/core/train_utils.py
View file @
957f32ad
...
...
@@ -134,7 +134,7 @@ def create_trainer(params: config_definitions.ExperimentConfig,
"""Create trainer."""
logging
.
info
(
'Running default trainer.'
)
model
=
task
.
build_model
()
optimizer
=
base_trainer
.
create_optimizer
(
params
.
trainer
,
params
.
runtime
)
optimizer
=
task
.
create_optimizer
(
params
.
trainer
,
params
.
runtime
)
return
trainer_cls
(
params
,
task
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment