"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "17ba1ca49db57ec2dd8b16b18bbcf68c995bf102"
Commit ee80adbf authored by Ruomei Yan's avatar Ruomei Yan
Browse files

Create an example for clustering mobilenet_v1 in resnet_imagenet_main.py

parent 669b0f18
...@@ -26,6 +26,7 @@ from absl import flags ...@@ -26,6 +26,7 @@ from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from tensorflow_model_optimization.python.core.clustering.keras import cluster
import tensorflow_model_optimization as tfmot import tensorflow_model_optimization as tfmot
from official.modeling import performance from official.modeling import performance
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
...@@ -38,6 +39,58 @@ from official.vision.image_classification.resnet import imagenet_preprocessing ...@@ -38,6 +39,58 @@ from official.vision.image_classification.resnet import imagenet_preprocessing
from official.vision.image_classification.resnet import resnet_model from official.vision.image_classification.resnet import resnet_model
def selective_layers_to_cluster(model):
last_3conv2d_layers_to_cluster = [
layer.name
for layer in model.layers
if isinstance(layer, tf.keras.layers.Conv2D) and
not isinstance(layer, tf.keras.layers.DepthwiseConv2D)
]
last_3conv2d_layers_to_cluster = last_3conv2d_layers_to_cluster[-3:]
return last_3conv2d_layers_to_cluster
def selective_clustering_clone_wrapper(clustering_params1, clustering_params2,
model):
def apply_clustering_to_conv2d_but_depthwise(layer):
layers_list = selective_layers_to_cluster(model)
if layer.name in layers_list:
if layer.name != layers_list[-1]:
print("Wrapped layer " + layer.name +
" with " +
str(clustering_params1["number_of_clusters"]) + " clusters.")
return cluster.cluster_weights(layer, **clustering_params1)
else:
print("Wrapped layer " + layer.name +
" with number of clusters equals to " +
str(clustering_params2["number_of_clusters"]) + " clusters.")
return cluster.cluster_weights(layer, **clustering_params2)
return layer
return apply_clustering_to_conv2d_but_depthwise
def cluster_model_selectively(model, selective_layers_to_cluster,
clustering_params1, clustering_params2):
result_layer_model = tf.keras.models.clone_model(
model,
clone_function=selective_clustering_clone_wrapper(clustering_params1,
clustering_params2,
model),
)
return result_layer_model
def get_selectively_clustered_model(model, clustering_params1,
clustering_params2):
clustered_model = cluster_model_selectively(model,
selective_layers_to_cluster,
clustering_params1,
clustering_params2)
return clustered_model
def run(flags_obj): def run(flags_obj):
"""Run ResNet ImageNet training and eval loop using native Keras APIs. """Run ResNet ImageNet training and eval loop using native Keras APIs.
...@@ -53,7 +106,6 @@ def run(flags_obj): ...@@ -53,7 +106,6 @@ def run(flags_obj):
""" """
keras_utils.set_session_config( keras_utils.set_session_config(
enable_xla=flags_obj.enable_xla) enable_xla=flags_obj.enable_xla)
# Execute flag override logic for better model performance # Execute flag override logic for better model performance
if flags_obj.tf_gpu_thread_mode: if flags_obj.tf_gpu_thread_mode:
keras_utils.set_gpu_thread_mode_and_count( keras_utils.set_gpu_thread_mode_and_count(
...@@ -117,7 +169,7 @@ def run(flags_obj): ...@@ -117,7 +169,7 @@ def run(flags_obj):
# This use_keras_image_data_format flags indicates whether image preprocessor # This use_keras_image_data_format flags indicates whether image preprocessor
# output format should be same as the keras backend image data format or just # output format should be same as the keras backend image data format or just
# channel-last format. # channel-last format.
use_keras_image_data_format = (flags_obj.model == 'mobilenet') use_keras_image_data_format = (flags_obj.model == 'mobilenet' or 'mobilenet_pretrained')
train_input_dataset = input_fn( train_input_dataset = input_fn(
is_training=True, is_training=True,
data_dir=flags_obj.data_dir, data_dir=flags_obj.data_dir,
...@@ -149,8 +201,8 @@ def run(flags_obj): ...@@ -149,8 +201,8 @@ def run(flags_obj):
boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]), boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
multipliers=list(p[0] for p in common.LR_SCHEDULE), multipliers=list(p[0] for p in common.LR_SCHEDULE),
compute_lr_on_cpu=True) compute_lr_on_cpu=True)
steps_per_epoch = ( steps_per_epoch = (imagenet_preprocessing.NUM_IMAGES['train'] //
imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size) flags_obj.batch_size)
with strategy_scope: with strategy_scope:
if flags_obj.optimizer == 'resnet50_default': if flags_obj.optimizer == 'resnet50_default':
...@@ -165,6 +217,9 @@ def run(flags_obj): ...@@ -165,6 +217,9 @@ def run(flags_obj):
decay_rate=flags_obj.lr_decay_factor, decay_rate=flags_obj.lr_decay_factor,
staircase=True), staircase=True),
momentum=0.9) momentum=0.9)
elif flags_obj.optimizer == 'mobilenet_fine_tune':
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-5, momentum=0.9)
if flags_obj.fp16_implementation == 'graph_rewrite': if flags_obj.fp16_implementation == 'graph_rewrite':
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
...@@ -187,6 +242,20 @@ def run(flags_obj): ...@@ -187,6 +242,20 @@ def run(flags_obj):
weights=None, weights=None,
classes=imagenet_preprocessing.NUM_CLASSES, classes=imagenet_preprocessing.NUM_CLASSES,
layers=tf.keras.layers) layers=tf.keras.layers)
elif flags_obj.model == 'mobilenet_pretrained':
shape = (3, 224, 224)
model = tf.keras.applications.mobilenet.MobileNet(
input_shape=shape,
alpha=1.0,
depth_multiplier=1,
dropout=1e-7,
include_top=True,
weights='imagenet',
input_tensor=tf.keras.layers.Input(shape),
pooling=None,
classes=1000,
layers=tf.keras.layers)
if flags_obj.pretrained_filepath: if flags_obj.pretrained_filepath:
model.load_weights(flags_obj.pretrained_filepath) model.load_weights(flags_obj.pretrained_filepath)
...@@ -205,15 +274,31 @@ def run(flags_obj): ...@@ -205,15 +274,31 @@ def run(flags_obj):
} }
model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params) model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
elif flags_obj.pruning_method: elif flags_obj.pruning_method:
raise NotImplementedError('Only polynomial_decay is currently supported.')
if flags_obj.clustering_method == 'selective_clustering':
if dtype != tf.float32:
raise NotImplementedError(
'Clustering is currently only supported on dtype=tf.float32.')
clustering_params1 = {
'number_of_clusters': flags_obj.number_of_clusters,
'cluster_centroids_init': 'linear'
}
clustering_params2 = {
'number_of_clusters': 32,
'cluster_centroids_init': 'linear'
}
model = get_selectively_clustered_model(model, clustering_params1,
clustering_params2)
elif flags_obj.clustering_method:
raise NotImplementedError( raise NotImplementedError(
'Only polynomial_decay is currently supported.') 'Only selective_clustering is implemented.')
model.compile( model.compile(loss='sparse_categorical_crossentropy',
loss='sparse_categorical_crossentropy', optimizer=optimizer,
optimizer=optimizer, metrics=(['sparse_categorical_accuracy']
metrics=(['sparse_categorical_accuracy'] if flags_obj.report_accuracy_metrics else None),
if flags_obj.report_accuracy_metrics else None), run_eagerly=flags_obj.run_eagerly)
run_eagerly=flags_obj.run_eagerly)
train_epochs = flags_obj.train_epochs train_epochs = flags_obj.train_epochs
...@@ -222,13 +307,13 @@ def run(flags_obj): ...@@ -222,13 +307,13 @@ def run(flags_obj):
enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export, enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
model_dir=flags_obj.model_dir) model_dir=flags_obj.model_dir)
# if mutliple epochs, ignore the train_steps flag. # If mutliple epochs, ignore the train_steps flag.
if train_epochs <= 1 and flags_obj.train_steps: if train_epochs <= 1 and flags_obj.train_steps:
steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch) steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
train_epochs = 1 train_epochs = 1
num_eval_steps = ( num_eval_steps = (imagenet_preprocessing.NUM_IMAGES['validation'] //
imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size) flags_obj.batch_size)
validation_data = eval_input_dataset validation_data = eval_input_dataset
if flags_obj.skip_eval: if flags_obj.skip_eval:
...@@ -242,9 +327,10 @@ def run(flags_obj): ...@@ -242,9 +327,10 @@ def run(flags_obj):
num_eval_steps = None num_eval_steps = None
validation_data = None validation_data = None
# if not strategy and flags_obj.explicit_gpu_placement:
if not strategy and flags_obj.explicit_gpu_placement: if not strategy and flags_obj.explicit_gpu_placement:
# TODO(b/135607227): Add device scope automatically in Keras training loop # TODO(b/135607227): Add device scope automatically in Keras training loop
# when not using distribition strategy. # when not using distribution strategy.
no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device = tf.device('/device:GPU:0')
no_dist_strat_device.__enter__() no_dist_strat_device.__enter__()
...@@ -265,6 +351,10 @@ def run(flags_obj): ...@@ -265,6 +351,10 @@ def run(flags_obj):
if flags_obj.pruning_method: if flags_obj.pruning_method:
model = tfmot.sparsity.keras.strip_pruning(model) model = tfmot.sparsity.keras.strip_pruning(model)
if flags_obj.clustering_method:
model = cluster.strip_clustering(model)
if flags_obj.enable_checkpoint_and_export: if flags_obj.enable_checkpoint_and_export:
if dtype == tf.bfloat16: if dtype == tf.bfloat16:
logging.warning('Keras model.save does not support bfloat16 dtype.') logging.warning('Keras model.save does not support bfloat16 dtype.')
...@@ -276,16 +366,23 @@ def run(flags_obj): ...@@ -276,16 +366,23 @@ def run(flags_obj):
if not strategy and flags_obj.explicit_gpu_placement: if not strategy and flags_obj.explicit_gpu_placement:
no_dist_strat_device.__exit__() no_dist_strat_device.__exit__()
if flags_obj.save_files_to:
keras_file = os.path.join(flags_obj.save_files_to, 'clustered.h5')
else:
keras_file = './clustered.h5'
print('Saving clustered and stripped model to: ', keras_file)
tf.keras.models.save_model(model, keras_file)
stats = common.build_stats(history, eval_output, callbacks) stats = common.build_stats(history, eval_output, callbacks)
return stats return stats
def define_imagenet_keras_flags(): def define_imagenet_keras_flags():
common.define_keras_flags( common.define_keras_flags(model=True,
model=True, optimizer=True,
optimizer=True, pretrained_filepath=True)
pretrained_filepath=True)
common.define_pruning_flags() common.define_pruning_flags()
common.define_clustering_flags()
flags_core.set_defaults() flags_core.set_defaults()
flags.adopt_module_key_flags(common) flags.adopt_module_key_flags(common)
...@@ -299,4 +396,4 @@ def main(_): ...@@ -299,4 +396,4 @@ def main(_):
if __name__ == '__main__': if __name__ == '__main__':
logging.set_verbosity(logging.INFO) logging.set_verbosity(logging.INFO)
define_imagenet_keras_flags() define_imagenet_keras_flags()
app.run(main) app.run(main)
\ No newline at end of file
...@@ -352,6 +352,14 @@ def define_pruning_flags(): ...@@ -352,6 +352,14 @@ def define_pruning_flags():
flags.DEFINE_integer('pruning_end_step', 100000, 'End step for pruning.') flags.DEFINE_integer('pruning_end_step', 100000, 'End step for pruning.')
flags.DEFINE_integer('pruning_frequency', 100, 'Frequency for pruning.') flags.DEFINE_integer('pruning_frequency', 100, 'Frequency for pruning.')
def define_clustering_flags():
"""Define flags for clustering methods."""
flags.DEFINE_string('clustering_method', None,
'None (no clustering) or selective_clustering.')
flags.DEFINE_integer('number_of_clusters', 256,
'Number of clusters used in each layer.')
flags.DEFINE_string('save_files_to', None,
'The path to save Keras models and tflite models.')
def get_synth_input_fn(height, def get_synth_input_fn(height,
width, width,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment