Commit faea89d9 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #8604 from Ruomei:toupstream/clusteringexample

PiperOrigin-RevId: 329801211
parents ec5af928 a87bb185
...@@ -42,6 +42,7 @@ MOBILENET_V1_MAX_TOP_1_ACCURACY = 0.68 ...@@ -42,6 +42,7 @@ MOBILENET_V1_MAX_TOP_1_ACCURACY = 0.68
MODEL_OPTIMIZATION_TOP_1_ACCURACY = { MODEL_OPTIMIZATION_TOP_1_ACCURACY = {
'RESNET50_FINETUNE_PRUNING': (0.76, 0.77), 'RESNET50_FINETUNE_PRUNING': (0.76, 0.77),
'MOBILENET_V1_FINETUNE_PRUNING': (0.67, 0.68), 'MOBILENET_V1_FINETUNE_PRUNING': (0.67, 0.68),
'MOBILENET_V1_FINETUNE_CLUSTERING': (0.68, 0.70)
} }
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
...@@ -995,7 +996,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -995,7 +996,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16_tweaked') FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu_fp16_tweaked')
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.dataset_num_private_threads = 40 FLAGS.datasets_num_private_threads = 40
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_8_gpu_fp16_dynamic_tweaked(self): def benchmark_8_gpu_fp16_dynamic_tweaked(self):
...@@ -1011,7 +1012,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark): ...@@ -1011,7 +1012,7 @@ class Resnet50KerasBenchmarkBase(keras_benchmark.KerasBenchmark):
FLAGS.batch_size = 256 * 8 # 8 GPUs FLAGS.batch_size = 256 * 8 # 8 GPUs
FLAGS.loss_scale = 'dynamic' FLAGS.loss_scale = 'dynamic'
FLAGS.tf_gpu_thread_mode = 'gpu_private' FLAGS.tf_gpu_thread_mode = 'gpu_private'
FLAGS.dataset_num_private_threads = 40 FLAGS.datasets_num_private_threads = 40
self._run_and_report_benchmark() self._run_and_report_benchmark()
def benchmark_xla_8_gpu_fp16(self): def benchmark_xla_8_gpu_fp16(self):
...@@ -1774,5 +1775,120 @@ class Resnet50KerasPruningBenchmarkReal(KerasPruningBenchmarkRealBase): ...@@ -1774,5 +1775,120 @@ class Resnet50KerasPruningBenchmarkReal(KerasPruningBenchmarkRealBase):
default_flags=default_flags, **kwargs) default_flags=default_flags, **kwargs)
class KerasClusteringAccuracyBase(keras_benchmark.KerasBenchmark):
"""Benchmark accuracy tests for clustering method."""
def __init__(self,
output_dir=None,
root_data_dir=None,
default_flags=None,
**kwargs):
"""An accuracy benchmark class for clustering method.
Args:
output_dir: directory where to output e.g. log files
root_data_dir: directory under which to look for dataset
default_flags: default flags
**kwargs: arbitrary named arguments. This is needed to make the
constructor forward compatible in case PerfZero provides more
named arguments before updating the constructor.
"""
if default_flags is None:
default_flags = {}
default_flags['clustering_method'] = 'selective_clustering'
default_flags['data_dir'] = os.path.join(root_data_dir, 'imagenet')
default_flags['model'] = 'mobilenet_pretrained'
default_flags['optimizer'] = 'mobilenet_fine_tune'
flag_methods = [resnet_imagenet_main.define_imagenet_keras_flags]
super(KerasClusteringAccuracyBase, self).__init__(
output_dir=output_dir,
flag_methods=flag_methods,
default_flags=default_flags,
**kwargs)
def benchmark_8_gpu(self):
"""Test Keras model with eager, dist_strat and 8 GPUs."""
self._setup()
FLAGS.num_gpus = 8
FLAGS.batch_size = 32 * 8
FLAGS.train_epochs = 1
FLAGS.model_dir = self._get_model_dir('benchmark_8_gpu')
FLAGS.dtype = 'fp32'
FLAGS.enable_eager = True
self._run_and_report_benchmark()
@benchmark_wrappers.enable_runtime_flags
def _run_and_report_benchmark(self,
top_1_min=MODEL_OPTIMIZATION_TOP_1_ACCURACY[
'MOBILENET_V1_FINETUNE_CLUSTERING'][0],
top_1_max=MODEL_OPTIMIZATION_TOP_1_ACCURACY[
'MOBILENET_V1_FINETUNE_CLUSTERING'][1]):
start_time_sec = time.time()
stats = resnet_imagenet_main.run(flags.FLAGS)
wall_time_sec = time.time() - start_time_sec
super(KerasClusteringAccuracyBase, self)._report_benchmark(
stats,
wall_time_sec,
top_1_min=top_1_min,
top_1_max=top_1_max,
total_batch_size=FLAGS.batch_size,
log_steps=100)
class MobilenetV1KerasClusteringAccuracy(KerasClusteringAccuracyBase):
"""Benchmark accuracy tests for MobilenetV1 with clustering method."""
def __init__(self, root_data_dir=None, **kwargs):
default_flags = {
'model': 'mobilenet_pretrained',
'optimizer': 'mobilenet_fine_tune',
}
super(MobilenetV1KerasClusteringAccuracy, self).__init__(
root_data_dir=root_data_dir,
default_flags=default_flags,
**kwargs)
def _run_and_report_benchmark(self):
super(MobilenetV1KerasClusteringAccuracy, self)._run_and_report_benchmark(
top_1_min=\
MODEL_OPTIMIZATION_TOP_1_ACCURACY['MOBILENET_V1_FINETUNE_CLUSTERING'][0],
top_1_max=\
MODEL_OPTIMIZATION_TOP_1_ACCURACY['MOBILENET_V1_FINETUNE_CLUSTERING'][1])
class KerasClusteringBenchmarkRealBase(Resnet50KerasBenchmarkBase):
"""Clustering method benchmarks."""
def __init__(self, root_data_dir=None, default_flags=None, **kwargs):
if default_flags is None:
default_flags = {}
default_flags.update({
'skip_eval': True,
'report_accuracy_metrics': False,
'data_dir': os.path.join(root_data_dir, 'imagenet'),
'clustering_method': 'selective_clustering',
'train_steps': 110,
'log_steps': 10,
})
super(KerasClusteringBenchmarkRealBase, self).__init__(
default_flags=default_flags, **kwargs)
class MobilenetV1KerasClusteringBenchmarkReal(KerasClusteringBenchmarkRealBase):
"""Clustering method benchmarks for MobilenetV1."""
def __init__(self, **kwargs):
default_flags = {
'model': 'mobilenet_pretrained',
'optimizer': 'mobilenet_fine_tune',
}
super(MobilenetV1KerasClusteringBenchmarkReal, self).__init__(
default_flags=default_flags, **kwargs)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -26,7 +26,6 @@ from absl import flags ...@@ -26,7 +26,6 @@ from absl import flags
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
import tensorflow_model_optimization as tfmot
from official.modeling import performance from official.modeling import performance
from official.utils.flags import core as flags_core from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils from official.utils.misc import distribution_utils
...@@ -38,6 +37,34 @@ from official.vision.image_classification.resnet import imagenet_preprocessing ...@@ -38,6 +37,34 @@ from official.vision.image_classification.resnet import imagenet_preprocessing
from official.vision.image_classification.resnet import resnet_model from official.vision.image_classification.resnet import resnet_model
def cluster_last_three_conv2d_layers(model):
import tensorflow_model_optimization as tfmot
last_three_conv2d_layers = [
layer for layer in model.layers
if isinstance(layer, tf.keras.layers.Conv2D)
][-3:]
cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
def cluster_fn(layer):
if layer not in last_three_conv2d_layers:
return layer
if layer == last_three_conv2d_layers[0] or \
layer == last_three_conv2d_layers[1]:
clustered = cluster_weights(layer, number_of_clusters=256, \
cluster_centroids_init=CentroidInitialization.LINEAR)
print("Clustered {} with 256 clusters".format(layer.name))
else:
clustered = cluster_weights(layer, number_of_clusters=32, \
cluster_centroids_init=CentroidInitialization.LINEAR)
print("Clustered {} with 32 clusters".format(layer.name))
return clustered
return tf.keras.models.clone_model(model, clone_function=cluster_fn)
def run(flags_obj): def run(flags_obj):
"""Run ResNet ImageNet training and eval loop using native Keras APIs. """Run ResNet ImageNet training and eval loop using native Keras APIs.
...@@ -53,7 +80,6 @@ def run(flags_obj): ...@@ -53,7 +80,6 @@ def run(flags_obj):
""" """
keras_utils.set_session_config( keras_utils.set_session_config(
enable_xla=flags_obj.enable_xla) enable_xla=flags_obj.enable_xla)
# Execute flag override logic for better model performance # Execute flag override logic for better model performance
if flags_obj.tf_gpu_thread_mode: if flags_obj.tf_gpu_thread_mode:
keras_utils.set_gpu_thread_mode_and_count( keras_utils.set_gpu_thread_mode_and_count(
...@@ -117,7 +143,8 @@ def run(flags_obj): ...@@ -117,7 +143,8 @@ def run(flags_obj):
# This use_keras_image_data_format flags indicates whether image preprocessor # This use_keras_image_data_format flags indicates whether image preprocessor
# output format should be same as the keras backend image data format or just # output format should be same as the keras backend image data format or just
# channel-last format. # channel-last format.
use_keras_image_data_format = (flags_obj.model == 'mobilenet') use_keras_image_data_format = \
(flags_obj.model == 'mobilenet' or 'mobilenet_pretrained')
train_input_dataset = input_fn( train_input_dataset = input_fn(
is_training=True, is_training=True,
data_dir=flags_obj.data_dir, data_dir=flags_obj.data_dir,
...@@ -155,9 +182,11 @@ def run(flags_obj): ...@@ -155,9 +182,11 @@ def run(flags_obj):
with strategy_scope: with strategy_scope:
if flags_obj.optimizer == 'resnet50_default': if flags_obj.optimizer == 'resnet50_default':
optimizer = common.get_optimizer(lr_schedule) optimizer = common.get_optimizer(lr_schedule)
elif flags_obj.optimizer == 'mobilenet_default': elif flags_obj.optimizer == 'mobilenet_default' or 'mobilenet_fine_tune':
initial_learning_rate = \ initial_learning_rate = \
flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
if flags_obj.optimizer == 'mobilenet_fine_tune':
initial_learning_rate = 1e-5
optimizer = tf.keras.optimizers.SGD( optimizer = tf.keras.optimizers.SGD(
learning_rate=tf.keras.optimizers.schedules.ExponentialDecay( learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate, initial_learning_rate,
...@@ -165,6 +194,7 @@ def run(flags_obj): ...@@ -165,6 +194,7 @@ def run(flags_obj):
decay_rate=flags_obj.lr_decay_factor, decay_rate=flags_obj.lr_decay_factor,
staircase=True), staircase=True),
momentum=0.9) momentum=0.9)
if flags_obj.fp16_implementation == 'graph_rewrite': if flags_obj.fp16_implementation == 'graph_rewrite':
# Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as # Note: when flags_obj.fp16_implementation == "graph_rewrite", dtype as
# determined by flags_core.get_tf_dtype(flags_obj) would be 'float32' # determined by flags_core.get_tf_dtype(flags_obj) would be 'float32'
...@@ -180,17 +210,25 @@ def run(flags_obj): ...@@ -180,17 +210,25 @@ def run(flags_obj):
elif flags_obj.model == 'resnet50_v1.5': elif flags_obj.model == 'resnet50_v1.5':
model = resnet_model.resnet50( model = resnet_model.resnet50(
num_classes=imagenet_preprocessing.NUM_CLASSES) num_classes=imagenet_preprocessing.NUM_CLASSES)
elif flags_obj.model == 'mobilenet': elif flags_obj.model == 'mobilenet' or 'mobilenet_pretrained':
# TODO(kimjaehong): Remove layers attribute when minimum TF version # TODO(kimjaehong): Remove layers attribute when minimum TF version
# support 2.0 layers by default. # support 2.0 layers by default.
if flags_obj.model == 'mobilenet_pretrained':
classes_labels = 1000
initial_weights = 'imagenet'
else:
classes_labels = imagenet_preprocessing.NUM_CLASSES
initial_weights = None
model = tf.keras.applications.mobilenet.MobileNet( model = tf.keras.applications.mobilenet.MobileNet(
weights=None, weights=initial_weights,
classes=imagenet_preprocessing.NUM_CLASSES, classes=classes_labels,
layers=tf.keras.layers) layers=tf.keras.layers)
if flags_obj.pretrained_filepath: if flags_obj.pretrained_filepath:
model.load_weights(flags_obj.pretrained_filepath) model.load_weights(flags_obj.pretrained_filepath)
if flags_obj.pruning_method == 'polynomial_decay': if flags_obj.pruning_method == 'polynomial_decay':
import tensorflow_model_optimization as tfmot
if dtype != tf.float32: if dtype != tf.float32:
raise NotImplementedError( raise NotImplementedError(
'Pruning is currently only supported on dtype=tf.float32.') 'Pruning is currently only supported on dtype=tf.float32.')
...@@ -205,8 +243,18 @@ def run(flags_obj): ...@@ -205,8 +243,18 @@ def run(flags_obj):
} }
model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params) model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)
elif flags_obj.pruning_method: elif flags_obj.pruning_method:
raise NotImplementedError('Only polynomial_decay is currently supported.')
if flags_obj.clustering_method == 'selective_clustering':
import tensorflow_model_optimization as tfmot
if dtype != tf.float32 or \
flags_obj.fp16_implementation == 'graph_rewrite':
raise NotImplementedError( raise NotImplementedError(
'Only polynomial_decay is currently supported.') 'Clustering is currently only supported on dtype=tf.float32.')
model = cluster_last_three_conv2d_layers(model)
elif flags_obj.clustering_method:
raise NotImplementedError(
'Only selective_clustering is implemented.')
model.compile( model.compile(
loss='sparse_categorical_crossentropy', loss='sparse_categorical_crossentropy',
...@@ -222,7 +270,7 @@ def run(flags_obj): ...@@ -222,7 +270,7 @@ def run(flags_obj):
enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export, enable_checkpoint_and_export=flags_obj.enable_checkpoint_and_export,
model_dir=flags_obj.model_dir) model_dir=flags_obj.model_dir)
# if mutliple epochs, ignore the train_steps flag. # If mutliple epochs, ignore the train_steps flag.
if train_epochs <= 1 and flags_obj.train_steps: if train_epochs <= 1 and flags_obj.train_steps:
steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch) steps_per_epoch = min(flags_obj.train_steps, steps_per_epoch)
train_epochs = 1 train_epochs = 1
...@@ -244,7 +292,7 @@ def run(flags_obj): ...@@ -244,7 +292,7 @@ def run(flags_obj):
if not strategy and flags_obj.explicit_gpu_placement: if not strategy and flags_obj.explicit_gpu_placement:
# TODO(b/135607227): Add device scope automatically in Keras training loop # TODO(b/135607227): Add device scope automatically in Keras training loop
# when not using distribition strategy. # when not using distribution strategy.
no_dist_strat_device = tf.device('/device:GPU:0') no_dist_strat_device = tf.device('/device:GPU:0')
no_dist_strat_device.__enter__() no_dist_strat_device.__enter__()
...@@ -265,6 +313,10 @@ def run(flags_obj): ...@@ -265,6 +313,10 @@ def run(flags_obj):
if flags_obj.pruning_method: if flags_obj.pruning_method:
model = tfmot.sparsity.keras.strip_pruning(model) model = tfmot.sparsity.keras.strip_pruning(model)
if flags_obj.clustering_method:
model = tfmot.clustering.keras.strip_clustering(model)
if flags_obj.enable_checkpoint_and_export: if flags_obj.enable_checkpoint_and_export:
if dtype == tf.bfloat16: if dtype == tf.bfloat16:
logging.warning('Keras model.save does not support bfloat16 dtype.') logging.warning('Keras model.save does not support bfloat16 dtype.')
...@@ -286,6 +338,7 @@ def define_imagenet_keras_flags(): ...@@ -286,6 +338,7 @@ def define_imagenet_keras_flags():
optimizer=True, optimizer=True,
pretrained_filepath=True) pretrained_filepath=True)
common.define_pruning_flags() common.define_pruning_flags()
common.define_clustering_flags()
flags_core.set_defaults() flags_core.set_defaults()
flags.adopt_module_key_flags(common) flags.adopt_module_key_flags(common)
......
...@@ -31,7 +31,8 @@ from official.vision.image_classification.resnet import imagenet_preprocessing ...@@ -31,7 +31,8 @@ from official.vision.image_classification.resnet import imagenet_preprocessing
"resnet", "resnet",
# "resnet_polynomial_decay", b/151854314 # "resnet_polynomial_decay", b/151854314
"mobilenet", "mobilenet",
# "mobilenet_polynomial_decay" b/151854314 # "mobilenet_polynomial_decay", b/151854314
"mobilenet_selective_clustering",
) )
class KerasImagenetTest(tf.test.TestCase): class KerasImagenetTest(tf.test.TestCase):
"""Unit tests for Keras Models with ImageNet.""" """Unit tests for Keras Models with ImageNet."""
...@@ -74,6 +75,11 @@ class KerasImagenetTest(tf.test.TestCase): ...@@ -74,6 +75,11 @@ class KerasImagenetTest(tf.test.TestCase):
"-pruning_method", "-pruning_method",
"polynomial_decay", "polynomial_decay",
], ],
"mobilenet_selective_clustering": [
"-model", "mobilenet_pretrained",
"-optimizer", "mobilenet_fine_tune",
"-clustering_method", "selective_clustering",
]
} }
_tempdir = None _tempdir = None
...@@ -167,7 +173,10 @@ class KerasImagenetTest(tf.test.TestCase): ...@@ -167,7 +173,10 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags = extra_flags + self.get_extra_flags_dict(flags_key) extra_flags = extra_flags + self.get_extra_flags_dict(flags_key)
if "polynomial_decay" in extra_flags: if "polynomial_decay" in extra_flags:
self.skipTest("Pruning with fp16 is not currently supported.") self.skipTest("Pruning with fp16 is currently not supported.")
if "selective_clustering" in extra_flags:
self.skipTest("Clustering with fp16 is currently not supported.")
integration.run_synthetic( integration.run_synthetic(
main=resnet_imagenet_main.run, main=resnet_imagenet_main.run,
...@@ -237,7 +246,10 @@ class KerasImagenetTest(tf.test.TestCase): ...@@ -237,7 +246,10 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags = extra_flags + self.get_extra_flags_dict(flags_key) extra_flags = extra_flags + self.get_extra_flags_dict(flags_key)
if "polynomial_decay" in extra_flags: if "polynomial_decay" in extra_flags:
self.skipTest("Pruning with fp16 is not currently supported.") self.skipTest("Pruning with fp16 is currently not supported.")
if "selective_clustering" in extra_flags:
self.skipTest("Clustering with fp16 is currently not supported.")
integration.run_synthetic( integration.run_synthetic(
main=resnet_imagenet_main.run, main=resnet_imagenet_main.run,
...@@ -264,7 +276,10 @@ class KerasImagenetTest(tf.test.TestCase): ...@@ -264,7 +276,10 @@ class KerasImagenetTest(tf.test.TestCase):
extra_flags = extra_flags + self.get_extra_flags_dict(flags_key) extra_flags = extra_flags + self.get_extra_flags_dict(flags_key)
if "polynomial_decay" in extra_flags: if "polynomial_decay" in extra_flags:
self.skipTest("Pruning with fp16 is not currently supported.") self.skipTest("Pruning with fp16 is currently not supported.")
if "selective_clustering" in extra_flags:
self.skipTest("Clustering with fp16 is currently not supported.")
integration.run_synthetic( integration.run_synthetic(
main=resnet_imagenet_main.run, main=resnet_imagenet_main.run,
......
...@@ -9,7 +9,7 @@ psutil>=5.4.3 ...@@ -9,7 +9,7 @@ psutil>=5.4.3
py-cpuinfo>=3.3.0 py-cpuinfo>=3.3.0
scipy>=0.19.1 scipy>=0.19.1
tensorflow-hub>=0.6.0 tensorflow-hub>=0.6.0
tensorflow-model-optimization>=0.2.1 tensorflow-model-optimization>=0.4.1
tensorflow-datasets tensorflow-datasets
tensorflow-addons tensorflow-addons
dataclasses dataclasses
......
...@@ -353,6 +353,13 @@ def define_pruning_flags(): ...@@ -353,6 +353,13 @@ def define_pruning_flags():
flags.DEFINE_integer('pruning_frequency', 100, 'Frequency for pruning.') flags.DEFINE_integer('pruning_frequency', 100, 'Frequency for pruning.')
def define_clustering_flags():
"""Define flags for clustering methods."""
flags.DEFINE_string('clustering_method', None,
'None (no clustering) or selective_clustering '
'(cluster last three Conv2D layers of the model).')
def get_synth_input_fn(height, def get_synth_input_fn(height,
width, width,
num_channels, num_channels,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment