Unverified Commit cdb65dac authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Model Compression] add scheduler high level api (#4236)

parent abb4dfdb
from tqdm import tqdm
import torch
from torchvision import datasets, transforms
from nni.algorithms.compression.v2.pytorch.pruning import AGPPruner
from examples.model_compress.models.cifar10.vgg import VGG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=True, transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
normalize,
]), download=True),
batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
normalize,
])),
batch_size=128, shuffle=False)
criterion = torch.nn.CrossEntropyLoss()
def trainer(model, optimizer, criterion, epoch):
model.train()
for data, target in tqdm(iterable=train_loader, desc='Epoch {}'.format(epoch)):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
def finetuner(model):
model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
for data, target in tqdm(iterable=train_loader, desc='Epoch PFs'):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
def evaluator(model):
model.eval()
correct = 0
with torch.no_grad():
for data, target in tqdm(iterable=test_loader, desc='Test'):
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
acc = 100 * correct / len(test_loader.dataset)
print('Accuracy: {}%\n'.format(acc))
return acc
if __name__ == '__main__':
model = VGG().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
# pre-train the model
for i in range(5):
trainer(model, optimizer, criterion, i)
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
dummy_input = torch.rand(10, 3, 32, 32).to(device)
# if you just want to keep the final result as the best result, you can pass evaluator as None.
# or the result with the highest score (given by evaluator) will be the best result.
# pruner = AGPPruner(model, config_list, 'l1', 10, finetuner=finetuner, speed_up=True, dummy_input=dummy_input, evaluator=evaluator)
pruner = AGPPruner(model, config_list, 'l1', 10, finetuner=finetuner, speed_up=True, dummy_input=dummy_input, evaluator=None)
pruner.compress()
_, model, masks, _, _ = pruner.get_best_result()
import functools
from tqdm import tqdm
import torch
......@@ -77,20 +76,13 @@ if __name__ == '__main__':
for i in range(5):
trainer(model, optimizer, criterion, i)
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
# Make sure initialize task generator at first, this because the model pass to the generator should be an unwrapped model.
# If you want to initialize pruner at first, you can use the follow code.
# pruner = L1NormPruner(model, config_list)
# pruner._unwrap_model()
# task_generator = AGPTaskGenerator(10, model, config_list, log_dir='.', keep_intermediate_result=True)
# pruner._wrap_model()
# No need to pass model and config_list to pruner during initializing when using scheduler.
pruner = L1NormPruner(None, None)
# you can specify the log_dir, all intermediate results and best result will save under this folder.
# if you don't want to keep intermediate results, you can set `keep_intermediate_result=False`.
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
task_generator = AGPTaskGenerator(10, model, config_list, log_dir='.', keep_intermediate_result=True)
pruner = L1NormPruner(model, config_list)
dummy_input = torch.rand(10, 3, 32, 32).to(device)
......
......@@ -77,6 +77,8 @@ if __name__ == '__main__':
print('\nThe accuracy after speed up:')
evaluator(model)
# Need a new optimizer due to the modules in model will be replaced during speedup.
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
print('\nFinetune the model after speed up:')
for i in range(5):
trainer(model, optimizer, criterion, i)
......
......@@ -37,7 +37,7 @@ class Compressor:
The abstract base pytorch compressor.
"""
def __init__(self, model: Module, config_list: List[Dict]):
def __init__(self, model: Optional[Module], config_list: Optional[List[Dict]]):
"""
Parameters
----------
......@@ -46,9 +46,11 @@ class Compressor:
config_list
The config list used by compressor, usually specifies the 'op_types' or 'op_names' that want to compress.
"""
assert isinstance(model, Module)
self.is_wrapped = False
self.reset(model=model, config_list=config_list)
if model is not None:
self.reset(model=model, config_list=config_list)
else:
_logger.warning('This compressor is not set model and config_list, waiting for reset() or pass this to scheduler.')
def reset(self, model: Module, config_list: List[Dict]):
"""
......
from .basic_pruner import *
from .basic_scheduler import PruningScheduler
from .tools import AGPTaskGenerator, LinearTaskGenerator, LotteryTicketTaskGenerator, SimulatedAnnealingTaskGenerator
from .iterative_pruner import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from typing import Dict, List, Callable, Optional
from torch import Tensor
from torch.nn import Module
from .basic_pruner import (
LevelPruner,
L1NormPruner,
L2NormPruner,
FPGMPruner,
SlimPruner,
ActivationAPoZRankPruner,
ActivationMeanRankPruner,
TaylorFOWeightPruner,
ADMMPruner
)
from .basic_scheduler import PruningScheduler
from .tools.task_generator import (
LinearTaskGenerator,
AGPTaskGenerator,
LotteryTicketTaskGenerator,
SimulatedAnnealingTaskGenerator
)
_logger = logging.getLogger(__name__)
__all__ = ['LinearPruner', 'AGPPruner', 'LotteryTicketPruner', 'SimulatedAnnealingPruner']
PRUNER_DICT = {
'level': LevelPruner,
'l1': L1NormPruner,
'l2': L2NormPruner,
'fpgm': FPGMPruner,
'slim': SlimPruner,
'apoz': ActivationAPoZRankPruner,
'mean_activation': ActivationMeanRankPruner,
'taylorfo': TaylorFOWeightPruner,
'admm': ADMMPruner
}
class IterativePruner(PruningScheduler):
def _wrap_model(self):
"""
Deprecated function.
"""
_logger.warning('Nothing will happen when calling this function.\
This pruner is an iterative pruner and does not directly wrap the model.')
def _unwrap_model(self):
"""
Deprecated function.
"""
_logger.warning('Nothing will happen when calling this function.\
This pruner is an iterative pruner and does not directly wrap the model.')
def export_model(self, *args, **kwargs):
"""
Deprecated function.
"""
_logger.warning('Nothing will happen when calling this function.\
The best result (and intermediate result if keeped) during iteration is under `log_dir` (default: \\.).')
class LinearPruner(IterativePruner):
"""
Parameters
----------
model : Module
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
pruning_algorithm : str
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
total_iteration : int
The total iteration number.
log_dir : str
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
speed_up : bool
If set True, speed up the model in each iteration.
dummy_input : Optional[torch.Tensor]
If `speed_up` is True, `dummy_input` is required for trace the model in speed up.
evaluator : Optional[Callable[[Module], float]]
Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result.
pruning_params : dict
If the pruner corresponding to the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
"""
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speed_up: bool = False, dummy_input: Optional[Tensor] = None,
evaluator: Optional[Callable[[Module], float]] = None, pruning_params: dict = {}):
task_generator = LinearTaskGenerator(total_iteration=total_iteration,
origin_model=model,
origin_config_list=config_list,
log_dir=log_dir,
keep_intermediate_result=keep_intermediate_result)
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False)
class AGPPruner(IterativePruner):
"""
Parameters
----------
model : Module
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
pruning_algorithm : str
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
total_iteration : int
The total iteration number.
log_dir : str
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
speed_up : bool
If set True, speed up the model in each iteration.
dummy_input : Optional[torch.Tensor]
If `speed_up` is True, `dummy_input` is required for trace the model in speed up.
evaluator : Optional[Callable[[Module], float]]
Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result.
pruning_params : dict
If the pruner corresponding to the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
"""
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speed_up: bool = False, dummy_input: Optional[Tensor] = None,
evaluator: Optional[Callable[[Module], float]] = None, pruning_params: dict = {}):
task_generator = AGPTaskGenerator(total_iteration=total_iteration,
origin_model=model,
origin_config_list=config_list,
log_dir=log_dir,
keep_intermediate_result=keep_intermediate_result)
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False)
class LotteryTicketPruner(IterativePruner):
"""
Parameters
----------
model : Module
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
pruning_algorithm : str
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
total_iteration : int
The total iteration number.
log_dir : str
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
speed_up : bool
If set True, speed up the model in each iteration.
dummy_input : Optional[torch.Tensor]
If `speed_up` is True, `dummy_input` is required for trace the model in speed up.
evaluator : Optional[Callable[[Module], float]]
Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result.
reset_weight : bool
If set True, the model weight will reset to the original model weight at the end of each iteration step.
pruning_params : dict
If the pruner corresponding to the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
"""
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speed_up: bool = False, dummy_input: Optional[Tensor] = None,
evaluator: Optional[Callable[[Module], float]] = None, reset_weight: bool = True,
pruning_params: dict = {}):
task_generator = LotteryTicketTaskGenerator(total_iteration=total_iteration,
origin_model=model,
origin_config_list=config_list,
log_dir=log_dir,
keep_intermediate_result=keep_intermediate_result)
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=reset_weight)
class SimulatedAnnealingPruner(IterativePruner):
"""
Parameters
----------
model : Module
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
pruning_algorithm : str
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
evaluator : Callable[[Module], float]
Evaluate the pruned model and give a score.
start_temperature : float
Start temperature of the simulated annealing process.
stop_temperature : float
Stop temperature of the simulated annealing process.
cool_down_rate : float
Cool down rate of the temperature.
perturbation_magnitude : float
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
log_dir : str
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
speed_up : bool
If set True, speed up the model in each iteration.
dummy_input : Optional[torch.Tensor]
If `speed_up` is True, `dummy_input` is required for trace the model in speed up.
pruning_params : dict
If the pruner corresponding to the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
"""
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str, evaluator: Callable[[Module], float],
start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9,
perturbation_magnitude: float = 0.35, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speed_up: bool = False, dummy_input: Optional[Tensor] = None,
pruning_params: dict = {}):
task_generator = SimulatedAnnealingTaskGenerator(origin_model=model,
origin_config_list=config_list,
start_temperature=start_temperature,
stop_temperature=stop_temperature,
cool_down_rate=cool_down_rate,
perturbation_magnitude=perturbation_magnitude,
log_dir=log_dir,
keep_intermediate_result=keep_intermediate_result)
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False)
......@@ -566,9 +566,9 @@ class TaskGenerator:
return best task id, best compact model, masks on the compact model, score, config list used in this task.
"""
if self._best_task_id is not None:
compact_model = torch.load(Path(self._log_dir_root, 'best_result', 'best_model.pth'))
compact_model_masks = torch.load(Path(self._log_dir_root, 'best_result', 'best_masks.pth'))
with Path(self._log_dir_root, 'best_result', 'best_config_list.json').open('r') as f:
compact_model = torch.load(Path(self._log_dir_root, 'best_result', 'model.pth'))
compact_model_masks = torch.load(Path(self._log_dir_root, 'best_result', 'masks.pth'))
with Path(self._log_dir_root, 'best_result', 'config_list.json').open('r') as f:
config_list = json_tricks.load(f)
return self._best_task_id, compact_model, compact_model_masks, self._best_score, config_list
return None
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import random
import unittest
import torch
import torch.nn.functional as F
from nni.algorithms.compression.v2.pytorch.pruning import (
LinearPruner,
AGPPruner,
LotteryTicketPruner,
SimulatedAnnealingPruner
)
from nni.algorithms.compression.v2.pytorch.utils import compute_sparsity_mask2compact
class TorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
self.bn1 = torch.nn.BatchNorm2d(5)
self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
self.bn2 = torch.nn.BatchNorm2d(10)
self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
self.fc2 = torch.nn.Linear(100, 10)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 10)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def evaluator(model):
return random.random()
class IterativePrunerTestCase(unittest.TestCase):
def test_linear_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = LinearPruner(model, config_list, 'level', 3, log_dir='../../logs')
pruner.compress()
_, pruned_model, masks, _, _ = pruner.get_best_result()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
def test_agp_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = AGPPruner(model, config_list, 'level', 3, log_dir='../../logs')
pruner.compress()
_, pruned_model, masks, _, _ = pruner.get_best_result()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
def test_lottery_ticket_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = LotteryTicketPruner(model, config_list, 'level', 3, log_dir='../../logs')
pruner.compress()
_, pruned_model, masks, _, _ = pruner.get_best_result()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
def test_simulated_annealing_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = SimulatedAnnealingPruner(model, config_list, 'level', evaluator, start_temperature=30, log_dir='../../logs')
pruner.compress()
_, pruned_model, masks, _, _ = pruner.get_best_result()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
if __name__ == '__main__':
unittest.main()
......@@ -6,7 +6,8 @@ import unittest
import torch
import torch.nn.functional as F
from nni.algorithms.compression.v2.pytorch.pruning import PruningScheduler, L1NormPruner, AGPTaskGenerator
from nni.algorithms.compression.v2.pytorch.pruning import PruningScheduler, L1NormPruner
from nni.algorithms.compression.v2.pytorch.pruning.tools import AGPTaskGenerator
class TorchModel(torch.nn.Module):
......
......@@ -8,7 +8,7 @@ import torch
import torch.nn.functional as F
from nni.algorithms.compression.v2.pytorch.base import Task, TaskResult
from nni.algorithms.compression.v2.pytorch.pruning import (
from nni.algorithms.compression.v2.pytorch.pruning.tools import (
AGPTaskGenerator,
LinearTaskGenerator,
LotteryTicketTaskGenerator,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment