Commit d47fc5bb authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 416322339
parent bef8f110
......@@ -485,6 +485,45 @@ MNMultiMAX_SEG_BLOCK_SPECS = {
]
}
# A smaller MNV3Small, with reduced filters for the last few layers
MNV3SmallReducedFilters = {
'spec_name':
'MobilenetV3SmallReducedFilters',
'block_spec_schema': [
'block_fn', 'kernel_size', 'strides', 'filters', 'activation',
'se_ratio', 'expand_ratio', 'use_normalization', 'use_bias', 'is_output'
],
'block_specs': [
('convbn', 3, 2, 16, 'hard_swish', None, None, True, False, False),
('invertedbottleneck', 3, 2, 16, 'relu', 0.25, 1, None, False, True),
('invertedbottleneck', 3, 2, 24, 'relu', None, 72. / 16, None, False,
False),
('invertedbottleneck', 3, 1, 24, 'relu', None, 88. / 24, None, False,
True),
('invertedbottleneck', 5, 2, 40, 'hard_swish', 0.25, 4, None, False,
False),
('invertedbottleneck', 5, 1, 40, 'hard_swish', 0.25, 6, None, False,
False),
('invertedbottleneck', 5, 1, 40, 'hard_swish', 0.25, 6, None, False,
False),
('invertedbottleneck', 5, 1, 48, 'hard_swish', 0.25, 3, None, False,
False),
('invertedbottleneck', 5, 1, 48, 'hard_swish', 0.25, 3, None, False,
True),
# Layers below are different from MobileNetV3Small and have
# half as many filters
('invertedbottleneck', 5, 2, 48, 'hard_swish', 0.25, 3, None, False,
False),
('invertedbottleneck', 5, 1, 48, 'hard_swish', 0.25, 6, None, False,
False),
('invertedbottleneck', 5, 1, 48, 'hard_swish', 0.25, 6, None, False,
True),
('convbn', 1, 1, 288, 'hard_swish', None, None, True, False, False),
('gpooling', None, None, None, None, None, None, None, None, False),
('convbn', 1, 1, 1024, 'hard_swish', None, None, False, True, False),
]
}
SUPPORTED_SPECS_MAP = {
'MobileNetV1': MNV1_BLOCK_SPECS,
'MobileNetV2': MNV2_BLOCK_SPECS,
......@@ -495,6 +534,7 @@ SUPPORTED_SPECS_MAP = {
'MobileNetMultiAVG': MNMultiAVG_BLOCK_SPECS,
'MobileNetMultiAVGSeg': MNMultiAVG_SEG_BLOCK_SPECS,
'MobileNetMultiMAXSeg': MNMultiMAX_SEG_BLOCK_SPECS,
'MobileNetV3SmallReducedFilters': MNV3SmallReducedFilters,
}
......
......@@ -38,6 +38,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiMAX',
'MobileNetMultiAVGSeg',
'MobileNetMultiMAXSeg',
'MobileNetV3SmallReducedFilters',
)
def test_serialize_deserialize(self, model_id):
# Create a network object that sets all of its config options.
......@@ -84,6 +85,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiMAX',
'MobileNetMultiAVGSeg',
'MobileNetMultiMAXSeg',
'MobileNetV3SmallReducedFilters',
],
))
def test_input_specs(self, input_dim, model_id):
......@@ -107,6 +109,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiAVG',
'MobileNetMultiMAX',
'MobileNetMultiAVGSeg',
'MobileNetV3SmallReducedFilters',
],
[32, 224],
))
......@@ -127,6 +130,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiAVG': [32, 64, 160, 192],
'MobileNetMultiAVGSeg': [32, 64, 160, 96],
'MobileNetMultiMAXSeg': [32, 64, 128, 96],
'MobileNetV3SmallReducedFilters': [16, 24, 48, 48],
}
network = mobilenet.MobileNet(model_id=model_id,
......@@ -152,6 +156,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiMAX',
'MobileNetMultiAVGSeg',
'MobileNetMultiMAXSeg',
'MobileNetV3SmallReducedFilters',
],
[32, 224],
))
......@@ -172,6 +177,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiAVG': [64, 192, 640, 768],
'MobileNetMultiAVGSeg': [64, 192, 640, 384],
'MobileNetMultiMAXSeg': [96, 128, 384, 320],
'MobileNetV3SmallReducedFilters': [16, 88, 144, 288],
}
network = mobilenet.MobileNet(model_id=model_id,
filter_size_scale=1.0,
......@@ -202,6 +208,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiMAX',
'MobileNetMultiAVGSeg',
'MobileNetMultiMAXSeg',
'MobileNetV3SmallReducedFilters',
],
[1.0, 0.75],
))
......@@ -227,6 +234,8 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
('MobileNetMultiAVGSeg', 0.75): 1395272,
('MobileNetMultiMAXSeg', 1.0): 1929088,
('MobileNetMultiMAXSeg', 0.75): 1216544,
('MobileNetV3SmallReducedFilters', 1.0): 694880,
('MobileNetV3SmallReducedFilters', 0.75): 505960,
}
input_size = 224
......@@ -250,6 +259,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiMAX',
'MobileNetMultiAVGSeg',
'MobileNetMultiMAXSeg',
'MobileNetV3SmallReducedFilters',
],
[8, 16, 32],
))
......@@ -269,6 +279,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiAVG': 192,
'MobileNetMultiAVGSeg': 448,
'MobileNetMultiMAXSeg': 448,
'MobileNetV3SmallReducedFilters': 48,
}
network = mobilenet.MobileNet(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment