Commit 7bd3c937 authored by Xianzhi Du's avatar Xianzhi Du Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 389999813
parent a2c19be8
......@@ -342,9 +342,10 @@ Berkin Akin, Suyog Gupta, and Andrew Howard
"""
MNMultiMAX_BLOCK_SPECS = {
'spec_name': 'MobileNetMultiMAX',
'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters',
'activation', 'expand_ratio',
'use_normalization', 'use_bias', 'is_output'],
'block_spec_schema': [
'block_fn', 'kernel_size', 'strides', 'filters', 'activation',
'expand_ratio', 'use_normalization', 'use_bias', 'is_output'
],
'block_specs': [
('convbn', 3, 2, 32, 'relu', None, True, False, False),
('invertedbottleneck', 3, 2, 32, 'relu', 3., None, False, True),
......@@ -363,15 +364,18 @@ MNMultiMAX_BLOCK_SPECS = {
('invertedbottleneck', 5, 1, 160, 'relu', 4., None, False, True),
('convbn', 1, 1, 960, 'relu', None, True, False, False),
('gpooling', None, None, None, None, None, None, None, False),
('convbn', 1, 1, 1280, 'relu', None, False, True, False),
# Remove bias and add batch norm for the last layer to support QAT
# and achieve slightly better accuracy.
('convbn', 1, 1, 1280, 'relu', None, True, False, False),
]
}
MNMultiAVG_BLOCK_SPECS = {
'spec_name': 'MobileNetMultiAVG',
'block_spec_schema': ['block_fn', 'kernel_size', 'strides', 'filters',
'activation', 'expand_ratio',
'use_normalization', 'use_bias', 'is_output'],
'block_spec_schema': [
'block_fn', 'kernel_size', 'strides', 'filters', 'activation',
'expand_ratio', 'use_normalization', 'use_bias', 'is_output'
],
'block_specs': [
('convbn', 3, 2, 32, 'relu', None, True, False, False),
('invertedbottleneck', 3, 2, 32, 'relu', 3., None, False, False),
......@@ -392,7 +396,9 @@ MNMultiAVG_BLOCK_SPECS = {
('invertedbottleneck', 5, 1, 192, 'relu', 4., None, False, True),
('convbn', 1, 1, 960, 'relu', None, True, False, False),
('gpooling', None, None, None, None, None, None, None, False),
('convbn', 1, 1, 1280, 'relu', None, False, True, False),
# Remove bias and add batch norm for the last layer to support QAT
# and achieve slightly better accuracy.
('convbn', 1, 1, 1280, 'relu', None, True, False, False),
]
}
......
......@@ -158,10 +158,10 @@ class MobileNetTest(parameterized.TestCase, tf.test.TestCase):
('MobileNetV3Small', 0.75): 1026552,
('MobileNetV3EdgeTPU', 1.0): 2849312,
('MobileNetV3EdgeTPU', 0.75): 1737288,
('MobileNetMultiAVG', 1.0): 3700576,
('MobileNetMultiAVG', 0.75): 2345864,
('MobileNetMultiMAX', 1.0): 3170720,
('MobileNetMultiMAX', 0.75): 2041976,
('MobileNetMultiAVG', 1.0): 3704416,
('MobileNetMultiAVG', 0.75): 2349704,
('MobileNetMultiMAX', 1.0): 3174560,
('MobileNetMultiMAX', 0.75): 2045816,
}
input_size = 224
......
......@@ -93,23 +93,6 @@ class ClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase):
def test_mobilenet_network_creation(self, mobilenet_model_id,
filter_size_scale):
"""Test for creation of a MobileNet classifier."""
mobilenet_params = {
('MobileNetV1', 1.0): 4254889,
('MobileNetV1', 0.75): 2602745,
('MobileNetV2', 1.0): 3540265,
('MobileNetV2', 0.75): 2664345,
('MobileNetV3Large', 1.0): 5508713,
('MobileNetV3Large', 0.75): 4013897,
('MobileNetV3Small', 1.0): 2555993,
('MobileNetV3Small', 0.75): 2052577,
('MobileNetV3EdgeTPU', 1.0): 4131593,
('MobileNetV3EdgeTPU', 0.75): 3019569,
('MobileNetMultiAVG', 1.0): 4982857,
('MobileNetMultiAVG', 0.75): 3628145,
('MobileNetMultiMAX', 1.0): 4453001,
('MobileNetMultiMAX', 0.75): 3324257,
}
inputs = np.random.rand(2, 224, 224, 3)
tf.keras.backend.set_image_data_format('channels_last')
......@@ -123,8 +106,6 @@ class ClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase):
num_classes=num_classes,
dropout_rate=0.2,
)
self.assertEqual(model.count_params(),
mobilenet_params[(mobilenet_model_id, filter_size_scale)])
logits = model(inputs)
self.assertAllEqual([2, num_classes], logits.numpy().shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment