"docs/en_US/vscode:/vscode.git/clone" did not exist on "58d5c2faf0303751e432a4f99af19ac25e3065fb"
Unverified Commit 3836689f authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

issue 4540 (#4594)

parent 21abc280
...@@ -171,10 +171,14 @@ class AutoMaskInference: ...@@ -171,10 +171,14 @@ class AutoMaskInference:
# apply the input mask # apply the input mask
for tid, in_tensor in enumerate(self.dummy_input): for tid, in_tensor in enumerate(self.dummy_input):
if isinstance(in_tensor, torch.Tensor) and self.in_masks[tid] is not None: if isinstance(in_tensor, torch.Tensor) and self.in_masks[tid] is not None:
# in_tensor.data = in_tensor.data * \
# self.in_masks[tid] + \
# (1-self.in_masks[tid]) * self.in_constants[tid]
# issue-4540 when two tensors are multiplied, the constants part make
# the propagation weaker, and lead to shape misaligment. Currently, we
# donnot support the constant folding, so, we just remove the constant here
in_tensor.data = in_tensor.data * \ in_tensor.data = in_tensor.data * \
self.in_masks[tid] + \ self.in_masks[tid]
(1-self.in_masks[tid]) * self.in_constants[tid]
def __apply_weight_mask(self): def __apply_weight_mask(self):
""" """
......
...@@ -163,7 +163,13 @@ class ChannelDependency(Dependency): ...@@ -163,7 +163,13 @@ class ChannelDependency(Dependency):
parent_layers = [] parent_layers = []
# find the node that contains aten::add # find the node that contains aten::add
# or aten::cat operations # or aten::cat operations
if node.op_type in ADD_TYPES: if node.op_type in ADD_TYPES or node.op_type in MUL_TYPES:
# refer issue 4540 for more details. Multiplication actually
# will not introduce the channel dependency, cause the misaligned
# channels can propagate to each other. However, when one of the input
# tensor is from skip connection(residual), the channel propagation
# may be failed(the input is also used by another layer and cannot be
# pruned), in this case, we need to fix the conflict maunally.
parent_layers = self._get_parent_layers(node) parent_layers = self._get_parent_layers(node)
elif node.op_type == CAT_TYPE: elif node.op_type == CAT_TYPE:
# To determine if this cat operation will introduce channel # To determine if this cat operation will introduce channel
......
...@@ -512,6 +512,46 @@ class SpeedupTestCase(TestCase): ...@@ -512,6 +512,46 @@ class SpeedupTestCase(TestCase):
print("Fine-grained speeduped model") print("Fine-grained speeduped model")
print(model) print(model)
def test_multiplication_speedup(self):
"""
Model from issue 4540.
"""
class Net(torch.nn.Module):
def __init__(self,):
super(Net, self).__init__()
self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
self.input = torch.nn.Conv2d(3, 8, 3)
self.bn = torch.nn.BatchNorm2d(8)
self.fc1 = torch.nn.Conv2d(8, 16, 1)
self.fc2 = torch.nn.Conv2d(16, 8, 1)
self.activation = torch.nn.ReLU()
self.scale_activation = torch.nn.Hardsigmoid()
self.out = torch.nn.Conv2d(8, 12, 1)
def forward(self, input):
input = self.activation(self.bn(self.input(input)))
scale = self.avgpool(input)
out1 = self.activation(self.fc1(scale))
out1 = self.scale_activation(self.fc2(out1))
return self.out(out1 * input)
model = Net().to(device)
model.eval()
im = torch.ones(1, 3, 512, 512).to(device)
model(im)
cfg_list = []
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
cfg_list.append({'op_types':['Conv2d'], 'sparsity':0.3, 'op_names':[name]})
pruner = L1FilterPruner(model, cfg_list)
pruner.compress()
pruner.export_model(MODEL_FILE, MASK_FILE)
pruner._unwrap_model()
ms=ModelSpeedup(model, im, MASK_FILE)
ms.speedup_model()
def tearDown(self): def tearDown(self):
if os.path.exists(MODEL_FILE): if os.path.exists(MODEL_FILE):
os.remove(MODEL_FILE) os.remove(MODEL_FILE)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment