Unverified Commit 321f39e7 authored by Jirka Borovec's avatar Jirka Borovec Committed by GitHub
Browse files
parent 3855901e
......@@ -167,5 +167,5 @@ class QuantizableGoogLeNet(GoogLeNet):
"""
for m in self.modules():
if type(m) == QuantizableBasicConv2d:
if type(m) is QuantizableBasicConv2d:
m.fuse_model()
......@@ -247,5 +247,5 @@ class QuantizableInception3(inception_module.Inception3):
"""
for m in self.modules():
if type(m) == QuantizableBasicConv2d:
if type(m) is QuantizableBasicConv2d:
m.fuse_model()
......@@ -30,7 +30,7 @@ class QuantizableInvertedResidual(InvertedResidual):
def fuse_model(self) -> None:
for idx in range(len(self.conv)):
if type(self.conv[idx]) == nn.Conv2d:
if type(self.conv[idx]) is nn.Conv2d:
fuse_modules(self.conv, [str(idx), str(idx + 1)], inplace=True)
......@@ -54,9 +54,9 @@ class QuantizableMobileNetV2(MobileNetV2):
def fuse_model(self) -> None:
for m in self.modules():
if type(m) == ConvNormActivation:
if type(m) is ConvNormActivation:
fuse_modules(m, ["0", "1", "2"], inplace=True)
if type(m) == QuantizableInvertedResidual:
if type(m) is QuantizableInvertedResidual:
m.fuse_model()
......
......@@ -99,12 +99,12 @@ class QuantizableMobileNetV3(MobileNetV3):
def fuse_model(self) -> None:
for m in self.modules():
if type(m) == ConvNormActivation:
if type(m) is ConvNormActivation:
modules_to_fuse = ["0", "1"]
if len(m) == 3 and type(m[2]) == nn.ReLU:
if len(m) == 3 and type(m[2]) is nn.ReLU:
modules_to_fuse.append("2")
fuse_modules(m, modules_to_fuse, inplace=True)
elif type(m) == QuantizableSqueezeExcitation:
elif type(m) is QuantizableSqueezeExcitation:
m.fuse_model()
......
......@@ -104,7 +104,7 @@ class QuantizableResNet(ResNet):
fuse_modules(self, ["conv1", "bn1", "relu"], inplace=True)
for m in self.modules():
if type(m) == QuantizableBottleneck or type(m) == QuantizableBasicBlock:
if type(m) is QuantizableBottleneck or type(m) is QuantizableBasicBlock:
m.fuse_model()
......
......@@ -68,7 +68,7 @@ class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2):
if name in ["conv1", "conv5"]:
torch.quantization.fuse_modules(m, [["0", "1", "2"]], inplace=True)
for m in self.modules():
if type(m) == QuantizableInvertedResidual:
if type(m) is QuantizableInvertedResidual:
if len(m.branch1._modules.items()) > 0:
torch.quantization.fuse_modules(m.branch1, [["0", "1"], ["2", "3", "4"]], inplace=True)
torch.quantization.fuse_modules(
......
......@@ -9,7 +9,7 @@ def _replace_relu(module: nn.Module) -> None:
# Checking for explicit type instead of instance
# as we only want to replace modules of the exact type
# not inherited classes
if type(mod) == nn.ReLU or type(mod) == nn.ReLU6:
if type(mod) is nn.ReLU or type(mod) is nn.ReLU6:
reassign[name] = nn.ReLU(inplace=False)
for key, value in reassign.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment