Unverified Commit ae4012e2 authored by hx89's avatar hx89 Committed by GitHub
Browse files

fix inception (#1954)

parent de140c19
...@@ -138,14 +138,16 @@ class QuantizableInceptionD(inception_module.InceptionD): ...@@ -138,14 +138,16 @@ class QuantizableInceptionD(inception_module.InceptionD):
class QuantizableInceptionE(inception_module.InceptionE): class QuantizableInceptionE(inception_module.InceptionE):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(QuantizableInceptionE, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs) super(QuantizableInceptionE, self).__init__(conv_block=QuantizableBasicConv2d, *args, **kwargs)
self.myop = nn.quantized.FloatFunctional() self.myop1 = nn.quantized.FloatFunctional()
self.myop2 = nn.quantized.FloatFunctional()
self.myop3 = nn.quantized.FloatFunctional()
def _forward(self, x): def _forward(self, x):
branch1x1 = self.branch1x1(x) branch1x1 = self.branch1x1(x)
branch3x3 = self.branch3x3_1(x) branch3x3 = self.branch3x3_1(x)
branch3x3 = [self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)] branch3x3 = [self.branch3x3_2a(branch3x3), self.branch3x3_2b(branch3x3)]
branch3x3 = self.myop.cat(branch3x3, 1) branch3x3 = self.myop1.cat(branch3x3, 1)
branch3x3dbl = self.branch3x3dbl_1(x) branch3x3dbl = self.branch3x3dbl_1(x)
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl) branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
...@@ -153,7 +155,7 @@ class QuantizableInceptionE(inception_module.InceptionE): ...@@ -153,7 +155,7 @@ class QuantizableInceptionE(inception_module.InceptionE):
self.branch3x3dbl_3a(branch3x3dbl), self.branch3x3dbl_3a(branch3x3dbl),
self.branch3x3dbl_3b(branch3x3dbl), self.branch3x3dbl_3b(branch3x3dbl),
] ]
branch3x3dbl = self.myop.cat(branch3x3dbl, 1) branch3x3dbl = self.myop2.cat(branch3x3dbl, 1)
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1) branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
branch_pool = self.branch_pool(branch_pool) branch_pool = self.branch_pool(branch_pool)
...@@ -163,7 +165,7 @@ class QuantizableInceptionE(inception_module.InceptionE): ...@@ -163,7 +165,7 @@ class QuantizableInceptionE(inception_module.InceptionE):
def forward(self, x): def forward(self, x):
outputs = self._forward(x) outputs = self._forward(x)
return self.myop.cat(outputs, 1) return self.myop3.cat(outputs, 1)
class QuantizableInceptionAux(inception_module.InceptionAux): class QuantizableInceptionAux(inception_module.InceptionAux):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment