Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
3462436c
Commit
3462436c
authored
Jan 13, 2020
by
Jaehong Kim
Committed by
A. Unique TensorFlower
Jan 13, 2020
Browse files
Apply model optimization pruning for image classification task.
PiperOrigin-RevId: 289573318
parent
fae2b55c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
222 additions
and
49 deletions
+222
-49
official/vision/image_classification/common.py
official/vision/image_classification/common.py
+53
-2
official/vision/image_classification/imagenet_preprocessing.py
...ial/vision/image_classification/imagenet_preprocessing.py
+28
-1
official/vision/image_classification/resnet_imagenet_main.py
official/vision/image_classification/resnet_imagenet_main.py
+82
-21
official/vision/image_classification/resnet_imagenet_test.py
official/vision/image_classification/resnet_imagenet_test.py
+59
-25
No files found.
official/vision/image_classification/common.py
View file @
3462436c
...
@@ -24,6 +24,7 @@ import numpy as np
...
@@ -24,6 +24,7 @@ import numpy as np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.keras.optimizer_v2
import
gradient_descent
as
gradient_descent_v2
from
tensorflow.python.keras.optimizer_v2
import
gradient_descent
as
gradient_descent_v2
import
tensorflow_model_optimization
as
tfmot
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
...
@@ -180,7 +181,12 @@ def get_optimizer(learning_rate=0.1):
...
@@ -180,7 +181,12 @@ def get_optimizer(learning_rate=0.1):
# TODO(hongkuny,haoyuzhang): make cifar model use_tensor_lr to clean up code.
# TODO(hongkuny,haoyuzhang): make cifar model use_tensor_lr to clean up code.
def
get_callbacks
(
steps_per_epoch
,
learning_rate_schedule_fn
=
None
):
def
get_callbacks
(
steps_per_epoch
,
learning_rate_schedule_fn
=
None
,
pruning_method
=
None
,
enable_checkpoint_and_export
=
False
,
model_dir
=
None
):
"""Returns common callbacks."""
"""Returns common callbacks."""
time_callback
=
keras_utils
.
TimeHistory
(
FLAGS
.
batch_size
,
FLAGS
.
log_steps
)
time_callback
=
keras_utils
.
TimeHistory
(
FLAGS
.
batch_size
,
FLAGS
.
log_steps
)
callbacks
=
[
time_callback
]
callbacks
=
[
time_callback
]
...
@@ -205,6 +211,19 @@ def get_callbacks(steps_per_epoch, learning_rate_schedule_fn=None):
...
@@ -205,6 +211,19 @@ def get_callbacks(steps_per_epoch, learning_rate_schedule_fn=None):
steps_per_epoch
)
steps_per_epoch
)
callbacks
.
append
(
profiler_callback
)
callbacks
.
append
(
profiler_callback
)
is_pruning_enabled
=
pruning_method
is
not
None
if
is_pruning_enabled
:
callbacks
.
append
(
tfmot
.
sparsity
.
keras
.
UpdatePruningStep
())
if
model_dir
is
not
None
:
callbacks
.
append
(
tfmot
.
sparsity
.
keras
.
PruningSummaries
(
log_dir
=
model_dir
,
profile_batch
=
0
))
if
enable_checkpoint_and_export
:
if
model_dir
is
not
None
:
ckpt_full_path
=
os
.
path
.
join
(
model_dir
,
'model.ckpt-{epoch:04d}'
)
callbacks
.
append
(
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
ckpt_full_path
,
save_weights_only
=
True
))
return
callbacks
return
callbacks
...
@@ -254,7 +273,11 @@ def build_stats(history, eval_output, callbacks):
...
@@ -254,7 +273,11 @@ def build_stats(history, eval_output, callbacks):
return
stats
return
stats
def
define_keras_flags
(
dynamic_loss_scale
=
True
):
def
define_keras_flags
(
dynamic_loss_scale
=
True
,
model
=
False
,
optimizer
=
False
,
pretrained_filepath
=
False
):
"""Define flags for Keras models."""
"""Define flags for Keras models."""
flags_core
.
define_base
(
clean
=
True
,
num_gpu
=
True
,
run_eagerly
=
True
,
flags_core
.
define_base
(
clean
=
True
,
num_gpu
=
True
,
run_eagerly
=
True
,
train_epochs
=
True
,
epochs_between_evals
=
True
,
train_epochs
=
True
,
epochs_between_evals
=
True
,
...
@@ -334,6 +357,17 @@ def define_keras_flags(dynamic_loss_scale=True):
...
@@ -334,6 +357,17 @@ def define_keras_flags(dynamic_loss_scale=True):
'a temporal flag during transition to tf.keras.layers. Do not use this '
'a temporal flag during transition to tf.keras.layers. Do not use this '
'flag for external usage. this will be removed shortly.'
)
'flag for external usage. this will be removed shortly.'
)
if
model
:
flags
.
DEFINE_string
(
'model'
,
'resnet50_v1.5'
,
'Name of model preset. (mobilenet, resnet50_v1.5)'
)
if
optimizer
:
flags
.
DEFINE_string
(
'optimizer'
,
'resnet50_default'
,
'Name of optimizer preset. '
'(mobilenet_default, resnet50_default)'
)
if
pretrained_filepath
:
flags
.
DEFINE_string
(
'pretrained_filepath'
,
''
,
'Pretrained file path.'
)
def
get_synth_data
(
height
,
width
,
num_channels
,
num_classes
,
dtype
):
def
get_synth_data
(
height
,
width
,
num_channels
,
num_classes
,
dtype
):
"""Creates a set of synthetic random data.
"""Creates a set of synthetic random data.
...
@@ -364,6 +398,23 @@ def get_synth_data(height, width, num_channels, num_classes, dtype):
...
@@ -364,6 +398,23 @@ def get_synth_data(height, width, num_channels, num_classes, dtype):
return
inputs
,
labels
return
inputs
,
labels
def
define_pruning_flags
():
"""Define flags for pruning methods."""
flags
.
DEFINE_string
(
'pruning_method'
,
None
,
'Pruning method.'
'None (no pruning) or polynomial_decay.'
)
flags
.
DEFINE_float
(
'pruning_initial_sparsity'
,
0.0
,
'Initial sparsity for pruning.'
)
flags
.
DEFINE_float
(
'pruning_final_sparsity'
,
0.5
,
'Final sparsity for pruning.'
)
flags
.
DEFINE_integer
(
'pruning_begin_step'
,
0
,
'Begin step for pruning.'
)
flags
.
DEFINE_integer
(
'pruning_end_step'
,
100000
,
'End step for pruning.'
)
flags
.
DEFINE_integer
(
'pruning_frequency'
,
100
,
'Frequency for pruning.'
)
def
get_synth_input_fn
(
height
,
width
,
num_channels
,
num_classes
,
def
get_synth_input_fn
(
height
,
width
,
num_channels
,
num_classes
,
dtype
=
tf
.
float32
,
drop_remainder
=
True
):
dtype
=
tf
.
float32
,
drop_remainder
=
True
):
"""Returns an input function that returns a dataset with random data.
"""Returns an input function that returns a dataset with random data.
...
...
official/vision/image_classification/imagenet_preprocessing.py
View file @
3462436c
...
@@ -226,7 +226,8 @@ def parse_record(raw_record, is_training, dtype):
...
@@ -226,7 +226,8 @@ def parse_record(raw_record, is_training, dtype):
dtype: data type to use for images/features.
dtype: data type to use for images/features.
Returns:
Returns:
Tuple with processed image tensor and one-hot-encoded label tensor.
Tuple with processed image tensor in a channel-last format and
one-hot-encoded label tensor.
"""
"""
image_buffer
,
label
,
bbox
=
parse_example_proto
(
raw_record
)
image_buffer
,
label
,
bbox
=
parse_example_proto
(
raw_record
)
...
@@ -246,6 +247,32 @@ def parse_record(raw_record, is_training, dtype):
...
@@ -246,6 +247,32 @@ def parse_record(raw_record, is_training, dtype):
return
image
,
label
return
image
,
label
def
get_parse_record_fn
(
use_keras_image_data_format
=
False
):
"""Get a function for parsing the records, accounting for image format.
This is useful by handling different types of Keras models. For instance,
the current resnet_model.resnet50 input format is always channel-last,
whereas the keras_applications mobilenet input format depends on
tf.keras.backend.image_data_format(). We should set
use_keras_image_data_format=False for the former and True for the latter.
Args:
use_keras_image_data_format: A boolean denoting whether data format is keras
backend image data format. If False, the image format is channel-last. If
True, the image format matches tf.keras.backend.image_data_format().
Returns:
Function to use for parsing the records.
"""
def
parse_record_fn
(
raw_record
,
is_training
,
dtype
):
image
,
label
=
parse_record
(
raw_record
,
is_training
,
dtype
)
if
use_keras_image_data_format
:
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_first'
:
image
=
tf
.
transpose
(
image
,
perm
=
[
2
,
0
,
1
])
return
image
,
label
return
parse_record_fn
def
input_fn
(
is_training
,
def
input_fn
(
is_training
,
data_dir
,
data_dir
,
batch_size
,
batch_size
,
...
...
official/vision/image_classification/resnet_imagenet_main.py
View file @
3462436c
...
@@ -25,6 +25,8 @@ from absl import flags
...
@@ -25,6 +25,8 @@ from absl import flags
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
import
tensorflow_model_optimization
as
tfmot
from
official.benchmark.models
import
trivial_model
from
official.benchmark.models
import
trivial_model
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.logs
import
logger
from
official.utils.logs
import
logger
...
@@ -44,6 +46,7 @@ def run(flags_obj):
...
@@ -44,6 +46,7 @@ def run(flags_obj):
Raises:
Raises:
ValueError: If fp16 is passed as it is not currently supported.
ValueError: If fp16 is passed as it is not currently supported.
NotImplementedError: If some features are not currently supported.
Returns:
Returns:
Dictionary of training and eval stats.
Dictionary of training and eval stats.
...
@@ -120,12 +123,20 @@ def run(flags_obj):
...
@@ -120,12 +123,20 @@ def run(flags_obj):
# in the dataset, as XLA-GPU doesn't support dynamic shapes.
# in the dataset, as XLA-GPU doesn't support dynamic shapes.
drop_remainder
=
flags_obj
.
enable_xla
drop_remainder
=
flags_obj
.
enable_xla
# Current resnet_model.resnet50 input format is always channel-last.
# We use keras_application mobilenet model which input format is depends on
# the keras beckend image data format.
# This use_keras_image_data_format flags indicates whether image preprocessor
# output format should be same as the keras backend image data format or just
# channel-last format.
use_keras_image_data_format
=
(
flags_obj
.
model
==
'mobilenet'
)
train_input_dataset
=
input_fn
(
train_input_dataset
=
input_fn
(
is_training
=
True
,
is_training
=
True
,
data_dir
=
flags_obj
.
data_dir
,
data_dir
=
flags_obj
.
data_dir
,
batch_size
=
flags_obj
.
batch_size
,
batch_size
=
flags_obj
.
batch_size
,
num_epochs
=
flags_obj
.
train_epochs
,
num_epochs
=
flags_obj
.
train_epochs
,
parse_record_fn
=
imagenet_preprocessing
.
parse_record
,
parse_record_fn
=
imagenet_preprocessing
.
get_parse_record_fn
(
use_keras_image_data_format
=
use_keras_image_data_format
),
datasets_num_private_threads
=
flags_obj
.
datasets_num_private_threads
,
datasets_num_private_threads
=
flags_obj
.
datasets_num_private_threads
,
dtype
=
dtype
,
dtype
=
dtype
,
drop_remainder
=
drop_remainder
,
drop_remainder
=
drop_remainder
,
...
@@ -140,7 +151,8 @@ def run(flags_obj):
...
@@ -140,7 +151,8 @@ def run(flags_obj):
data_dir
=
flags_obj
.
data_dir
,
data_dir
=
flags_obj
.
data_dir
,
batch_size
=
flags_obj
.
batch_size
,
batch_size
=
flags_obj
.
batch_size
,
num_epochs
=
flags_obj
.
train_epochs
,
num_epochs
=
flags_obj
.
train_epochs
,
parse_record_fn
=
imagenet_preprocessing
.
parse_record
,
parse_record_fn
=
imagenet_preprocessing
.
get_parse_record_fn
(
use_keras_image_data_format
=
use_keras_image_data_format
),
dtype
=
dtype
,
dtype
=
dtype
,
drop_remainder
=
drop_remainder
)
drop_remainder
=
drop_remainder
)
...
@@ -153,9 +165,27 @@ def run(flags_obj):
...
@@ -153,9 +165,27 @@ def run(flags_obj):
boundaries
=
list
(
p
[
1
]
for
p
in
common
.
LR_SCHEDULE
[
1
:]),
boundaries
=
list
(
p
[
1
]
for
p
in
common
.
LR_SCHEDULE
[
1
:]),
multipliers
=
list
(
p
[
0
]
for
p
in
common
.
LR_SCHEDULE
),
multipliers
=
list
(
p
[
0
]
for
p
in
common
.
LR_SCHEDULE
),
compute_lr_on_cpu
=
True
)
compute_lr_on_cpu
=
True
)
steps_per_epoch
=
(
imagenet_preprocessing
.
NUM_IMAGES
[
'train'
]
//
flags_obj
.
batch_size
)
learning_rate_schedule_fn
=
None
with
strategy_scope
:
with
strategy_scope
:
optimizer
=
common
.
get_optimizer
(
lr_schedule
)
if
flags_obj
.
optimizer
==
'resnet50_default'
:
optimizer
=
common
.
get_optimizer
(
lr_schedule
)
learning_rate_schedule_fn
=
common
.
learning_rate_schedule
elif
flags_obj
.
optimizer
==
'mobilenet_default'
:
lr_decay_factor
=
0.94
num_epochs_per_decay
=
2.5
initial_learning_rate_per_sample
=
0.000007
initial_learning_rate
=
\
initial_learning_rate_per_sample
*
flags_obj
.
batch_size
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
learning_rate
=
tf
.
keras
.
optimizers
.
schedules
.
ExponentialDecay
(
initial_learning_rate
,
decay_steps
=
steps_per_epoch
*
num_epochs_per_decay
,
decay_rate
=
lr_decay_factor
,
staircase
=
True
),
momentum
=
0.9
)
if
flags_obj
.
fp16_implementation
==
'graph_rewrite'
:
if
flags_obj
.
fp16_implementation
==
'graph_rewrite'
:
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
...
@@ -169,11 +199,37 @@ def run(flags_obj):
...
@@ -169,11 +199,37 @@ def run(flags_obj):
if
flags_obj
.
use_trivial_model
:
if
flags_obj
.
use_trivial_model
:
model
=
trivial_model
.
trivial_model
(
model
=
trivial_model
.
trivial_model
(
imagenet_preprocessing
.
NUM_CLASSES
)
imagenet_preprocessing
.
NUM_CLASSES
)
el
se
:
el
if
flags_obj
.
model
==
'resnet50_v1.5'
:
resnet_model
.
change_keras_layer
(
flags_obj
.
use_tf_keras_layers
)
resnet_model
.
change_keras_layer
(
flags_obj
.
use_tf_keras_layers
)
model
=
resnet_model
.
resnet50
(
model
=
resnet_model
.
resnet50
(
num_classes
=
imagenet_preprocessing
.
NUM_CLASSES
)
num_classes
=
imagenet_preprocessing
.
NUM_CLASSES
)
elif
flags_obj
.
model
==
'mobilenet'
:
# TODO(kimjaehong): Remove layers attribute when minimum TF version
# support 2.0 layers by default.
model
=
tf
.
keras
.
applications
.
mobilenet
.
MobileNet
(
weights
=
None
,
classes
=
imagenet_preprocessing
.
NUM_CLASSES
,
layers
=
tf
.
keras
.
layers
)
if
flags_obj
.
pretrained_filepath
:
model
.
load_weights
(
flags_obj
.
pretrained_filepath
)
if
flags_obj
.
pruning_method
==
'polynomial_decay'
:
if
dtype
!=
tf
.
float32
:
raise
NotImplementedError
(
'Pruning is currently only supported on dtype=tf.float32.'
)
pruning_params
=
{
'pruning_schedule'
:
tfmot
.
sparsity
.
keras
.
PolynomialDecay
(
initial_sparsity
=
flags_obj
.
pruning_initial_sparsity
,
final_sparsity
=
flags_obj
.
pruning_final_sparsity
,
begin_step
=
flags_obj
.
pruning_begin_step
,
end_step
=
flags_obj
.
pruning_end_step
,
frequency
=
flags_obj
.
pruning_frequency
),
}
model
=
tfmot
.
sparsity
.
keras
.
prune_low_magnitude
(
model
,
**
pruning_params
)
elif
flags_obj
.
pruning_method
:
raise
NotImplementedError
(
'Only polynomial_decay is currently supported.'
)
# TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
# TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
# a valid arg for this model. Also remove as a valid flag.
# a valid arg for this model. Also remove as a valid flag.
if
flags_obj
.
force_v2_in_keras_compile
is
not
None
:
if
flags_obj
.
force_v2_in_keras_compile
is
not
None
:
...
@@ -192,16 +248,14 @@ def run(flags_obj):
...
@@ -192,16 +248,14 @@ def run(flags_obj):
if
flags_obj
.
report_accuracy_metrics
else
None
),
if
flags_obj
.
report_accuracy_metrics
else
None
),
run_eagerly
=
flags_obj
.
run_eagerly
)
run_eagerly
=
flags_obj
.
run_eagerly
)
steps_per_epoch
=
(
imagenet_preprocessing
.
NUM_IMAGES
[
'train'
]
//
flags_obj
.
batch_size
)
train_epochs
=
flags_obj
.
train_epochs
train_epochs
=
flags_obj
.
train_epochs
callbacks
=
common
.
get_callbacks
(
steps_per_epoch
,
callbacks
=
common
.
get_callbacks
(
common
.
learning_rate_schedule
)
steps_per_epoch
=
steps_per_epoch
,
if
flags_obj
.
enabl
e_che
ckpoint_and_export
:
learning_rat
e_
s
che
dule_fn
=
learning_rate_schedule_fn
,
ckpt_full_path
=
os
.
path
.
join
(
flags_obj
.
model_dir
,
'model.ckpt-{epoch:04d}'
)
pruning_method
=
flags_obj
.
pruning_method
,
callbacks
.
append
(
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
ckpt_full_path
,
enable_checkpoint_and_export
=
flags_obj
.
enable_checkpoint_and_export
,
save_weights_only
=
True
)
)
model_dir
=
flags_obj
.
model_dir
)
# if mutliple epochs, ignore the train_steps flag.
# if mutliple epochs, ignore the train_steps flag.
if
train_epochs
<=
1
and
flags_obj
.
train_steps
:
if
train_epochs
<=
1
and
flags_obj
.
train_steps
:
...
@@ -237,13 +291,6 @@ def run(flags_obj):
...
@@ -237,13 +291,6 @@ def run(flags_obj):
validation_data
=
validation_data
,
validation_data
=
validation_data
,
validation_freq
=
flags_obj
.
epochs_between_evals
,
validation_freq
=
flags_obj
.
epochs_between_evals
,
verbose
=
2
)
verbose
=
2
)
if
flags_obj
.
enable_checkpoint_and_export
:
if
dtype
==
tf
.
bfloat16
:
logging
.
warning
(
"Keras model.save does not support bfloat16 dtype."
)
else
:
# Keras model.save assumes a float32 input designature.
export_path
=
os
.
path
.
join
(
flags_obj
.
model_dir
,
'saved_model'
)
model
.
save
(
export_path
,
include_optimizer
=
False
)
eval_output
=
None
eval_output
=
None
if
not
flags_obj
.
skip_eval
:
if
not
flags_obj
.
skip_eval
:
...
@@ -251,6 +298,16 @@ def run(flags_obj):
...
@@ -251,6 +298,16 @@ def run(flags_obj):
steps
=
num_eval_steps
,
steps
=
num_eval_steps
,
verbose
=
2
)
verbose
=
2
)
if
flags_obj
.
pruning_method
:
model
=
tfmot
.
sparsity
.
keras
.
strip_pruning
(
model
)
if
flags_obj
.
enable_checkpoint_and_export
:
if
dtype
==
tf
.
bfloat16
:
logging
.
warning
(
'Keras model.save does not support bfloat16 dtype.'
)
else
:
# Keras model.save assumes a float32 input designature.
export_path
=
os
.
path
.
join
(
flags_obj
.
model_dir
,
'saved_model'
)
model
.
save
(
export_path
,
include_optimizer
=
False
)
if
not
strategy
and
flags_obj
.
explicit_gpu_placement
:
if
not
strategy
and
flags_obj
.
explicit_gpu_placement
:
no_dist_strat_device
.
__exit__
()
no_dist_strat_device
.
__exit__
()
...
@@ -259,7 +316,11 @@ def run(flags_obj):
...
@@ -259,7 +316,11 @@ def run(flags_obj):
def
define_imagenet_keras_flags
():
def
define_imagenet_keras_flags
():
common
.
define_keras_flags
()
common
.
define_keras_flags
(
model
=
True
,
optimizer
=
True
,
pretrained_filepath
=
True
)
common
.
define_pruning_flags
()
flags_core
.
set_defaults
()
flags_core
.
set_defaults
()
flags
.
adopt_module_key_flags
(
common
)
flags
.
adopt_module_key_flags
(
common
)
...
...
official/vision/image_classification/resnet_imagenet_test.py
View file @
3462436c
...
@@ -18,6 +18,7 @@ from __future__ import absolute_import
...
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
from
absl.testing
import
parameterized
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.eager
import
context
from
tensorflow.python.eager
import
context
...
@@ -27,14 +28,40 @@ from official.vision.image_classification import imagenet_preprocessing
...
@@ -27,14 +28,40 @@ from official.vision.image_classification import imagenet_preprocessing
from
official.vision.image_classification
import
resnet_imagenet_main
from
official.vision.image_classification
import
resnet_imagenet_main
@
parameterized
.
parameters
(
"resnet"
,
"resnet_polynomial_decay"
,
"mobilenet"
,
"mobilenet_polynomial_decay"
)
class
KerasImagenetTest
(
tf
.
test
.
TestCase
):
class
KerasImagenetTest
(
tf
.
test
.
TestCase
):
"""Unit tests for Keras ResNet with ImageNet."""
"""Unit tests for Keras Models with ImageNet."""
_default_flags_dict
=
[
_extra_flags
=
[
"-batch_size"
,
"4"
,
"-batch_size"
,
"4"
,
"-train_steps"
,
"1"
,
"-train_steps"
,
"1"
,
"-use_synthetic_data"
,
"true"
"-use_synthetic_data"
,
"true"
,
"-data_format"
,
"channels_last"
,
]
]
_extra_flags_dict
=
{
"resnet"
:
[
"-model"
,
"resnet50_v1.5"
,
"-optimizer"
,
"resnet50_default"
,
],
"resnet_polynomial_decay"
:
[
"-model"
,
"resnet50_v1.5"
,
"-optimizer"
,
"resnet50_default"
,
"-pruning_method"
,
"polynomial_decay"
,
"-use_tf_keras_layers"
,
"true"
,
],
"mobilenet"
:
[
"-model"
,
"mobilenet"
,
"-optimizer"
,
"mobilenet_default"
,
],
"mobilenet_polynomial_decay"
:
[
"-model"
,
"mobilenet"
,
"-optimizer"
,
"mobilenet_default"
,
"-pruning_method"
,
"polynomial_decay"
,
],
}
_tempdir
=
None
_tempdir
=
None
@
classmethod
@
classmethod
...
@@ -53,16 +80,18 @@ class KerasImagenetTest(tf.test.TestCase):
...
@@ -53,16 +80,18 @@ class KerasImagenetTest(tf.test.TestCase):
tf
.
io
.
gfile
.
rmtree
(
self
.
get_temp_dir
())
tf
.
io
.
gfile
.
rmtree
(
self
.
get_temp_dir
())
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
self
.
policy
)
tf
.
compat
.
v2
.
keras
.
mixed_precision
.
experimental
.
set_policy
(
self
.
policy
)
def
test_end_to_end_no_dist_strat
(
self
):
def
get_extra_flags_dict
(
self
,
flags_key
):
return
self
.
_extra_flags_dict
[
flags_key
]
+
self
.
_default_flags_dict
def
test_end_to_end_no_dist_strat
(
self
,
flags_key
):
"""Test Keras model with 1 GPU, no distribution strategy."""
"""Test Keras model with 1 GPU, no distribution strategy."""
config
=
keras_utils
.
get_config_proto_v1
()
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
extra_flags
=
[
extra_flags
=
[
"-distribution_strategy"
,
"off"
,
"-distribution_strategy"
,
"off"
,
"-data_format"
,
"channels_last"
,
]
]
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
get
_extra_flags
_dict
(
flags_key
)
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
resnet_imagenet_main
.
run
,
main
=
resnet_imagenet_main
.
run
,
...
@@ -70,14 +99,13 @@ class KerasImagenetTest(tf.test.TestCase):
...
@@ -70,14 +99,13 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
def
test_end_to_end_graph_no_dist_strat
(
self
):
def
test_end_to_end_graph_no_dist_strat
(
self
,
flags_key
):
"""Test Keras model in legacy graph mode with 1 GPU, no dist strat."""
"""Test Keras model in legacy graph mode with 1 GPU, no dist strat."""
extra_flags
=
[
extra_flags
=
[
"-enable_eager"
,
"false"
,
"-enable_eager"
,
"false"
,
"-distribution_strategy"
,
"off"
,
"-distribution_strategy"
,
"off"
,
"-data_format"
,
"channels_last"
,
]
]
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
get
_extra_flags
_dict
(
flags_key
)
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
resnet_imagenet_main
.
run
,
main
=
resnet_imagenet_main
.
run
,
...
@@ -85,7 +113,7 @@ class KerasImagenetTest(tf.test.TestCase):
...
@@ -85,7 +113,7 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
def
test_end_to_end_1_gpu
(
self
):
def
test_end_to_end_1_gpu
(
self
,
flags_key
):
"""Test Keras model with 1 GPU."""
"""Test Keras model with 1 GPU."""
config
=
keras_utils
.
get_config_proto_v1
()
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
...
@@ -98,10 +126,9 @@ class KerasImagenetTest(tf.test.TestCase):
...
@@ -98,10 +126,9 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags
=
[
extra_flags
=
[
"-num_gpus"
,
"1"
,
"-num_gpus"
,
"1"
,
"-distribution_strategy"
,
"mirrored"
,
"-distribution_strategy"
,
"mirrored"
,
"-data_format"
,
"channels_last"
,
"-enable_checkpoint_and_export"
,
"1"
,
"-enable_checkpoint_and_export"
,
"1"
,
]
]
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
get
_extra_flags
_dict
(
flags_key
)
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
resnet_imagenet_main
.
run
,
main
=
resnet_imagenet_main
.
run
,
...
@@ -109,7 +136,7 @@ class KerasImagenetTest(tf.test.TestCase):
...
@@ -109,7 +136,7 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
def
test_end_to_end_1_gpu_fp16
(
self
):
def
test_end_to_end_1_gpu_fp16
(
self
,
flags_key
):
"""Test Keras model with 1 GPU and fp16."""
"""Test Keras model with 1 GPU and fp16."""
config
=
keras_utils
.
get_config_proto_v1
()
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
...
@@ -123,9 +150,11 @@ class KerasImagenetTest(tf.test.TestCase):
...
@@ -123,9 +150,11 @@ class KerasImagenetTest(tf.test.TestCase):
"-num_gpus"
,
"1"
,
"-num_gpus"
,
"1"
,
"-dtype"
,
"fp16"
,
"-dtype"
,
"fp16"
,
"-distribution_strategy"
,
"mirrored"
,
"-distribution_strategy"
,
"mirrored"
,
"-data_format"
,
"channels_last"
,
]
]
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
get_extra_flags_dict
(
flags_key
)
if
"polynomial_decay"
in
extra_flags
:
self
.
skipTest
(
"Pruning with fp16 is not currently supported."
)
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
resnet_imagenet_main
.
run
,
main
=
resnet_imagenet_main
.
run
,
...
@@ -133,8 +162,7 @@ class KerasImagenetTest(tf.test.TestCase):
...
@@ -133,8 +162,7 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
def
test_end_to_end_2_gpu
(
self
,
flags_key
):
def
test_end_to_end_2_gpu
(
self
):
"""Test Keras model with 2 GPUs."""
"""Test Keras model with 2 GPUs."""
config
=
keras_utils
.
get_config_proto_v1
()
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
...
@@ -148,7 +176,7 @@ class KerasImagenetTest(tf.test.TestCase):
...
@@ -148,7 +176,7 @@ class KerasImagenetTest(tf.test.TestCase):
"-num_gpus"
,
"2"
,
"-num_gpus"
,
"2"
,
"-distribution_strategy"
,
"mirrored"
,
"-distribution_strategy"
,
"mirrored"
,
]
]
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
get
_extra_flags
_dict
(
flags_key
)
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
resnet_imagenet_main
.
run
,
main
=
resnet_imagenet_main
.
run
,
...
@@ -156,7 +184,7 @@ class KerasImagenetTest(tf.test.TestCase):
...
@@ -156,7 +184,7 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
def
test_end_to_end_xla_2_gpu
(
self
):
def
test_end_to_end_xla_2_gpu
(
self
,
flags_key
):
"""Test Keras model with XLA and 2 GPUs."""
"""Test Keras model with XLA and 2 GPUs."""
config
=
keras_utils
.
get_config_proto_v1
()
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
...
@@ -171,7 +199,7 @@ class KerasImagenetTest(tf.test.TestCase):
...
@@ -171,7 +199,7 @@ class KerasImagenetTest(tf.test.TestCase):
"-enable_xla"
,
"true"
,
"-enable_xla"
,
"true"
,
"-distribution_strategy"
,
"mirrored"
,
"-distribution_strategy"
,
"mirrored"
,
]
]
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
get
_extra_flags
_dict
(
flags_key
)
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
resnet_imagenet_main
.
run
,
main
=
resnet_imagenet_main
.
run
,
...
@@ -179,7 +207,7 @@ class KerasImagenetTest(tf.test.TestCase):
...
@@ -179,7 +207,7 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
def
test_end_to_end_2_gpu_fp16
(
self
):
def
test_end_to_end_2_gpu_fp16
(
self
,
flags_key
):
"""Test Keras model with 2 GPUs and fp16."""
"""Test Keras model with 2 GPUs and fp16."""
config
=
keras_utils
.
get_config_proto_v1
()
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
...
@@ -194,7 +222,10 @@ class KerasImagenetTest(tf.test.TestCase):
...
@@ -194,7 +222,10 @@ class KerasImagenetTest(tf.test.TestCase):
"-dtype"
,
"fp16"
,
"-dtype"
,
"fp16"
,
"-distribution_strategy"
,
"mirrored"
,
"-distribution_strategy"
,
"mirrored"
,
]
]
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
get_extra_flags_dict
(
flags_key
)
if
"polynomial_decay"
in
extra_flags
:
self
.
skipTest
(
"Pruning with fp16 is not currently supported."
)
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
resnet_imagenet_main
.
run
,
main
=
resnet_imagenet_main
.
run
,
...
@@ -202,7 +233,7 @@ class KerasImagenetTest(tf.test.TestCase):
...
@@ -202,7 +233,7 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
def
test_end_to_end_xla_2_gpu_fp16
(
self
):
def
test_end_to_end_xla_2_gpu_fp16
(
self
,
flags_key
):
"""Test Keras model with XLA, 2 GPUs and fp16."""
"""Test Keras model with XLA, 2 GPUs and fp16."""
config
=
keras_utils
.
get_config_proto_v1
()
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
...
@@ -218,7 +249,10 @@ class KerasImagenetTest(tf.test.TestCase):
...
@@ -218,7 +249,10 @@ class KerasImagenetTest(tf.test.TestCase):
"-enable_xla"
,
"true"
,
"-enable_xla"
,
"true"
,
"-distribution_strategy"
,
"mirrored"
,
"-distribution_strategy"
,
"mirrored"
,
]
]
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
get_extra_flags_dict
(
flags_key
)
if
"polynomial_decay"
in
extra_flags
:
self
.
skipTest
(
"Pruning with fp16 is not currently supported."
)
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
resnet_imagenet_main
.
run
,
main
=
resnet_imagenet_main
.
run
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment