Commit ee80adbf authored by Ruomei Yan's avatar Ruomei Yan
Browse files

Create an example for clustering mobilenet_v1 in resnet_imagenet_main.py

parent 669b0f18
......@@ -26,6 +26,7 @@ from absl import flags
from absl import logging
import tensorflow as tf
from tensorflow_model_optimization.python.core.clustering.keras import cluster
import tensorflow_model_optimization as tfmot
from official.modeling import performance
from official.utils.flags import core as flags_core
......@@ -38,6 +39,58 @@ from official.vision.image_classification.resnet import imagenet_preprocessing
from official.vision.image_classification.resnet import resnet_model
def selective_layers_to_cluster(model):
last_3conv2d_layers_to_cluster = [
layer.name
for layer in model.layers
if isinstance(layer, tf.keras.layers.Conv2D) and
not isinstance(layer, tf.keras.layers.DepthwiseConv2D)
]
last_3conv2d_layers_to_cluster = last_3conv2d_layers_to_cluster[-3:]
return last_3conv2d_layers_to_cluster
def selective_clustering_clone_wrapper(clustering_params1, clustering_params2,
model):
def apply_clustering_to_conv2d_but_depthwise(layer):
layers_list = selective_layers_to_cluster(model)
if layer.name in layers_list:
if layer.name != layers_list[-1]:
print("Wrapped layer " + layer.name +
" with " +
str(clustering_params1["number_of_clusters"]) + " clusters.")
return cluster.cluster_weights(layer, **clustering_params1)
else:
print("Wrapped layer " + layer.name +
" with number of clusters equals to " +
str(clustering_params2["number_of_clusters"]) + " clusters.")
return cluster.cluster_weights(layer, **clustering_params2)
return layer
return apply_clustering_to_conv2d_but_depthwise
def cluster_model_selectively(model, selective_layers_to_cluster,
clustering_params1, clustering_params2):
result_layer_model = tf.keras.models.clone_model(
model,
clone_function=selective_clustering_clone_wrapper(clustering_params1,
clustering_params2,
model),
)
return result_layer_model
def get_selectively_clustered_model(model, clustering_params1,
clustering_params2):
clustered_model = cluster_model_selectively(model,
selective_layers_to_cluster,
clustering_params1,
clustering_params2)
return clustered_model
def run(flags_obj):
"""Run ResNet ImageNet training and eval loop using native Keras APIs.
......@@ -53,7 +106,6 @@ def run(flags_obj):
"""
keras_utils.set_session_config(
enable_xla=flags_obj.enable_xla)
# Execute flag override logic for better model performance
if flags_obj.tf_gpu_thread_mode:
keras_utils.set_gpu_thread_mode_and_count(
......@@ -117,7 +169,7 @@ def run(flags_obj):
# This use_keras_image_data_format flags indicates whether image preprocessor
# output format should be same as the keras backend image data format or just
# channel-last format.
use_keras_image_data_format = (flags_obj.model == 'mobilenet')
use_keras_image_data_format = (flags_obj.model == 'mobilenet' or 'mobilenet_pretrained')
train_input_dataset = input_fn(
is_training=True,
data_dir=flags_obj.data_dir,
......@@ -149,8 +201,8 @@ def run(flags_obj):
boundaries=list(p[1] for p in common.LR_SCHEDULE[1:]),
multipliers=list(p[0] for p in common.LR_SCHEDULE),
compute_lr_on_cpu=True)
steps_per_epoch = (
imagenet_preprocessing.NUM_IMAGES['train'] // flags_obj.batch_size)
steps_per_epoch = (imagenet_preprocessing.NUM_IMAGES['train'] //
flags_obj.batch_size)
with strategy_scope:
if flags_obj.optimizer == 'resnet50_default':
......@@ -165,6 +217,9 @@ def run(flags_obj):
decay_rate=flags_obj.lr_decay_factor,
staircase=True),
momentum=0.9)
elif flags_obj.optimizer == 'mobilenet_fine_tune':
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-5, momentum=0.9)
if flags_obj.fp16_implementation == 'graph_rewrite':
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
......@@ -187,6 +242,20 @@ def run(flags_obj):
weights=None,
classes=imagenet_preprocessing.NUM_CLASSES,
layers=tf.keras.layers)
elif flags_obj.model == 'mobilenet_pretrained':
shape = (3, 224, 224)
model = tf.keras.applications.mobilenet.MobileNet(
input_shape=shape,
alpha=1.0,
depth_multiplier=1,
dropout=1e-7,
include_top=True,
weights='imagenet',
input_tensor=tf.keras.layers.Input(shape),
pooling=None,
classes=1000,
layers=tf.keras.layers)
if flags_obj.pretrained_filepath:
model.load_weights(flags_obj.pretrained_filepath)
......@@ -205,11 +274,27 @@ def run(flags_obj):
}
model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
elif flags_obj.pruning_method:
raise NotImplementedError('Only polynomial_decay is currently supported.')
if flags_obj.clustering_method == 'selective_clustering':
if dtype != tf.float32:
raise NotImplementedError(
'Clustering is currently only supported on dtype=tf.float32.')
clustering_params1 = {
'number_of_clusters': flags_obj.number_of_clusters,
'cluster_centroids_init': 'linear'
}
clustering_params2 = {
'number_of_clusters': 32,
'cluster_centroids_init': 'linear'
}
model = get_selectively_clustered_model(model, clustering_params1,
clustering_params2)
elif flags_obj.clustering_method:
raise NotImplementedError(
'Only polynomial_decay is currently supported.')
'Only selective_clustering is implemented.')
model.compile(
loss='sparse_categorical_crossentropy',
model.compile(loss='sparse_categorical_crossentropy',
optimizer=optimizer,
metrics=(['sparse_categorical_accuracy']
if flags_obj.report_accuracy_metrics else None),
......@@ -222,13 +307,13 @@ def run(flags_obj):
enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
model_dir=flags_obj.model_dir)
# if mutliple epochs, ignore the train_steps flag.
# If mutliple epochs, ignore the train_steps flag.
if train_epochs <= 1 and flags_obj.train_steps:
steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
train_epochs = 1
num_eval_steps = (
imagenet_preprocessing.NUM_IMAGES['validation'] // flags_obj.batch_size)
num_eval_steps = (imagenet_preprocessing.NUM_IMAGES['validation'] //
flags_obj.batch_size)
validation_data = eval_input_dataset
if flags_obj.skip_eval:
......@@ -242,9 +327,10 @@ def run(flags_obj):
num_eval_steps = None
validation_data = None
# if not strategy and flags_obj.explicit_gpu_placement:
if not strategy and flags_obj.explicit_gpu_placement:
# TODO(b/135607227): Add device scope automatically in Keras training loop
# when not using distribition strategy.
# when not using distribution strategy.
no_dist_strat_device = tf.device('/device:GPU:0')
no_dist_strat_device.__enter__()
......@@ -265,6 +351,10 @@ def run(flags_obj):
if flags_obj.pruning_method:
model = tfmot.sparsity.keras.strip_pruning(model)
if flags_obj.clustering_method:
model = cluster.strip_clustering(model)
if flags_obj.enable_checkpoint_and_export:
if dtype == tf.bfloat16:
logging.warning('Keras model.save does not support bfloat16 dtype.')
......@@ -276,16 +366,23 @@ def run(flags_obj):
if not strategy and flags_obj.explicit_gpu_placement:
no_dist_strat_device.__exit__()
if flags_obj.save_files_to:
keras_file = os.path.join(flags_obj.save_files_to, 'clustered.h5')
else:
keras_file = './clustered.h5'
print('Saving clustered and stripped model to: ', keras_file)
tf.keras.models.save_model(model, keras_file)
stats = common.build_stats(history, eval_output, callbacks)
return stats
def define_imagenet_keras_flags():
common.define_keras_flags(
model=True,
common.define_keras_flags(model=True,
optimizer=True,
pretrained_filepath=True)
common.define_pruning_flags()
common.define_clustering_flags()
flags_core.set_defaults()
flags.adopt_module_key_flags(common)
......
......@@ -352,6 +352,14 @@ def define_pruning_flags():
flags.DEFINE_integer('pruning_end_step', 100000, 'End step for pruning.')
flags.DEFINE_integer('pruning_frequency', 100, 'Frequency for pruning.')
def define_clustering_flags():
"""Define flags for clustering methods."""
flags.DEFINE_string('clustering_method', None,
'None (no clustering) or selective_clustering.')
flags.DEFINE_integer('number_of_clusters', 256,
'Number of clusters used in each layer.')
flags.DEFINE_string('save_files_to', None,
'The path to save Keras models and tflite models.')
def get_synth_input_fn(height,
width,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment