Unverified Commit 88925314 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Bugbash] fix bug in compression (#4259)

parent a55a5559
import sys
from tqdm import tqdm
import torch
......@@ -5,7 +6,8 @@ from torchvision import datasets, transforms
from nni.algorithms.compression.v2.pytorch.pruning import AGPPruner
from examples.model_compress.models.cifar10.vgg import VGG
sys.path.append('../../models')
from cifar10.vgg import VGG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
......
import sys
from tqdm import tqdm
import torch
......@@ -7,7 +8,8 @@ from nni.algorithms.compression.v2.pytorch.pruning import L1NormPruner
from nni.algorithms.compression.v2.pytorch.pruning.tools import AGPTaskGenerator
from nni.algorithms.compression.v2.pytorch.pruning.basic_scheduler import PruningScheduler
from examples.model_compress.models.cifar10.vgg import VGG
sys.path.append('../../models')
from cifar10.vgg import VGG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
......
import sys
from tqdm import tqdm
import torch
......@@ -6,7 +7,8 @@ from torchvision import datasets, transforms
from nni.algorithms.compression.v2.pytorch.pruning import L1NormPruner
from nni.compression.pytorch.speedup import ModelSpeedup
from examples.model_compress.models.cifar10.vgg import VGG
sys.path.append('../../models')
from cifar10.vgg import VGG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
......@@ -72,7 +74,7 @@ if __name__ == '__main__':
evaluator(model)
pruner._unwrap_model()
ModelSpeedup(model, dummy_input=torch.rand(10, 3, 32, 32).to(device), masks_file='simple_masks.pth').speedup_model()
ModelSpeedup(model, dummy_input=torch.rand(10, 3, 32, 32).to(device), masks_file=masks).speedup_model()
print('\nThe accuracy after speed up:')
evaluator(model)
......
......@@ -384,6 +384,7 @@ class ADMMPruner(IterativePruner):
for i, wrapper in enumerate(self.get_modules_wrapper()):
z = wrapper.module.weight.data + self.U[i]
self.Z[i] = self._projection(z, wrapper.config['sparsity'], wrapper)
torch.cuda.empty_cache()
self.U[i] = self.U[i] + wrapper.module.weight.data - self.Z[i]
# apply prune
......
......@@ -110,6 +110,8 @@ def replace_prelu(prelu, masks):
in_mask = in_masks[0]
weight_mask = weight_mask['weight']
if weight_mask.size(0) == 1:
return prelu
pruned_in, remained_in = convert_to_coarse_mask(in_mask, 1)
pruned_out, remained_out = convert_to_coarse_mask(output_mask, 1)
n_remained_in = weight_mask.size(0) - pruned_in.size(0)
......@@ -221,8 +223,9 @@ def replace_batchnorm1d(norm, masks):
affine=norm.affine,
track_running_stats=norm.track_running_stats)
# assign weights
new_norm.weight.data = torch.index_select(norm.weight.data, 0, remained_in)
new_norm.bias.data = torch.index_select(norm.bias.data, 0, remained_in)
if norm.affine:
new_norm.weight.data = torch.index_select(norm.weight.data, 0, remained_in)
new_norm.bias.data = torch.index_select(norm.bias.data, 0, remained_in)
new_norm.running_mean.data = torch.index_select(
norm.running_mean.data, 0, remained_in)
......@@ -264,8 +267,9 @@ def replace_batchnorm2d(norm, masks):
affine=norm.affine,
track_running_stats=norm.track_running_stats)
# assign weights
new_norm.weight.data = torch.index_select(norm.weight.data, 0, remained_in)
new_norm.bias.data = torch.index_select(norm.bias.data, 0, remained_in)
if norm.affine:
new_norm.weight.data = torch.index_select(norm.weight.data, 0, remained_in)
new_norm.bias.data = torch.index_select(norm.bias.data, 0, remained_in)
new_norm.running_mean.data = torch.index_select(
norm.running_mean.data, 0, remained_in)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment