"...resnet50_tensorflow.git" did not exist on "b668f5948f3aa6d472ff74d81dea5c2770d4795d"
Commit 7a9c72ad authored by Shixin Luo's avatar Shixin Luo
Browse files

fix the last layer of mobilenet v2 and add testing for complete classification...

fix the last layer of mobilenet v2 and add testing for complete classification with mobilenet backbone
parent 890fc1c9
...@@ -232,7 +232,7 @@ def image_classification_imagenet_mobilenet() -> cfg.ExperimentConfig: ...@@ -232,7 +232,7 @@ def image_classification_imagenet_mobilenet() -> cfg.ExperimentConfig:
mobilenet=backbones.MobileNet( mobilenet=backbones.MobileNet(
model_id='MobileNetV2', width_multiplier=1.0)), model_id='MobileNetV2', width_multiplier=1.0)),
norm_activation=common.NormActivation( norm_activation=common.NormActivation(
norm_momentum=0.9997, norm_epsilon=1e-3)), norm_momentum=0.997, norm_epsilon=1e-3)),
losses=Losses(l2_weight_decay=2e-5, label_smoothing=0.1), losses=Losses(l2_weight_decay=2e-5, label_smoothing=0.1),
train_data=DataConfig( train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'), input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
......
...@@ -197,7 +197,7 @@ MNV2_BLOCK_SPECS = { ...@@ -197,7 +197,7 @@ MNV2_BLOCK_SPECS = {
('invertedbottleneck', 3, 1, 160, 6.), ('invertedbottleneck', 3, 1, 160, 6.),
('invertedbottleneck', 3, 1, 160, 6.), ('invertedbottleneck', 3, 1, 160, 6.),
('invertedbottleneck', 3, 1, 320, 6.), ('invertedbottleneck', 3, 1, 320, 6.),
('convbn', 1, 2, 1280, None), ('convbn', 1, 1, 1280, None),
] ]
} }
......
...@@ -77,6 +77,52 @@ class ClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase): ...@@ -77,6 +77,52 @@ class ClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase):
logits = model(inputs) logits = model(inputs)
self.assertAllEqual([2, num_classes], logits.numpy().shape) self.assertAllEqual([2, num_classes], logits.numpy().shape)
@combinations.generate(
combinations.combine(
mobilenet_model_id=[
'MobileNetV1',
'MobileNetV2',
'MobileNetV3Large',
'MobileNetV3Small',
'MobileNetV3EdgeTPU'
],
width_multiplier=[1.0, 0.75],
))
def test_mobilenet_network_creation(
self, mobilenet_model_id, width_multiplier):
"""Test for creation of a MobileNet classifier."""
mobilenet_params = {
('MobileNetV1', 1.0): 4254889,
('MobileNetV1', 0.75): 2602745,
('MobileNetV2', 1.0): 3540265,
('MobileNetV2', 0.75): 2664345,
('MobileNetV3Large', 1.0): 5508713,
('MobileNetV3Large', 0.75): 4013897,
('MobileNetV3Small', 1.0): 2555993,
('MobileNetV3Small', 0.75): 2052577,
('MobileNetV3EdgeTPU', 1.0): 4131593,
('MobileNetV3EdgeTPU', 0.75): 3019569,
}
inputs = np.random.rand(2, 224, 224, 3)
tf.keras.backend.set_image_data_format('channels_last')
backbone = backbones.MobileNet(
model_id=mobilenet_model_id, width_multiplier=width_multiplier)
num_classes = 1001
model = classification_model.ClassificationModel(
backbone=backbone,
num_classes=num_classes,
dropout_rate=0.2,
)
self.assertEqual(model.count_params(), mobilenet_params[
(mobilenet_model_id, width_multiplier)])
logits = model(inputs)
self.assertAllEqual([2, num_classes], logits.numpy().shape)
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
strategy=[ strategy=[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment