"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "df89d3e02a41e34581e8065cf5868a9570fa3010"
Commit 4412001f authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #8604 from Ruomei:toupstream/clusteringexample

PiperOrigin-RevId: 329802437
parents faea89d9 a87bb185
...@@ -37,15 +37,16 @@ from official.vision.image_classification.resnet import imagenet_preprocessing ...@@ -37,15 +37,16 @@ from official.vision.image_classification.resnet import imagenet_preprocessing
from official.vision.image_classification.resnet import resnet_model from official.vision.image_classification.resnet import resnet_model
def cluster_last_three_conv2d_layers(model): def _cluster_last_three_conv2d_layers(model):
import tensorflow_model_optimization as tfmot """Helper method to cluster last three conv2d layers."""
last_three_conv2d_layers = [ import tensorflow_model_optimization as tfmot # pylint: disable=g-import-not-at-top
last_three_conv2d_layers = [
layer for layer in model.layers layer for layer in model.layers
if isinstance(layer, tf.keras.layers.Conv2D) if isinstance(layer, tf.keras.layers.Conv2D)
][-3:] ][-3:]
cluster_weights = tfmot.clustering.keras.cluster_weights cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization centroid_initialization = tfmot.clustering.keras.CentroidInitialization
def cluster_fn(layer): def cluster_fn(layer):
if layer not in last_three_conv2d_layers: if layer not in last_three_conv2d_layers:
...@@ -54,12 +55,12 @@ def cluster_last_three_conv2d_layers(model): ...@@ -54,12 +55,12 @@ def cluster_last_three_conv2d_layers(model):
if layer == last_three_conv2d_layers[0] or \ if layer == last_three_conv2d_layers[0] or \
layer == last_three_conv2d_layers[1]: layer == last_three_conv2d_layers[1]:
clustered = cluster_weights(layer, number_of_clusters=256, \ clustered = cluster_weights(layer, number_of_clusters=256, \
cluster_centroids_init=CentroidInitialization.LINEAR) cluster_centroids_init=centroid_initialization.LINEAR)
print("Clustered {} with 256 clusters".format(layer.name)) print('Clustered {} with 256 clusters'.format(layer.name))
else: else:
clustered = cluster_weights(layer, number_of_clusters=32, \ clustered = cluster_weights(layer, number_of_clusters=32, \
cluster_centroids_init=CentroidInitialization.LINEAR) cluster_centroids_init=centroid_initialization.LINEAR)
print("Clustered {} with 32 clusters".format(layer.name)) print('Clustered {} with 32 clusters'.format(layer.name))
return clustered return clustered
return tf.keras.models.clone_model(model, clone_function=cluster_fn) return tf.keras.models.clone_model(model, clone_function=cluster_fn)
...@@ -228,7 +229,7 @@ def run(flags_obj): ...@@ -228,7 +229,7 @@ def run(flags_obj):
model.load_weights(flags_obj.pretrained_filepath) model.load_weights(flags_obj.pretrained_filepath)
if flags_obj.pruning_method == 'polynomial_decay': if flags_obj.pruning_method == 'polynomial_decay':
import tensorflow_model_optimization as tfmot import tensorflow_model_optimization as tfmot # pylint: disable=g-import-not-at-top
if dtype != tf.float32: if dtype != tf.float32:
raise NotImplementedError( raise NotImplementedError(
'Pruning is currently only supported on dtype=tf.float32.') 'Pruning is currently only supported on dtype=tf.float32.')
...@@ -246,12 +247,12 @@ def run(flags_obj): ...@@ -246,12 +247,12 @@ def run(flags_obj):
raise NotImplementedError('Only polynomial_decay is currently supported.') raise NotImplementedError('Only polynomial_decay is currently supported.')
if flags_obj.clustering_method == 'selective_clustering': if flags_obj.clustering_method == 'selective_clustering':
import tensorflow_model_optimization as tfmot import tensorflow_model_optimization as tfmot # pylint: disable=g-import-not-at-top
if dtype != tf.float32 or \ if dtype != tf.float32 or \
flags_obj.fp16_implementation == 'graph_rewrite': flags_obj.fp16_implementation == 'graph_rewrite':
raise NotImplementedError( raise NotImplementedError(
'Clustering is currently only supported on dtype=tf.float32.') 'Clustering is currently only supported on dtype=tf.float32.')
model = cluster_last_three_conv2d_layers(model) model = _cluster_last_three_conv2d_layers(model)
elif flags_obj.clustering_method: elif flags_obj.clustering_method:
raise NotImplementedError( raise NotImplementedError(
'Only selective_clustering is implemented.') 'Only selective_clustering is implemented.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment