Commit d47fc5bb authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 416322339
parent bef8f110
...@@ -485,6 +485,45 @@ MNMultiMAX_SEG_BLOCK_SPECS = { ...@@ -485,6 +485,45 @@ MNMultiMAX_SEG_BLOCK_SPECS = {
] ]
} }
# A smaller MNV3Small, with reduced filters for the last few layers
MNV3SmallReducedFilters = {
'spec_name':
'MobilenetV3SmallReducedFilters',
'block_spec_schema': [
'block_fn', 'kernel_size', 'strides', 'filters', 'activation',
'se_ratio', 'expand_ratio', 'use_normalization', 'use_bias', 'is_output'
],
'block_specs': [
('convbn', 3, 2, 16, 'hard_swish', None, None, True, False, False),
('invertedbottleneck', 3, 2, 16, 'relu', 0.25, 1, None, False, True),
('invertedbottleneck', 3, 2, 24, 'relu', None, 72. / 16, None, False,
False),
('invertedbottleneck', 3, 1, 24, 'relu', None, 88. / 24, None, False,
True),
('invertedbottleneck', 5, 2, 40, 'hard_swish', 0.25, 4, None, False,
False),
('invertedbottleneck', 5, 1, 40, 'hard_swish', 0.25, 6, None, False,
False),
('invertedbottleneck', 5, 1, 40, 'hard_swish', 0.25, 6, None, False,
False),
('invertedbottleneck', 5, 1, 48, 'hard_swish', 0.25, 3, None, False,
False),
('invertedbottleneck', 5, 1, 48, 'hard_swish', 0.25, 3, None, False,
True),
# Layers below are different from MobileNetV3Small and have
# half as many filters
('invertedbottleneck', 5, 2, 48, 'hard_swish', 0.25, 3, None, False,
False),
('invertedbottleneck', 5, 1, 48, 'hard_swish', 0.25, 6, None, False,
False),
('invertedbottleneck', 5, 1, 48, 'hard_swish', 0.25, 6, None, False,
True),
('convbn', 1, 1, 288, 'hard_swish', None, None, True, False, False),
('gpooling', None, None, None, None, None, None, None, None, False),
('convbn', 1, 1, 1024, 'hard_swish', None, None, False, True, False),
]
}
SUPPORTED_SPECS_MAP = { SUPPORTED_SPECS_MAP = {
'MobileNetV1': MNV1_BLOCK_SPECS, 'MobileNetV1': MNV1_BLOCK_SPECS,
'MobileNetV2': MNV2_BLOCK_SPECS, 'MobileNetV2': MNV2_BLOCK_SPECS,
...@@ -495,6 +534,7 @@ SUPPORTED_SPECS_MAP = { ...@@ -495,6 +534,7 @@ SUPPORTED_SPECS_MAP = {
'MobileNetMultiAVG': MNMultiAVG_BLOCK_SPECS, 'MobileNetMultiAVG': MNMultiAVG_BLOCK_SPECS,
'MobileNetMultiAVGSeg': MNMultiAVG_SEG_BLOCK_SPECS, 'MobileNetMultiAVGSeg': MNMultiAVG_SEG_BLOCK_SPECS,
'MobileNetMultiMAXSeg': MNMultiMAX_SEG_BLOCK_SPECS, 'MobileNetMultiMAXSeg': MNMultiMAX_SEG_BLOCK_SPECS,
'MobileNetV3SmallReducedFilters': MNV3SmallReducedFilters,
} }
......
...@@ -38,6 +38,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -38,6 +38,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiMAX', 'MobileNetMultiMAX',
'MobileNetMultiAVGSeg', 'MobileNetMultiAVGSeg',
'MobileNetMultiMAXSeg', 'MobileNetMultiMAXSeg',
'MobileNetV3SmallReducedFilters',
) )
def test_serialize_deserialize(self, model_id): def test_serialize_deserialize(self, model_id):
# Create a network object that sets all of its config options. # Create a network object that sets all of its config options.
...@@ -84,6 +85,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -84,6 +85,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiMAX', 'MobileNetMultiMAX',
'MobileNetMultiAVGSeg', 'MobileNetMultiAVGSeg',
'MobileNetMultiMAXSeg', 'MobileNetMultiMAXSeg',
'MobileNetV3SmallReducedFilters',
], ],
)) ))
def test_input_specs(self, input_dim, model_id): def test_input_specs(self, input_dim, model_id):
...@@ -107,6 +109,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -107,6 +109,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiAVG', 'MobileNetMultiAVG',
'MobileNetMultiMAX', 'MobileNetMultiMAX',
'MobileNetMultiAVGSeg', 'MobileNetMultiAVGSeg',
'MobileNetV3SmallReducedFilters',
], ],
[32, 224], [32, 224],
)) ))
...@@ -127,6 +130,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -127,6 +130,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiAVG': [32, 64, 160, 192], 'MobileNetMultiAVG': [32, 64, 160, 192],
'MobileNetMultiAVGSeg': [32, 64, 160, 96], 'MobileNetMultiAVGSeg': [32, 64, 160, 96],
'MobileNetMultiMAXSeg': [32, 64, 128, 96], 'MobileNetMultiMAXSeg': [32, 64, 128, 96],
'MobileNetV3SmallReducedFilters': [16, 24, 48, 48],
} }
network = mobilenet.MobileNet(model_id=model_id, network = mobilenet.MobileNet(model_id=model_id,
...@@ -152,6 +156,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -152,6 +156,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiMAX', 'MobileNetMultiMAX',
'MobileNetMultiAVGSeg', 'MobileNetMultiAVGSeg',
'MobileNetMultiMAXSeg', 'MobileNetMultiMAXSeg',
'MobileNetV3SmallReducedFilters',
], ],
[32, 224], [32, 224],
)) ))
...@@ -172,6 +177,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -172,6 +177,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiAVG': [64, 192, 640, 768], 'MobileNetMultiAVG': [64, 192, 640, 768],
'MobileNetMultiAVGSeg': [64, 192, 640, 384], 'MobileNetMultiAVGSeg': [64, 192, 640, 384],
'MobileNetMultiMAXSeg': [96, 128, 384, 320], 'MobileNetMultiMAXSeg': [96, 128, 384, 320],
'MobileNetV3SmallReducedFilters': [16, 88, 144, 288],
} }
network = mobilenet.MobileNet(model_id=model_id, network = mobilenet.MobileNet(model_id=model_id,
filter_size_scale=1.0, filter_size_scale=1.0,
...@@ -202,6 +208,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -202,6 +208,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiMAX', 'MobileNetMultiMAX',
'MobileNetMultiAVGSeg', 'MobileNetMultiAVGSeg',
'MobileNetMultiMAXSeg', 'MobileNetMultiMAXSeg',
'MobileNetV3SmallReducedFilters',
], ],
[1.0, 0.75], [1.0, 0.75],
)) ))
...@@ -227,6 +234,8 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -227,6 +234,8 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
('MobileNetMultiAVGSeg', 0.75): 1395272, ('MobileNetMultiAVGSeg', 0.75): 1395272,
('MobileNetMultiMAXSeg', 1.0): 1929088, ('MobileNetMultiMAXSeg', 1.0): 1929088,
('MobileNetMultiMAXSeg', 0.75): 1216544, ('MobileNetMultiMAXSeg', 0.75): 1216544,
('MobileNetV3SmallReducedFilters', 1.0): 694880,
('MobileNetV3SmallReducedFilters', 0.75): 505960,
} }
input_size = 224 input_size = 224
...@@ -250,6 +259,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -250,6 +259,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiMAX', 'MobileNetMultiMAX',
'MobileNetMultiAVGSeg', 'MobileNetMultiAVGSeg',
'MobileNetMultiMAXSeg', 'MobileNetMultiMAXSeg',
'MobileNetV3SmallReducedFilters',
], ],
[8, 16, 32], [8, 16, 32],
)) ))
...@@ -269,6 +279,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -269,6 +279,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiAVG': 192, 'MobileNetMultiAVG': 192,
'MobileNetMultiAVGSeg': 448, 'MobileNetMultiAVGSeg': 448,
'MobileNetMultiMAXSeg': 448, 'MobileNetMultiMAXSeg': 448,
'MobileNetV3SmallReducedFilters': 48,
} }
network = mobilenet.MobileNet( network = mobilenet.MobileNet(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment