Unverified Commit e3e17f47 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Model Compression] Add Unit Test (#4125)

parent 9a68cdb2
import argparse
import logging
from pathlib import Path
import torch
from torchvision import transforms, datasets
from nni.algorithms.compression.v2.pytorch import pruning
from nni.compression.pytorch import ModelSpeedup
from examples.model_compress.models.cifar10.vgg import VGG
logging.getLogger().setLevel(logging.DEBUG)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VGG().to(device)
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=True, transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
normalize,
]), download=True),
batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
normalize,
])),
batch_size=200, shuffle=False)
criterion = torch.nn.CrossEntropyLoss()
def trainer(model, optimizer, criterion, epoch=None):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def evaluator(model):
model.eval()
criterion = torch.nn.NLLLoss()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
acc = 100 * correct / len(test_loader.dataset)
print('Test Loss: {} Accuracy: {}%\n'.format(
test_loss, acc))
return acc
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
fintune_optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
def main(args):
if args.pre_train:
for i in range(1):
trainer(model, fintune_optimizer, criterion, epoch=i)
config_list = [{
'op_types': ['Conv2d'],
'sparsity_per_layer': 0.8
}]
kwargs = {
'model': model,
'config_list': config_list,
}
if args.pruner == 'level':
pruner = pruning.LevelPruner(**kwargs)
else:
kwargs['mode'] = args.mode
if kwargs['mode'] == 'dependency_aware':
kwargs['dummy_input'] = torch.rand(10, 3, 32, 32).to(device)
if args.pruner == 'l1norm':
pruner = pruning.L1NormPruner(**kwargs)
elif args.pruner == 'l2norm':
pruner = pruning.L2NormPruner(**kwargs)
elif args.pruner == 'fpgm':
pruner = pruning.FPGMPruner(**kwargs)
else:
kwargs['trainer'] = trainer
kwargs['optimizer'] = optimizer
kwargs['criterion'] = criterion
if args.pruner == 'slim':
kwargs['config_list'] = [{
'op_types': ['BatchNorm2d'],
'total_sparsity': 0.8,
'max_sparsity_per_layer': 0.9
}]
kwargs['training_epochs'] = 1
pruner = pruning.SlimPruner(**kwargs)
elif args.pruner == 'mean_activation':
pruner = pruning.ActivationMeanRankPruner(**kwargs)
elif args.pruner == 'apoz':
pruner = pruning.ActivationAPoZRankPruner(**kwargs)
elif args.pruner == 'taylorfo':
pruner = pruning.TaylorFOWeightPruner(**kwargs)
pruned_model, masks = pruner.compress()
pruner.show_pruned_weights()
if args.speed_up:
tmp_masks = {}
for name, mask in masks.items():
tmp_masks[name] = {}
tmp_masks[name]['weight'] = mask.get('weight_mask')
if 'bias' in masks:
tmp_masks[name]['bias'] = mask.get('bias_mask')
torch.save(tmp_masks, Path('./temp_masks.pth'))
pruner._unwrap_model()
ModelSpeedup(model, torch.rand(10, 3, 32, 32).to(device), Path('./temp_masks.pth'))
if args.finetune:
for i in range(1):
trainer(pruned_model, fintune_optimizer, criterion, epoch=i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Example for model comporession')
parser.add_argument('--pruner', type=str, default='l1norm',
choices=['level', 'l1norm', 'l2norm', 'slim',
'fpgm', 'mean_activation', 'apoz', 'taylorfo'],
help='pruner to use')
parser.add_argument('--mode', type=str, default='normal',
choices=['normal', 'dependency_aware', 'global'])
parser.add_argument('--pre-train', action='store_true', default=False,
help='Whether to pre-train the model')
parser.add_argument('--speed-up', action='store_true', default=False,
help='Whether to speed-up the pruned model')
parser.add_argument('--finetune', action='store_true', default=False,
help='Whether to finetune the pruned model')
args = parser.parse_args()
main(args)
import functools
from tqdm import tqdm
import torch
from torchvision import datasets, transforms
from nni.algorithms.compression.v2.pytorch.pruning import L1NormPruner
from nni.algorithms.compression.v2.pytorch.pruning.tools import AGPTaskGenerator
from nni.algorithms.compression.v2.pytorch.pruning.basic_scheduler import PruningScheduler
from examples.model_compress.models.cifar10.vgg import VGG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=True, transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
normalize,
]), download=True),
batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
normalize,
])),
batch_size=128, shuffle=False)
criterion = torch.nn.CrossEntropyLoss()
def trainer(model, optimizer, criterion, epoch):
model.train()
for data, target in tqdm(iterable=train_loader, desc='Epoch {}'.format(epoch)):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
def finetuner(model):
model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
for data, target in tqdm(iterable=train_loader, desc='Epoch PFs'):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
def evaluator(model):
model.eval()
correct = 0
with torch.no_grad():
for data, target in tqdm(iterable=test_loader, desc='Test'):
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
acc = 100 * correct / len(test_loader.dataset)
print('Accuracy: {}%\n'.format(acc))
return acc
if __name__ == '__main__':
model = VGG().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
# pre-train the model
for i in range(5):
trainer(model, optimizer, criterion, i)
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
# Make sure initialize task generator at first, this because the model pass to the generator should be an unwrapped model.
# If you want to initialize pruner at first, you can use the follow code.
# pruner = L1NormPruner(model, config_list)
# pruner._unwrap_model()
# task_generator = AGPTaskGenerator(10, model, config_list, log_dir='.', keep_intermediate_result=True)
# pruner._wrap_model()
# you can specify the log_dir, all intermediate results and best result will save under this folder.
# if you don't want to keep intermediate results, you can set `keep_intermediate_result=False`.
task_generator = AGPTaskGenerator(10, model, config_list, log_dir='.', keep_intermediate_result=True)
pruner = L1NormPruner(model, config_list)
dummy_input = torch.rand(10, 3, 32, 32).to(device)
# if you just want to keep the final result as the best result, you can pass evaluator as None.
# or the result with the highest score (given by evaluator) will be the best result.
# scheduler = PruningScheduler(pruner, task_generator, finetuner=finetuner, speed_up=True, dummy_input=dummy_input, evaluator=evaluator)
scheduler = PruningScheduler(pruner, task_generator, finetuner=finetuner, speed_up=True, dummy_input=dummy_input, evaluator=None)
scheduler.compress()
from tqdm import tqdm
import torch
from torchvision import datasets, transforms
from nni.algorithms.compression.v2.pytorch.pruning import L1NormPruner
from nni.compression.pytorch.speedup import ModelSpeedup
from examples.model_compress.models.cifar10.vgg import VGG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=True, transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
normalize,
]), download=True),
batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
normalize,
])),
batch_size=128, shuffle=False)
criterion = torch.nn.CrossEntropyLoss()
def trainer(model, optimizer, criterion, epoch):
model.train()
for data, target in tqdm(iterable=train_loader, desc='Epoch {}'.format(epoch)):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
def evaluator(model):
model.eval()
correct = 0
with torch.no_grad():
for data, target in tqdm(iterable=test_loader, desc='Test'):
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
acc = 100 * correct / len(test_loader.dataset)
print('Accuracy: {}%\n'.format(acc))
return acc
if __name__ == '__main__':
model = VGG().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
print('\nPre-train the model:')
for i in range(5):
trainer(model, optimizer, criterion, i)
evaluator(model)
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = L1NormPruner(model, config_list)
_, masks = pruner.compress()
print('\nThe accuracy with masks:')
evaluator(model)
pruner._unwrap_model()
ModelSpeedup(model, dummy_input=torch.rand(10, 3, 32, 32).to(device), masks_file='simple_masks.pth').speedup_model()
print('\nThe accuracy after speed up:')
evaluator(model)
print('\nFinetune the model after speed up:')
for i in range(5):
trainer(model, optimizer, criterion, i)
evaluator(model)
......@@ -3,13 +3,13 @@
import collections
import logging
from typing import List, Dict, Optional, OrderedDict, Tuple, Any
from typing import List, Dict, Optional, Tuple, Any
import torch
from torch.nn import Module
from nni.common.graph_utils import TorchModuleGraph
from nni.compression.pytorch.utils import get_module_by_name
from nni.algorithms.compression.v2.pytorch.utils import get_module_by_name
_logger = logging.getLogger(__name__)
......@@ -149,7 +149,7 @@ class Compressor:
return None
return ret
def get_modules_wrapper(self) -> OrderedDict[str, Module]:
def get_modules_wrapper(self) -> Dict[str, Module]:
"""
Returns
-------
......
......@@ -5,12 +5,12 @@ import gc
import logging
import os
from pathlib import Path
from typing import List, Dict, Tuple, Literal, Optional
from typing import List, Dict, Tuple, Optional
import json_tricks
import torch
from torch import Tensor
from torch.nn import Module
from torch.tensor import Tensor
_logger = logging.getLogger(__name__)
......@@ -37,7 +37,7 @@ class Task:
self.masks_path = masks_path
self.config_list_path = config_list_path
self.status: Literal['Pending', 'Running', 'Finished'] = 'Pending'
self.status = 'Pending'
self.score: Optional[float] = None
self.state = {}
......
from .basic_pruner import *
from .basic_scheduler import PruningScheduler
from .tools import AGPTaskGenerator, LinearTaskGenerator, LotteryTicketTaskGenerator, SimulatedAnnealingTaskGenerator
......@@ -13,8 +13,7 @@ from torch.nn import Module
from torch.optim import Optimizer
from nni.algorithms.compression.v2.pytorch.base.pruner import Pruner
from nni.algorithms.compression.v2.pytorch.utils.config_validation import CompressorSchema
from nni.algorithms.compression.v2.pytorch.utils.pruning import config_list_canonical
from nni.algorithms.compression.v2.pytorch.utils import CompressorSchema, config_list_canonical
from .tools import (
DataCollector,
......@@ -43,7 +42,7 @@ from .tools import (
_logger = logging.getLogger(__name__)
__all__ = ['LevelPruner', 'L1NormPruner', 'L2NormPruner', 'FPGMPruner', 'SlimPruner', 'ActivationPruner',
'ActivationAPoZRankPruner', 'ActivationMeanRankPruner', 'TaylorFOWeightPruner']
'ActivationAPoZRankPruner', 'ActivationMeanRankPruner', 'TaylorFOWeightPruner', 'ADMMPruner']
NORMAL_SCHEMA = {
Or('sparsity', 'sparsity_per_layer'): And(float, lambda n: 0 <= n < 1),
......@@ -688,7 +687,7 @@ class ADMMPruner(BasicPruner):
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
- rho : Penalty parameters in ADMM algorithm.
- rho : Penalty parameters in ADMM algorithm. Default: 1e-4.
- op_types : Operation types to prune.
- op_names : Operation names to prune.
- op_partial_names: An auxiliary field collecting matched op_names in model, then this will convert to op_names.
......@@ -744,7 +743,7 @@ class ADMMPruner(BasicPruner):
def patched_criterion(output: Tensor, target: Tensor):
penalty = torch.tensor(0.0).to(output.device)
for name, wrapper in self.get_modules_wrapper().items():
rho = wrapper.config['rho']
rho = wrapper.config.get('rho', 1e-4)
penalty += (rho / 2) * torch.sqrt(torch.norm(wrapper.module.weight - self.Z[name] + self.U[name]))
return origin_criterion(output, target) + penalty
return patched_criterion
......
......@@ -452,7 +452,7 @@ class TaskGenerator:
This class used to generate config list for pruner in each iteration.
"""
def __init__(self, origin_model: Module, origin_masks: Dict[str, Dict[str, Tensor]] = {},
origin_config_list: List[Dict] = [], log_dir: str = '.', keep_intermidiate_result: bool = False):
origin_config_list: List[Dict] = [], log_dir: str = '.', keep_intermediate_result: bool = False):
"""
Parameters
----------
......@@ -465,16 +465,16 @@ class TaskGenerator:
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
log_dir
The log directory use to saving the task generator log.
keep_intermidiate_result
keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
"""
assert isinstance(origin_model, Module), 'Only support pytorch module.'
self._log_dir_root = Path(log_dir, datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')).absolute()
self._log_dir_root.mkdir(parents=True, exist_ok=True)
self._keep_intermidiate_result = keep_intermidiate_result
self._intermidiate_result_dir = Path(self._log_dir_root, 'intermidiate_result')
self._intermidiate_result_dir.mkdir(parents=True, exist_ok=True)
self._keep_intermediate_result = keep_intermediate_result
self._intermediate_result_dir = Path(self._log_dir_root, 'intermediate_result')
self._intermediate_result_dir.mkdir(parents=True, exist_ok=True)
# save origin data in {log_dir}/origin
self._origin_model_path = Path(self._log_dir_root, 'origin', 'model.pth')
......@@ -506,7 +506,6 @@ class TaskGenerator:
def update_best_result(self, task_result: TaskResult):
score = task_result.score
if score is not None:
task_id = task_result.task_id
task = self._tasks[task_id]
task.score = score
......@@ -540,7 +539,7 @@ class TaskGenerator:
self._pending_tasks.extend(self.generate_tasks(task_result))
self._dump_tasks_info()
if not self._keep_intermidiate_result:
if not self._keep_intermediate_result:
self._tasks[task_id].clean_up()
def next(self) -> Optional[Task]:
......
......@@ -103,8 +103,10 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator):
def _get_dependency(self):
graph = self.pruner.generate_graph(dummy_input=self.dummy_input)
self.channel_depen = ChannelDependency(traced_model=graph.trace).dependency_sets
self.group_depen = GroupDependency(traced_model=graph.trace).dependency_sets
self.pruner._unwrap_model()
self.channel_depen = ChannelDependency(model=self.pruner.bound_model, dummy_input=self.dummy_input, traced_model=graph.trace).dependency_sets
self.group_depen = GroupDependency(model=self.pruner.bound_model, dummy_input=self.dummy_input, traced_model=graph.trace).dependency_sets
self.pruner._wrap_model()
def generate_sparsity(self, metrics: Dict) -> Dict[str, Dict[str, Tensor]]:
self._get_dependency()
......
......@@ -13,7 +13,7 @@ import torch
from torch.nn import Module
from nni.algorithms.compression.v2.pytorch.base import Task, TaskResult
from nni.algorithms.compression.v2.pytorch.utils.pruning import (
from nni.algorithms.compression.v2.pytorch.utils import (
config_list_canonical,
compute_sparsity,
get_model_weights_numel
......@@ -25,7 +25,7 @@ _logger = logging.getLogger(__name__)
class FunctionBasedTaskGenerator(TaskGenerator):
def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict],
origin_masks: Dict[str, Dict[str, Tensor]] = {}, log_dir: str = '.', keep_intermidiate_result: bool = False):
origin_masks: Dict[str, Dict[str, Tensor]] = {}, log_dir: str = '.', keep_intermediate_result: bool = False):
"""
Parameters
----------
......@@ -40,7 +40,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
The pre masks on the origin model. This mask maybe user-defined or maybe generate by previous pruning.
log_dir
The log directory use to saving the task generator log.
keep_intermidiate_result
keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
"""
self.current_iteration = 0
......@@ -48,7 +48,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
self.total_iteration = total_iteration
super().__init__(origin_model, origin_config_list=self.target_sparsity, origin_masks=origin_masks,
log_dir=log_dir, keep_intermidiate_result=keep_intermidiate_result)
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result)
def init_pending_tasks(self) -> List[Task]:
origin_model = torch.load(self._origin_model_path)
......@@ -62,9 +62,9 @@ class FunctionBasedTaskGenerator(TaskGenerator):
compact_model = task_result.compact_model
compact_model_masks = task_result.compact_model_masks
# save intermidiate result
model_path = Path(self._intermidiate_result_dir, '{}_compact_model.pth'.format(task_result.task_id))
masks_path = Path(self._intermidiate_result_dir, '{}_compact_model_masks.pth'.format(task_result.task_id))
# save intermediate result
model_path = Path(self._intermediate_result_dir, '{}_compact_model.pth'.format(task_result.task_id))
masks_path = Path(self._intermediate_result_dir, '{}_compact_model_masks.pth'.format(task_result.task_id))
torch.save(compact_model, model_path)
torch.save(compact_model_masks, masks_path)
......@@ -81,7 +81,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
task_id = self._task_id_candidate
new_config_list = self.generate_config_list(self.target_sparsity, self.current_iteration, compact2origin_sparsity)
config_list_path = Path(self._intermidiate_result_dir, '{}_config_list.json'.format(task_id))
config_list_path = Path(self._intermediate_result_dir, '{}_config_list.json'.format(task_id))
with Path(config_list_path).open('w') as f:
json_tricks.dump(new_config_list, f, indent=4)
......@@ -124,9 +124,9 @@ class LinearTaskGenerator(FunctionBasedTaskGenerator):
class LotteryTicketTaskGenerator(FunctionBasedTaskGenerator):
def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict],
origin_masks: Dict[str, Dict[str, Tensor]] = {}, log_dir: str = '.', keep_intermidiate_result: bool = False):
origin_masks: Dict[str, Dict[str, Tensor]] = {}, log_dir: str = '.', keep_intermediate_result: bool = False):
super().__init__(total_iteration, origin_model, origin_config_list, origin_masks=origin_masks, log_dir=log_dir,
keep_intermidiate_result=keep_intermidiate_result)
keep_intermediate_result=keep_intermediate_result)
self.current_iteration = 1
def generate_config_list(self, target_sparsity: List[Dict], iteration: int, compact2origin_sparsity: List[Dict]) -> List[Dict]:
......@@ -147,7 +147,7 @@ class LotteryTicketTaskGenerator(FunctionBasedTaskGenerator):
class SimulatedAnnealingTaskGenerator(TaskGenerator):
def __init__(self, origin_model: Module, origin_config_list: List[Dict], origin_masks: Dict[str, Dict[str, Tensor]] = {},
start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9,
perturbation_magnitude: float = 0.35, log_dir: str = '.', keep_intermidiate_result: bool = False):
perturbation_magnitude: float = 0.35, log_dir: str = '.', keep_intermediate_result: bool = False):
"""
Parameters
----------
......@@ -168,7 +168,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
log_dir
The log directory use to saving the task generator log.
keep_intermidiate_result
keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
"""
self.start_temperature = start_temperature
......@@ -186,7 +186,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
self._current_score = None
super().__init__(origin_model, origin_masks=origin_masks, origin_config_list=origin_config_list,
log_dir=log_dir, keep_intermidiate_result=keep_intermidiate_result)
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result)
def _adjust_target_sparsity(self):
"""
......@@ -288,8 +288,8 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
origin_model = torch.load(self._origin_model_path)
origin_masks = torch.load(self._origin_masks_path)
self.temp_model_path = Path(self._intermidiate_result_dir, 'origin_compact_model.pth')
self.temp_masks_path = Path(self._intermidiate_result_dir, 'origin_compact_model_masks.pth')
self.temp_model_path = Path(self._intermediate_result_dir, 'origin_compact_model.pth')
self.temp_masks_path = Path(self._intermediate_result_dir, 'origin_compact_model_masks.pth')
torch.save(origin_model, self.temp_model_path)
torch.save(origin_masks, self.temp_masks_path)
......@@ -319,7 +319,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
task_id = self._task_id_candidate
new_config_list = self._recover_real_sparsity(deepcopy(self._temp_config_list))
config_list_path = Path(self._intermidiate_result_dir, '{}_config_list.json'.format(task_id))
config_list_path = Path(self._intermediate_result_dir, '{}_config_list.json'.format(task_id))
with Path(config_list_path).open('w') as f:
json_tricks.dump(new_config_list, f, indent=4)
......
from .config_validation import CompressorSchema
from .pruning import (
config_list_canonical,
unfold_config_list,
dedupe_config_list,
compute_sparsity_compact2origin,
compute_sparsity_mask2compact,
compute_sparsity,
get_model_weights_numel,
get_module_by_name
)
......@@ -224,3 +224,32 @@ def get_model_weights_numel(model: Module, config_list: List[Dict], masks: Dict[
else:
model_weights_numel[module_name] = module.weight.data.numel()
return model_weights_numel, masked_rate
# FIXME: to avoid circular import, copy this function in this place
def get_module_by_name(model, module_name):
"""
Get a module specified by its module name
Parameters
----------
model : pytorch model
the pytorch model from which to get its module
module_name : str
the name of the required module
Returns
-------
module, module
the parent module of the required module, the required module
"""
name_list = module_name.split(".")
for name in name_list[:-1]:
if hasattr(model, name):
model = getattr(model, name)
else:
return None, None
if hasattr(model, name_list[-1]):
leaf_module = getattr(model, name_list[-1])
return model, leaf_module
else:
return None, None
......@@ -6,6 +6,7 @@ import logging
import torch
import numpy as np
from nni.compression.pytorch.compressor import PrunerModuleWrapper
from nni.algorithms.compression.v2.pytorch.base import PrunerModuleWrapper as PrunerModuleWrapper_v2
from .utils import get_module_by_name
......@@ -390,7 +391,7 @@ class GroupDependency(Dependency):
"""
node_name = node_group.name
_, leaf_module = get_module_by_name(self.model, node_name)
if isinstance(leaf_module, PrunerModuleWrapper):
if isinstance(leaf_module, (PrunerModuleWrapper, PrunerModuleWrapper_v2)):
leaf_module = leaf_module.module
assert isinstance(
leaf_module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d))
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import unittest
import torch
import torch.nn.functional as F
from nni.algorithms.compression.v2.pytorch.pruning import (
LevelPruner,
L1NormPruner,
L2NormPruner,
SlimPruner,
FPGMPruner,
ActivationAPoZRankPruner,
ActivationMeanRankPruner,
TaylorFOWeightPruner,
ADMMPruner
)
from nni.algorithms.compression.v2.pytorch.utils import compute_sparsity_mask2compact
class TorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
self.bn1 = torch.nn.BatchNorm2d(5)
self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
self.bn2 = torch.nn.BatchNorm2d(10)
self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
self.fc2 = torch.nn.Linear(100, 10)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 10)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def trainer(model, optimizer, criterion):
model.train()
input = torch.rand(10, 1, 28, 28)
label = torch.Tensor(list(range(10))).type(torch.LongTensor)
optimizer.zero_grad()
output = model(input)
loss = criterion(output, label)
loss.backward()
optimizer.step()
def get_optimizer(model):
return torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
class PrunerTestCase(unittest.TestCase):
def test_level_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = LevelPruner(model=model, config_list=config_list)
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
def test_l1_norm_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = L1NormPruner(model=model, config_list=config_list, mode='dependency_aware',
dummy_input=torch.rand(10, 1, 28, 28))
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
def test_l2_norm_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = L2NormPruner(model=model, config_list=config_list, mode='dependency_aware',
dummy_input=torch.rand(10, 1, 28, 28))
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
def test_fpgm_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = FPGMPruner(model=model, config_list=config_list, mode='dependency_aware',
dummy_input=torch.rand(10, 1, 28, 28))
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
def test_slim_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['BatchNorm2d'], 'total_sparsity': 0.8}]
pruner = SlimPruner(model=model, config_list=config_list, trainer=trainer, optimizer=get_optimizer(model),
criterion=criterion, training_epochs=1, scale=0.001, mode='global')
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
def test_activation_apoz_rank_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = ActivationAPoZRankPruner(model=model, config_list=config_list, trainer=trainer,
optimizer=get_optimizer(model), criterion=criterion, training_batches=1,
activation='relu', mode='dependency_aware',
dummy_input=torch.rand(10, 1, 28, 28))
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
def test_activation_mean_rank_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = ActivationMeanRankPruner(model=model, config_list=config_list, trainer=trainer,
optimizer=get_optimizer(model), criterion=criterion, training_batches=1,
activation='relu', mode='dependency_aware',
dummy_input=torch.rand(10, 1, 28, 28))
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
def test_taylor_fo_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = TaylorFOWeightPruner(model=model, config_list=config_list, trainer=trainer,
optimizer=get_optimizer(model), criterion=criterion, training_batches=1,
mode='dependency_aware', dummy_input=torch.rand(10, 1, 28, 28))
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
def test_admm_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8, 'rho': 1e-3}]
pruner = ADMMPruner(model=model, config_list=config_list, trainer=trainer, optimizer=get_optimizer(model),
criterion=criterion, iterations=2, training_epochs=1)
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
if __name__ == '__main__':
unittest.main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import unittest
import torch
import torch.nn.functional as F
from nni.algorithms.compression.v2.pytorch.base import Pruner
from nni.algorithms.compression.v2.pytorch.pruning.tools import (
WeightDataCollector,
WeightTrainerBasedDataCollector,
SingleHookTrainerBasedDataCollector
)
from nni.algorithms.compression.v2.pytorch.pruning.tools import (
NormMetricsCalculator,
MultiDataNormMetricsCalculator,
DistMetricsCalculator,
APoZRankMetricsCalculator,
MeanRankMetricsCalculator
)
from nni.algorithms.compression.v2.pytorch.pruning.tools import (
NormalSparsityAllocator,
GlobalSparsityAllocator
)
from nni.algorithms.compression.v2.pytorch.pruning.tools.base import HookCollectorInfo
from nni.algorithms.compression.v2.pytorch.utils import get_module_by_name
class TorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
self.bn1 = torch.nn.BatchNorm2d(5)
self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
self.bn2 = torch.nn.BatchNorm2d(10)
self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
self.fc2 = torch.nn.Linear(100, 10)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 10)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def trainer(model, optimizer, criterion):
model.train()
input = torch.rand(10, 1, 28, 28)
label = torch.Tensor(list(range(10))).type(torch.LongTensor)
optimizer.zero_grad()
output = model(input)
loss = criterion(output, label)
loss.backward()
optimizer.step()
def get_optimizer(model):
return torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
class PruningToolsTestCase(unittest.TestCase):
def test_data_collector(self):
model = TorchModel()
w1 = torch.rand(5, 1, 5, 5)
w2 = torch.rand(10, 5, 5, 5)
model.conv1.weight.data = w1
model.conv2.weight.data = w2
config_list = [{'op_types': ['Conv2d']}]
pruner = Pruner(model, config_list)
# Test WeightDataCollector
data_collector = WeightDataCollector(pruner)
data = data_collector.collect()
assert all(torch.equal(get_module_by_name(model, module_name)[1].module.weight.data, data[module_name]) for module_name in ['conv1', 'conv2'])
# Test WeightTrainerBasedDataCollector
def opt_after():
model.conv1.module.weight.data = torch.ones(5, 1, 5, 5)
model.conv2.module.weight.data = torch.ones(10, 5, 5, 5)
data_collector = WeightTrainerBasedDataCollector(pruner, trainer, get_optimizer(model), criterion, 1, opt_after_tasks=[opt_after])
data = data_collector.collect()
assert all(torch.equal(get_module_by_name(model, module_name)[1].module.weight.data, data[module_name]) for module_name in ['conv1', 'conv2'])
assert all(t.numel() == (t == 1).type_as(t).sum().item() for t in data.values())
# Test SingleHookTrainerBasedDataCollector
def _collector(buffer, weight_tensor):
def collect_taylor(grad):
if len(buffer) < 2:
buffer.append(grad.clone().detach())
return collect_taylor
hook_targets = {'conv1': model.conv1.module.weight, 'conv2': model.conv2.module.weight}
collector_info = HookCollectorInfo(hook_targets, 'tensor', _collector)
data_collector = SingleHookTrainerBasedDataCollector(pruner, trainer, get_optimizer(model), criterion, 2, collector_infos=[collector_info])
data = data_collector.collect()
assert all(len(t) == 2 for t in data.values())
def test_metrics_calculator(self):
# Test NormMetricsCalculator
metrics_calculator = NormMetricsCalculator(dim=0, p=2)
data = {
'1': torch.ones(3, 3, 3),
'2': torch.ones(4, 4) * 2
}
result = {
'1': torch.ones(3) * 3,
'2': torch.ones(4) * 4
}
metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items())
# Test DistMetricsCalculator
metrics_calculator = DistMetricsCalculator(dim=0, p=2)
data = {
'1': torch.tensor([[1, 2], [4, 6]], dtype=torch.float32),
'2': torch.tensor([[0, 0], [1, 1]], dtype=torch.float32)
}
result = {
'1': torch.tensor([5, 5], dtype=torch.float32),
'2': torch.sqrt(torch.tensor([2, 2], dtype=torch.float32))
}
metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items())
# Test MultiDataNormMetricsCalculator
metrics_calculator = MultiDataNormMetricsCalculator(dim=0, p=1)
data = {
'1': [torch.ones(3, 3, 3), torch.ones(3, 3, 3) * 2],
'2': [torch.ones(4, 4), torch.ones(4, 4) * 2]
}
result = {
'1': torch.ones(3) * 27,
'2': torch.ones(4) * 12
}
metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items())
# Test APoZRankMetricsCalculator
metrics_calculator = APoZRankMetricsCalculator(dim=1)
data = {
'1': [torch.tensor([[1, 0], [0, 1]], dtype=torch.float32), torch.tensor([[0, 1], [1, 0]], dtype=torch.float32)],
'2': [torch.tensor([[1, 0, 1], [0, 1, 0]], dtype=torch.float32), torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)]
}
result = {
'1': torch.tensor([0.5, 0.5], dtype=torch.float32),
'2': torch.tensor([0.25, 0.25, 0.5], dtype=torch.float32)
}
metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items())
# Test MeanRankMetricsCalculator
metrics_calculator = MeanRankMetricsCalculator(dim=1)
data = {
'1': [torch.tensor([[1, 0], [0, 1]], dtype=torch.float32), torch.tensor([[0, 1], [1, 0]], dtype=torch.float32)],
'2': [torch.tensor([[1, 0, 1], [0, 1, 0]], dtype=torch.float32), torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)]
}
result = {
'1': torch.tensor([0.5, 0.5], dtype=torch.float32),
'2': torch.tensor([0.25, 0.25, 0.5], dtype=torch.float32)
}
metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items())
def test_sparsity_allocator(self):
# Test NormalSparsityAllocator
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.8}]
pruner = Pruner(model, config_list)
metrics = {
'conv1': torch.rand(5, 1, 5, 5),
'conv2': torch.rand(10, 5, 5, 5)
}
sparsity_allocator = NormalSparsityAllocator(pruner)
masks = sparsity_allocator.generate_sparsity(metrics)
assert all(v['weight'].sum() / v['weight'].numel() == 0.2 for k, v in masks.items())
# Test GlobalSparsityAllocator
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.8}]
pruner = Pruner(model, config_list)
sparsity_allocator = GlobalSparsityAllocator(pruner)
masks = sparsity_allocator.generate_sparsity(metrics)
total_elements, total_masked_elements = 0, 0
for t in masks.values():
total_elements += t['weight'].numel()
total_masked_elements += t['weight'].sum().item()
assert total_masked_elements / total_elements == 0.2
if __name__ == '__main__':
unittest.main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import unittest
import torch
import torch.nn.functional as F
from nni.algorithms.compression.v2.pytorch.pruning import PruningScheduler, L1NormPruner, AGPTaskGenerator
class TorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
self.bn1 = torch.nn.BatchNorm2d(5)
self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
self.bn2 = torch.nn.BatchNorm2d(10)
self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
self.fc2 = torch.nn.Linear(100, 10)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 10)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
class PruningSchedulerTestCase(unittest.TestCase):
def test_pruning_scheduler(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
task_generator = AGPTaskGenerator(1, model, config_list)
pruner = L1NormPruner(model, config_list)
scheduler = PruningScheduler(pruner, task_generator)
scheduler.compress()
if __name__ == '__main__':
unittest.main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import List
import unittest
import torch
import torch.nn.functional as F
from nni.algorithms.compression.v2.pytorch.base import Task, TaskResult
from nni.algorithms.compression.v2.pytorch.pruning import (
AGPTaskGenerator,
LinearTaskGenerator,
LotteryTicketTaskGenerator,
SimulatedAnnealingTaskGenerator
)
class TorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
self.bn1 = torch.nn.BatchNorm2d(5)
self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
self.bn2 = torch.nn.BatchNorm2d(10)
self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
self.fc2 = torch.nn.Linear(100, 10)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 10)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def run_task_generator(task_generator_type):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
if task_generator_type == 'agp':
task_generator = AGPTaskGenerator(5, model, config_list)
elif task_generator_type == 'linear':
task_generator = LinearTaskGenerator(5, model, config_list)
elif task_generator_type == 'lottery_ticket':
task_generator = LotteryTicketTaskGenerator(5, model, config_list)
elif task_generator_type == 'simulated_annealing':
task_generator = SimulatedAnnealingTaskGenerator(model, config_list)
count = run_task_generator_(task_generator)
if task_generator_type == 'agp':
assert count == 6
elif task_generator_type == 'linear':
assert count == 6
elif task_generator_type == 'lottery_ticket':
assert count == 6
elif task_generator_type == 'simulated_annealing':
assert count == 17
def run_task_generator_(task_generator):
task = task_generator.next()
factor = 0.9
count = 0
while task is not None:
factor = factor ** 2
count += 1
task_result = TaskResult(task.task_id, TorchModel(), {}, {}, 1 - factor)
task_generator.receive_task_result(task_result)
task = task_generator.next()
return count
class TaskGenerator(unittest.TestCase):
def test_agp_task_generator(self):
run_task_generator('agp')
def test_linear_task_generator(self):
run_task_generator('linear')
def test_lottery_ticket_task_generator(self):
run_task_generator('lottery_ticket')
def test_simulated_annealing_task_generator(self):
run_task_generator('simulated_annealing')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment