Commit 44a28e87 authored by Rino Lee's avatar Rino Lee Committed by A. Unique TensorFlower
Browse files

Add more prunable block-weight suffix patterns to cover various types of Resnet and Mobilenet

PiperOrigin-RevId: 437816100
parent 19a92653
......@@ -28,16 +28,25 @@ from official.vision.tasks import image_classification
class ImageClassificationTask(image_classification.ImageClassificationTask):
"""A task for image classification with pruning."""
_BLOCK_LAYER_SUFFIX_MAP = {
mobilenet.Conv2DBNBlock: ('conv2d/kernel:0',),
nn_blocks.BottleneckBlock: (
'conv2d/kernel:0',
'conv2d_1/kernel:0',
'conv2d_2/kernel:0',
'conv2d_3/kernel:0',
),
nn_blocks.InvertedBottleneckBlock:
('conv2d/kernel:0', 'conv2d_1/kernel:0',
'depthwise_conv2d/depthwise_kernel:0'),
mobilenet.Conv2DBNBlock: ('conv2d/kernel:0',),
nn_blocks.InvertedBottleneckBlock: (
'conv2d/kernel:0',
'conv2d_1/kernel:0',
'conv2d_2/kernel:0',
'conv2d_3/kernel:0',
'depthwise_conv2d/depthwise_kernel:0',
),
nn_blocks.ResidualBlock: (
'conv2d/kernel:0',
'conv2d_1/kernel:0',
'conv2d_2/kernel:0',
),
}
def build_model(self) -> tf.keras.Model:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment