Commit 967ef26c authored by ekka's avatar ekka Committed by Francisco Massa
Browse files

Remove dependency from functool in ShuffleNetsV2 (#916)

* Remove dependency from functool in ShuffleNetsV2

This PR removes the dependence of the ShuffleNetV2 code from `functool`.

* flake fix
parent ccd1b27d
import functools
import torch import torch
import torch.nn as nn import torch.nn as nn
from .utils import load_state_dict_from_url from .utils import load_state_dict_from_url
...@@ -45,26 +43,23 @@ class InvertedResidual(nn.Module): ...@@ -45,26 +43,23 @@ class InvertedResidual(nn.Module):
branch_features = oup // 2 branch_features = oup // 2
assert (self.stride != 1) or (inp == branch_features << 1) assert (self.stride != 1) or (inp == branch_features << 1)
pw_conv11 = functools.partial(nn.Conv2d, kernel_size=1, stride=1, padding=0, bias=False)
dw_conv33 = functools.partial(self.depthwise_conv,
kernel_size=3, stride=self.stride, padding=1)
if self.stride > 1: if self.stride > 1:
self.branch1 = nn.Sequential( self.branch1 = nn.Sequential(
dw_conv33(inp, inp), self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2d(inp), nn.BatchNorm2d(inp),
pw_conv11(inp, branch_features), nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features), nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
) )
self.branch2 = nn.Sequential( self.branch2 = nn.Sequential(
pw_conv11(inp if (self.stride > 1) else branch_features, branch_features), nn.Conv2d(inp if (self.stride > 1) else branch_features,
branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features), nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
dw_conv33(branch_features, branch_features), self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2d(branch_features), nn.BatchNorm2d(branch_features),
pw_conv11(branch_features, branch_features), nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features), nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment