Commit 060c10f0 authored by Matthew Yeung's avatar Matthew Yeung Committed by Francisco Massa
Browse files

allow user to define residual settings (#965)

* allow user to define residual settings

* 4spaces

* linting errors

* backward compatible, and added test
parent a0a93ff8
...@@ -71,6 +71,13 @@ class Tester(unittest.TestCase): ...@@ -71,6 +71,13 @@ class Tester(unittest.TestCase):
f = 2 ** sum(i) f = 2 ** sum(i)
self.assertEqual(out.shape, (1, 2048, 7 * f, 7 * f)) self.assertEqual(out.shape, (1, 2048, 7 * f, 7 * f))
def test_mobilenetv2_residual_setting(self):
model = models.__dict__["mobilenet_v2"](inverted_residual_setting=[[1, 16, 1, 1], [6, 24, 2, 2]])
model.eval()
x = torch.rand(1, 3, 224, 224)
out = model(x)
self.assertEqual(out.shape[-1], 1000)
for model_name in get_available_classification_models(): for model_name in get_available_classification_models():
# for-loop bodies don't define scopes, so we have to save the variables # for-loop bodies don't define scopes, so we have to save the variables
......
...@@ -50,11 +50,13 @@ class InvertedResidual(nn.Module): ...@@ -50,11 +50,13 @@ class InvertedResidual(nn.Module):
class MobileNetV2(nn.Module): class MobileNetV2(nn.Module):
def __init__(self, num_classes=1000, width_mult=1.0): def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None):
super(MobileNetV2, self).__init__() super(MobileNetV2, self).__init__()
block = InvertedResidual block = InvertedResidual
input_channel = 32 input_channel = 32
last_channel = 1280 last_channel = 1280
if inverted_residual_setting is None:
inverted_residual_setting = [ inverted_residual_setting = [
# t, c, n, s # t, c, n, s
[1, 16, 1, 1], [1, 16, 1, 1],
...@@ -66,6 +68,11 @@ class MobileNetV2(nn.Module): ...@@ -66,6 +68,11 @@ class MobileNetV2(nn.Module):
[6, 320, 1, 1], [6, 320, 1, 1],
] ]
# only check the first element, assuming user knows t,c,n,s are required
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
raise ValueError("inverted_residual_setting should be non-empty "
"or a 4-element list, got {}".format(inverted_residual_setting))
# building first layer # building first layer
input_channel = int(input_channel * width_mult) input_channel = int(input_channel * width_mult)
self.last_channel = int(last_channel * max(1.0, width_mult)) self.last_channel = int(last_channel * max(1.0, width_mult))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment