Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
356c98bd
Commit
356c98bd
authored
Aug 07, 2020
by
Kaushik Shivakumar
Browse files
Merge remote-tracking branch 'upstream/master' into detr-push-3
parents
d31aba8a
b9785623
Changes
1000
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
288 additions
and
242 deletions
+288
-242
official/vision/detection/evaluation/coco_utils.py
official/vision/detection/evaluation/coco_utils.py
+1
-1
official/vision/detection/modeling/architecture/factory.py
official/vision/detection/modeling/architecture/factory.py
+6
-2
official/vision/detection/modeling/architecture/keras_utils.py
...ial/vision/detection/modeling/architecture/keras_utils.py
+1
-0
official/vision/detection/modeling/losses.py
official/vision/detection/modeling/losses.py
+1
-2
official/vision/detection/modeling/retinanet_model.py
official/vision/detection/modeling/retinanet_model.py
+1
-4
official/vision/image_classification/configs/examples/resnet/imagenet/gpu.yaml
..._classification/configs/examples/resnet/imagenet/gpu.yaml
+0
-2
official/vision/image_classification/configs/examples/resnet/imagenet/tpu.yaml
..._classification/configs/examples/resnet/imagenet/tpu.yaml
+0
-2
official/vision/image_classification/learning_rate.py
official/vision/image_classification/learning_rate.py
+15
-59
official/vision/image_classification/learning_rate_test.py
official/vision/image_classification/learning_rate_test.py
+0
-38
official/vision/image_classification/optimizer_factory.py
official/vision/image_classification/optimizer_factory.py
+14
-17
official/vision/image_classification/optimizer_factory_test.py
...ial/vision/image_classification/optimizer_factory_test.py
+0
-1
official/vision/image_classification/resnet/resnet_config.py
official/vision/image_classification/resnet/resnet_config.py
+9
-14
official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py
...n/image_classification/resnet/resnet_ctl_imagenet_main.py
+1
-0
official/vision/image_classification/resnet/resnet_runnable.py
...ial/vision/image_classification/resnet/resnet_runnable.py
+11
-5
orbit/__init__.py
orbit/__init__.py
+0
-1
orbit/controller.py
orbit/controller.py
+14
-12
orbit/controller_test.py
orbit/controller_test.py
+134
-4
orbit/runner.py
orbit/runner.py
+5
-4
orbit/standard_runner.py
orbit/standard_runner.py
+38
-43
orbit/standard_runner_test.py
orbit/standard_runner_test.py
+37
-31
No files found.
Too many changes to show.
To preserve performance only
1000 of 1000+
files are displayed.
Plain diff
Email patch
official/vision/detection/evaluation/coco_utils.py
View file @
356c98bd
...
...
@@ -237,7 +237,7 @@ def convert_groundtruths_to_coco_dataset(groundtruths, label_map=None):
(
boxes
[
j
,
k
,
3
]
-
boxes
[
j
,
k
,
1
])
*
(
boxes
[
j
,
k
,
2
]
-
boxes
[
j
,
k
,
0
]))
if
'masks'
in
groundtruths
:
mask
=
Image
.
open
(
six
.
String
IO
(
groundtruths
[
'masks'
][
i
][
j
,
k
]))
mask
=
Image
.
open
(
six
.
Bytes
IO
(
groundtruths
[
'masks'
][
i
][
j
,
k
]))
width
,
height
=
mask
.
size
np_mask
=
(
np
.
array
(
mask
.
getdata
()).
reshape
(
height
,
width
).
astype
(
np
.
uint8
))
...
...
official/vision/detection/modeling/architecture/factory.py
View file @
356c98bd
...
...
@@ -77,11 +77,13 @@ def multilevel_features_generator(params):
def
retinanet_head_generator
(
params
):
"""Generator function for RetinaNet head architecture."""
head_params
=
params
.
retinanet_head
anchors_per_location
=
params
.
anchor
.
num_scales
*
len
(
params
.
anchor
.
aspect_ratios
)
return
heads
.
RetinanetHead
(
params
.
architecture
.
min_level
,
params
.
architecture
.
max_level
,
params
.
architecture
.
num_classes
,
head_params
.
anchors_per_location
,
anchors_per_location
,
head_params
.
num_convs
,
head_params
.
num_filters
,
head_params
.
use_separable_conv
,
...
...
@@ -91,10 +93,12 @@ def retinanet_head_generator(params):
def
rpn_head_generator
(
params
):
"""Generator function for RPN head architecture."""
head_params
=
params
.
rpn_head
anchors_per_location
=
params
.
anchor
.
num_scales
*
len
(
params
.
anchor
.
aspect_ratios
)
return
heads
.
RpnHead
(
params
.
architecture
.
min_level
,
params
.
architecture
.
max_level
,
head_params
.
anchors_per_location
,
anchors_per_location
,
head_params
.
num_convs
,
head_params
.
num_filters
,
head_params
.
use_separable_conv
,
...
...
official/vision/detection/modeling/architecture/keras_utils.py
View file @
356c98bd
...
...
@@ -23,6 +23,7 @@ from tensorflow.python.keras import backend
try
:
from
tensorflow.python.keras.engine
import
keras_tensor
# pylint: disable=g-import-not-at-top,unused-import
keras_tensor
.
disable_keras_tensors
()
except
ImportError
:
keras_tensor
=
None
...
...
official/vision/detection/modeling/losses.py
View file @
356c98bd
...
...
@@ -449,7 +449,7 @@ class RetinanetBoxLoss(object):
num_positives: number of positive examples in the minibatch.
Returns:
an integ
a
r tensor representing total box regression loss.
an integ
e
r tensor representing total box regression loss.
"""
# Sums all positives in a batch for normalization and avoids zero
# num_positives_sum, which would lead to inf loss during training
...
...
@@ -457,7 +457,6 @@ class RetinanetBoxLoss(object):
box_losses
=
[]
for
level
in
box_outputs
.
keys
():
# Onehot encoding for classification labels.
box_targets_l
=
labels
[
level
]
box_losses
.
append
(
self
.
box_loss
(
box_outputs
[
level
],
box_targets_l
,
num_positives_sum
))
...
...
official/vision/detection/modeling/retinanet_model.py
View file @
356c98bd
...
...
@@ -59,11 +59,8 @@ class RetinanetModel(base_model.Model):
self
.
_transpose_input
=
params
.
train
.
transpose_input
assert
not
self
.
_transpose_input
,
'Transpose input is not supported.'
# Input layer.
input_shape
=
(
params
.
retinanet_parser
.
output_size
+
[
params
.
retinanet_parser
.
num_channels
])
self
.
_input_layer
=
tf
.
keras
.
layers
.
Input
(
shape
=
input_shape
,
name
=
''
,
shape
=
(
None
,
None
,
params
.
retinanet_parser
.
num_channels
)
,
name
=
''
,
dtype
=
tf
.
bfloat16
if
self
.
_use_bfloat16
else
tf
.
float32
)
def
build_outputs
(
self
,
inputs
,
mode
):
...
...
official/vision/image_classification/configs/examples/resnet/imagenet/gpu.yaml
View file @
356c98bd
...
...
@@ -40,8 +40,6 @@ model:
momentum
:
0.9
decay
:
0.9
epsilon
:
0.001
learning_rate
:
name
:
'
piecewise_constant_with_warmup'
loss
:
label_smoothing
:
0.1
train
:
...
...
official/vision/image_classification/configs/examples/resnet/imagenet/tpu.yaml
View file @
356c98bd
...
...
@@ -43,8 +43,6 @@ model:
epsilon
:
0.001
moving_average_decay
:
0.
lookahead
:
False
learning_rate
:
name
:
'
piecewise_constant_with_warmup'
loss
:
label_smoothing
:
0.1
train
:
...
...
official/vision/image_classification/learning_rate.py
View file @
356c98bd
...
...
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
typing
import
Any
,
List
,
Mapping
from
typing
import
Any
,
Mapping
,
Optional
import
numpy
as
np
import
tensorflow
as
tf
...
...
@@ -32,23 +32,33 @@ class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
def
__init__
(
self
,
lr_schedule
:
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
,
warmup_steps
:
int
):
warmup_steps
:
int
,
warmup_lr
:
Optional
[
float
]
=
None
):
"""Add warmup decay to a learning rate schedule.
Args:
lr_schedule: base learning rate scheduler
warmup_steps: number of warmup steps
warmup_lr: an optional field for the final warmup learning rate. This
should be provided if the base `lr_schedule` does not contain this
field.
"""
super
(
WarmupDecaySchedule
,
self
).
__init__
()
self
.
_lr_schedule
=
lr_schedule
self
.
_warmup_steps
=
warmup_steps
self
.
_warmup_lr
=
warmup_lr
def
__call__
(
self
,
step
:
int
):
lr
=
self
.
_lr_schedule
(
step
)
if
self
.
_warmup_steps
:
if
self
.
_warmup_lr
is
not
None
:
initial_learning_rate
=
tf
.
convert_to_tensor
(
self
.
_lr_schedule
.
initial_learning_rate
,
name
=
"initial_learning_rate"
)
self
.
_warmup_lr
,
name
=
"initial_learning_rate"
)
else
:
initial_learning_rate
=
tf
.
convert_to_tensor
(
self
.
_lr_schedule
.
initial_learning_rate
,
name
=
"initial_learning_rate"
)
dtype
=
initial_learning_rate
.
dtype
global_step_recomp
=
tf
.
cast
(
step
,
dtype
)
warmup_steps
=
tf
.
cast
(
self
.
_warmup_steps
,
dtype
)
...
...
@@ -62,65 +72,11 @@ class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
config
=
self
.
_lr_schedule
.
get_config
()
config
.
update
({
"warmup_steps"
:
self
.
_warmup_steps
,
"warmup_lr"
:
self
.
_warmup_lr
,
})
return
config
# TODO(b/149030439) - refactor this with
# tf.keras.optimizers.schedules.PiecewiseConstantDecay + WarmupDecaySchedule.
class
PiecewiseConstantDecayWithWarmup
(
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
):
"""Piecewise constant decay with warmup schedule."""
def
__init__
(
self
,
batch_size
:
int
,
epoch_size
:
int
,
warmup_epochs
:
int
,
boundaries
:
List
[
int
],
multipliers
:
List
[
float
]):
"""Piecewise constant decay with warmup.
Args:
batch_size: The training batch size used in the experiment.
epoch_size: The size of an epoch, or the number of examples in an epoch.
warmup_epochs: The number of warmup epochs to apply.
boundaries: The list of floats with strictly increasing entries.
multipliers: The list of multipliers/learning rates to use for the
piecewise portion. The length must be 1 less than that of boundaries.
"""
super
(
PiecewiseConstantDecayWithWarmup
,
self
).
__init__
()
if
len
(
boundaries
)
!=
len
(
multipliers
)
-
1
:
raise
ValueError
(
"The length of boundaries must be 1 less than the "
"length of multipliers"
)
base_lr_batch_size
=
256
steps_per_epoch
=
epoch_size
//
batch_size
self
.
_rescaled_lr
=
BASE_LEARNING_RATE
*
batch_size
/
base_lr_batch_size
self
.
_step_boundaries
=
[
float
(
steps_per_epoch
)
*
x
for
x
in
boundaries
]
self
.
_lr_values
=
[
self
.
_rescaled_lr
*
m
for
m
in
multipliers
]
self
.
_warmup_steps
=
warmup_epochs
*
steps_per_epoch
def
__call__
(
self
,
step
:
int
):
"""Compute learning rate at given step."""
def
warmup_lr
():
return
self
.
_rescaled_lr
*
(
step
/
tf
.
cast
(
self
.
_warmup_steps
,
tf
.
float32
))
def
piecewise_lr
():
return
tf
.
compat
.
v1
.
train
.
piecewise_constant
(
tf
.
cast
(
step
,
tf
.
float32
),
self
.
_step_boundaries
,
self
.
_lr_values
)
return
tf
.
cond
(
step
<
self
.
_warmup_steps
,
warmup_lr
,
piecewise_lr
)
def
get_config
(
self
)
->
Mapping
[
str
,
Any
]:
return
{
"rescaled_lr"
:
self
.
_rescaled_lr
,
"step_boundaries"
:
self
.
_step_boundaries
,
"lr_values"
:
self
.
_lr_values
,
"warmup_steps"
:
self
.
_warmup_steps
,
}
class
CosineDecayWithWarmup
(
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
):
"""Class to generate learning rate tensor."""
...
...
official/vision/image_classification/learning_rate_test.py
View file @
356c98bd
...
...
@@ -46,44 +46,6 @@ class LearningRateTests(tf.test.TestCase):
self
.
assertAllClose
(
self
.
evaluate
(
lr
(
step
)),
step
/
warmup_steps
*
initial_lr
)
def
test_piecewise_constant_decay_with_warmup
(
self
):
"""Basic computational test for piecewise constant decay with warmup."""
boundaries
=
[
1
,
2
,
3
]
warmup_epochs
=
boundaries
[
0
]
learning_rate_multipliers
=
[
1.0
,
0.1
,
0.001
]
expected_keys
=
[
'rescaled_lr'
,
'step_boundaries'
,
'lr_values'
,
'warmup_steps'
,
]
expected_lrs
=
[
0.0
,
0.1
,
0.1
]
lr
=
learning_rate
.
PiecewiseConstantDecayWithWarmup
(
batch_size
=
256
,
epoch_size
=
256
,
warmup_epochs
=
warmup_epochs
,
boundaries
=
boundaries
[
1
:],
multipliers
=
learning_rate_multipliers
)
step
=
0
config
=
lr
.
get_config
()
self
.
assertAllInSet
(
list
(
config
.
keys
()),
expected_keys
)
for
boundary
,
expected_lr
in
zip
(
boundaries
,
expected_lrs
):
for
_
in
range
(
step
,
boundary
):
self
.
assertAllClose
(
self
.
evaluate
(
lr
(
step
)),
expected_lr
)
step
+=
1
def
test_piecewise_constant_decay_invalid_boundaries
(
self
):
with
self
.
assertRaisesRegex
(
ValueError
,
'The length of boundaries must be 1 less '
):
learning_rate
.
PiecewiseConstantDecayWithWarmup
(
batch_size
=
256
,
epoch_size
=
256
,
warmup_epochs
=
1
,
boundaries
=
[
1
,
2
],
multipliers
=
[
1
,
2
])
def
test_cosine_decay_with_warmup
(
self
):
"""Basic computational test for cosine decay with warmup."""
expected_lrs
=
[
0.0
,
0.1
,
0.05
,
0.0
]
...
...
official/vision/image_classification/optimizer_factory.py
View file @
356c98bd
...
...
@@ -370,29 +370,26 @@ def build_learning_rate(params: base_configs.LearningRateConfig,
decay_steps
=
decay_steps
,
decay_rate
=
decay_rate
,
staircase
=
params
.
staircase
)
elif
decay_type
==
'piecewise_constant_with_warmup'
:
logging
.
info
(
'Using Piecewise constant decay with warmup. '
'Parameters: batch_size: %d, epoch_size: %d, '
'warmup_epochs: %d, boundaries: %s, multipliers: %s'
,
batch_size
,
params
.
examples_per_epoch
,
params
.
warmup_epochs
,
params
.
boundaries
,
params
.
multipliers
)
lr
=
learning_rate
.
PiecewiseConstantDecayWithWarmup
(
batch_size
=
batch_size
,
epoch_size
=
params
.
examples_per_epoch
,
warmup_epochs
=
params
.
warmup_epochs
,
boundaries
=
params
.
boundaries
,
multipliers
=
params
.
multipliers
)
elif
decay_type
==
'stepwise'
:
steps_per_epoch
=
params
.
examples_per_epoch
//
batch_size
boundaries
=
[
boundary
*
steps_per_epoch
for
boundary
in
params
.
boundaries
]
multipliers
=
[
batch_size
*
multiplier
for
multiplier
in
params
.
multipliers
]
logging
.
info
(
'Using stepwise learning rate. Parameters: '
'boundaries: %s, values: %s'
,
boundaries
,
multipliers
)
lr
=
tf
.
keras
.
optimizers
.
schedules
.
PiecewiseConstantDecay
(
boundaries
=
boundaries
,
values
=
multipliers
)
elif
decay_type
==
'cosine_with_warmup'
:
lr
=
learning_rate
.
CosineDecayWithWarmup
(
batch_size
=
batch_size
,
total_steps
=
train_epochs
*
train_steps
,
warmup_steps
=
warmup_steps
)
if
warmup_steps
>
0
:
if
decay_type
not
in
[
'piecewise_constant_with_warmup'
,
'cosine_with_warmup'
]:
if
decay_type
not
in
[
'cosine_with_warmup'
]:
logging
.
info
(
'Applying %d warmup steps to the learning rate'
,
warmup_steps
)
lr
=
learning_rate
.
WarmupDecaySchedule
(
lr
,
warmup_steps
)
lr
=
learning_rate
.
WarmupDecaySchedule
(
lr
,
warmup_steps
,
warmup_lr
=
base_lr
)
return
lr
official/vision/image_classification/optimizer_factory_test.py
View file @
356c98bd
...
...
@@ -93,7 +93,6 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
@
parameterized
.
named_parameters
(
(
'exponential'
,
'exponential'
),
(
'piecewise_constant_with_warmup'
,
'piecewise_constant_with_warmup'
),
(
'cosine_with_warmup'
,
'cosine_with_warmup'
))
def
test_learning_rate_with_decay_and_warmup
(
self
,
lr_decay_type
):
"""Basic smoke test for syntax."""
...
...
official/vision/image_classification/resnet/resnet_config.py
View file @
356c98bd
...
...
@@ -18,22 +18,12 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
typing
import
Any
,
Mapping
import
dataclasses
from
official.modeling.hyperparams
import
base_config
from
official.vision.image_classification.configs
import
base_configs
_RESNET_LR_SCHEDULE
=
[
# (multiplier, epoch to start) tuples
(
1.0
,
5
),
(
0.1
,
30
),
(
0.01
,
60
),
(
0.001
,
80
)
]
_RESNET_LR_BOUNDARIES
=
list
(
p
[
1
]
for
p
in
_RESNET_LR_SCHEDULE
[
1
:])
_RESNET_LR_MULTIPLIERS
=
list
(
p
[
0
]
for
p
in
_RESNET_LR_SCHEDULE
)
_RESNET_LR_WARMUP_EPOCHS
=
_RESNET_LR_SCHEDULE
[
0
][
1
]
@
dataclasses
.
dataclass
class
ResNetModelConfig
(
base_configs
.
ModelConfig
):
"""Configuration for the ResNet model."""
...
...
@@ -56,8 +46,13 @@ class ResNetModelConfig(base_configs.ModelConfig):
moving_average_decay
=
None
)
learning_rate
:
base_configs
.
LearningRateConfig
=
(
base_configs
.
LearningRateConfig
(
name
=
'piecewise_constant_with_warmup'
,
name
=
'stepwise'
,
initial_lr
=
0.1
,
examples_per_epoch
=
1281167
,
warmup_epochs
=
_RESNET_LR_WARMUP_EPOCHS
,
boundaries
=
_RESNET_LR_BOUNDARIES
,
multipliers
=
_RESNET_LR_MULTIPLIERS
))
boundaries
=
[
30
,
60
,
80
],
warmup_epochs
=
5
,
scale_by_batch_size
=
1.
/
256.
,
multipliers
=
[
0.1
/
256
,
0.01
/
256
,
0.001
/
256
,
0.0001
/
256
]))
official/vision/image_classification/resnet/resnet_ctl_imagenet_main.py
View file @
356c98bd
...
...
@@ -167,6 +167,7 @@ def run(flags_obj):
steps_per_loop
=
steps_per_loop
,
checkpoint_manager
=
checkpoint_manager
,
summary_interval
=
summary_interval
,
summary_dir
=
flags_obj
.
model_dir
,
eval_summary_dir
=
os
.
path
.
join
(
flags_obj
.
model_dir
,
'eval'
))
time_callback
.
on_train_begin
()
...
...
official/vision/image_classification/resnet/resnet_runnable.py
View file @
356c98bd
...
...
@@ -107,9 +107,12 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator):
.
datasets_num_private_threads
,
dtype
=
self
.
dtype
,
drop_remainder
=
True
)
orbit
.
StandardTrainer
.
__init__
(
self
,
train_dataset
,
flags_obj
.
use_tf_while_loop
,
flags_obj
.
use_tf_function
)
orbit
.
StandardTrainer
.
__init__
(
self
,
train_dataset
,
options
=
orbit
.
StandardTrainerOptions
(
use_tf_while_loop
=
flags_obj
.
use_tf_while_loop
,
use_tf_function
=
flags_obj
.
use_tf_function
))
if
not
flags_obj
.
skip_eval
:
eval_dataset
=
orbit
.
utils
.
make_distributed_dataset
(
self
.
strategy
,
...
...
@@ -119,8 +122,11 @@ class ResnetRunnable(orbit.StandardTrainer, orbit.StandardEvaluator):
batch_size
=
self
.
batch_size
,
parse_record_fn
=
imagenet_preprocessing
.
parse_record
,
dtype
=
self
.
dtype
)
orbit
.
StandardEvaluator
.
__init__
(
self
,
eval_dataset
,
flags_obj
.
use_tf_function
)
orbit
.
StandardEvaluator
.
__init__
(
self
,
eval_dataset
,
options
=
orbit
.
StandardEvaluatorOptions
(
use_tf_function
=
flags_obj
.
use_tf_function
))
def
train_loop_begin
(
self
):
"""See base class."""
...
...
orbit/__init__.py
View file @
356c98bd
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
orbit/controller.py
View file @
356c98bd
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -16,8 +15,9 @@
"""A light weight utilities to train TF2 models."""
import
time
from
typing
import
Callable
,
Optional
,
Text
,
Union
from
typing
import
Callable
,
Dict
,
Optional
,
Text
,
Union
from
absl
import
logging
import
numpy
as
np
from
orbit
import
runner
from
orbit
import
utils
...
...
@@ -71,9 +71,11 @@ class Controller:
`trainer.train` function will always be enabled. If set, the value
should be divisible by steps_per_loop.
summary_dir: The directory to restore and write checkpoints and summaries.
If None, it will be set to `checkpoint_manager.directory`.
For example, You can set it to `checkpoint_manager.directory`.
If None, it will not write training summarizes.
eval_summary_dir: The directory to write eval summaries. If None, it will
be set to `summary_dir`.
be set to `summary_dir`. If both `summary_dir` and `eval_summary_dir`
are None, it will not write evaluation summarizes.
Raises:
ValueError: If both `trainer` and `evaluator` are None.
...
...
@@ -108,9 +110,6 @@ class Controller:
self
.
global_step
=
global_step
self
.
checkpoint_manager
=
checkpoint_manager
if
summary_dir
is
None
and
checkpoint_manager
:
summary_dir
=
checkpoint_manager
.
directory
if
self
.
trainer
is
not
None
:
self
.
step_timer
=
None
self
.
steps_per_loop
=
steps_per_loop
...
...
@@ -118,7 +117,6 @@ class Controller:
self
.
summary_manager
=
utils
.
SummaryManager
(
summary_dir
,
tf
.
summary
.
scalar
,
global_step
=
self
.
global_step
)
eval_summary_writer
=
None
if
self
.
evaluator
is
not
None
:
eval_summary_dir
=
eval_summary_dir
or
summary_dir
if
eval_summary_dir
==
summary_dir
and
self
.
trainer
is
not
None
:
...
...
@@ -177,7 +175,7 @@ class Controller:
if
checkpoint_at_completion
:
self
.
save_checkpoint
()
def
evaluate
(
self
,
steps
:
int
=
None
):
def
evaluate
(
self
,
steps
:
int
=
None
)
->
Optional
[
Dict
[
Text
,
np
.
number
]]
:
"""Runs evaluation.
This method calls the `evaluate` method on the Evaluator object for `steps`
...
...
@@ -186,10 +184,12 @@ class Controller:
Args:
steps: The number of steps to evaluate for.
Returns:
The evaluation results as a dictionary of numpy values.
Raises:
ValueError: If no checkpoint found in `self.checkpoint_manager.directory`.
ValueError: If `evaluator` is not provided.
"""
if
self
.
evaluator
is
None
:
raise
ValueError
(
"`evaluator` must be provided to call `evaluate()` "
...
...
@@ -204,7 +204,7 @@ class Controller:
else
:
logging
.
info
(
"Evaluating at train step: %s"
,
current_step
)
with
self
.
eval_summary_manager
.
summary_writer
.
as_default
():
with
self
.
eval_summary_manager
.
summary_writer
()
.
as_default
():
eval_outputs
=
self
.
evaluator
.
evaluate
(
steps
)
if
eval_outputs
:
...
...
@@ -217,6 +217,8 @@ class Controller:
self
.
eval_summary_manager
.
write_summaries
(
eval_outputs
)
self
.
eval_summary_manager
.
flush
()
return
eval_outputs
def
restore_checkpoint
(
self
,
checkpoint_path
:
Text
=
None
):
"""Restore or initialize the model.
...
...
@@ -334,7 +336,7 @@ class Controller:
current_step
+=
num_steps
num_steps
=
tf
.
convert_to_tensor
(
num_steps
,
dtype
=
tf
.
int32
)
with
self
.
summary_manager
.
summary_writer
.
as_default
():
with
self
.
summary_manager
.
summary_writer
()
.
as_default
():
# Create a lambda that returns true when summaries should be written.
should_record
=
False
# Allows static optimization in no-summary cases.
if
self
.
summary_interval
:
...
...
orbit/controller_test.py
View file @
356c98bd
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -158,6 +157,57 @@ class TestEvaluator(standard_runner.StandardEvaluator):
}
class
TestEvaluatorWithNestedSummary
(
standard_runner
.
StandardEvaluator
):
"""Implements the training and evaluation APIs for the test model."""
def
__init__
(
self
):
self
.
strategy
=
tf
.
distribute
.
get_strategy
()
self
.
model
=
create_model
()
dataset
=
self
.
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
)
dataset2
=
self
.
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
)
self
.
loss
=
tf
.
keras
.
metrics
.
Mean
(
"loss"
,
dtype
=
tf
.
float32
)
self
.
accuracy
=
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
"accuracy"
,
dtype
=
tf
.
float32
)
self
.
loss2
=
tf
.
keras
.
metrics
.
Mean
(
"loss"
,
dtype
=
tf
.
float32
)
self
.
accuracy2
=
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
"accuracy"
,
dtype
=
tf
.
float32
)
standard_runner
.
StandardEvaluator
.
__init__
(
self
,
eval_dataset
=
{
"dataset"
:
dataset
,
"dataset2"
:
dataset2
})
def
eval_step
(
self
,
iterator
):
def
_replicated_step
(
loss
,
accuracy
,
inputs
):
"""Replicated evaluation step."""
inputs
,
targets
=
inputs
outputs
=
self
.
model
(
inputs
)
loss
.
update_state
(
tf
.
keras
.
losses
.
MSE
(
targets
,
outputs
))
accuracy
.
update_state
(
targets
,
outputs
)
self
.
strategy
.
run
(
lambda
inputs
:
_replicated_step
(
self
.
loss
,
self
.
accuracy
,
inputs
),
args
=
(
next
(
iterator
[
"dataset"
]),))
self
.
strategy
.
run
(
lambda
inputs
:
_replicated_step
(
self
.
loss2
,
self
.
accuracy2
,
inputs
),
args
=
(
next
(
iterator
[
"dataset2"
]),))
def
eval_end
(
self
):
return
{
"dataset"
:
{
"loss"
:
self
.
loss
.
result
(),
"accuracy"
:
self
.
accuracy
.
result
()
},
"dataset2"
:
{
"loss"
:
self
.
loss2
.
result
(),
"accuracy"
:
self
.
accuracy2
.
result
()
},
}
class
TestTrainerWithSummaries
(
standard_runner
.
StandardTrainer
):
"""A Trainer model with summaries for testing purposes."""
...
...
@@ -171,7 +221,10 @@ class TestTrainerWithSummaries(standard_runner.StandardTrainer):
self
.
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
)
)
standard_runner
.
StandardTrainer
.
__init__
(
self
,
train_dataset
,
use_tpu_summary_optimization
=
True
)
self
,
train_dataset
,
options
=
standard_runner
.
StandardTrainerOptions
(
use_tpu_summary_optimization
=
True
))
def
build_train_dataset
(
self
):
return
self
.
strategy
.
experimental_distribute_datasets_from_function
(
...
...
@@ -241,6 +294,56 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
train_steps
=
10
,
eval_steps
=
2
,
eval_interval
=
6
)
self
.
assertEqual
(
test_runner
.
global_step
,
10
)
def
test_has_checkpoint_no_summaries
(
self
):
test_runner
=
TestRunner
()
# Has checkpoint, but no summary directories.
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
test_runner
.
model
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
self
.
model_dir
,
max_to_keep
=
None
,
step_counter
=
test_runner
.
global_step
)
test_controller
=
controller
.
Controller
(
trainer
=
test_runner
,
evaluator
=
test_runner
,
global_step
=
test_runner
.
global_step
,
checkpoint_manager
=
checkpoint_manager
,
steps_per_loop
=
2
)
test_controller
.
train_and_evaluate
(
train_steps
=
10
,
eval_steps
=
2
,
eval_interval
=
6
)
self
.
assertEqual
(
test_runner
.
global_step
,
10
)
# No summaries are saved.
self
.
assertEmpty
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
checkpoint_manager
.
directory
,
"events.*"
)))
def
test_has_checkpoint_eval_summary_only
(
self
):
test_runner
=
TestRunner
()
# Has checkpoint, but no summary directories.
checkpoint
=
tf
.
train
.
Checkpoint
(
model
=
test_runner
.
model
)
checkpoint_manager
=
tf
.
train
.
CheckpointManager
(
checkpoint
,
self
.
model_dir
,
max_to_keep
=
None
,
step_counter
=
test_runner
.
global_step
)
test_controller
=
controller
.
Controller
(
trainer
=
test_runner
,
evaluator
=
test_runner
,
global_step
=
test_runner
.
global_step
,
checkpoint_manager
=
checkpoint_manager
,
eval_summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
),
steps_per_loop
=
2
)
test_controller
.
train_and_evaluate
(
train_steps
=
10
,
eval_steps
=
2
,
eval_interval
=
6
)
self
.
assertEqual
(
test_runner
.
global_step
,
10
)
# Training summaries are not saved.
self
.
assertEmpty
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
checkpoint_manager
.
directory
,
"events.*"
)))
# Evaluation summaries are saved.
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
glob
(
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval/events.*"
)))
@
parameterized
.
named_parameters
((
"return_numpy"
,
True
),
(
"return_tensor"
,
False
))
def
test_train_and_evaluate
(
self
,
return_numpy
):
...
...
@@ -329,7 +432,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
checkpoint_manager
=
checkpoint_manager
,
summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/train"
),
eval_summary_dir
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
))
test_controller
.
evaluate
(
steps
=
2
)
eval_results
=
test_controller
.
evaluate
(
steps
=
2
)
# Only eval summaries are written
self
.
assertFalse
(
...
...
@@ -339,6 +442,7 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertNotEmpty
(
summaries_with_matching_keyword
(
"eval_loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval"
)))
self
.
assertIn
(
"eval_loss"
,
eval_results
)
# Tests continuous eval with timeout and timeout_fn.
done_file
=
os
.
path
.
join
(
self
.
model_dir
,
"summaries/eval/Done"
)
...
...
@@ -558,7 +662,8 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
evaluator
=
test_runner
,
global_step
=
test_runner
.
global_step
,
steps_per_loop
=
10
,
checkpoint_manager
=
checkpoint_manager
)
checkpoint_manager
=
checkpoint_manager
,
summary_dir
=
self
.
model_dir
)
test_controller
.
train_and_evaluate
(
train_steps
=
10
,
eval_steps
=
2
,
eval_interval
=
5
)
...
...
@@ -569,6 +674,31 @@ class ControllerTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertLen
(
summaries_with_matching_keyword
(
"eval_loss"
,
self
.
model_dir
),
2
)
def
test_evaluate_with_nested_summaries
(
self
):
test_evaluator
=
TestEvaluatorWithNestedSummary
()
test_controller
=
controller
.
Controller
(
evaluator
=
test_evaluator
,
global_step
=
tf
.
Variable
(
0
,
dtype
=
tf
.
int64
),
eval_summary_dir
=
self
.
model_dir
)
test_controller
.
evaluate
(
steps
=
5
)
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"dataset"
)))
self
.
assertNotEmpty
(
summaries_with_matching_keyword
(
"loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"dataset"
)))
self
.
assertNotEmpty
(
summaries_with_matching_keyword
(
"accuracy"
,
os
.
path
.
join
(
self
.
model_dir
,
"dataset"
)))
self
.
assertNotEmpty
(
tf
.
io
.
gfile
.
listdir
(
os
.
path
.
join
(
self
.
model_dir
,
"dataset2"
)))
self
.
assertNotEmpty
(
summaries_with_matching_keyword
(
"loss"
,
os
.
path
.
join
(
self
.
model_dir
,
"dataset2"
)))
self
.
assertNotEmpty
(
summaries_with_matching_keyword
(
"accuracy"
,
os
.
path
.
join
(
self
.
model_dir
,
"dataset2"
)))
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
orbit/runner.py
View file @
356c98bd
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -35,7 +34,7 @@ class AbstractTrainer(tf.Module, metaclass=abc.ABCMeta):
large in Eager mode. It is usually encouraged to create a host training loop
(e.g. using a `tf.range` wrapping `strategy.run` inside a
`tf.function`) in the TPU case. For the cases that don't require host
training loop to ach
e
ive peak performance, users can just implement a simple
training loop to achi
e
ve peak performance, users can just implement a simple
python loop to drive each step.
Args:
...
...
@@ -45,7 +44,8 @@ class AbstractTrainer(tf.Module, metaclass=abc.ABCMeta):
Returns:
The function may return a dictionary of `Tensors` or numpy arrays, which
will be written to logs and as TensorBoard summaries.
will be written to logs and as TensorBoard summaries. It can also be a
nested dictionary, yielding a hierarchy of summary directories.
"""
pass
...
...
@@ -67,6 +67,7 @@ class AbstractEvaluator(tf.Module, metaclass=abc.ABCMeta):
Returns:
The function may return a dictionary of `Tensors` or numpy arrays, which
will be written to logs and as TensorBoard summaries.
will be written to logs and as TensorBoard summaries. It can also be a
nested dictionary, yielding a hierarchy of summary directories.
"""
pass
orbit/standard_runner.py
View file @
356c98bd
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -24,20 +23,22 @@ import tensorflow as tf
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
TrainerO
verride
s
:
"""Advanced o
verride
s for
O
rbit
t
rainer
s
.
class
Standard
TrainerO
ption
s
:
"""Advanced o
ption
s for
`o
rbit
.StandardT
rainer
`
.
Attributes:
use_tf_while_loop: A boolean indicates whether to wrap the train step with
a `tf.while_loop`.
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, training will run on pure eager mode.
use_tpu_summary_optimization: A boolean indicates whether to enable the
performance optimization for summaries in TPUs. In TPUs, writing
summaries with outside compilation inside train step is slow. If True,
it creates two `tf.function` with two XLA programs: one with summaries
and one without, and run the program with summaries (slow one) only if
necessary.
use_tf_while_loop: A boolean indicating whether to run the training loop
using a `tf.while_loop`. If `True`, `use_tf_function` must also be `True`.
use_tf_function: A boolean indicating whether to apply `tf.function` to the
training loop. This will only affect the body of the loop (involving
`train_step`); `train_loop_begin` and `train_loop_end` will always be run
in eager mode.
use_tpu_summary_optimization: A boolean indicating whether to enable a
performance optimization for summaries in TPUs. Writing summaries
conditionally with outside compilation on TPUs can be extremely slow. If
`True`, this optimization creates two `tf.function`s with two XLA programs
(one with summary calls, and one without). The program with summaries runs
only for one step when summaries should be recorded.
"""
use_tf_while_loop
:
bool
=
True
use_tf_function
:
bool
=
True
...
...
@@ -47,39 +48,29 @@ class TrainerOverrides:
class
StandardTrainer
(
runner
.
AbstractTrainer
,
metaclass
=
abc
.
ABCMeta
):
"""Implements the standard functionality of AbstractTrainer APIs."""
def
__init__
(
self
,
train_dataset
,
use_tf_while_loop
=
True
,
use_tf_function
=
True
,
use_tpu_summary_optimization
=
False
):
def
__init__
(
self
,
train_dataset
,
options
:
StandardTrainerOptions
=
None
):
"""Construct a `StandardTrainer` object.
Args:
train_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset.
use_tf_while_loop: A boolean indicates whether to wrap the train step with
a `tf.while_loop`.
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, training will run on pure eager mode.
use_tpu_summary_optimization: A boolean indicates whether to enable the
performance optimization for summaries in TPUs. In TPUs, writing
summaries with outside compilation inside train step is slow. If True,
it creates two `tf.function` with two XLA programs: one with summaries
and one without, and run the program with summaries (slow one) only if
necessary.
options: An `orbit.StandardTrainerOptions` instance.
"""
if
use_tf_while_loop
and
not
use_tf_function
:
options
=
options
or
StandardTrainerOptions
()
if
options
.
use_tf_while_loop
and
not
options
.
use_tf_function
:
raise
ValueError
(
"`use_tf_while_loop=True` and `use_tf_function=False` "
"is not supported"
)
if
use_tpu_summary_optimization
and
not
use_tf_while_loop
:
if
options
.
use_tpu_summary_optimization
and
not
options
.
use_tf_while_loop
:
raise
ValueError
(
"`use_tpu_summary_optimization=True` and "
"`use_tf_while_loop=False` is not supported"
)
self
.
_use_tf_while_loop
=
use_tf_while_loop
self
.
_use_tf_function
=
use_tf_function
self
.
_use_tf_while_loop
=
options
.
use_tf_while_loop
self
.
_use_tf_function
=
options
.
use_tf_function
self
.
_use_tpu_summary_optimization
=
options
.
use_tpu_summary_optimization
self
.
_train_dataset
=
train_dataset
self
.
_train_iter
=
None
self
.
_train_loop_fn
=
None
self
.
_use_tpu_summary_optimization
=
use_tpu_summary_optimization
def
train
(
self
,
num_steps
:
Optional
[
tf
.
Tensor
])
->
Optional
[
Dict
[
Text
,
tf
.
Tensor
]]:
...
...
@@ -144,7 +135,8 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
Returns:
The function may return a dictionary of `Tensors`, which will be
written to logs and as TensorBoard summaries.
written to logs and as TensorBoard summaries. It can also be a
nested dictionary, yielding a hierarchy of summary directories.
"""
pass
...
...
@@ -168,12 +160,14 @@ class StandardTrainer(runner.AbstractTrainer, metaclass=abc.ABCMeta):
@
dataclasses
.
dataclass
(
frozen
=
True
)
class
EvaluatorO
verride
s
:
"""Advanced o
verrides for Orbit e
valuator
s
.
class
Standard
EvaluatorO
ption
s
:
"""Advanced o
ptions for the `orbit.StandardE
valuator
`
.
Attributes:
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, training will run on pure eager mode.
use_tf_function: A boolean indicating whether to apply `tf.function` to the
training loop. This will only affect the body of the loop (involving
`train_step`); `train_loop_begin` and `train_loop_end` will always be run
in eager mode.
"""
use_tf_function
:
bool
=
True
...
...
@@ -181,16 +175,16 @@ class EvaluatorOverrides:
class
StandardEvaluator
(
runner
.
AbstractEvaluator
,
metaclass
=
abc
.
ABCMeta
):
"""Implements the standard functionality of AbstractEvaluator APIs."""
def
__init__
(
self
,
eval_dataset
,
use_tf_function
=
Tru
e
):
def
__init__
(
self
,
eval_dataset
,
options
:
StandardEvaluatorOptions
=
Non
e
):
"""Construct a `StandardEvaluator` object.
Args:
eval_dataset: A tf.nest-compatible structure of tf.data.Dataset or
DistributedDataset.
use_tf_function: A boolean indicates whether a `tf.function` will be used.
If False, evaluation will run on pure eager mode.
options: An `orbit.StandardEvaluatorOptions` instance.
"""
self
.
_eval_use_tf_function
=
use_tf_function
options
=
options
or
StandardEvaluatorOptions
()
self
.
_eval_use_tf_function
=
options
.
use_tf_function
self
.
_eval_dataset
=
eval_dataset
self
.
_eval_loop_fn
=
None
...
...
@@ -261,7 +255,8 @@ class StandardEvaluator(runner.AbstractEvaluator, metaclass=abc.ABCMeta):
Returns:
The function may return a dictionary of `Tensors`, which will be
written to logs and as TensorBoard summaries.
written to logs and as TensorBoard summaries. It can also be a
nested dictionary, yielding a hierarchy of summary directories.
"""
pass
...
...
orbit/standard_runner_test.py
View file @
356c98bd
# Lint as: python3
# Copyright 2020 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -14,9 +13,9 @@
# limitations under the License.
# ==============================================================================
"""Tests for orbit.standard_runner."""
# pylint: disable=g-bad-import-order
from
orbit
import
standard_runner
from
orbit
import
utils
import
tensorflow
as
tf
...
...
@@ -34,46 +33,49 @@ def dataset_fn(input_context=None):
return
dataset
class
TestRunner
(
standard_runner
.
StandardTrainer
,
standard_runner
.
StandardEvaluator
):
"""Implements the training and evaluation APIs for tests."""
class
TestTrainer
(
standard_runner
.
StandardTrainer
):
"""A StandardTrainer subclass for tests."""
def
__init__
(
self
):
def
__init__
(
self
,
options
=
None
):
self
.
strategy
=
tf
.
distribute
.
get_strategy
()
self
.
global_step
=
tf
.
Variable
(
0
,
trainable
=
False
,
dtype
=
tf
.
int64
,
name
=
'global_step'
,
aggregation
=
tf
.
VariableAggregation
.
ONLY_FIRST_REPLICA
)
standard_runner
.
StandardTrainer
.
__init__
(
self
,
train_dataset
=
None
)
standard_runner
.
StandardEvaluator
.
__init__
(
self
,
eval_dataset
=
None
)
self
.
global_step
=
utils
.
create_global_step
()
distribute
=
self
.
strategy
.
experimental_distribute_datasets_from_function
dataset
=
distribute
(
dataset_fn
)
super
().
__init__
(
train_dataset
=
dataset
,
options
=
options
)
def
train_loop_begin
(
self
):
self
.
train_dataset
=
(
self
.
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
)
)
self
.
global_step
.
assign
(
0
)
def
train_step
(
self
,
iterator
):
def
_
replica
ted
_step
(
_
):
def
replica_step
(
_
):
self
.
global_step
.
assign_add
(
1
)
self
.
strategy
.
run
(
_
replica
ted
_step
,
args
=
(
next
(
iterator
),))
self
.
strategy
.
run
(
replica_step
,
args
=
(
next
(
iterator
),))
def
train_loop_end
(
self
):
return
self
.
global_step
.
numpy
()
class
TestEvaluator
(
standard_runner
.
StandardEvaluator
):
"""A StandardEvaluator subclass for tests."""
def
__init__
(
self
,
options
=
None
):
self
.
strategy
=
tf
.
distribute
.
get_strategy
()
self
.
global_step
=
utils
.
create_global_step
()
distribute
=
self
.
strategy
.
experimental_distribute_datasets_from_function
dataset
=
distribute
(
dataset_fn
)
super
().
__init__
(
eval_dataset
=
dataset
,
options
=
options
)
def
eval_begin
(
self
):
self
.
eval_dataset
=
self
.
strategy
.
experimental_distribute_datasets_from_function
(
dataset_fn
)
self
.
global_step
.
assign
(
0
)
def
eval_step
(
self
,
iterator
):
def
_
replica
ted
_step
(
_
):
def
replica_step
(
_
):
self
.
global_step
.
assign_add
(
1
)
self
.
strategy
.
run
(
_
replica
ted
_step
,
args
=
(
next
(
iterator
),))
self
.
strategy
.
run
(
replica_step
,
args
=
(
next
(
iterator
),))
def
eval_end
(
self
):
return
self
.
global_step
.
numpy
()
...
...
@@ -81,15 +83,19 @@ class TestRunner(standard_runner.StandardTrainer,
class
StandardRunnerTest
(
tf
.
test
.
TestCase
):
def
test_train
(
self
):
test_runner
=
TestRunner
()
self
.
assertEqual
(
test_runner
.
train
(
tf
.
convert_to_tensor
(
10
,
dtype
=
tf
.
int32
)),
10
)
def
test_default_trainer
(
self
):
trainer
=
TestTrainer
()
self
.
assertEqual
(
trainer
.
train
(
tf
.
constant
(
10
)),
10
)
def
test_trainer_with_tpu_summary_optimization
(
self
):
options
=
standard_runner
.
StandardTrainerOptions
(
use_tpu_summary_optimization
=
True
)
trainer
=
TestTrainer
(
options
)
self
.
assertEqual
(
trainer
.
train
(
tf
.
constant
(
10
)),
10
)
def
test_eval
(
self
):
test_runner
=
TestRunner
()
self
.
assertEqual
(
test_runner
.
evaluate
(
tf
.
convert_to_tensor
(
10
,
dtype
=
tf
.
int32
)),
10
)
def
test_default_evaluator
(
self
):
evaluator
=
TestEvaluator
()
self
.
assertEqual
(
evaluator
.
evaluate
(
tf
.
constant
(
10
)),
10
)
if
__name__
==
'__main__'
:
...
...
Prev
1
2
3
4
5
6
7
8
9
…
50
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment