Commit 060c10f0 authored by Matthew Yeung's avatar Matthew Yeung Committed by Francisco Massa
Browse files

allow user to define residual settings (#965)

* allow user to define residual settings

* 4spaces

* linting errors

* backward compatible, and added test
parent a0a93ff8
......@@ -71,6 +71,13 @@ class Tester(unittest.TestCase):
f = 2 ** sum(i)
self.assertEqual(out.shape, (1, 2048, 7 * f, 7 * f))
def test_mobilenetv2_residual_setting(self):
model = models.__dict__["mobilenet_v2"](inverted_residual_setting=[[1, 16, 1, 1], [6, 24, 2, 2]])
model.eval()
x = torch.rand(1, 3, 224, 224)
out = model(x)
self.assertEqual(out.shape[-1], 1000)
for model_name in get_available_classification_models():
# for-loop bodies don't define scopes, so we have to save the variables
......
......@@ -50,21 +50,28 @@ class InvertedResidual(nn.Module):
class MobileNetV2(nn.Module):
def __init__(self, num_classes=1000, width_mult=1.0):
def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None):
super(MobileNetV2, self).__init__()
block = InvertedResidual
input_channel = 32
last_channel = 1280
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
if inverted_residual_setting is None:
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
# only check the first element, assuming user knows t,c,n,s are required
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
raise ValueError("inverted_residual_setting should be non-empty "
"or a 4-element list, got {}".format(inverted_residual_setting))
# building first layer
input_channel = int(input_channel * width_mult)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment