"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "2c83637db7a1aacab8b2a50dfdda14db7b2f48de"
Unverified Commit e3e17f47 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Model Compression] Add Unit Test (#4125)

parent 9a68cdb2
import argparse
import logging
from pathlib import Path
import torch
from torchvision import transforms, datasets
from nni.algorithms.compression.v2.pytorch import pruning
from nni.compression.pytorch import ModelSpeedup
from examples.model_compress.models.cifar10.vgg import VGG
logging.getLogger().setLevel(logging.DEBUG)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VGG().to(device)
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=True, transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
normalize,
]), download=True),
batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
normalize,
])),
batch_size=200, shuffle=False)
criterion = torch.nn.CrossEntropyLoss()
def trainer(model, optimizer, criterion, epoch=None):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def evaluator(model):
model.eval()
criterion = torch.nn.NLLLoss()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
acc = 100 * correct / len(test_loader.dataset)
print('Test Loss: {} Accuracy: {}%\n'.format(
test_loss, acc))
return acc
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
fintune_optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
def main(args):
if args.pre_train:
for i in range(1):
trainer(model, fintune_optimizer, criterion, epoch=i)
config_list = [{
'op_types': ['Conv2d'],
'sparsity_per_layer': 0.8
}]
kwargs = {
'model': model,
'config_list': config_list,
}
if args.pruner == 'level':
pruner = pruning.LevelPruner(**kwargs)
else:
kwargs['mode'] = args.mode
if kwargs['mode'] == 'dependency_aware':
kwargs['dummy_input'] = torch.rand(10, 3, 32, 32).to(device)
if args.pruner == 'l1norm':
pruner = pruning.L1NormPruner(**kwargs)
elif args.pruner == 'l2norm':
pruner = pruning.L2NormPruner(**kwargs)
elif args.pruner == 'fpgm':
pruner = pruning.FPGMPruner(**kwargs)
else:
kwargs['trainer'] = trainer
kwargs['optimizer'] = optimizer
kwargs['criterion'] = criterion
if args.pruner == 'slim':
kwargs['config_list'] = [{
'op_types': ['BatchNorm2d'],
'total_sparsity': 0.8,
'max_sparsity_per_layer': 0.9
}]
kwargs['training_epochs'] = 1
pruner = pruning.SlimPruner(**kwargs)
elif args.pruner == 'mean_activation':
pruner = pruning.ActivationMeanRankPruner(**kwargs)
elif args.pruner == 'apoz':
pruner = pruning.ActivationAPoZRankPruner(**kwargs)
elif args.pruner == 'taylorfo':
pruner = pruning.TaylorFOWeightPruner(**kwargs)
pruned_model, masks = pruner.compress()
pruner.show_pruned_weights()
if args.speed_up:
tmp_masks = {}
for name, mask in masks.items():
tmp_masks[name] = {}
tmp_masks[name]['weight'] = mask.get('weight_mask')
if 'bias' in masks:
tmp_masks[name]['bias'] = mask.get('bias_mask')
torch.save(tmp_masks, Path('./temp_masks.pth'))
pruner._unwrap_model()
ModelSpeedup(model, torch.rand(10, 3, 32, 32).to(device), Path('./temp_masks.pth'))
if args.finetune:
for i in range(1):
trainer(pruned_model, fintune_optimizer, criterion, epoch=i)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='PyTorch CIFAR10 Example for model comporession')
parser.add_argument('--pruner', type=str, default='l1norm',
choices=['level', 'l1norm', 'l2norm', 'slim',
'fpgm', 'mean_activation', 'apoz', 'taylorfo'],
help='pruner to use')
parser.add_argument('--mode', type=str, default='normal',
choices=['normal', 'dependency_aware', 'global'])
parser.add_argument('--pre-train', action='store_true', default=False,
help='Whether to pre-train the model')
parser.add_argument('--speed-up', action='store_true', default=False,
help='Whether to speed-up the pruned model')
parser.add_argument('--finetune', action='store_true', default=False,
help='Whether to finetune the pruned model')
args = parser.parse_args()
main(args)
import functools
from tqdm import tqdm
import torch
from torchvision import datasets, transforms
from nni.algorithms.compression.v2.pytorch.pruning import L1NormPruner
from nni.algorithms.compression.v2.pytorch.pruning.tools import AGPTaskGenerator
from nni.algorithms.compression.v2.pytorch.pruning.basic_scheduler import PruningScheduler
from examples.model_compress.models.cifar10.vgg import VGG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=True, transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
normalize,
]), download=True),
batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
normalize,
])),
batch_size=128, shuffle=False)
criterion = torch.nn.CrossEntropyLoss()
def trainer(model, optimizer, criterion, epoch):
model.train()
for data, target in tqdm(iterable=train_loader, desc='Epoch {}'.format(epoch)):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
def finetuner(model):
model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
for data, target in tqdm(iterable=train_loader, desc='Epoch PFs'):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
def evaluator(model):
model.eval()
correct = 0
with torch.no_grad():
for data, target in tqdm(iterable=test_loader, desc='Test'):
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
acc = 100 * correct / len(test_loader.dataset)
print('Accuracy: {}%\n'.format(acc))
return acc
if __name__ == '__main__':
model = VGG().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
# pre-train the model
for i in range(5):
trainer(model, optimizer, criterion, i)
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
# Make sure initialize task generator at first, this because the model pass to the generator should be an unwrapped model.
# If you want to initialize pruner at first, you can use the follow code.
# pruner = L1NormPruner(model, config_list)
# pruner._unwrap_model()
# task_generator = AGPTaskGenerator(10, model, config_list, log_dir='.', keep_intermediate_result=True)
# pruner._wrap_model()
# you can specify the log_dir, all intermediate results and best result will save under this folder.
# if you don't want to keep intermediate results, you can set `keep_intermediate_result=False`.
task_generator = AGPTaskGenerator(10, model, config_list, log_dir='.', keep_intermediate_result=True)
pruner = L1NormPruner(model, config_list)
dummy_input = torch.rand(10, 3, 32, 32).to(device)
# if you just want to keep the final result as the best result, you can pass evaluator as None.
# or the result with the highest score (given by evaluator) will be the best result.
# scheduler = PruningScheduler(pruner, task_generator, finetuner=finetuner, speed_up=True, dummy_input=dummy_input, evaluator=evaluator)
scheduler = PruningScheduler(pruner, task_generator, finetuner=finetuner, speed_up=True, dummy_input=dummy_input, evaluator=None)
scheduler.compress()
from tqdm import tqdm
import torch
from torchvision import datasets, transforms
from nni.algorithms.compression.v2.pytorch.pruning import L1NormPruner
from nni.compression.pytorch.speedup import ModelSpeedup
from examples.model_compress.models.cifar10.vgg import VGG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=True, transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
normalize,
]), download=True),
batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
normalize,
])),
batch_size=128, shuffle=False)
criterion = torch.nn.CrossEntropyLoss()
def trainer(model, optimizer, criterion, epoch):
model.train()
for data, target in tqdm(iterable=train_loader, desc='Epoch {}'.format(epoch)):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
def evaluator(model):
model.eval()
correct = 0
with torch.no_grad():
for data, target in tqdm(iterable=test_loader, desc='Test'):
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
acc = 100 * correct / len(test_loader.dataset)
print('Accuracy: {}%\n'.format(acc))
return acc
if __name__ == '__main__':
model = VGG().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
print('\nPre-train the model:')
for i in range(5):
trainer(model, optimizer, criterion, i)
evaluator(model)
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = L1NormPruner(model, config_list)
_, masks = pruner.compress()
print('\nThe accuracy with masks:')
evaluator(model)
pruner._unwrap_model()
ModelSpeedup(model, dummy_input=torch.rand(10, 3, 32, 32).to(device), masks_file='simple_masks.pth').speedup_model()
print('\nThe accuracy after speed up:')
evaluator(model)
print('\nFinetune the model after speed up:')
for i in range(5):
trainer(model, optimizer, criterion, i)
evaluator(model)
...@@ -3,13 +3,13 @@ ...@@ -3,13 +3,13 @@
import collections import collections
import logging import logging
from typing import List, Dict, Optional, OrderedDict, Tuple, Any from typing import List, Dict, Optional, Tuple, Any
import torch import torch
from torch.nn import Module from torch.nn import Module
from nni.common.graph_utils import TorchModuleGraph from nni.common.graph_utils import TorchModuleGraph
from nni.compression.pytorch.utils import get_module_by_name from nni.algorithms.compression.v2.pytorch.utils import get_module_by_name
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -149,7 +149,7 @@ class Compressor: ...@@ -149,7 +149,7 @@ class Compressor:
return None return None
return ret return ret
def get_modules_wrapper(self) -> OrderedDict[str, Module]: def get_modules_wrapper(self) -> Dict[str, Module]:
""" """
Returns Returns
------- -------
......
...@@ -5,12 +5,12 @@ import gc ...@@ -5,12 +5,12 @@ import gc
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import List, Dict, Tuple, Literal, Optional from typing import List, Dict, Tuple, Optional
import json_tricks import json_tricks
import torch import torch
from torch import Tensor
from torch.nn import Module from torch.nn import Module
from torch.tensor import Tensor
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -37,7 +37,7 @@ class Task: ...@@ -37,7 +37,7 @@ class Task:
self.masks_path = masks_path self.masks_path = masks_path
self.config_list_path = config_list_path self.config_list_path = config_list_path
self.status: Literal['Pending', 'Running', 'Finished'] = 'Pending' self.status = 'Pending'
self.score: Optional[float] = None self.score: Optional[float] = None
self.state = {} self.state = {}
......
from .basic_pruner import * from .basic_pruner import *
from .basic_scheduler import PruningScheduler
from .tools import AGPTaskGenerator, LinearTaskGenerator, LotteryTicketTaskGenerator, SimulatedAnnealingTaskGenerator
...@@ -13,8 +13,7 @@ from torch.nn import Module ...@@ -13,8 +13,7 @@ from torch.nn import Module
from torch.optim import Optimizer from torch.optim import Optimizer
from nni.algorithms.compression.v2.pytorch.base.pruner import Pruner from nni.algorithms.compression.v2.pytorch.base.pruner import Pruner
from nni.algorithms.compression.v2.pytorch.utils.config_validation import CompressorSchema from nni.algorithms.compression.v2.pytorch.utils import CompressorSchema, config_list_canonical
from nni.algorithms.compression.v2.pytorch.utils.pruning import config_list_canonical
from .tools import ( from .tools import (
DataCollector, DataCollector,
...@@ -43,7 +42,7 @@ from .tools import ( ...@@ -43,7 +42,7 @@ from .tools import (
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
__all__ = ['LevelPruner', 'L1NormPruner', 'L2NormPruner', 'FPGMPruner', 'SlimPruner', 'ActivationPruner', __all__ = ['LevelPruner', 'L1NormPruner', 'L2NormPruner', 'FPGMPruner', 'SlimPruner', 'ActivationPruner',
'ActivationAPoZRankPruner', 'ActivationMeanRankPruner', 'TaylorFOWeightPruner'] 'ActivationAPoZRankPruner', 'ActivationMeanRankPruner', 'TaylorFOWeightPruner', 'ADMMPruner']
NORMAL_SCHEMA = { NORMAL_SCHEMA = {
Or('sparsity', 'sparsity_per_layer'): And(float, lambda n: 0 <= n < 1), Or('sparsity', 'sparsity_per_layer'): And(float, lambda n: 0 <= n < 1),
...@@ -688,7 +687,7 @@ class ADMMPruner(BasicPruner): ...@@ -688,7 +687,7 @@ class ADMMPruner(BasicPruner):
Supported keys: Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed. - sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity. - sparsity_per_layer : Equals to sparsity.
- rho : Penalty parameters in ADMM algorithm. - rho : Penalty parameters in ADMM algorithm. Default: 1e-4.
- op_types : Operation types to prune. - op_types : Operation types to prune.
- op_names : Operation names to prune. - op_names : Operation names to prune.
- op_partial_names: An auxiliary field collecting matched op_names in model, then this will convert to op_names. - op_partial_names: An auxiliary field collecting matched op_names in model, then this will convert to op_names.
...@@ -744,7 +743,7 @@ class ADMMPruner(BasicPruner): ...@@ -744,7 +743,7 @@ class ADMMPruner(BasicPruner):
def patched_criterion(output: Tensor, target: Tensor): def patched_criterion(output: Tensor, target: Tensor):
penalty = torch.tensor(0.0).to(output.device) penalty = torch.tensor(0.0).to(output.device)
for name, wrapper in self.get_modules_wrapper().items(): for name, wrapper in self.get_modules_wrapper().items():
rho = wrapper.config['rho'] rho = wrapper.config.get('rho', 1e-4)
penalty += (rho / 2) * torch.sqrt(torch.norm(wrapper.module.weight - self.Z[name] + self.U[name])) penalty += (rho / 2) * torch.sqrt(torch.norm(wrapper.module.weight - self.Z[name] + self.U[name]))
return origin_criterion(output, target) + penalty return origin_criterion(output, target) + penalty
return patched_criterion return patched_criterion
......
...@@ -452,7 +452,7 @@ class TaskGenerator: ...@@ -452,7 +452,7 @@ class TaskGenerator:
This class used to generate config list for pruner in each iteration. This class used to generate config list for pruner in each iteration.
""" """
def __init__(self, origin_model: Module, origin_masks: Dict[str, Dict[str, Tensor]] = {}, def __init__(self, origin_model: Module, origin_masks: Dict[str, Dict[str, Tensor]] = {},
origin_config_list: List[Dict] = [], log_dir: str = '.', keep_intermidiate_result: bool = False): origin_config_list: List[Dict] = [], log_dir: str = '.', keep_intermediate_result: bool = False):
""" """
Parameters Parameters
---------- ----------
...@@ -465,16 +465,16 @@ class TaskGenerator: ...@@ -465,16 +465,16 @@ class TaskGenerator:
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list. This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
log_dir log_dir
The log directory use to saving the task generator log. The log directory use to saving the task generator log.
keep_intermidiate_result keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration. If keeping the intermediate result, including intermediate model and masks during each iteration.
""" """
assert isinstance(origin_model, Module), 'Only support pytorch module.' assert isinstance(origin_model, Module), 'Only support pytorch module.'
self._log_dir_root = Path(log_dir, datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')).absolute() self._log_dir_root = Path(log_dir, datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')).absolute()
self._log_dir_root.mkdir(parents=True, exist_ok=True) self._log_dir_root.mkdir(parents=True, exist_ok=True)
self._keep_intermidiate_result = keep_intermidiate_result self._keep_intermediate_result = keep_intermediate_result
self._intermidiate_result_dir = Path(self._log_dir_root, 'intermidiate_result') self._intermediate_result_dir = Path(self._log_dir_root, 'intermediate_result')
self._intermidiate_result_dir.mkdir(parents=True, exist_ok=True) self._intermediate_result_dir.mkdir(parents=True, exist_ok=True)
# save origin data in {log_dir}/origin # save origin data in {log_dir}/origin
self._origin_model_path = Path(self._log_dir_root, 'origin', 'model.pth') self._origin_model_path = Path(self._log_dir_root, 'origin', 'model.pth')
...@@ -506,16 +506,15 @@ class TaskGenerator: ...@@ -506,16 +506,15 @@ class TaskGenerator:
def update_best_result(self, task_result: TaskResult): def update_best_result(self, task_result: TaskResult):
score = task_result.score score = task_result.score
if score is not None: task_id = task_result.task_id
task_id = task_result.task_id task = self._tasks[task_id]
task = self._tasks[task_id] task.score = score
task.score = score if self._best_score is None or score > self._best_score:
if self._best_score is None or score > self._best_score: self._best_score = score
self._best_score = score self._best_task_id = task_id
self._best_task_id = task_id with Path(task.config_list_path).open('r') as fr:
with Path(task.config_list_path).open('r') as fr: best_config_list = json_tricks.load(fr)
best_config_list = json_tricks.load(fr) self._save_data('best_result', task_result.compact_model, task_result.compact_model_masks, best_config_list)
self._save_data('best_result', task_result.compact_model, task_result.compact_model_masks, best_config_list)
def init_pending_tasks(self) -> List[Task]: def init_pending_tasks(self) -> List[Task]:
raise NotImplementedError() raise NotImplementedError()
...@@ -540,7 +539,7 @@ class TaskGenerator: ...@@ -540,7 +539,7 @@ class TaskGenerator:
self._pending_tasks.extend(self.generate_tasks(task_result)) self._pending_tasks.extend(self.generate_tasks(task_result))
self._dump_tasks_info() self._dump_tasks_info()
if not self._keep_intermidiate_result: if not self._keep_intermediate_result:
self._tasks[task_id].clean_up() self._tasks[task_id].clean_up()
def next(self) -> Optional[Task]: def next(self) -> Optional[Task]:
......
...@@ -103,8 +103,10 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator): ...@@ -103,8 +103,10 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator):
def _get_dependency(self): def _get_dependency(self):
graph = self.pruner.generate_graph(dummy_input=self.dummy_input) graph = self.pruner.generate_graph(dummy_input=self.dummy_input)
self.channel_depen = ChannelDependency(traced_model=graph.trace).dependency_sets self.pruner._unwrap_model()
self.group_depen = GroupDependency(traced_model=graph.trace).dependency_sets self.channel_depen = ChannelDependency(model=self.pruner.bound_model, dummy_input=self.dummy_input, traced_model=graph.trace).dependency_sets
self.group_depen = GroupDependency(model=self.pruner.bound_model, dummy_input=self.dummy_input, traced_model=graph.trace).dependency_sets
self.pruner._wrap_model()
def generate_sparsity(self, metrics: Dict) -> Dict[str, Dict[str, Tensor]]: def generate_sparsity(self, metrics: Dict) -> Dict[str, Dict[str, Tensor]]:
self._get_dependency() self._get_dependency()
......
...@@ -13,7 +13,7 @@ import torch ...@@ -13,7 +13,7 @@ import torch
from torch.nn import Module from torch.nn import Module
from nni.algorithms.compression.v2.pytorch.base import Task, TaskResult from nni.algorithms.compression.v2.pytorch.base import Task, TaskResult
from nni.algorithms.compression.v2.pytorch.utils.pruning import ( from nni.algorithms.compression.v2.pytorch.utils import (
config_list_canonical, config_list_canonical,
compute_sparsity, compute_sparsity,
get_model_weights_numel get_model_weights_numel
...@@ -25,7 +25,7 @@ _logger = logging.getLogger(__name__) ...@@ -25,7 +25,7 @@ _logger = logging.getLogger(__name__)
class FunctionBasedTaskGenerator(TaskGenerator): class FunctionBasedTaskGenerator(TaskGenerator):
def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict], def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict],
origin_masks: Dict[str, Dict[str, Tensor]] = {}, log_dir: str = '.', keep_intermidiate_result: bool = False): origin_masks: Dict[str, Dict[str, Tensor]] = {}, log_dir: str = '.', keep_intermediate_result: bool = False):
""" """
Parameters Parameters
---------- ----------
...@@ -40,7 +40,7 @@ class FunctionBasedTaskGenerator(TaskGenerator): ...@@ -40,7 +40,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
The pre masks on the origin model. This mask maybe user-defined or maybe generate by previous pruning. The pre masks on the origin model. This mask maybe user-defined or maybe generate by previous pruning.
log_dir log_dir
The log directory use to saving the task generator log. The log directory use to saving the task generator log.
keep_intermidiate_result keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration. If keeping the intermediate result, including intermediate model and masks during each iteration.
""" """
self.current_iteration = 0 self.current_iteration = 0
...@@ -48,7 +48,7 @@ class FunctionBasedTaskGenerator(TaskGenerator): ...@@ -48,7 +48,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
self.total_iteration = total_iteration self.total_iteration = total_iteration
super().__init__(origin_model, origin_config_list=self.target_sparsity, origin_masks=origin_masks, super().__init__(origin_model, origin_config_list=self.target_sparsity, origin_masks=origin_masks,
log_dir=log_dir, keep_intermidiate_result=keep_intermidiate_result) log_dir=log_dir, keep_intermediate_result=keep_intermediate_result)
def init_pending_tasks(self) -> List[Task]: def init_pending_tasks(self) -> List[Task]:
origin_model = torch.load(self._origin_model_path) origin_model = torch.load(self._origin_model_path)
...@@ -62,9 +62,9 @@ class FunctionBasedTaskGenerator(TaskGenerator): ...@@ -62,9 +62,9 @@ class FunctionBasedTaskGenerator(TaskGenerator):
compact_model = task_result.compact_model compact_model = task_result.compact_model
compact_model_masks = task_result.compact_model_masks compact_model_masks = task_result.compact_model_masks
# save intermidiate result # save intermediate result
model_path = Path(self._intermidiate_result_dir, '{}_compact_model.pth'.format(task_result.task_id)) model_path = Path(self._intermediate_result_dir, '{}_compact_model.pth'.format(task_result.task_id))
masks_path = Path(self._intermidiate_result_dir, '{}_compact_model_masks.pth'.format(task_result.task_id)) masks_path = Path(self._intermediate_result_dir, '{}_compact_model_masks.pth'.format(task_result.task_id))
torch.save(compact_model, model_path) torch.save(compact_model, model_path)
torch.save(compact_model_masks, masks_path) torch.save(compact_model_masks, masks_path)
...@@ -81,7 +81,7 @@ class FunctionBasedTaskGenerator(TaskGenerator): ...@@ -81,7 +81,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
task_id = self._task_id_candidate task_id = self._task_id_candidate
new_config_list = self.generate_config_list(self.target_sparsity, self.current_iteration, compact2origin_sparsity) new_config_list = self.generate_config_list(self.target_sparsity, self.current_iteration, compact2origin_sparsity)
config_list_path = Path(self._intermidiate_result_dir, '{}_config_list.json'.format(task_id)) config_list_path = Path(self._intermediate_result_dir, '{}_config_list.json'.format(task_id))
with Path(config_list_path).open('w') as f: with Path(config_list_path).open('w') as f:
json_tricks.dump(new_config_list, f, indent=4) json_tricks.dump(new_config_list, f, indent=4)
...@@ -124,9 +124,9 @@ class LinearTaskGenerator(FunctionBasedTaskGenerator): ...@@ -124,9 +124,9 @@ class LinearTaskGenerator(FunctionBasedTaskGenerator):
class LotteryTicketTaskGenerator(FunctionBasedTaskGenerator): class LotteryTicketTaskGenerator(FunctionBasedTaskGenerator):
def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict], def __init__(self, total_iteration: int, origin_model: Module, origin_config_list: List[Dict],
origin_masks: Dict[str, Dict[str, Tensor]] = {}, log_dir: str = '.', keep_intermidiate_result: bool = False): origin_masks: Dict[str, Dict[str, Tensor]] = {}, log_dir: str = '.', keep_intermediate_result: bool = False):
super().__init__(total_iteration, origin_model, origin_config_list, origin_masks=origin_masks, log_dir=log_dir, super().__init__(total_iteration, origin_model, origin_config_list, origin_masks=origin_masks, log_dir=log_dir,
keep_intermidiate_result=keep_intermidiate_result) keep_intermediate_result=keep_intermediate_result)
self.current_iteration = 1 self.current_iteration = 1
def generate_config_list(self, target_sparsity: List[Dict], iteration: int, compact2origin_sparsity: List[Dict]) -> List[Dict]: def generate_config_list(self, target_sparsity: List[Dict], iteration: int, compact2origin_sparsity: List[Dict]) -> List[Dict]:
...@@ -147,7 +147,7 @@ class LotteryTicketTaskGenerator(FunctionBasedTaskGenerator): ...@@ -147,7 +147,7 @@ class LotteryTicketTaskGenerator(FunctionBasedTaskGenerator):
class SimulatedAnnealingTaskGenerator(TaskGenerator): class SimulatedAnnealingTaskGenerator(TaskGenerator):
def __init__(self, origin_model: Module, origin_config_list: List[Dict], origin_masks: Dict[str, Dict[str, Tensor]] = {}, def __init__(self, origin_model: Module, origin_config_list: List[Dict], origin_masks: Dict[str, Dict[str, Tensor]] = {},
start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9, start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9,
perturbation_magnitude: float = 0.35, log_dir: str = '.', keep_intermidiate_result: bool = False): perturbation_magnitude: float = 0.35, log_dir: str = '.', keep_intermediate_result: bool = False):
""" """
Parameters Parameters
---------- ----------
...@@ -168,7 +168,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator): ...@@ -168,7 +168,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature. Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
log_dir log_dir
The log directory use to saving the task generator log. The log directory use to saving the task generator log.
keep_intermidiate_result keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration. If keeping the intermediate result, including intermediate model and masks during each iteration.
""" """
self.start_temperature = start_temperature self.start_temperature = start_temperature
...@@ -186,7 +186,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator): ...@@ -186,7 +186,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
self._current_score = None self._current_score = None
super().__init__(origin_model, origin_masks=origin_masks, origin_config_list=origin_config_list, super().__init__(origin_model, origin_masks=origin_masks, origin_config_list=origin_config_list,
log_dir=log_dir, keep_intermidiate_result=keep_intermidiate_result) log_dir=log_dir, keep_intermediate_result=keep_intermediate_result)
def _adjust_target_sparsity(self): def _adjust_target_sparsity(self):
""" """
...@@ -288,8 +288,8 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator): ...@@ -288,8 +288,8 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
origin_model = torch.load(self._origin_model_path) origin_model = torch.load(self._origin_model_path)
origin_masks = torch.load(self._origin_masks_path) origin_masks = torch.load(self._origin_masks_path)
self.temp_model_path = Path(self._intermidiate_result_dir, 'origin_compact_model.pth') self.temp_model_path = Path(self._intermediate_result_dir, 'origin_compact_model.pth')
self.temp_masks_path = Path(self._intermidiate_result_dir, 'origin_compact_model_masks.pth') self.temp_masks_path = Path(self._intermediate_result_dir, 'origin_compact_model_masks.pth')
torch.save(origin_model, self.temp_model_path) torch.save(origin_model, self.temp_model_path)
torch.save(origin_masks, self.temp_masks_path) torch.save(origin_masks, self.temp_masks_path)
...@@ -319,7 +319,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator): ...@@ -319,7 +319,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
task_id = self._task_id_candidate task_id = self._task_id_candidate
new_config_list = self._recover_real_sparsity(deepcopy(self._temp_config_list)) new_config_list = self._recover_real_sparsity(deepcopy(self._temp_config_list))
config_list_path = Path(self._intermidiate_result_dir, '{}_config_list.json'.format(task_id)) config_list_path = Path(self._intermediate_result_dir, '{}_config_list.json'.format(task_id))
with Path(config_list_path).open('w') as f: with Path(config_list_path).open('w') as f:
json_tricks.dump(new_config_list, f, indent=4) json_tricks.dump(new_config_list, f, indent=4)
......
from .config_validation import CompressorSchema
from .pruning import (
config_list_canonical,
unfold_config_list,
dedupe_config_list,
compute_sparsity_compact2origin,
compute_sparsity_mask2compact,
compute_sparsity,
get_model_weights_numel,
get_module_by_name
)
...@@ -224,3 +224,32 @@ def get_model_weights_numel(model: Module, config_list: List[Dict], masks: Dict[ ...@@ -224,3 +224,32 @@ def get_model_weights_numel(model: Module, config_list: List[Dict], masks: Dict[
else: else:
model_weights_numel[module_name] = module.weight.data.numel() model_weights_numel[module_name] = module.weight.data.numel()
return model_weights_numel, masked_rate return model_weights_numel, masked_rate
# FIXME: to avoid circular import, copy this function in this place
def get_module_by_name(model, module_name):
"""
Get a module specified by its module name
Parameters
----------
model : pytorch model
the pytorch model from which to get its module
module_name : str
the name of the required module
Returns
-------
module, module
the parent module of the required module, the required module
"""
name_list = module_name.split(".")
for name in name_list[:-1]:
if hasattr(model, name):
model = getattr(model, name)
else:
return None, None
if hasattr(model, name_list[-1]):
leaf_module = getattr(model, name_list[-1])
return model, leaf_module
else:
return None, None
...@@ -6,6 +6,7 @@ import logging ...@@ -6,6 +6,7 @@ import logging
import torch import torch
import numpy as np import numpy as np
from nni.compression.pytorch.compressor import PrunerModuleWrapper from nni.compression.pytorch.compressor import PrunerModuleWrapper
from nni.algorithms.compression.v2.pytorch.base import PrunerModuleWrapper as PrunerModuleWrapper_v2
from .utils import get_module_by_name from .utils import get_module_by_name
...@@ -390,7 +391,7 @@ class GroupDependency(Dependency): ...@@ -390,7 +391,7 @@ class GroupDependency(Dependency):
""" """
node_name = node_group.name node_name = node_group.name
_, leaf_module = get_module_by_name(self.model, node_name) _, leaf_module = get_module_by_name(self.model, node_name)
if isinstance(leaf_module, PrunerModuleWrapper): if isinstance(leaf_module, (PrunerModuleWrapper, PrunerModuleWrapper_v2)):
leaf_module = leaf_module.module leaf_module = leaf_module.module
assert isinstance( assert isinstance(
leaf_module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)) leaf_module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d))
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import unittest
import torch
import torch.nn.functional as F
from nni.algorithms.compression.v2.pytorch.pruning import (
LevelPruner,
L1NormPruner,
L2NormPruner,
SlimPruner,
FPGMPruner,
ActivationAPoZRankPruner,
ActivationMeanRankPruner,
TaylorFOWeightPruner,
ADMMPruner
)
from nni.algorithms.compression.v2.pytorch.utils import compute_sparsity_mask2compact
class TorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
self.bn1 = torch.nn.BatchNorm2d(5)
self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
self.bn2 = torch.nn.BatchNorm2d(10)
self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
self.fc2 = torch.nn.Linear(100, 10)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 10)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def trainer(model, optimizer, criterion):
model.train()
input = torch.rand(10, 1, 28, 28)
label = torch.Tensor(list(range(10))).type(torch.LongTensor)
optimizer.zero_grad()
output = model(input)
loss = criterion(output, label)
loss.backward()
optimizer.step()
def get_optimizer(model):
return torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
class PrunerTestCase(unittest.TestCase):
def test_level_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = LevelPruner(model=model, config_list=config_list)
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
def test_l1_norm_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = L1NormPruner(model=model, config_list=config_list, mode='dependency_aware',
dummy_input=torch.rand(10, 1, 28, 28))
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
def test_l2_norm_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = L2NormPruner(model=model, config_list=config_list, mode='dependency_aware',
dummy_input=torch.rand(10, 1, 28, 28))
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
def test_fpgm_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = FPGMPruner(model=model, config_list=config_list, mode='dependency_aware',
dummy_input=torch.rand(10, 1, 28, 28))
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
def test_slim_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['BatchNorm2d'], 'total_sparsity': 0.8}]
pruner = SlimPruner(model=model, config_list=config_list, trainer=trainer, optimizer=get_optimizer(model),
criterion=criterion, training_epochs=1, scale=0.001, mode='global')
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
def test_activation_apoz_rank_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = ActivationAPoZRankPruner(model=model, config_list=config_list, trainer=trainer,
optimizer=get_optimizer(model), criterion=criterion, training_batches=1,
activation='relu', mode='dependency_aware',
dummy_input=torch.rand(10, 1, 28, 28))
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
def test_activation_mean_rank_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = ActivationMeanRankPruner(model=model, config_list=config_list, trainer=trainer,
optimizer=get_optimizer(model), criterion=criterion, training_batches=1,
activation='relu', mode='dependency_aware',
dummy_input=torch.rand(10, 1, 28, 28))
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
def test_taylor_fo_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = TaylorFOWeightPruner(model=model, config_list=config_list, trainer=trainer,
optimizer=get_optimizer(model), criterion=criterion, training_batches=1,
mode='dependency_aware', dummy_input=torch.rand(10, 1, 28, 28))
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
def test_admm_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8, 'rho': 1e-3}]
pruner = ADMMPruner(model=model, config_list=config_list, trainer=trainer, optimizer=get_optimizer(model),
criterion=criterion, iterations=2, training_epochs=1)
pruned_model, masks = pruner.compress()
pruner._unwrap_model()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
if __name__ == '__main__':
unittest.main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import unittest
import torch
import torch.nn.functional as F
from nni.algorithms.compression.v2.pytorch.base import Pruner
from nni.algorithms.compression.v2.pytorch.pruning.tools import (
WeightDataCollector,
WeightTrainerBasedDataCollector,
SingleHookTrainerBasedDataCollector
)
from nni.algorithms.compression.v2.pytorch.pruning.tools import (
NormMetricsCalculator,
MultiDataNormMetricsCalculator,
DistMetricsCalculator,
APoZRankMetricsCalculator,
MeanRankMetricsCalculator
)
from nni.algorithms.compression.v2.pytorch.pruning.tools import (
NormalSparsityAllocator,
GlobalSparsityAllocator
)
from nni.algorithms.compression.v2.pytorch.pruning.tools.base import HookCollectorInfo
from nni.algorithms.compression.v2.pytorch.utils import get_module_by_name
class TorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
self.bn1 = torch.nn.BatchNorm2d(5)
self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
self.bn2 = torch.nn.BatchNorm2d(10)
self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
self.fc2 = torch.nn.Linear(100, 10)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 10)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def trainer(model, optimizer, criterion):
model.train()
input = torch.rand(10, 1, 28, 28)
label = torch.Tensor(list(range(10))).type(torch.LongTensor)
optimizer.zero_grad()
output = model(input)
loss = criterion(output, label)
loss.backward()
optimizer.step()
def get_optimizer(model):
return torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
class PruningToolsTestCase(unittest.TestCase):
def test_data_collector(self):
model = TorchModel()
w1 = torch.rand(5, 1, 5, 5)
w2 = torch.rand(10, 5, 5, 5)
model.conv1.weight.data = w1
model.conv2.weight.data = w2
config_list = [{'op_types': ['Conv2d']}]
pruner = Pruner(model, config_list)
# Test WeightDataCollector
data_collector = WeightDataCollector(pruner)
data = data_collector.collect()
assert all(torch.equal(get_module_by_name(model, module_name)[1].module.weight.data, data[module_name]) for module_name in ['conv1', 'conv2'])
# Test WeightTrainerBasedDataCollector
def opt_after():
model.conv1.module.weight.data = torch.ones(5, 1, 5, 5)
model.conv2.module.weight.data = torch.ones(10, 5, 5, 5)
data_collector = WeightTrainerBasedDataCollector(pruner, trainer, get_optimizer(model), criterion, 1, opt_after_tasks=[opt_after])
data = data_collector.collect()
assert all(torch.equal(get_module_by_name(model, module_name)[1].module.weight.data, data[module_name]) for module_name in ['conv1', 'conv2'])
assert all(t.numel() == (t == 1).type_as(t).sum().item() for t in data.values())
# Test SingleHookTrainerBasedDataCollector
def _collector(buffer, weight_tensor):
def collect_taylor(grad):
if len(buffer) < 2:
buffer.append(grad.clone().detach())
return collect_taylor
hook_targets = {'conv1': model.conv1.module.weight, 'conv2': model.conv2.module.weight}
collector_info = HookCollectorInfo(hook_targets, 'tensor', _collector)
data_collector = SingleHookTrainerBasedDataCollector(pruner, trainer, get_optimizer(model), criterion, 2, collector_infos=[collector_info])
data = data_collector.collect()
assert all(len(t) == 2 for t in data.values())
def test_metrics_calculator(self):
# Test NormMetricsCalculator
metrics_calculator = NormMetricsCalculator(dim=0, p=2)
data = {
'1': torch.ones(3, 3, 3),
'2': torch.ones(4, 4) * 2
}
result = {
'1': torch.ones(3) * 3,
'2': torch.ones(4) * 4
}
metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items())
# Test DistMetricsCalculator
metrics_calculator = DistMetricsCalculator(dim=0, p=2)
data = {
'1': torch.tensor([[1, 2], [4, 6]], dtype=torch.float32),
'2': torch.tensor([[0, 0], [1, 1]], dtype=torch.float32)
}
result = {
'1': torch.tensor([5, 5], dtype=torch.float32),
'2': torch.sqrt(torch.tensor([2, 2], dtype=torch.float32))
}
metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items())
# Test MultiDataNormMetricsCalculator
metrics_calculator = MultiDataNormMetricsCalculator(dim=0, p=1)
data = {
'1': [torch.ones(3, 3, 3), torch.ones(3, 3, 3) * 2],
'2': [torch.ones(4, 4), torch.ones(4, 4) * 2]
}
result = {
'1': torch.ones(3) * 27,
'2': torch.ones(4) * 12
}
metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items())
# Test APoZRankMetricsCalculator
metrics_calculator = APoZRankMetricsCalculator(dim=1)
data = {
'1': [torch.tensor([[1, 0], [0, 1]], dtype=torch.float32), torch.tensor([[0, 1], [1, 0]], dtype=torch.float32)],
'2': [torch.tensor([[1, 0, 1], [0, 1, 0]], dtype=torch.float32), torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)]
}
result = {
'1': torch.tensor([0.5, 0.5], dtype=torch.float32),
'2': torch.tensor([0.25, 0.25, 0.5], dtype=torch.float32)
}
metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items())
# Test MeanRankMetricsCalculator
metrics_calculator = MeanRankMetricsCalculator(dim=1)
data = {
'1': [torch.tensor([[1, 0], [0, 1]], dtype=torch.float32), torch.tensor([[0, 1], [1, 0]], dtype=torch.float32)],
'2': [torch.tensor([[1, 0, 1], [0, 1, 0]], dtype=torch.float32), torch.tensor([[0, 0, 1], [0, 0, 0]], dtype=torch.float32)]
}
result = {
'1': torch.tensor([0.5, 0.5], dtype=torch.float32),
'2': torch.tensor([0.25, 0.25, 0.5], dtype=torch.float32)
}
metrics = metrics_calculator.calculate_metrics(data)
assert all(torch.equal(result[k], v) for k, v in metrics.items())
def test_sparsity_allocator(self):
# Test NormalSparsityAllocator
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.8}]
pruner = Pruner(model, config_list)
metrics = {
'conv1': torch.rand(5, 1, 5, 5),
'conv2': torch.rand(10, 5, 5, 5)
}
sparsity_allocator = NormalSparsityAllocator(pruner)
masks = sparsity_allocator.generate_sparsity(metrics)
assert all(v['weight'].sum() / v['weight'].numel() == 0.2 for k, v in masks.items())
# Test GlobalSparsityAllocator
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.8}]
pruner = Pruner(model, config_list)
sparsity_allocator = GlobalSparsityAllocator(pruner)
masks = sparsity_allocator.generate_sparsity(metrics)
total_elements, total_masked_elements = 0, 0
for t in masks.values():
total_elements += t['weight'].numel()
total_masked_elements += t['weight'].sum().item()
assert total_masked_elements / total_elements == 0.2
if __name__ == '__main__':
unittest.main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import unittest
import torch
import torch.nn.functional as F
from nni.algorithms.compression.v2.pytorch.pruning import PruningScheduler, L1NormPruner, AGPTaskGenerator
class TorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
self.bn1 = torch.nn.BatchNorm2d(5)
self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
self.bn2 = torch.nn.BatchNorm2d(10)
self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
self.fc2 = torch.nn.Linear(100, 10)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 10)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
class PruningSchedulerTestCase(unittest.TestCase):
def test_pruning_scheduler(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
task_generator = AGPTaskGenerator(1, model, config_list)
pruner = L1NormPruner(model, config_list)
scheduler = PruningScheduler(pruner, task_generator)
scheduler.compress()
if __name__ == '__main__':
unittest.main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import List
import unittest
import torch
import torch.nn.functional as F
from nni.algorithms.compression.v2.pytorch.base import Task, TaskResult
from nni.algorithms.compression.v2.pytorch.pruning import (
AGPTaskGenerator,
LinearTaskGenerator,
LotteryTicketTaskGenerator,
SimulatedAnnealingTaskGenerator
)
class TorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
self.bn1 = torch.nn.BatchNorm2d(5)
self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
self.bn2 = torch.nn.BatchNorm2d(10)
self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
self.fc2 = torch.nn.Linear(100, 10)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 10)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def run_task_generator(task_generator_type):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
if task_generator_type == 'agp':
task_generator = AGPTaskGenerator(5, model, config_list)
elif task_generator_type == 'linear':
task_generator = LinearTaskGenerator(5, model, config_list)
elif task_generator_type == 'lottery_ticket':
task_generator = LotteryTicketTaskGenerator(5, model, config_list)
elif task_generator_type == 'simulated_annealing':
task_generator = SimulatedAnnealingTaskGenerator(model, config_list)
count = run_task_generator_(task_generator)
if task_generator_type == 'agp':
assert count == 6
elif task_generator_type == 'linear':
assert count == 6
elif task_generator_type == 'lottery_ticket':
assert count == 6
elif task_generator_type == 'simulated_annealing':
assert count == 17
def run_task_generator_(task_generator):
task = task_generator.next()
factor = 0.9
count = 0
while task is not None:
factor = factor ** 2
count += 1
task_result = TaskResult(task.task_id, TorchModel(), {}, {}, 1 - factor)
task_generator.receive_task_result(task_result)
task = task_generator.next()
return count
class TaskGenerator(unittest.TestCase):
def test_agp_task_generator(self):
run_task_generator('agp')
def test_linear_task_generator(self):
run_task_generator('linear')
def test_lottery_ticket_task_generator(self):
run_task_generator('lottery_ticket')
def test_simulated_annealing_task_generator(self):
run_task_generator('simulated_annealing')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment