Commit 164a2657 authored by ekka's avatar ekka Committed by Francisco Massa
Browse files

Remove `input_size` argument from ShuffleNetV2 (#886)

* remove 'input_size' parameter from shufflenetv2

* Update shufflenetv2.py
parent d5347856
...@@ -85,7 +85,7 @@ class InvertedResidual(nn.Module): ...@@ -85,7 +85,7 @@ class InvertedResidual(nn.Module):
class ShuffleNetV2(nn.Module): class ShuffleNetV2(nn.Module):
def __init__(self, num_classes=1000, input_size=224, width_mult=1): def __init__(self, num_classes=1000, width_mult=1):
super(ShuffleNetV2, self).__init__() super(ShuffleNetV2, self).__init__()
try: try:
...@@ -145,8 +145,8 @@ class ShuffleNetV2(nn.Module): ...@@ -145,8 +145,8 @@ class ShuffleNetV2(nn.Module):
return stages[str(mult)] return stages[str(mult)]
def shufflenetv2(pretrained=False, num_classes=1000, input_size=224, width_mult=1, **kwargs): def shufflenetv2(pretrained=False, num_classes=1000, width_mult=1, **kwargs):
model = ShuffleNetV2(num_classes=num_classes, input_size=input_size, width_mult=width_mult) model = ShuffleNetV2(num_classes=num_classes, width_mult=width_mult)
if pretrained: if pretrained:
# change width_mult to float # change width_mult to float
...@@ -164,17 +164,17 @@ def shufflenetv2(pretrained=False, num_classes=1000, input_size=224, width_mult= ...@@ -164,17 +164,17 @@ def shufflenetv2(pretrained=False, num_classes=1000, input_size=224, width_mult=
return model return model
def shufflenetv2_x0_5(pretrained=False, num_classes=1000, input_size=224, **kwargs): def shufflenetv2_x0_5(pretrained=False, num_classes=1000, **kwargs):
return shufflenetv2(pretrained, num_classes, input_size, 0.5) return shufflenetv2(pretrained, num_classes, 0.5)
def shufflenetv2_x1_0(pretrained=False, num_classes=1000, input_size=224, **kwargs): def shufflenetv2_x1_0(pretrained=False, num_classes=1000, **kwargs):
return shufflenetv2(pretrained, num_classes, input_size, 1) return shufflenetv2(pretrained, num_classes, 1)
def shufflenetv2_x1_5(pretrained=False, num_classes=1000, input_size=224, **kwargs): def shufflenetv2_x1_5(pretrained=False, num_classes=1000, **kwargs):
return shufflenetv2(pretrained, num_classes, input_size, 1.5) return shufflenetv2(pretrained, num_classes, 1.5)
def shufflenetv2_x2_0(pretrained=False, num_classes=1000, input_size=224, **kwargs): def shufflenetv2_x2_0(pretrained=False, num_classes=1000, **kwargs):
return shufflenetv2(pretrained, num_classes, input_size, 2) return shufflenetv2(pretrained, num_classes, 2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment