Unverified Commit cdb65dac authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Model Compression] add scheduler high level api (#4236)

parent abb4dfdb
from tqdm import tqdm
import torch
from torchvision import datasets, transforms
from nni.algorithms.compression.v2.pytorch.pruning import AGPPruner
from examples.model_compress.models.cifar10.vgg import VGG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=True, transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
normalize,
]), download=True),
batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
normalize,
])),
batch_size=128, shuffle=False)
criterion = torch.nn.CrossEntropyLoss()
def trainer(model, optimizer, criterion, epoch):
model.train()
for data, target in tqdm(iterable=train_loader, desc='Epoch {}'.format(epoch)):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
def finetuner(model):
model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
for data, target in tqdm(iterable=train_loader, desc='Epoch PFs'):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
def evaluator(model):
model.eval()
correct = 0
with torch.no_grad():
for data, target in tqdm(iterable=test_loader, desc='Test'):
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
acc = 100 * correct / len(test_loader.dataset)
print('Accuracy: {}%\n'.format(acc))
return acc
if __name__ == '__main__':
model = VGG().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
# pre-train the model
for i in range(5):
trainer(model, optimizer, criterion, i)
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
dummy_input = torch.rand(10, 3, 32, 32).to(device)
# if you just want to keep the final result as the best result, you can pass evaluator as None.
# or the result with the highest score (given by evaluator) will be the best result.
# pruner = AGPPruner(model, config_list, 'l1', 10, finetuner=finetuner, speed_up=True, dummy_input=dummy_input, evaluator=evaluator)
pruner = AGPPruner(model, config_list, 'l1', 10, finetuner=finetuner, speed_up=True, dummy_input=dummy_input, evaluator=None)
pruner.compress()
_, model, masks, _, _ = pruner.get_best_result()
import functools
from tqdm import tqdm from tqdm import tqdm
import torch import torch
...@@ -77,20 +76,13 @@ if __name__ == '__main__': ...@@ -77,20 +76,13 @@ if __name__ == '__main__':
for i in range(5): for i in range(5):
trainer(model, optimizer, criterion, i) trainer(model, optimizer, criterion, i)
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}] # No need to pass model and config_list to pruner during initializing when using scheduler.
pruner = L1NormPruner(None, None)
# Make sure initialize task generator at first, this because the model pass to the generator should be an unwrapped model.
# If you want to initialize pruner at first, you can use the follow code.
# pruner = L1NormPruner(model, config_list)
# pruner._unwrap_model()
# task_generator = AGPTaskGenerator(10, model, config_list, log_dir='.', keep_intermediate_result=True)
# pruner._wrap_model()
# you can specify the log_dir, all intermediate results and best result will save under this folder. # you can specify the log_dir, all intermediate results and best result will save under this folder.
# if you don't want to keep intermediate results, you can set `keep_intermediate_result=False`. # if you don't want to keep intermediate results, you can set `keep_intermediate_result=False`.
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
task_generator = AGPTaskGenerator(10, model, config_list, log_dir='.', keep_intermediate_result=True) task_generator = AGPTaskGenerator(10, model, config_list, log_dir='.', keep_intermediate_result=True)
pruner = L1NormPruner(model, config_list)
dummy_input = torch.rand(10, 3, 32, 32).to(device) dummy_input = torch.rand(10, 3, 32, 32).to(device)
......
...@@ -77,6 +77,8 @@ if __name__ == '__main__': ...@@ -77,6 +77,8 @@ if __name__ == '__main__':
print('\nThe accuracy after speed up:') print('\nThe accuracy after speed up:')
evaluator(model) evaluator(model)
# Need a new optimizer due to the modules in model will be replaced during speedup.
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
print('\nFinetune the model after speed up:') print('\nFinetune the model after speed up:')
for i in range(5): for i in range(5):
trainer(model, optimizer, criterion, i) trainer(model, optimizer, criterion, i)
......
...@@ -37,7 +37,7 @@ class Compressor: ...@@ -37,7 +37,7 @@ class Compressor:
The abstract base pytorch compressor. The abstract base pytorch compressor.
""" """
def __init__(self, model: Module, config_list: List[Dict]): def __init__(self, model: Optional[Module], config_list: Optional[List[Dict]]):
""" """
Parameters Parameters
---------- ----------
...@@ -46,9 +46,11 @@ class Compressor: ...@@ -46,9 +46,11 @@ class Compressor:
config_list config_list
The config list used by compressor, usually specifies the 'op_types' or 'op_names' that want to compress. The config list used by compressor, usually specifies the 'op_types' or 'op_names' that want to compress.
""" """
assert isinstance(model, Module)
self.is_wrapped = False self.is_wrapped = False
self.reset(model=model, config_list=config_list) if model is not None:
self.reset(model=model, config_list=config_list)
else:
_logger.warning('This compressor is not set model and config_list, waiting for reset() or pass this to scheduler.')
def reset(self, model: Module, config_list: List[Dict]): def reset(self, model: Module, config_list: List[Dict]):
""" """
......
from .basic_pruner import * from .basic_pruner import *
from .basic_scheduler import PruningScheduler from .basic_scheduler import PruningScheduler
from .tools import AGPTaskGenerator, LinearTaskGenerator, LotteryTicketTaskGenerator, SimulatedAnnealingTaskGenerator from .iterative_pruner import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from typing import Dict, List, Callable, Optional
from torch import Tensor
from torch.nn import Module
from .basic_pruner import (
LevelPruner,
L1NormPruner,
L2NormPruner,
FPGMPruner,
SlimPruner,
ActivationAPoZRankPruner,
ActivationMeanRankPruner,
TaylorFOWeightPruner,
ADMMPruner
)
from .basic_scheduler import PruningScheduler
from .tools.task_generator import (
LinearTaskGenerator,
AGPTaskGenerator,
LotteryTicketTaskGenerator,
SimulatedAnnealingTaskGenerator
)
_logger = logging.getLogger(__name__)
__all__ = ['LinearPruner', 'AGPPruner', 'LotteryTicketPruner', 'SimulatedAnnealingPruner']
PRUNER_DICT = {
'level': LevelPruner,
'l1': L1NormPruner,
'l2': L2NormPruner,
'fpgm': FPGMPruner,
'slim': SlimPruner,
'apoz': ActivationAPoZRankPruner,
'mean_activation': ActivationMeanRankPruner,
'taylorfo': TaylorFOWeightPruner,
'admm': ADMMPruner
}
class IterativePruner(PruningScheduler):
def _wrap_model(self):
"""
Deprecated function.
"""
_logger.warning('Nothing will happen when calling this function.\
This pruner is an iterative pruner and does not directly wrap the model.')
def _unwrap_model(self):
"""
Deprecated function.
"""
_logger.warning('Nothing will happen when calling this function.\
This pruner is an iterative pruner and does not directly wrap the model.')
def export_model(self, *args, **kwargs):
"""
Deprecated function.
"""
_logger.warning('Nothing will happen when calling this function.\
The best result (and intermediate result if keeped) during iteration is under `log_dir` (default: \\.).')
class LinearPruner(IterativePruner):
"""
Parameters
----------
model : Module
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
pruning_algorithm : str
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
total_iteration : int
The total iteration number.
log_dir : str
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
speed_up : bool
If set True, speed up the model in each iteration.
dummy_input : Optional[torch.Tensor]
If `speed_up` is True, `dummy_input` is required for trace the model in speed up.
evaluator : Optional[Callable[[Module], float]]
Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result.
pruning_params : dict
If the pruner corresponding to the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
"""
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speed_up: bool = False, dummy_input: Optional[Tensor] = None,
evaluator: Optional[Callable[[Module], float]] = None, pruning_params: dict = {}):
task_generator = LinearTaskGenerator(total_iteration=total_iteration,
origin_model=model,
origin_config_list=config_list,
log_dir=log_dir,
keep_intermediate_result=keep_intermediate_result)
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False)
class AGPPruner(IterativePruner):
"""
Parameters
----------
model : Module
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
pruning_algorithm : str
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
total_iteration : int
The total iteration number.
log_dir : str
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
speed_up : bool
If set True, speed up the model in each iteration.
dummy_input : Optional[torch.Tensor]
If `speed_up` is True, `dummy_input` is required for trace the model in speed up.
evaluator : Optional[Callable[[Module], float]]
Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result.
pruning_params : dict
If the pruner corresponding to the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
"""
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speed_up: bool = False, dummy_input: Optional[Tensor] = None,
evaluator: Optional[Callable[[Module], float]] = None, pruning_params: dict = {}):
task_generator = AGPTaskGenerator(total_iteration=total_iteration,
origin_model=model,
origin_config_list=config_list,
log_dir=log_dir,
keep_intermediate_result=keep_intermediate_result)
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False)
class LotteryTicketPruner(IterativePruner):
"""
Parameters
----------
model : Module
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
pruning_algorithm : str
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
total_iteration : int
The total iteration number.
log_dir : str
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
speed_up : bool
If set True, speed up the model in each iteration.
dummy_input : Optional[torch.Tensor]
If `speed_up` is True, `dummy_input` is required for trace the model in speed up.
evaluator : Optional[Callable[[Module], float]]
Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result.
reset_weight : bool
If set True, the model weight will reset to the original model weight at the end of each iteration step.
pruning_params : dict
If the pruner corresponding to the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
"""
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speed_up: bool = False, dummy_input: Optional[Tensor] = None,
evaluator: Optional[Callable[[Module], float]] = None, reset_weight: bool = True,
pruning_params: dict = {}):
task_generator = LotteryTicketTaskGenerator(total_iteration=total_iteration,
origin_model=model,
origin_config_list=config_list,
log_dir=log_dir,
keep_intermediate_result=keep_intermediate_result)
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=reset_weight)
class SimulatedAnnealingPruner(IterativePruner):
"""
Parameters
----------
model : Module
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
pruning_algorithm : str
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
evaluator : Callable[[Module], float]
Evaluate the pruned model and give a score.
start_temperature : float
Start temperature of the simulated annealing process.
stop_temperature : float
Stop temperature of the simulated annealing process.
cool_down_rate : float
Cool down rate of the temperature.
perturbation_magnitude : float
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
log_dir : str
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
speed_up : bool
If set True, speed up the model in each iteration.
dummy_input : Optional[torch.Tensor]
If `speed_up` is True, `dummy_input` is required for trace the model in speed up.
pruning_params : dict
If the pruner corresponding to the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
"""
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str, evaluator: Callable[[Module], float],
start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9,
perturbation_magnitude: float = 0.35, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speed_up: bool = False, dummy_input: Optional[Tensor] = None,
pruning_params: dict = {}):
task_generator = SimulatedAnnealingTaskGenerator(origin_model=model,
origin_config_list=config_list,
start_temperature=start_temperature,
stop_temperature=stop_temperature,
cool_down_rate=cool_down_rate,
perturbation_magnitude=perturbation_magnitude,
log_dir=log_dir,
keep_intermediate_result=keep_intermediate_result)
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False)
...@@ -566,9 +566,9 @@ class TaskGenerator: ...@@ -566,9 +566,9 @@ class TaskGenerator:
return best task id, best compact model, masks on the compact model, score, config list used in this task. return best task id, best compact model, masks on the compact model, score, config list used in this task.
""" """
if self._best_task_id is not None: if self._best_task_id is not None:
compact_model = torch.load(Path(self._log_dir_root, 'best_result', 'best_model.pth')) compact_model = torch.load(Path(self._log_dir_root, 'best_result', 'model.pth'))
compact_model_masks = torch.load(Path(self._log_dir_root, 'best_result', 'best_masks.pth')) compact_model_masks = torch.load(Path(self._log_dir_root, 'best_result', 'masks.pth'))
with Path(self._log_dir_root, 'best_result', 'best_config_list.json').open('r') as f: with Path(self._log_dir_root, 'best_result', 'config_list.json').open('r') as f:
config_list = json_tricks.load(f) config_list = json_tricks.load(f)
return self._best_task_id, compact_model, compact_model_masks, self._best_score, config_list return self._best_task_id, compact_model, compact_model_masks, self._best_score, config_list
return None return None
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import random
import unittest
import torch
import torch.nn.functional as F
from nni.algorithms.compression.v2.pytorch.pruning import (
LinearPruner,
AGPPruner,
LotteryTicketPruner,
SimulatedAnnealingPruner
)
from nni.algorithms.compression.v2.pytorch.utils import compute_sparsity_mask2compact
class TorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
self.bn1 = torch.nn.BatchNorm2d(5)
self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
self.bn2 = torch.nn.BatchNorm2d(10)
self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
self.fc2 = torch.nn.Linear(100, 10)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 10)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def evaluator(model):
return random.random()
class IterativePrunerTestCase(unittest.TestCase):
def test_linear_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = LinearPruner(model, config_list, 'level', 3, log_dir='../../logs')
pruner.compress()
_, pruned_model, masks, _, _ = pruner.get_best_result()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
def test_agp_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = AGPPruner(model, config_list, 'level', 3, log_dir='../../logs')
pruner.compress()
_, pruned_model, masks, _, _ = pruner.get_best_result()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
def test_lottery_ticket_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = LotteryTicketPruner(model, config_list, 'level', 3, log_dir='../../logs')
pruner.compress()
_, pruned_model, masks, _, _ = pruner.get_best_result()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
def test_simulated_annealing_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = SimulatedAnnealingPruner(model, config_list, 'level', evaluator, start_temperature=30, log_dir='../../logs')
pruner.compress()
_, pruned_model, masks, _, _ = pruner.get_best_result()
sparsity_list = compute_sparsity_mask2compact(pruned_model, masks, config_list)
assert 0.79 < sparsity_list[0]['total_sparsity'] < 0.81
if __name__ == '__main__':
unittest.main()
...@@ -6,7 +6,8 @@ import unittest ...@@ -6,7 +6,8 @@ import unittest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from nni.algorithms.compression.v2.pytorch.pruning import PruningScheduler, L1NormPruner, AGPTaskGenerator from nni.algorithms.compression.v2.pytorch.pruning import PruningScheduler, L1NormPruner
from nni.algorithms.compression.v2.pytorch.pruning.tools import AGPTaskGenerator
class TorchModel(torch.nn.Module): class TorchModel(torch.nn.Module):
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from nni.algorithms.compression.v2.pytorch.base import Task, TaskResult from nni.algorithms.compression.v2.pytorch.base import Task, TaskResult
from nni.algorithms.compression.v2.pytorch.pruning import ( from nni.algorithms.compression.v2.pytorch.pruning.tools import (
AGPTaskGenerator, AGPTaskGenerator,
LinearTaskGenerator, LinearTaskGenerator,
LotteryTicketTaskGenerator, LotteryTicketTaskGenerator,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment