Commit 20101930 authored by Ruomei Yan's avatar Ruomei Yan
Browse files

Address review comments from mid June

parent 7dfef01d
...@@ -905,7 +905,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -905,7 +905,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.enable_eager = True FLAGS.enable_eager = True
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
FLAGS.batch_size = 128 * 8 # 8 GPUs FLAGS.batch_size = 128 * 8
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_8_gpu_amp(self): def benchmark_8_gpu_amp(self):
...@@ -918,7 +918,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -918,7 +918,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.fp16_implementation = 'graph_rewrite' FLAGS.fp16_implementation = 'graph_rewrite'
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_amp')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_8_gpu_tweaked(self): def benchmark_8_gpu_tweaked(self):
...@@ -942,7 +942,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -942,7 +942,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.enable_xla = True FLAGS.enable_xla = True
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu') FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu')
FLAGS.batch_size = 128 * 8 # 8 GPUs FLAGS.batch_size = 128 * 8
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_xla_8_gpu_amp(self): def benchmark_xla_8_gpu_amp(self):
...@@ -956,7 +956,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -956,7 +956,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.enable_xla = True FLAGS.enable_xla = True
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_amp') FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_amp')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_xla_8_gpu_tweaked(self): def benchmark_xla_8_gpu_tweaked(self):
...@@ -982,7 +982,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -982,7 +982,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.enable_eager = True FLAGS.enable_eager = True
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_8_gpu_fp16_tweaked(self): def benchmark_8_gpu_fp16_tweaked(self):
...@@ -994,7 +994,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -994,7 +994,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.enable_eager = True FLAGS.enable_eager = True
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16_tweaked') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 40 FLAGS.datasets_num_private_threads = 40
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -1009,7 +1009,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -1009,7 +1009,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir( FLAGS.model_dir = self._get_model_dir(
'benchmark_8_gpu_fp16_dynamic_tweaked') 'benchmark_8_gpu_fp16_dynamic_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8
FLAGS.loss_scale = 'dynamic' FLAGS.loss_scale = 'dynamic'
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 40 FLAGS.datasets_num_private_threads = 40
...@@ -1025,7 +1025,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -1025,7 +1025,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.enable_xla = True FLAGS.enable_xla = True
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_fp16') FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_fp16')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_xla_8_gpu_fp16_tweaked(self): def benchmark_xla_8_gpu_fp16_tweaked(self):
...@@ -1038,7 +1038,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -1038,7 +1038,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.enable_xla = True FLAGS.enable_xla = True
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_fp16_tweaked') FLAGS.model_dir = self._get_model_dir('benchmark_xla_8_gpu_fp16_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 48 FLAGS.datasets_num_private_threads = 48
self._run_and_report_benchmark() self._run_and_report_benchmark()
...@@ -1074,7 +1074,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -1074,7 +1074,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.distribution_strategy = 'mirrored' FLAGS.distribution_strategy = 'mirrored'
FLAGS.model_dir = self._get_model_dir( FLAGS.model_dir = self._get_model_dir(
'benchmark_xla_8_gpu_fp16_dynamic_tweaked') 'benchmark_xla_8_gpu_fp16_dynamic_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8
FLAGS.loss_scale = 'dynamic' FLAGS.loss_scale = 'dynamic'
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.datasets_num_private_threads = 48 FLAGS.datasets_num_private_threads = 48
...@@ -1871,7 +1871,6 @@ class KerasClusteringBenchmarkRealBase(Resnet50KerasBenchmarkBase): ...@@ -1871,7 +1871,6 @@ class KerasClusteringBenchmarkRealBase(Resnet50KerasBenchmarkBase):
'report_accuracy_metrics': False, 'report_accuracy_metrics': False,
'data_dir': os.path.join(root_data_dir, 'imagenet'), 'data_dir': os.path.join(root_data_dir, 'imagenet'),
'clustering_method': 'selective_clustering', 'clustering_method': 'selective_clustering',
'number_of_clusters': 256,
'train_steps': 110, 'train_steps': 110,
'log_steps': 10, 'log_steps': 10,
}) })
......
...@@ -26,7 +26,7 @@ from absl import flags ...@@ -26,7 +26,7 @@ from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from tensorflow_model_optimization.python.core.clustering.keras import cluster
import tensorflow_model_optimization as tfmot import tensorflow_model_optimization as tfmot
from official.modeling import performance from official.modeling import performance
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
...@@ -39,56 +39,38 @@ from official.vision.image_classification.resnet import imagenet_preprocessing ...@@ -39,56 +39,38 @@ from official.vision.image_classification.resnet import imagenet_preprocessing
from official.vision.image_classification.resnet import resnet_model from official.vision.image_classification.resnet import resnet_model
def selective_layers_to_cluster(model): def cluster_last_three_conv2d_layers(model):
last_3conv2d_layers_to_cluster = [ last_three_conv2d_layers = [
layer.name layer for layer in model.layers
for layer in model.layers
if isinstance(layer, tf.keras.layers.Conv2D) and if isinstance(layer, tf.keras.layers.Conv2D) and
not isinstance(layer, tf.keras.layers.DepthwiseConv2D) not isinstance(layer, tf.keras.layers.DepthwiseConv2D)
] ]
last_3conv2d_layers_to_cluster = last_3conv2d_layers_to_cluster[-3:] last_three_conv2d_layers = last_three_conv2d_layers[-3:]
return last_3conv2d_layers_to_cluster
cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
def selective_clustering_clone_wrapper(clustering_params1, clustering_params2, clustering_params1 = {
model): 'number_of_clusters': 256,
'cluster_centroids_init': CentroidInitialization.LINEAR
def apply_clustering_to_conv2d_but_depthwise(layer): }
layers_list = selective_layers_to_cluster(model) clustering_params2 = {
if layer.name in layers_list: 'number_of_clusters': 32,
if layer.name != layers_list[-1]: 'cluster_centroids_init': CentroidInitialization.LINEAR
print("Wrapped layer " + layer.name + }
" with " +
str(clustering_params1["number_of_clusters"]) + " clusters.") def cluster_fn(layer):
return cluster.cluster_weights(layer, **clustering_params1) if layer not in last_three_conv2d_layers:
else: return layer
print("Wrapped layer " + layer.name +
" with " + if layer == last_three_conv2d_layers[0] or layer == last_three_conv2d_layers[1]:
str(clustering_params2["number_of_clusters"]) + " clusters.") clustered = cluster_weights(layer, **clustering_params1)
return cluster.cluster_weights(layer, **clustering_params2) print("Clustered {} with {} clusters".format(layer.name, clustering_params1['number_of_clusters']))
return layer else:
clustered = cluster_weights(layer, **clustering_params2)
return apply_clustering_to_conv2d_but_depthwise print("Clustered {} with {} clusters".format(layer.name, clustering_params2['number_of_clusters']))
return clustered
def cluster_model_selectively(model, selective_layers_to_cluster,
clustering_params1, clustering_params2):
result_layer_model = tf.keras.models.clone_model(
model,
clone_function=selective_clustering_clone_wrapper(clustering_params1,
clustering_params2,
model),
)
return result_layer_model
def get_selectively_clustered_model(model, clustering_params1, return tf.keras.models.clone_model(model, clone_function=cluster_fn)
clustering_params2):
clustered_model = cluster_model_selectively(model,
selective_layers_to_cluster,
clustering_params1,
clustering_params2)
return clustered_model
def run(flags_obj): def run(flags_obj):
...@@ -244,12 +226,8 @@ def run(flags_obj): ...@@ -244,12 +226,8 @@ def run(flags_obj):
layers=tf.keras.layers) layers=tf.keras.layers)
elif flags_obj.model == 'mobilenet_pretrained': elif flags_obj.model == 'mobilenet_pretrained':
model = tf.keras.applications.mobilenet.MobileNet( model = tf.keras.applications.mobilenet.MobileNet(
alpha=1.0,
depth_multiplier=1,
dropout=1e-7, dropout=1e-7,
include_top=True,
weights='imagenet', weights='imagenet',
pooling=None,
classes=1000, classes=1000,
layers=tf.keras.layers) layers=tf.keras.layers)
...@@ -277,16 +255,7 @@ def run(flags_obj): ...@@ -277,16 +255,7 @@ def run(flags_obj):
if dtype != tf.float32 or flags_obj.fp16_implementation == 'graph_rewrite': if dtype != tf.float32 or flags_obj.fp16_implementation == 'graph_rewrite':
raise NotImplementedError( raise NotImplementedError(
'Clustering is currently only supported on dtype=tf.float32.') 'Clustering is currently only supported on dtype=tf.float32.')
clustering_params1 = { model = cluster_last_three_conv2d_layers(model)
'number_of_clusters': flags_obj.number_of_clusters,
'cluster_centroids_init': 'linear'
}
clustering_params2 = {
'number_of_clusters': 32,
'cluster_centroids_init': 'linear'
}
model = get_selectively_clustered_model(model, clustering_params1,
clustering_params2)
elif flags_obj.clustering_method: elif flags_obj.clustering_method:
raise NotImplementedError( raise NotImplementedError(
'Only selective_clustering is implemented.') 'Only selective_clustering is implemented.')
...@@ -324,7 +293,6 @@ def run(flags_obj): ...@@ -324,7 +293,6 @@ def run(flags_obj):
num_eval_steps = None num_eval_steps = None
validation_data = None validation_data = None
# if not strategy and flags_obj.explicit_gpu_placement:
if not strategy and flags_obj.explicit_gpu_placement: if not strategy and flags_obj.explicit_gpu_placement:
# TODO(b/135607227): Add device scope automatically in Keras training loop # TODO(b/135607227): Add device scope automatically in Keras training loop
# when not using distribution strategy. # when not using distribution strategy.
...@@ -350,7 +318,7 @@ def run(flags_obj): ...@@ -350,7 +318,7 @@ def run(flags_obj):
model = tfmot.sparsity.keras.strip_pruning(model) model = tfmot.sparsity.keras.strip_pruning(model)
if flags_obj.clustering_method: if flags_obj.clustering_method:
model = cluster.strip_clustering(model) model = tfmot.clustering.keras.strip_clustering(model)
if flags_obj.enable_checkpoint_and_export: if flags_obj.enable_checkpoint_and_export:
if dtype == tf.bfloat16: if dtype == tf.bfloat16:
...@@ -363,14 +331,6 @@ def run(flags_obj): ...@@ -363,14 +331,6 @@ def run(flags_obj):
if not strategy and flags_obj.explicit_gpu_placement: if not strategy and flags_obj.explicit_gpu_placement:
no_dist_strat_device.__exit__() no_dist_strat_device.__exit__()
if flags_obj.clustering_method:
if flags_obj.save_files_to:
keras_file = os.path.join(flags_obj.save_files_to, 'clustered.h5')
else:
keras_file = './clustered.h5'
print('Saving clustered and stripped model to: ', keras_file)
tf.keras.models.save_model(model, keras_file)
stats = common.build_stats(history, eval_output, callbacks) stats = common.build_stats(history, eval_output, callbacks)
return stats return stats
......
...@@ -26,7 +26,7 @@ from official.benchmark.models import resnet_imagenet_main ...@@ -26,7 +26,7 @@ from official.benchmark.models import resnet_imagenet_main
from official.utils.testing import integration from official.utils.testing import integration
from official.vision.image_classification.resnet import imagenet_preprocessing from official.vision.image_classification.resnet import imagenet_preprocessing
# TBC: joint clustering and tuning is not supported yet so only one flag should be selected
@parameterized.parameters( @parameterized.parameters(
"resnet", "resnet",
# "resnet_polynomial_decay", b/151854314 # "resnet_polynomial_decay", b/151854314
......
...@@ -9,7 +9,7 @@ psutil>=5.4.3 ...@@ -9,7 +9,7 @@ psutil>=5.4.3
py-cpuinfo>=3.3.0 py-cpuinfo>=3.3.0
scipy>=0.19.1 scipy>=0.19.1
tensorflow-hub>=0.6.0 tensorflow-hub>=0.6.0
tensorflow-model-optimization>=0.2.1 tensorflow-model-optimization>=0.4.1
tensorflow-datasets tensorflow-datasets
tensorflow-addons tensorflow-addons
dataclasses dataclasses
......
...@@ -355,11 +355,8 @@ def define_pruning_flags(): ...@@ -355,11 +355,8 @@ def define_pruning_flags():
def define_clustering_flags(): def define_clustering_flags():
"""Define flags for clustering methods.""" """Define flags for clustering methods."""
flags.DEFINE_string('clustering_method', None, flags.DEFINE_string('clustering_method', None,
'None (no clustering) or selective_clustering.') 'None (no clustering) or selective_clustering'\
flags.DEFINE_integer('number_of_clusters', 256, '(cluster last three Conv2D layers of the model).')
'Number of clusters used in each layer.')
flags.DEFINE_string('save_files_to', None,
'The path to save Keras models and tflite models.')
def get_synth_input_fn(height, def get_synth_input_fn(height,
width, width,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment