Unverified Commit 321f39e7 authored by Jirka Borovec's avatar Jirka Borovec Committed by GitHub
Browse files
parent 3855901e
...@@ -167,5 +167,5 @@ class QuantizableGoogLeNet(GoogLeNet): ...@@ -167,5 +167,5 @@ class QuantizableGoogLeNet(GoogLeNet):
""" """
for m in self.modules(): for m in self.modules():
if type(m) == QuantizableBasicConv2d: if type(m) is QuantizableBasicConv2d:
m.fuse_model() m.fuse_model()
...@@ -247,5 +247,5 @@ class QuantizableInception3(inception_module.Inception3): ...@@ -247,5 +247,5 @@ class QuantizableInception3(inception_module.Inception3):
""" """
for m in self.modules(): for m in self.modules():
if type(m) == QuantizableBasicConv2d: if type(m) is QuantizableBasicConv2d:
m.fuse_model() m.fuse_model()
...@@ -30,7 +30,7 @@ class QuantizableInvertedResidual(InvertedResidual): ...@@ -30,7 +30,7 @@ class QuantizableInvertedResidual(InvertedResidual):
def fuse_model(self) -> None: def fuse_model(self) -> None:
for idx in range(len(self.conv)): for idx in range(len(self.conv)):
if type(self.conv[idx]) == nn.Conv2d: if type(self.conv[idx]) is nn.Conv2d:
fuse_modules(self.conv, [str(idx), str(idx + 1)], inplace=True) fuse_modules(self.conv, [str(idx), str(idx + 1)], inplace=True)
...@@ -54,9 +54,9 @@ class QuantizableMobileNetV2(MobileNetV2): ...@@ -54,9 +54,9 @@ class QuantizableMobileNetV2(MobileNetV2):
def fuse_model(self) -> None: def fuse_model(self) -> None:
for m in self.modules(): for m in self.modules():
if type(m) == ConvNormActivation: if type(m) is ConvNormActivation:
fuse_modules(m, ["0", "1", "2"], inplace=True) fuse_modules(m, ["0", "1", "2"], inplace=True)
if type(m) == QuantizableInvertedResidual: if type(m) is QuantizableInvertedResidual:
m.fuse_model() m.fuse_model()
......
...@@ -99,12 +99,12 @@ class QuantizableMobileNetV3(MobileNetV3): ...@@ -99,12 +99,12 @@ class QuantizableMobileNetV3(MobileNetV3):
def fuse_model(self) -> None: def fuse_model(self) -> None:
for m in self.modules(): for m in self.modules():
if type(m) == ConvNormActivation: if type(m) is ConvNormActivation:
modules_to_fuse = ["0", "1"] modules_to_fuse = ["0", "1"]
if len(m) == 3 and type(m[2]) == nn.ReLU: if len(m) == 3 and type(m[2]) is nn.ReLU:
modules_to_fuse.append("2") modules_to_fuse.append("2")
fuse_modules(m, modules_to_fuse, inplace=True) fuse_modules(m, modules_to_fuse, inplace=True)
elif type(m) == QuantizableSqueezeExcitation: elif type(m) is QuantizableSqueezeExcitation:
m.fuse_model() m.fuse_model()
......
...@@ -104,7 +104,7 @@ class QuantizableResNet(ResNet): ...@@ -104,7 +104,7 @@ class QuantizableResNet(ResNet):
fuse_modules(self, ["conv1", "bn1", "relu"], inplace=True) fuse_modules(self, ["conv1", "bn1", "relu"], inplace=True)
for m in self.modules(): for m in self.modules():
if type(m) == QuantizableBottleneck or type(m) == QuantizableBasicBlock: if type(m) is QuantizableBottleneck or type(m) is QuantizableBasicBlock:
m.fuse_model() m.fuse_model()
......
...@@ -68,7 +68,7 @@ class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2): ...@@ -68,7 +68,7 @@ class QuantizableShuffleNetV2(shufflenetv2.ShuffleNetV2):
if name in ["conv1", "conv5"]: if name in ["conv1", "conv5"]:
torch.quantization.fuse_modules(m, [["0", "1", "2"]], inplace=True) torch.quantization.fuse_modules(m, [["0", "1", "2"]], inplace=True)
for m in self.modules(): for m in self.modules():
if type(m) == QuantizableInvertedResidual: if type(m) is QuantizableInvertedResidual:
if len(m.branch1._modules.items()) > 0: if len(m.branch1._modules.items()) > 0:
torch.quantization.fuse_modules(m.branch1, [["0", "1"], ["2", "3", "4"]], inplace=True) torch.quantization.fuse_modules(m.branch1, [["0", "1"], ["2", "3", "4"]], inplace=True)
torch.quantization.fuse_modules( torch.quantization.fuse_modules(
......
...@@ -9,7 +9,7 @@ def _replace_relu(module: nn.Module) -> None: ...@@ -9,7 +9,7 @@ def _replace_relu(module: nn.Module) -> None:
# Checking for explicit type instead of instance # Checking for explicit type instead of instance
# as we only want to replace modules of the exact type # as we only want to replace modules of the exact type
# not inherited classes # not inherited classes
if type(mod) == nn.ReLU or type(mod) == nn.ReLU6: if type(mod) is nn.ReLU or type(mod) is nn.ReLU6:
reassign[name] = nn.ReLU(inplace=False) reassign[name] = nn.ReLU(inplace=False)
for key, value in reassign.items(): for key, value in reassign.items():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment