Commit 7acb972a authored by Yuqi Li's avatar Yuqi Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 413820673
parent 78d99a22
...@@ -420,7 +420,8 @@ MNMultiAVG_BLOCK_SPECS = { ...@@ -420,7 +420,8 @@ MNMultiAVG_BLOCK_SPECS = {
# Similar to MobileNetMultiAVG and used for segmentation task. # Similar to MobileNetMultiAVG and used for segmentation task.
# Reduced the filters by a factor of 2 in the last block. # Reduced the filters by a factor of 2 in the last block.
MNMultiAVG_SEG_BLOCK_SPECS = { MNMultiAVG_SEG_BLOCK_SPECS = {
'spec_name': 'MobileNetMultiAVGSeg', 'spec_name':
'MobileNetMultiAVGSeg',
'block_spec_schema': [ 'block_spec_schema': [
'block_fn', 'kernel_size', 'strides', 'filters', 'activation', 'block_fn', 'kernel_size', 'strides', 'filters', 'activation',
'expand_ratio', 'use_normalization', 'use_bias', 'is_output' 'expand_ratio', 'use_normalization', 'use_bias', 'is_output'
...@@ -443,7 +444,40 @@ MNMultiAVG_SEG_BLOCK_SPECS = { ...@@ -443,7 +444,40 @@ MNMultiAVG_SEG_BLOCK_SPECS = {
('invertedbottleneck', 5, 1, 96, 'relu', 2., True, False, False), ('invertedbottleneck', 5, 1, 96, 'relu', 2., True, False, False),
('invertedbottleneck', 5, 1, 96, 'relu', 4., True, False, False), ('invertedbottleneck', 5, 1, 96, 'relu', 4., True, False, False),
('invertedbottleneck', 5, 1, 96, 'relu', 4., True, False, True), ('invertedbottleneck', 5, 1, 96, 'relu', 4., True, False, True),
('convbn', 1, 1, 480, 'relu', None, True, False, False), ('convbn', 1, 1, 448, 'relu', None, True, False, True),
('gpooling', None, None, None, None, None, None, None, False),
# Remove bias and add batch norm for the last layer to support QAT
# and achieve slightly better accuracy.
('convbn', 1, 1, 1280, 'relu', None, True, False, False),
]
}
# Similar to MobileNetMultiMax and used for segmentation task.
# Reduced the filters by a factor of 2 in the last block.
MNMultiMAX_SEG_BLOCK_SPECS = {
'spec_name':
'MobileNetMultiMAXSeg',
'block_spec_schema': [
'block_fn', 'kernel_size', 'strides', 'filters', 'activation',
'expand_ratio', 'use_normalization', 'use_bias', 'is_output'
],
'block_specs': [
('convbn', 3, 2, 32, 'relu', None, True, False, False),
('invertedbottleneck', 3, 2, 32, 'relu', 3., True, False, True),
('invertedbottleneck', 5, 2, 64, 'relu', 6., True, False, False),
('invertedbottleneck', 3, 1, 64, 'relu', 2., True, False, False),
('invertedbottleneck', 3, 1, 64, 'relu', 2., True, False, True),
('invertedbottleneck', 5, 2, 128, 'relu', 6., True, False, False),
('invertedbottleneck', 3, 1, 128, 'relu', 4., True, False, False),
('invertedbottleneck', 3, 1, 128, 'relu', 3., True, False, False),
('invertedbottleneck', 3, 1, 128, 'relu', 3., True, False, False),
('invertedbottleneck', 3, 1, 128, 'relu', 6., True, False, False),
('invertedbottleneck', 3, 1, 128, 'relu', 3., True, False, True),
('invertedbottleneck', 3, 2, 160, 'relu', 6., True, False, False),
('invertedbottleneck', 5, 1, 96, 'relu', 2., True, False, False),
('invertedbottleneck', 3, 1, 96, 'relu', 4., True, False, False),
('invertedbottleneck', 5, 1, 96, 'relu', 320.0 / 96, True, False, True),
('convbn', 1, 1, 448, 'relu', None, True, False, True),
('gpooling', None, None, None, None, None, None, None, False), ('gpooling', None, None, None, None, None, None, None, False),
# Remove bias and add batch norm for the last layer to support QAT # Remove bias and add batch norm for the last layer to support QAT
# and achieve slightly better accuracy. # and achieve slightly better accuracy.
...@@ -460,6 +494,7 @@ SUPPORTED_SPECS_MAP = { ...@@ -460,6 +494,7 @@ SUPPORTED_SPECS_MAP = {
'MobileNetMultiMAX': MNMultiMAX_BLOCK_SPECS, 'MobileNetMultiMAX': MNMultiMAX_BLOCK_SPECS,
'MobileNetMultiAVG': MNMultiAVG_BLOCK_SPECS, 'MobileNetMultiAVG': MNMultiAVG_BLOCK_SPECS,
'MobileNetMultiAVGSeg': MNMultiAVG_SEG_BLOCK_SPECS, 'MobileNetMultiAVGSeg': MNMultiAVG_SEG_BLOCK_SPECS,
'MobileNetMultiMAXSeg': MNMultiMAX_SEG_BLOCK_SPECS,
} }
......
...@@ -37,6 +37,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -37,6 +37,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiAVG', 'MobileNetMultiAVG',
'MobileNetMultiMAX', 'MobileNetMultiMAX',
'MobileNetMultiAVGSeg', 'MobileNetMultiAVGSeg',
'MobileNetMultiMAXSeg',
) )
def test_serialize_deserialize(self, model_id): def test_serialize_deserialize(self, model_id):
# Create a network object that sets all of its config options. # Create a network object that sets all of its config options.
...@@ -82,6 +83,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -82,6 +83,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiAVG', 'MobileNetMultiAVG',
'MobileNetMultiMAX', 'MobileNetMultiMAX',
'MobileNetMultiAVGSeg', 'MobileNetMultiAVGSeg',
'MobileNetMultiMAXSeg',
], ],
)) ))
def test_input_specs(self, input_dim, model_id): def test_input_specs(self, input_dim, model_id):
...@@ -124,6 +126,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -124,6 +126,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiMAX': [32, 64, 128, 160], 'MobileNetMultiMAX': [32, 64, 128, 160],
'MobileNetMultiAVG': [32, 64, 160, 192], 'MobileNetMultiAVG': [32, 64, 160, 192],
'MobileNetMultiAVGSeg': [32, 64, 160, 96], 'MobileNetMultiAVGSeg': [32, 64, 160, 96],
'MobileNetMultiMAXSeg': [32, 64, 128, 96],
} }
network = mobilenet.MobileNet(model_id=model_id, network = mobilenet.MobileNet(model_id=model_id,
...@@ -148,6 +151,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -148,6 +151,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiAVG', 'MobileNetMultiAVG',
'MobileNetMultiMAX', 'MobileNetMultiMAX',
'MobileNetMultiAVGSeg', 'MobileNetMultiAVGSeg',
'MobileNetMultiMAXSeg',
], ],
[32, 224], [32, 224],
)) ))
...@@ -167,6 +171,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -167,6 +171,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiMAX': [96, 128, 384, 640], 'MobileNetMultiMAX': [96, 128, 384, 640],
'MobileNetMultiAVG': [64, 192, 640, 768], 'MobileNetMultiAVG': [64, 192, 640, 768],
'MobileNetMultiAVGSeg': [64, 192, 640, 384], 'MobileNetMultiAVGSeg': [64, 192, 640, 384],
'MobileNetMultiMAXSeg': [96, 128, 384, 320],
} }
network = mobilenet.MobileNet(model_id=model_id, network = mobilenet.MobileNet(model_id=model_id,
filter_size_scale=1.0, filter_size_scale=1.0,
...@@ -196,6 +201,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -196,6 +201,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiMAX', 'MobileNetMultiMAX',
'MobileNetMultiMAX', 'MobileNetMultiMAX',
'MobileNetMultiAVGSeg', 'MobileNetMultiAVGSeg',
'MobileNetMultiMAXSeg',
], ],
[1.0, 0.75], [1.0, 0.75],
)) ))
...@@ -217,8 +223,10 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -217,8 +223,10 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
('MobileNetMultiAVG', 0.75): 2349704, ('MobileNetMultiAVG', 0.75): 2349704,
('MobileNetMultiMAX', 1.0): 3174560, ('MobileNetMultiMAX', 1.0): 3174560,
('MobileNetMultiMAX', 0.75): 2045816, ('MobileNetMultiMAX', 0.75): 2045816,
('MobileNetMultiAVGSeg', 1.0): 2284000, ('MobileNetMultiAVGSeg', 1.0): 2239840,
('MobileNetMultiAVGSeg', 0.75): 1427816, ('MobileNetMultiAVGSeg', 0.75): 1395272,
('MobileNetMultiMAXSeg', 1.0): 1929088,
('MobileNetMultiMAXSeg', 0.75): 1216544,
} }
input_size = 224 input_size = 224
...@@ -241,6 +249,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -241,6 +249,7 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetMultiAVG', 'MobileNetMultiAVG',
'MobileNetMultiMAX', 'MobileNetMultiMAX',
'MobileNetMultiAVGSeg', 'MobileNetMultiAVGSeg',
'MobileNetMultiMAXSeg',
], ],
[8, 16, 32], [8, 16, 32],
)) ))
...@@ -258,7 +267,8 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -258,7 +267,8 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
'MobileNetV3EdgeTPU': 192, 'MobileNetV3EdgeTPU': 192,
'MobileNetMultiMAX': 160, 'MobileNetMultiMAX': 160,
'MobileNetMultiAVG': 192, 'MobileNetMultiAVG': 192,
'MobileNetMultiAVGSeg': 96, 'MobileNetMultiAVGSeg': 448,
'MobileNetMultiMAXSeg': 448,
} }
network = mobilenet.MobileNet( network = mobilenet.MobileNet(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment