Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0edeca54
Commit
0edeca54
authored
Aug 05, 2020
by
Dan Holtmann-Rice
Committed by
A. Unique TensorFlower
Aug 05, 2020
Browse files
Internal change
PiperOrigin-RevId: 325088513
parent
29d45e88
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
86 additions
and
75 deletions
+86
-75
official/vision/image_classification/resnet/resnet_runnable.py
...ial/vision/image_classification/resnet/resnet_runnable.py
+11
-5
orbit/controller_test.py
orbit/controller_test.py
+4
-1
orbit/standard_runner.py
orbit/standard_runner.py
+34
-40
orbit/standard_runner_test.py
orbit/standard_runner_test.py
+37
-29
No files found.
official/vision/image_classification/resnet/resnet_runnable.py
View file @
0edeca54
...
...
@@ -107,9 +107,12 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator):
.
datasets_num_private_threads
,
dtype
=
self
.
dtype
,
drop_remainder
=
True
)
orbit
.
StandardTrainer
.
__init__
(
self
,
train_dataset
,
flags_obj
.
use_tf_while_loop
,
flags_obj
.
use_tf_function
)
orbit
.
StandardTrainer
.
__init__
(
self
,
train_dataset
,
options
=
orbit
.
StandardTrainerOptions
(
use_tf_while_loop
=
flags_obj
.
use_tf_while_loop
,
use_tf_function
=
flags_obj
.
use_tf_function
))
if
not
flags_obj
.
skip_eval
:
eval_dataset
=
orbit
.
utils
.
make_distributed_dataset
(
self
.
strategy
,
...
...
@@ -119,8 +122,11 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator):
batch_size
=
self
.
batch_size
,
parse_record_fn
=
imagenet_preprocessing
.
parse_record
,
dtype
=
self
.
dtype
)
orbit
.
StandardEvaluator
.
__init__
(
self
,
eval_dataset
,
flags_obj
.
use_tf_function
)
orbit
.
StandardEvaluator
.
__init__
(
self
,
eval_dataset
,
options
=
orbit
.
StandardEvaluatorOptions
(
use_tf_function
=
flags_obj
.
use_tf_function
))
def
train_loop_begin
(
self
):
"""See base class."""
...
...
orbit/controller_test.py
View file @
0edeca54
...
...
@@ -221,7 +221,10 @@ class TestTrainerWithSummaries(standard_runner.StandardTrainer):
self
.
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
)
)
standard_runner
.
StandardTrainer
.
__init__
(
self
,
train_dataset
,
use_tpu_summary_optimization
=
True
)
self
,
train_dataset
,
options
=
standard_runner
.
StandardTrainerOptions
(
use_tpu_summary_optimization
=
True
))
def
build_train_dataset
(
self
):
return
self
.
strategy
.
experimental_distribute_datasets_from_function
(
...
...
orbit/standard_runner.py
View file @
0edeca54
...
...
@@ -23,20 +23,22 @@ import tensorflow as tf
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
TrainerO
verride
s
:
"""Advanced o
verride
s for
O
rbit
t
rainer
s
.
class
Standard
TrainerO
ption
s
:
"""Advanced o
ption
s for
`o
rbit
.StandardT
rainer
`
.
Attributes:
use_tf_while_loop: A boolean indicates whether to wrap the train step with
a `tf.while_loop`.
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, training will run on pure eager mode.
use_tpu_summary_optimization: A boolean indicates whether to enable the
performance optimization for summaries in TPUs. In TPUs, writing
summaries with outside compilation inside train step is slow. If True,
it creates two `tf.function` with two XLA programs: one with summaries
and one without, and run the program with summaries (slow one) only if
necessary.
use_tf_while_loop: A boolean indicating whether to run the training loop
using a `tf.while_loop`. If `True`, `use_tf_function` must also be `True`.
use_tf_function: A boolean indicating whether to apply `tf.function` to the
training loop. This will only affect the body of the loop (involving
`train_step`); `train_loop_begin` and `train_loop_end` will always be run
in eager mode.
use_tpu_summary_optimization: A boolean indicating whether to enable a
performance optimization for summaries in TPUs. Writing summaries
conditionally with outside compilation on TPUs can be extremely slow. If
`True`, this optimization creates two `tf.function`s with two XLA programs
(one with summary calls, and one without). The program with summaries runs
only for one step when summaries should be recorded.
"""
use_tf_while_loop
:
bool
=
True
use_tf_function
:
bool
=
True
...
...
@@ -46,39 +48,29 @@ class TrainerOverrides:
class
StandardTrainer
(
runner
.
AbstractTrainer
,
metaclass
=
abc
.
ABCMeta
):
"""Implements the standard functionality of AbstractTrainer APIs."""
def
__init__
(
self
,
train_dataset
,
use_tf_while_loop
=
True
,
use_tf_function
=
True
,
use_tpu_summary_optimization
=
False
):
def
__init__
(
self
,
train_dataset
,
options
:
StandardTrainerOptions
=
None
):
"""Construct a `StandardTrainer` object.
Args:
train_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset.
use_tf_while_loop: A boolean indicates whether to wrap the train step with
a `tf.while_loop`.
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, training will run on pure eager mode.
use_tpu_summary_optimization: A boolean indicates whether to enable the
performance optimization for summaries in TPUs. In TPUs, writing
summaries with outside compilation inside train step is slow. If True,
it creates two `tf.function` with two XLA programs: one with summaries
and one without, and run the program with summaries (slow one) only if
necessary.
options: An `orbit.StandardTrainerOptions` instance.
"""
if
use_tf_while_loop
and
not
use_tf_function
:
options
=
options
or
StandardTrainerOptions
()
if
options
.
use_tf_while_loop
and
not
options
.
use_tf_function
:
raise
ValueError
(
"`use_tf_while_loop=True` and `use_tf_function=False` "
"is not supported"
)
if
use_tpu_summary_optimization
and
not
use_tf_while_loop
:
if
options
.
use_tpu_summary_optimization
and
not
options
.
use_tf_while_loop
:
raise
ValueError
(
"`use_tpu_summary_optimization=True` and "
"`use_tf_while_loop=False` is not supported"
)
self
.
_use_tf_while_loop
=
use_tf_while_loop
self
.
_use_tf_function
=
use_tf_function
self
.
_use_tf_while_loop
=
options
.
use_tf_while_loop
self
.
_use_tf_function
=
options
.
use_tf_function
self
.
_use_tpu_summary_optimization
=
options
.
use_tpu_summary_optimization
self
.
_train_dataset
=
train_dataset
self
.
_train_iter
=
None
self
.
_train_loop_fn
=
None
self
.
_use_tpu_summary_optimization
=
use_tpu_summary_optimization
def
train
(
self
,
num_steps
:
Optional
[
tf
.
Tensor
])
->
Optional
[
Dict
[
Text
,
tf
.
Tensor
]]:
...
...
@@ -168,12 +160,14 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
EvaluatorO
verride
s
:
"""Advanced o
verrides for Orbit e
valuator
s
.
class
Standard
EvaluatorO
ption
s
:
"""Advanced o
ptions for the `orbit.StandardE
valuator
`
.
Attributes:
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, training will run on pure eager mode.
use_tf_function: A boolean indicating whether to apply `tf.function` to the
training loop. This will only affect the body of the loop (involving
`train_step`); `train_loop_begin` and `train_loop_end` will always be run
in eager mode.
"""
use_tf_function
:
bool
=
True
...
...
@@ -181,16 +175,16 @@ class EvaluatorOverrides:
class
StandardEvaluator
(
runner
.
AbstractEvaluator
,
metaclass
=
abc
.
ABCMeta
):
"""Implements the standard functionality of AbstractEvaluator APIs."""
def
__init__
(
self
,
eval_dataset
,
use_tf_function
=
Tru
e
):
def
__init__
(
self
,
eval_dataset
,
options
:
StandardEvaluatorOptions
=
Non
e
):
"""Construct a `StandardEvaluator` object.
Args:
eval_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset.
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, evaluation will run on pure eager mode.
options: An `orbit.StandardEvaluatorOptions` instance.
"""
self
.
_eval_use_tf_function
=
use_tf_function
options
=
options
or
StandardEvaluatorOptions
()
self
.
_eval_use_tf_function
=
options
.
use_tf_function
self
.
_eval_dataset
=
eval_dataset
self
.
_eval_loop_fn
=
None
...
...
orbit/standard_runner_test.py
View file @
0edeca54
...
...
@@ -15,6 +15,7 @@
"""Tests for orbit.standard_runner."""
from
orbit
import
standard_runner
from
orbit
import
utils
import
tensorflow
as
tf
...
...
@@ -32,46 +33,49 @@ def dataset_fn(input_context=None):
return
dataset
class
TestRunner
(
standard_runner
.
StandardTrainer
,
standard_runner
.
StandardEvaluator
):
"""Implements the training and evaluation APIs for tests."""
class
TestTrainer
(
standard_runner
.
StandardTrainer
):
"""A StandardTrainer subclass for tests."""
def
__init__
(
self
):
def
__init__
(
self
,
options
=
None
):
self
.
strategy
=
tf
.
distribute
.
get_strategy
()
self
.
global_step
=
tf
.
Variable
(
0
,
trainable
=
False
,
dtype
=
tf
.
int64
,
name
=
'global_step'
,
aggregation
=
tf
.
VariableAggregation
.
ONLY_FIRST_REPLICA
)
standard_runner
.
StandardTrainer
.
__init__
(
self
,
train_dataset
=
None
)
standard_runner
.
StandardEvaluator
.
__init__
(
self
,
eval_dataset
=
None
)
self
.
global_step
=
utils
.
create_global_step
()
distribute
=
self
.
strategy
.
experimental_distribute_datasets_from_function
dataset
=
distribute
(
dataset_fn
)
super
().
__init__
(
train_dataset
=
dataset
,
options
=
options
)
def
train_loop_begin
(
self
):
self
.
train_dataset
=
(
self
.
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
)
)
self
.
global_step
.
assign
(
0
)
def
train_step
(
self
,
iterator
):
def
_
replica
ted
_step
(
_
):
def
replica_step
(
_
):
self
.
global_step
.
assign_add
(
1
)
self
.
strategy
.
run
(
_
replica
ted
_step
,
args
=
(
next
(
iterator
),))
self
.
strategy
.
run
(
replica_step
,
args
=
(
next
(
iterator
),))
def
train_loop_end
(
self
):
return
self
.
global_step
.
numpy
()
class
TestEvaluator
(
standard_runner
.
StandardEvaluator
):
"""A StandardEvaluator subclass for tests."""
def
__init__
(
self
,
options
=
None
):
self
.
strategy
=
tf
.
distribute
.
get_strategy
()
self
.
global_step
=
utils
.
create_global_step
()
distribute
=
self
.
strategy
.
experimental_distribute_datasets_from_function
dataset
=
distribute
(
dataset_fn
)
super
().
__init__
(
eval_dataset
=
dataset
,
options
=
options
)
def
eval_begin
(
self
):
self
.
eval_dataset
=
self
.
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
)
self
.
global_step
.
assign
(
0
)
def
eval_step
(
self
,
iterator
):
def
_
replica
ted
_step
(
_
):
def
replica_step
(
_
):
self
.
global_step
.
assign_add
(
1
)
self
.
strategy
.
run
(
_
replica
ted
_step
,
args
=
(
next
(
iterator
),))
self
.
strategy
.
run
(
replica_step
,
args
=
(
next
(
iterator
),))
def
eval_end
(
self
):
return
self
.
global_step
.
numpy
()
...
...
@@ -79,15 +83,19 @@ class TestRunner(standard_runner.StandardTrainer,
class
StandardRunnerTest
(
tf
.
test
.
TestCase
):
def
test_train
(
self
):
test_runner
=
TestRunner
()
self
.
assertEqual
(
test_runner
.
train
(
tf
.
convert_to_tensor
(
10
,
dtype
=
tf
.
int32
)),
10
)
def
test_default_trainer
(
self
):
trainer
=
TestTrainer
()
self
.
assertEqual
(
trainer
.
train
(
tf
.
constant
(
10
)),
10
)
def
test_trainer_with_tpu_summary_optimization
(
self
):
options
=
standard_runner
.
StandardTrainerOptions
(
use_tpu_summary_optimization
=
True
)
trainer
=
TestTrainer
(
options
)
self
.
assertEqual
(
trainer
.
train
(
tf
.
constant
(
10
)),
10
)
def
test_eval
(
self
):
test_runner
=
TestRunner
()
self
.
assertEqual
(
test_runner
.
evaluate
(
tf
.
convert_to_tensor
(
10
,
dtype
=
tf
.
int32
)),
10
)
def
test_default_evaluator
(
self
):
evaluator
=
TestEvaluator
()
self
.
assertEqual
(
evaluator
.
evaluate
(
tf
.
constant
(
10
)),
10
)
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment