Unverified Commit 92f6754e authored by colorjam's avatar colorjam Committed by GitHub
Browse files

[Model Compression] Update api of iterative pruners (#3507)

parent 26f47727
...@@ -31,7 +31,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None): ...@@ -31,7 +31,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
# if the input is the path of the mask_file # if the input is the path of the mask_file
assert os.path.exists(masks) assert os.path.exists(masks)
masks = torch.load(masks) masks = torch.load(masks)
assert len(masks) > 0, 'Mask tensor cannot be empty' assert len(masks) > 0, 'Mask tensor cannot be empty'
# if the user uses the model and dummy_input to trace the model, we # if the user uses the model and dummy_input to trace the model, we
# should get the traced model handly, so that, we only trace the # should get the traced model handly, so that, we only trace the
# model once, GroupMaskConflict and ChannelMaskConflict will reuse # model once, GroupMaskConflict and ChannelMaskConflict will reuse
...@@ -181,10 +181,8 @@ class GroupMaskConflict(MaskFix): ...@@ -181,10 +181,8 @@ class GroupMaskConflict(MaskFix):
w_mask = self.masks[layername]['weight'] w_mask = self.masks[layername]['weight']
shape = w_mask.size() shape = w_mask.size()
count = np.prod(shape[1:]) count = np.prod(shape[1:])
all_ones = (w_mask.flatten(1).sum(-1) == all_ones = (w_mask.flatten(1).sum(-1) == count).nonzero().squeeze(1).tolist()
count).nonzero().squeeze(1).tolist() all_zeros = (w_mask.flatten(1).sum(-1) == 0).nonzero().squeeze(1).tolist()
all_zeros = (w_mask.flatten(1).sum(-1) ==
0).nonzero().squeeze(1).tolist()
if len(all_ones) + len(all_zeros) < w_mask.size(0): if len(all_ones) + len(all_zeros) < w_mask.size(0):
# In fine-grained pruning, skip this layer # In fine-grained pruning, skip this layer
_logger.info('Layers %s using fine-grained pruning', layername) _logger.info('Layers %s using fine-grained pruning', layername)
...@@ -198,7 +196,7 @@ class GroupMaskConflict(MaskFix): ...@@ -198,7 +196,7 @@ class GroupMaskConflict(MaskFix):
group_masked = [] group_masked = []
for i in range(group): for i in range(group):
_start = step * i _start = step * i
_end = step * (i+1) _end = step * (i + 1)
_tmp_list = list( _tmp_list = list(
filter(lambda x: _start <= x and x < _end, all_zeros)) filter(lambda x: _start <= x and x < _end, all_zeros))
group_masked.append(_tmp_list) group_masked.append(_tmp_list)
...@@ -286,7 +284,7 @@ class ChannelMaskConflict(MaskFix): ...@@ -286,7 +284,7 @@ class ChannelMaskConflict(MaskFix):
0, 2, 3) if self.conv_prune_dim == 0 else (1, 2, 3) 0, 2, 3) if self.conv_prune_dim == 0 else (1, 2, 3)
channel_mask = (mask.abs().sum(tmp_sum_idx) != 0).int() channel_mask = (mask.abs().sum(tmp_sum_idx) != 0).int()
channel_masks.append(channel_mask) channel_masks.append(channel_mask)
if (channel_mask.sum() * (mask.numel() / mask.shape[1-self.conv_prune_dim])).item() != (mask > 0).sum().item(): if (channel_mask.sum() * (mask.numel() / mask.shape[1 - self.conv_prune_dim])).item() != (mask > 0).sum().item():
fine_grained = True fine_grained = True
else: else:
raise RuntimeError( raise RuntimeError(
......
...@@ -61,9 +61,8 @@ class CompressorTestCase(TestCase): ...@@ -61,9 +61,8 @@ class CompressorTestCase(TestCase):
def test_torch_level_pruner(self): def test_torch_level_pruner(self):
model = TorchModel() model = TorchModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
configure_list = [{'sparsity': 0.8, 'op_types': ['default']}] configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
torch_pruner.LevelPruner(model, configure_list, optimizer).compress() torch_pruner.LevelPruner(model, configure_list).compress()
def test_torch_naive_quantizer(self): def test_torch_naive_quantizer(self):
model = TorchModel() model = TorchModel()
...@@ -93,7 +92,7 @@ class CompressorTestCase(TestCase): ...@@ -93,7 +92,7 @@ class CompressorTestCase(TestCase):
model = TorchModel() model = TorchModel()
config_list = [{'sparsity': 0.6, 'op_types': ['Conv2d']}, {'sparsity': 0.2, 'op_types': ['Conv2d']}] config_list = [{'sparsity': 0.6, 'op_types': ['Conv2d']}, {'sparsity': 0.2, 'op_types': ['Conv2d']}]
pruner = torch_pruner.FPGMPruner(model, config_list, torch.optim.SGD(model.parameters(), lr=0.01)) pruner = torch_pruner.FPGMPruner(model, config_list)
model.conv2.module.weight.data = torch.tensor(w).float() model.conv2.module.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(model.conv2) masks = pruner.calc_mask(model.conv2)
...@@ -152,7 +151,7 @@ class CompressorTestCase(TestCase): ...@@ -152,7 +151,7 @@ class CompressorTestCase(TestCase):
config_list = [{'sparsity': 0.2, 'op_types': ['BatchNorm2d']}] config_list = [{'sparsity': 0.2, 'op_types': ['BatchNorm2d']}]
model.bn1.weight.data = torch.tensor(w).float() model.bn1.weight.data = torch.tensor(w).float()
model.bn2.weight.data = torch.tensor(-w).float() model.bn2.weight.data = torch.tensor(-w).float()
pruner = torch_pruner.SlimPruner(model, config_list) pruner = torch_pruner.SlimPruner(model, config_list, optimizer=None, trainer=None, criterion=None)
mask1 = pruner.calc_mask(model.bn1) mask1 = pruner.calc_mask(model.bn1)
mask2 = pruner.calc_mask(model.bn2) mask2 = pruner.calc_mask(model.bn2)
...@@ -165,7 +164,7 @@ class CompressorTestCase(TestCase): ...@@ -165,7 +164,7 @@ class CompressorTestCase(TestCase):
config_list = [{'sparsity': 0.6, 'op_types': ['BatchNorm2d']}] config_list = [{'sparsity': 0.6, 'op_types': ['BatchNorm2d']}]
model.bn1.weight.data = torch.tensor(w).float() model.bn1.weight.data = torch.tensor(w).float()
model.bn2.weight.data = torch.tensor(w).float() model.bn2.weight.data = torch.tensor(w).float()
pruner = torch_pruner.SlimPruner(model, config_list) pruner = torch_pruner.SlimPruner(model, config_list, optimizer=None, trainer=None, criterion=None)
mask1 = pruner.calc_mask(model.bn1) mask1 = pruner.calc_mask(model.bn1)
mask2 = pruner.calc_mask(model.bn2) mask2 = pruner.calc_mask(model.bn2)
...@@ -202,8 +201,8 @@ class CompressorTestCase(TestCase): ...@@ -202,8 +201,8 @@ class CompressorTestCase(TestCase):
model = TorchModel() model = TorchModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
pruner = torch_pruner.TaylorFOWeightFilterPruner(model, config_list, optimizer, statistics_batch_num=1) pruner = torch_pruner.TaylorFOWeightFilterPruner(model, config_list, optimizer, trainer=None, criterion=None, sparsity_training_epochs=1)
x = torch.rand((1, 1, 28, 28), requires_grad=True) x = torch.rand((1, 1, 28, 28), requires_grad=True)
model.conv1.module.weight.data = torch.tensor(w1).float() model.conv1.module.weight.data = torch.tensor(w1).float()
model.conv2.module.weight.data = torch.tensor(w2).float() model.conv2.module.weight.data = torch.tensor(w2).float()
...@@ -345,7 +344,7 @@ class CompressorTestCase(TestCase): ...@@ -345,7 +344,7 @@ class CompressorTestCase(TestCase):
], ],
[ [
{'sparsity': 0.2 }, {'sparsity': 0.2 },
{'sparsity': 0.6, 'op_names': 'abc' } {'sparsity': 0.6, 'op_names': 'abc'}
] ]
] ]
model = TorchModel() model = TorchModel()
...@@ -353,7 +352,13 @@ class CompressorTestCase(TestCase): ...@@ -353,7 +352,13 @@ class CompressorTestCase(TestCase):
for pruner_class in pruner_classes: for pruner_class in pruner_classes:
for config_list in bad_configs: for config_list in bad_configs:
try: try:
pruner_class(model, config_list, optimizer) kwargs = {}
if pruner_class in (torch_pruner.SlimPruner, torch_pruner.AGPPruner, torch_pruner.ActivationMeanRankFilterPruner, torch_pruner.ActivationAPoZRankFilterPruner):
kwargs = {'optimizer': None, 'trainer': None, 'criterion': None}
print('kwargs', kwargs)
pruner_class(model, config_list, **kwargs)
print(config_list) print(config_list)
assert False, 'Validation error should be raised for bad configuration' assert False, 'Validation error should be raised for bad configuration'
except schema.SchemaError: except schema.SchemaError:
......
...@@ -46,6 +46,24 @@ def generate_random_sparsity_v2(model): ...@@ -46,6 +46,24 @@ def generate_random_sparsity_v2(model):
'sparsity': sparsity}) 'sparsity': sparsity})
return cfg_list return cfg_list
def train(model, criterion, optimizer, callback=None):
model.train()
device = next(model.parameters()).device
data = torch.randn(2, 3, 224, 224).to(device)
target = torch.tensor([0, 1]).long().to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# callback should be inserted between loss.backward() and optimizer.step()
if callback:
callback()
optimizer.step()
def trainer(model, optimizer, criterion, epoch, callback=None):
return train(model, criterion, optimizer, callback=callback)
class DependencyawareTest(TestCase): class DependencyawareTest(TestCase):
@unittest.skipIf(torch.__version__ < "1.3.0", "not supported") @unittest.skipIf(torch.__version__ < "1.3.0", "not supported")
...@@ -55,6 +73,7 @@ class DependencyawareTest(TestCase): ...@@ -55,6 +73,7 @@ class DependencyawareTest(TestCase):
sparsity = 0.7 sparsity = 0.7
cfg_list = [{'op_types': ['Conv2d'], 'sparsity':sparsity}] cfg_list = [{'op_types': ['Conv2d'], 'sparsity':sparsity}]
dummy_input = torch.ones(1, 3, 224, 224) dummy_input = torch.ones(1, 3, 224, 224)
for model_name in model_zoo: for model_name in model_zoo:
for pruner in pruners: for pruner in pruners:
print('Testing on ', pruner) print('Testing on ', pruner)
...@@ -72,16 +91,12 @@ class DependencyawareTest(TestCase): ...@@ -72,16 +91,12 @@ class DependencyawareTest(TestCase):
momentum=0.9, momentum=0.9,
weight_decay=4e-5) weight_decay=4e-5)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
tmp_pruner = pruner( if pruner == TaylorFOWeightFilterPruner:
net, cfg_list, optimizer, dependency_aware=True, dummy_input=dummy_input) tmp_pruner = pruner(
# train one single batch so that the the pruner can collect the net, cfg_list, optimizer, trainer=trainer, criterion=criterion, dependency_aware=True, dummy_input=dummy_input)
# statistic else:
optimizer.zero_grad() tmp_pruner = pruner(
out = net(dummy_input) net, cfg_list, dependency_aware=True, dummy_input=dummy_input)
batchsize = dummy_input.size(0)
loss = criterion(out, torch.zeros(batchsize, dtype=torch.int64))
loss.backward()
optimizer.step()
tmp_pruner.compress() tmp_pruner.compress()
tmp_pruner.export_model(MODEL_FILE, MASK_FILE) tmp_pruner.export_model(MODEL_FILE, MASK_FILE)
...@@ -91,7 +106,7 @@ class DependencyawareTest(TestCase): ...@@ -91,7 +106,7 @@ class DependencyawareTest(TestCase):
ms.speedup_model() ms.speedup_model()
for name, module in net.named_modules(): for name, module in net.named_modules():
if isinstance(module, nn.Conv2d): if isinstance(module, nn.Conv2d):
expected = int(ori_filters[name] * (1-sparsity)) expected = int(ori_filters[name] * (1 - sparsity))
filter_diff = abs(expected - module.out_channels) filter_diff = abs(expected - module.out_channels)
errmsg = '%s Ori: %d, Expected: %d, Real: %d' % ( errmsg = '%s Ori: %d, Expected: %d, Real: %d' % (
name, ori_filters[name], expected, module.out_channels) name, ori_filters[name], expected, module.out_channels)
...@@ -124,16 +139,13 @@ class DependencyawareTest(TestCase): ...@@ -124,16 +139,13 @@ class DependencyawareTest(TestCase):
momentum=0.9, momentum=0.9,
weight_decay=4e-5) weight_decay=4e-5)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
tmp_pruner = pruner(
net, cfg_list, optimizer, dependency_aware=True, dummy_input=dummy_input) if pruner in (TaylorFOWeightFilterPruner, ActivationMeanRankFilterPruner, ActivationAPoZRankFilterPruner):
# train one single batch so that the the pruner can collect the tmp_pruner = pruner(
# statistic net, cfg_list, optimizer, trainer=trainer, criterion=criterion, dependency_aware=True, dummy_input=dummy_input)
optimizer.zero_grad() else:
out = net(dummy_input) tmp_pruner = pruner(
batchsize = dummy_input.size(0) net, cfg_list, dependency_aware=True, dummy_input=dummy_input)
loss = criterion(out, torch.zeros(batchsize, dtype=torch.int64))
loss.backward()
optimizer.step()
tmp_pruner.compress() tmp_pruner.compress()
tmp_pruner.export_model(MODEL_FILE, MASK_FILE) tmp_pruner.export_model(MODEL_FILE, MASK_FILE)
......
...@@ -17,7 +17,7 @@ from unittest import TestCase, main ...@@ -17,7 +17,7 @@ from unittest import TestCase, main
from nni.compression.pytorch import ModelSpeedup, apply_compression_results from nni.compression.pytorch import ModelSpeedup, apply_compression_results
from nni.algorithms.compression.pytorch.pruning import L1FilterPruner from nni.algorithms.compression.pytorch.pruning import L1FilterPruner
from nni.algorithms.compression.pytorch.pruning.weight_masker import WeightMasker from nni.algorithms.compression.pytorch.pruning.weight_masker import WeightMasker
from nni.algorithms.compression.pytorch.pruning.one_shot import _StructuredFilterPruner from nni.algorithms.compression.pytorch.pruning.dependency_aware_pruner import DependencyAwarePruner
torch.manual_seed(0) torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
...@@ -205,7 +205,7 @@ class L1ChannelMasker(WeightMasker): ...@@ -205,7 +205,7 @@ class L1ChannelMasker(WeightMasker):
return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias} return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias}
class L1ChannelPruner(_StructuredFilterPruner): class L1ChannelPruner(DependencyAwarePruner):
def __init__(self, model, config_list, optimizer=None, dependency_aware=False, dummy_input=None): def __init__(self, model, config_list, optimizer=None, dependency_aware=False, dummy_input=None):
super().__init__(model, config_list, pruning_algorithm='l1', optimizer=optimizer, super().__init__(model, config_list, pruning_algorithm='l1', optimizer=optimizer,
dependency_aware=dependency_aware, dummy_input=dummy_input) dependency_aware=dependency_aware, dummy_input=dummy_input)
......
...@@ -42,13 +42,10 @@ prune_config = { ...@@ -42,13 +42,10 @@ prune_config = {
'agp': { 'agp': {
'pruner_class': AGPPruner, 'pruner_class': AGPPruner,
'config_list': [{ 'config_list': [{
'initial_sparsity': 0., 'sparsity': 0.8,
'final_sparsity': 0.8,
'start_epoch': 0,
'end_epoch': 10,
'frequency': 1,
'op_types': ['Conv2d'] 'op_types': ['Conv2d']
}], }],
'trainer': lambda model, optimizer, criterion, epoch: model,
'validators': [] 'validators': []
}, },
'slim': { 'slim': {
...@@ -57,6 +54,7 @@ prune_config = { ...@@ -57,6 +54,7 @@ prune_config = {
'sparsity': 0.7, 'sparsity': 0.7,
'op_types': ['BatchNorm2d'] 'op_types': ['BatchNorm2d']
}], }],
'trainer': lambda model, optimizer, criterion, epoch: model,
'validators': [ 'validators': [
lambda model: validate_sparsity(model.bn1, 0.7, model.bias) lambda model: validate_sparsity(model.bn1, 0.7, model.bias)
] ]
...@@ -97,6 +95,7 @@ prune_config = { ...@@ -97,6 +95,7 @@ prune_config = {
'sparsity': 0.5, 'sparsity': 0.5,
'op_types': ['Conv2d'], 'op_types': ['Conv2d'],
}], }],
'trainer': lambda model, optimizer, criterion, epoch: model,
'validators': [ 'validators': [
lambda model: validate_sparsity(model.conv1, 0.5, model.bias) lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
] ]
...@@ -107,6 +106,7 @@ prune_config = { ...@@ -107,6 +106,7 @@ prune_config = {
'sparsity': 0.5, 'sparsity': 0.5,
'op_types': ['Conv2d'], 'op_types': ['Conv2d'],
}], }],
'trainer': lambda model, optimizer, criterion, epoch: model,
'validators': [ 'validators': [
lambda model: validate_sparsity(model.conv1, 0.5, model.bias) lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
] ]
...@@ -117,6 +117,7 @@ prune_config = { ...@@ -117,6 +117,7 @@ prune_config = {
'sparsity': 0.5, 'sparsity': 0.5,
'op_types': ['Conv2d'], 'op_types': ['Conv2d'],
}], }],
'trainer': lambda model, optimizer, criterion, epoch: model,
'validators': [ 'validators': [
lambda model: validate_sparsity(model.conv1, 0.5, model.bias) lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
] ]
...@@ -127,7 +128,7 @@ prune_config = { ...@@ -127,7 +128,7 @@ prune_config = {
'sparsity': 0.5, 'sparsity': 0.5,
'op_types': ['Conv2d'] 'op_types': ['Conv2d']
}], }],
'short_term_fine_tuner': lambda model:model, 'short_term_fine_tuner': lambda model: model,
'evaluator':lambda model: 0.9, 'evaluator':lambda model: 0.9,
'validators': [] 'validators': []
}, },
...@@ -146,7 +147,7 @@ prune_config = { ...@@ -146,7 +147,7 @@ prune_config = {
'sparsity': 0.5, 'sparsity': 0.5,
'op_types': ['Conv2d'], 'op_types': ['Conv2d'],
}], }],
'trainer': lambda model, optimizer, criterion, epoch, callback : model, 'trainer': lambda model, optimizer, criterion, epoch : model,
'validators': [ 'validators': [
lambda model: validate_sparsity(model.conv1, 0.5, model.bias) lambda model: validate_sparsity(model.conv1, 0.5, model.bias)
] ]
...@@ -158,7 +159,7 @@ prune_config = { ...@@ -158,7 +159,7 @@ prune_config = {
'op_types': ['Conv2d'], 'op_types': ['Conv2d'],
}], }],
'base_algo': 'l1', 'base_algo': 'l1',
'trainer': lambda model, optimizer, criterion, epoch, callback : model, 'trainer': lambda model, optimizer, criterion, epoch : model,
'evaluator': lambda model: 0.9, 'evaluator': lambda model: 0.9,
'dummy_input': torch.randn([64, 1, 28, 28]), 'dummy_input': torch.randn([64, 1, 28, 28]),
'validators': [] 'validators': []
...@@ -170,7 +171,7 @@ prune_config = { ...@@ -170,7 +171,7 @@ prune_config = {
'op_types': ['Conv2d'], 'op_types': ['Conv2d'],
}], }],
'base_algo': 'l2', 'base_algo': 'l2',
'trainer': lambda model, optimizer, criterion, epoch, callback : model, 'trainer': lambda model, optimizer, criterion, epoch : model,
'evaluator': lambda model: 0.9, 'evaluator': lambda model: 0.9,
'dummy_input': torch.randn([64, 1, 28, 28]), 'dummy_input': torch.randn([64, 1, 28, 28]),
'validators': [] 'validators': []
...@@ -182,7 +183,7 @@ prune_config = { ...@@ -182,7 +183,7 @@ prune_config = {
'op_types': ['Conv2d'], 'op_types': ['Conv2d'],
}], }],
'base_algo': 'fpgm', 'base_algo': 'fpgm',
'trainer': lambda model, optimizer, criterion, epoch, callback : model, 'trainer': lambda model, optimizer, criterion, epoch : model,
'evaluator': lambda model: 0.9, 'evaluator': lambda model: 0.9,
'dummy_input': torch.randn([64, 1, 28, 28]), 'dummy_input': torch.randn([64, 1, 28, 28]),
'validators': [] 'validators': []
...@@ -206,88 +207,87 @@ class Model(nn.Module): ...@@ -206,88 +207,87 @@ class Model(nn.Module):
def forward(self, x): def forward(self, x):
return self.fc(self.pool(self.bn1(self.conv1(x))).view(x.size(0), -1)) return self.fc(self.pool(self.bn1(self.conv1(x))).view(x.size(0), -1))
class SimpleDataset:
def __getitem__(self, index):
return torch.randn(3, 32, 32), 1.
def __len__(self):
return 1000
def train(model, train_loader, criterion, optimizer):
model.train()
device = next(model.parameters()).device
x = torch.randn(2, 1, 28, 28).to(device)
y = torch.tensor([0, 1]).long().to(device)
# print('hello...')
for _ in range(2):
out = model(x)
loss = criterion(out, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'taylorfo', 'mean_activation', 'apoz', 'netadapt', 'simulatedannealing', 'admm', 'autocompress_l1', 'autocompress_l2', 'autocompress_fpgm',], bias=True): def pruners_test(pruner_names=['level', 'agp', 'slim', 'fpgm', 'l1', 'l2', 'taylorfo', 'mean_activation', 'apoz', 'netadapt', 'simulatedannealing', 'admm', 'autocompress_l1', 'autocompress_l2', 'autocompress_fpgm',], bias=True):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dummy_input = torch.randn(2, 1, 28, 28).to(device)
criterion = torch.nn.CrossEntropyLoss()
train_loader = torch.utils.data.DataLoader(SimpleDataset(), batch_size=16, shuffle=False, drop_last=True)
def trainer(model, optimizer, criterion, epoch):
return train(model, train_loader, criterion, optimizer)
for pruner_name in pruner_names: for pruner_name in pruner_names:
print('testing {}...'.format(pruner_name)) print('testing {}...'.format(pruner_name))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model(bias=bias).to(device) model = Model(bias=bias).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
config_list = prune_config[pruner_name]['config_list'] config_list = prune_config[pruner_name]['config_list']
x = torch.randn(2, 1, 28, 28).to(device)
y = torch.tensor([0, 1]).long().to(device)
out = model(x)
loss = F.cross_entropy(out, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if pruner_name == 'netadapt': if pruner_name == 'netadapt':
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, short_term_fine_tuner=prune_config[pruner_name]['short_term_fine_tuner'], evaluator=prune_config[pruner_name]['evaluator']) pruner = prune_config[pruner_name]['pruner_class'](model, config_list, short_term_fine_tuner=prune_config[pruner_name]['short_term_fine_tuner'], evaluator=prune_config[pruner_name]['evaluator'])
elif pruner_name == 'simulatedannealing': elif pruner_name == 'simulatedannealing':
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, evaluator=prune_config[pruner_name]['evaluator']) pruner = prune_config[pruner_name]['pruner_class'](model, config_list, evaluator=prune_config[pruner_name]['evaluator'])
elif pruner_name in ('agp', 'slim', 'taylorfo', 'apoz', 'mean_activation'):
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=trainer, optimizer=optimizer, criterion=criterion)
elif pruner_name == 'admm': elif pruner_name == 'admm':
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer']) pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=trainer)
elif pruner_name.startswith('autocompress'): elif pruner_name.startswith('autocompress'):
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer'], evaluator=prune_config[pruner_name]['evaluator'], dummy_input=x, base_algo=prune_config[pruner_name]['base_algo']) pruner = prune_config[pruner_name]['pruner_class'](model, config_list, trainer=prune_config[pruner_name]['trainer'], evaluator=prune_config[pruner_name]['evaluator'], criterion=torch.nn.CrossEntropyLoss(), dummy_input=dummy_input, base_algo=prune_config[pruner_name]['base_algo'])
else: else:
pruner = prune_config[pruner_name]['pruner_class'](model, config_list, optimizer) pruner = prune_config[pruner_name]['pruner_class'](model, config_list)
pruner.compress()
x = torch.randn(2, 1, 28, 28).to(device)
y = torch.tensor([0, 1]).long().to(device)
out = model(x)
loss = F.cross_entropy(out, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if pruner_name == 'taylorfo':
# taylorfo algorithm calculate contributions at first iteration(step), and do pruning
# when iteration >= statistics_batch_num (default 1)
optimizer.step()
pruner.compress()
pruner.export_model('./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth', input_shape=(2,1,28,28), device=device) pruner.export_model('./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth', input_shape=(2,1,28,28), device=device)
for v in prune_config[pruner_name]['validators']: for v in prune_config[pruner_name]['validators']:
v(model) v(model)
filePaths = ['./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth', './search_history.csv', './search_result.json'] filePaths = ['./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth', './search_history.csv', './search_result.json']
for f in filePaths: for f in filePaths:
if os.path.exists(f): if os.path.exists(f):
os.remove(f) os.remove(f)
def _test_agp(pruning_algorithm):
model = Model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
config_list = prune_config['agp']['config_list']
pruner = AGPPruner(model, config_list, optimizer, pruning_algorithm=pruning_algorithm) def _test_agp(pruning_algorithm):
pruner.compress() train_loader = torch.utils.data.DataLoader(SimpleDataset(), batch_size=16, shuffle=False, drop_last=True)
model = Model()
x = torch.randn(2, 1, 28, 28) optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
y = torch.tensor([0, 1]).long()
for epoch in range(config_list[0]['start_epoch'], config_list[0]['end_epoch']+1): def trainer(model, optimizer, criterion, epoch):
pruner.update_epoch(epoch) return train(model, train_loader, criterion, optimizer)
out = model(x)
loss = F.cross_entropy(out, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
target_sparsity = pruner.compute_target_sparsity(config_list[0]) config_list = prune_config['agp']['config_list']
actual_sparsity = (model.conv1.weight_mask == 0).sum().item() / model.conv1.weight_mask.numel() pruner = AGPPruner(model, config_list, optimizer=optimizer, trainer=trainer, criterion=torch.nn.CrossEntropyLoss(), pruning_algorithm=pruning_algorithm)
# set abs_tol = 0.2, considering the sparsity error for channel pruning when number of channels is small. pruner.compress()
assert math.isclose(actual_sparsity, target_sparsity, abs_tol=0.2)
class SimpleDataset: target_sparsity = pruner.compute_target_sparsity(config_list[0])
def __getitem__(self, index): actual_sparsity = (model.conv1.weight_mask == 0).sum().item() / model.conv1.weight_mask.numel()
return torch.randn(3, 32, 32), 1. # set abs_tol = 0.2, considering the sparsity error for channel pruning when number of channels is small.
assert math.isclose(actual_sparsity, target_sparsity, abs_tol=0.2)
def __len__(self):
return 1000
class PrunerTestCase(TestCase): class PrunerTestCase(TestCase):
def test_pruners(self): def test_pruners(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment