Commit 4412001f authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #8604 from Ruomei:toupstream/clusteringexample

PiperOrigin-RevId: 329802437
parents faea89d9 a87bb185
......@@ -37,15 +37,16 @@ from official.vision.image_classification.resnet import imagenet_preprocessing
from official.vision.image_classification.resnet import resnet_model
def cluster_last_three_conv2d_layers(model):
import tensorflow_model_optimization as tfmot
def _cluster_last_three_conv2d_layers(model):
"""Helper method to cluster last three conv2d layers."""
import tensorflow_model_optimization as tfmot # pylint: disable=g-import-not-at-top
last_three_conv2d_layers = [
layer for layer in model.layers
if isinstance(layer, tf.keras.layers.Conv2D)
][-3:]
cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
centroid_initialization = tfmot.clustering.keras.CentroidInitialization
def cluster_fn(layer):
if layer not in last_three_conv2d_layers:
......@@ -54,12 +55,12 @@ def cluster_last_three_conv2d_layers(model):
if layer == last_three_conv2d_layers[0] or \
layer == last_three_conv2d_layers[1]:
clustered = cluster_weights(layer, number_of_clusters=256, \
cluster_centroids_init=CentroidInitialization.LINEAR)
print("Clustered {} with 256 clusters".format(layer.name))
cluster_centroids_init=centroid_initialization.LINEAR)
print('Clustered {} with 256 clusters'.format(layer.name))
else:
clustered = cluster_weights(layer, number_of_clusters=32, \
cluster_centroids_init=CentroidInitialization.LINEAR)
print("Clustered {} with 32 clusters".format(layer.name))
cluster_centroids_init=centroid_initialization.LINEAR)
print('Clustered {} with 32 clusters'.format(layer.name))
return clustered
return tf.keras.models.clone_model(model, clone_function=cluster_fn)
......@@ -228,7 +229,7 @@ def run(flags_obj):
model.load_weights(flags_obj.pretrained_filepath)
if flags_obj.pruning_method == 'polynomial_decay':
import tensorflow_model_optimization as tfmot
import tensorflow_model_optimization as tfmot # pylint: disable=g-import-not-at-top
if dtype != tf.float32:
raise NotImplementedError(
'Pruning is currently only supported on dtype=tf.float32.')
......@@ -246,12 +247,12 @@ def run(flags_obj):
raise NotImplementedError('Only polynomial_decay is currently supported.')
if flags_obj.clustering_method == 'selective_clustering':
import tensorflow_model_optimization as tfmot
import tensorflow_model_optimization as tfmot # pylint: disable=g-import-not-at-top
if dtype != tf.float32 or \
flags_obj.fp16_implementation == 'graph_rewrite':
raise NotImplementedError(
'Clustering is currently only supported on dtype=tf.float32.')
model = cluster_last_three_conv2d_layers(model)
model = _cluster_last_three_conv2d_layers(model)
elif flags_obj.clustering_method:
raise NotImplementedError(
'Only selective_clustering is implemented.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment