Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ee80adbf
Commit
ee80adbf
authored
May 28, 2020
by
Ruomei Yan
Browse files
Create an example for clustering mobilenet_v1 in resnet_imagenet_main.py
parent
669b0f18
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
125 additions
and
20 deletions
+125
-20
official/benchmark/models/resnet_imagenet_main.py
official/benchmark/models/resnet_imagenet_main.py
+117
-20
official/vision/image_classification/resnet/common.py
official/vision/image_classification/resnet/common.py
+8
-0
No files found.
official/benchmark/models/resnet_imagenet_main.py
View file @
ee80adbf
...
...
@@ -26,6 +26,7 @@ from absl import flags
from
absl
import
logging
import
tensorflow
as
tf
from
tensorflow_model_optimization.python.core.clustering.keras
import
cluster
import
tensorflow_model_optimization
as
tfmot
from
official.modeling
import
performance
from
official.utils.flags
import
core
as
flags_core
...
...
@@ -38,6 +39,58 @@ from official.vision.image_classification.resnet import imagenet_preprocessing
from
official.vision.image_classification.resnet
import
resnet_model
def
selective_layers_to_cluster
(
model
):
last_3conv2d_layers_to_cluster
=
[
layer
.
name
for
layer
in
model
.
layers
if
isinstance
(
layer
,
tf
.
keras
.
layers
.
Conv2D
)
and
not
isinstance
(
layer
,
tf
.
keras
.
layers
.
DepthwiseConv2D
)
]
last_3conv2d_layers_to_cluster
=
last_3conv2d_layers_to_cluster
[
-
3
:]
return
last_3conv2d_layers_to_cluster
def
selective_clustering_clone_wrapper
(
clustering_params1
,
clustering_params2
,
model
):
def
apply_clustering_to_conv2d_but_depthwise
(
layer
):
layers_list
=
selective_layers_to_cluster
(
model
)
if
layer
.
name
in
layers_list
:
if
layer
.
name
!=
layers_list
[
-
1
]:
print
(
"Wrapped layer "
+
layer
.
name
+
" with "
+
str
(
clustering_params1
[
"number_of_clusters"
])
+
" clusters."
)
return
cluster
.
cluster_weights
(
layer
,
**
clustering_params1
)
else
:
print
(
"Wrapped layer "
+
layer
.
name
+
" with number of clusters equals to "
+
str
(
clustering_params2
[
"number_of_clusters"
])
+
" clusters."
)
return
cluster
.
cluster_weights
(
layer
,
**
clustering_params2
)
return
layer
return
apply_clustering_to_conv2d_but_depthwise
def
cluster_model_selectively
(
model
,
selective_layers_to_cluster
,
clustering_params1
,
clustering_params2
):
result_layer_model
=
tf
.
keras
.
models
.
clone_model
(
model
,
clone_function
=
selective_clustering_clone_wrapper
(
clustering_params1
,
clustering_params2
,
model
),
)
return
result_layer_model
def
get_selectively_clustered_model
(
model
,
clustering_params1
,
clustering_params2
):
clustered_model
=
cluster_model_selectively
(
model
,
selective_layers_to_cluster
,
clustering_params1
,
clustering_params2
)
return
clustered_model
def
run
(
flags_obj
):
"""Run ResNet ImageNet training and eval loop using native Keras APIs.
...
...
@@ -53,7 +106,6 @@ def run(flags_obj):
"""
keras_utils
.
set_session_config
(
enable_xla
=
flags_obj
.
enable_xla
)
# Execute flag override logic for better model performance
if
flags_obj
.
tf_gpu_thread_mode
:
keras_utils
.
set_gpu_thread_mode_and_count
(
...
...
@@ -117,7 +169,7 @@ def run(flags_obj):
# This use_keras_image_data_format flags indicates whether image preprocessor
# output format should be same as the keras backend image data format or just
# channel-last format.
use_keras_image_data_format
=
(
flags_obj
.
model
==
'mobilenet'
)
use_keras_image_data_format
=
(
flags_obj
.
model
==
'mobilenet'
or
'mobilenet_pretrained'
)
train_input_dataset
=
input_fn
(
is_training
=
True
,
data_dir
=
flags_obj
.
data_dir
,
...
...
@@ -149,8 +201,8 @@ def run(flags_obj):
boundaries
=
list
(
p
[
1
]
for
p
in
common
.
LR_SCHEDULE
[
1
:]),
multipliers
=
list
(
p
[
0
]
for
p
in
common
.
LR_SCHEDULE
),
compute_lr_on_cpu
=
True
)
steps_per_epoch
=
(
imagenet_preprocessing
.
NUM_IMAGES
[
'train'
]
//
flags_obj
.
batch_size
)
steps_per_epoch
=
(
imagenet_preprocessing
.
NUM_IMAGES
[
'train'
]
//
flags_obj
.
batch_size
)
with
strategy_scope
:
if
flags_obj
.
optimizer
==
'resnet50_default'
:
...
...
@@ -165,6 +217,9 @@ def run(flags_obj):
decay_rate
=
flags_obj
.
lr_decay_factor
,
staircase
=
True
),
momentum
=
0.9
)
elif
flags_obj
.
optimizer
==
'mobilenet_fine_tune'
:
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
learning_rate
=
1e-5
,
momentum
=
0.9
)
if
flags_obj
.
fp16_implementation
==
'graph_rewrite'
:
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
...
...
@@ -187,6 +242,20 @@ def run(flags_obj):
weights
=
None
,
classes
=
imagenet_preprocessing
.
NUM_CLASSES
,
layers
=
tf
.
keras
.
layers
)
elif
flags_obj
.
model
==
'mobilenet_pretrained'
:
shape
=
(
3
,
224
,
224
)
model
=
tf
.
keras
.
applications
.
mobilenet
.
MobileNet
(
input_shape
=
shape
,
alpha
=
1.0
,
depth_multiplier
=
1
,
dropout
=
1e-7
,
include_top
=
True
,
weights
=
'imagenet'
,
input_tensor
=
tf
.
keras
.
layers
.
Input
(
shape
),
pooling
=
None
,
classes
=
1000
,
layers
=
tf
.
keras
.
layers
)
if
flags_obj
.
pretrained_filepath
:
model
.
load_weights
(
flags_obj
.
pretrained_filepath
)
...
...
@@ -205,15 +274,31 @@ def run(flags_obj):
}
model
=
tfmot
.
sparsity
.
keras
.
prune_low_magnitude
(
model
,
**
pruning_params
)
elif
flags_obj
.
pruning_method
:
raise
NotImplementedError
(
'Only polynomial_decay is currently supported.'
)
if
flags_obj
.
clustering_method
==
'selective_clustering'
:
if
dtype
!=
tf
.
float32
:
raise
NotImplementedError
(
'Clustering is currently only supported on dtype=tf.float32.'
)
clustering_params1
=
{
'number_of_clusters'
:
flags_obj
.
number_of_clusters
,
'cluster_centroids_init'
:
'linear'
}
clustering_params2
=
{
'number_of_clusters'
:
32
,
'cluster_centroids_init'
:
'linear'
}
model
=
get_selectively_clustered_model
(
model
,
clustering_params1
,
clustering_params2
)
elif
flags_obj
.
clustering_method
:
raise
NotImplementedError
(
'Only
polynomial_decay is currently suppor
ted.'
)
'Only
selective_clustering is implemen
ted.'
)
model
.
compile
(
loss
=
'sparse_categorical_crossentropy'
,
optimizer
=
optimizer
,
metrics
=
([
'sparse_categorical_accuracy'
]
if
flags_obj
.
report_accuracy_metrics
else
None
),
run_eagerly
=
flags_obj
.
run_eagerly
)
model
.
compile
(
loss
=
'sparse_categorical_crossentropy'
,
optimizer
=
optimizer
,
metrics
=
([
'sparse_categorical_accuracy'
]
if
flags_obj
.
report_accuracy_metrics
else
None
),
run_eagerly
=
flags_obj
.
run_eagerly
)
train_epochs
=
flags_obj
.
train_epochs
...
...
@@ -222,13 +307,13 @@ def run(flags_obj):
enable_checkpoint_and_export
=
flags_obj
.
enable_checkpoint_and_export
,
model_dir
=
flags_obj
.
model_dir
)
#
i
f mutliple epochs, ignore the train_steps flag.
#
I
f mutliple epochs, ignore the train_steps flag.
if
train_epochs
<=
1
and
flags_obj
.
train_steps
:
steps_per_epoch
=
min
(
flags_obj
.
train_steps
,
steps_per_epoch
)
train_epochs
=
1
num_eval_steps
=
(
imagenet_preprocessing
.
NUM_IMAGES
[
'validation'
]
//
flags_obj
.
batch_size
)
num_eval_steps
=
(
imagenet_preprocessing
.
NUM_IMAGES
[
'validation'
]
//
flags_obj
.
batch_size
)
validation_data
=
eval_input_dataset
if
flags_obj
.
skip_eval
:
...
...
@@ -242,9 +327,10 @@ def run(flags_obj):
num_eval_steps
=
None
validation_data
=
None
# if not strategy and flags_obj.explicit_gpu_placement:
if
not
strategy
and
flags_obj
.
explicit_gpu_placement
:
# TODO(b/135607227): Add device scope automatically in Keras training loop
# when not using distrib
i
tion strategy.
# when not using distrib
u
tion strategy.
no_dist_strat_device
=
tf
.
device
(
'/device:GPU:0'
)
no_dist_strat_device
.
__enter__
()
...
...
@@ -265,6 +351,10 @@ def run(flags_obj):
if
flags_obj
.
pruning_method
:
model
=
tfmot
.
sparsity
.
keras
.
strip_pruning
(
model
)
if
flags_obj
.
clustering_method
:
model
=
cluster
.
strip_clustering
(
model
)
if
flags_obj
.
enable_checkpoint_and_export
:
if
dtype
==
tf
.
bfloat16
:
logging
.
warning
(
'Keras model.save does not support bfloat16 dtype.'
)
...
...
@@ -276,16 +366,23 @@ def run(flags_obj):
if
not
strategy
and
flags_obj
.
explicit_gpu_placement
:
no_dist_strat_device
.
__exit__
()
if
flags_obj
.
save_files_to
:
keras_file
=
os
.
path
.
join
(
flags_obj
.
save_files_to
,
'clustered.h5'
)
else
:
keras_file
=
'./clustered.h5'
print
(
'Saving clustered and stripped model to: '
,
keras_file
)
tf
.
keras
.
models
.
save_model
(
model
,
keras_file
)
stats
=
common
.
build_stats
(
history
,
eval_output
,
callbacks
)
return
stats
def
define_imagenet_keras_flags
():
common
.
define_keras_flags
(
model
=
True
,
optimizer
=
True
,
pretrained_filepath
=
True
)
common
.
define_keras_flags
(
model
=
True
,
optimizer
=
True
,
pretrained_filepath
=
True
)
common
.
define_pruning_flags
()
common
.
define_clustering_flags
()
flags_core
.
set_defaults
()
flags
.
adopt_module_key_flags
(
common
)
...
...
@@ -299,4 +396,4 @@ def main(_):
if
__name__
==
'__main__'
:
logging
.
set_verbosity
(
logging
.
INFO
)
define_imagenet_keras_flags
()
app
.
run
(
main
)
app
.
run
(
main
)
\ No newline at end of file
official/vision/image_classification/resnet/common.py
View file @
ee80adbf
...
...
@@ -352,6 +352,14 @@ def define_pruning_flags():
flags
.
DEFINE_integer
(
'pruning_end_step'
,
100000
,
'End step for pruning.'
)
flags
.
DEFINE_integer
(
'pruning_frequency'
,
100
,
'Frequency for pruning.'
)
def
define_clustering_flags
():
"""Define flags for clustering methods."""
flags
.
DEFINE_string
(
'clustering_method'
,
None
,
'None (no clustering) or selective_clustering.'
)
flags
.
DEFINE_integer
(
'number_of_clusters'
,
256
,
'Number of clusters used in each layer.'
)
flags
.
DEFINE_string
(
'save_files_to'
,
None
,
'The path to save Keras models and tflite models.'
)
def
get_synth_input_fn
(
height
,
width
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment