Unverified Commit 6dfdc546 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Compression v2] Add optimizer & lr scheduler construct helper (#4332)

parent 7978c25a
...@@ -155,8 +155,13 @@ Usage ...@@ -155,8 +155,13 @@ Usage
.. code-block:: python .. code-block:: python
from nni.algorithms.compression.v2.pytorch.pruning import SlimPruner from nni.algorithms.compression.v2.pytorch.pruning import SlimPruner
from nni.algorithms.compression.v2.pytorch.utils import trace_parameters
# make sure you have used nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize
traced_optimizer = trace_parameters(torch.optim.Adam)(model.parameters())
config_list = [{ 'sparsity': 0.8, 'op_types': ['BatchNorm2d'] }] config_list = [{ 'sparsity': 0.8, 'op_types': ['BatchNorm2d'] }]
pruner = SlimPruner(model, config_list, trainer, optimizer, criterion, training_epochs=1) pruner = SlimPruner(model, config_list, trainer, traced_optimizer, criterion, training_epochs=1)
masked_model, masks = pruner.compress() masked_model, masks = pruner.compress()
For detailed example please refer to :githublink:`examples/model_compress/pruning/v2/slim_pruning_torch.py <examples/model_compress/pruning/v2/slim_pruning_torch.py>` For detailed example please refer to :githublink:`examples/model_compress/pruning/v2/slim_pruning_torch.py <examples/model_compress/pruning/v2/slim_pruning_torch.py>`
...@@ -187,8 +192,13 @@ Usage ...@@ -187,8 +192,13 @@ Usage
.. code-block:: python .. code-block:: python
from nni.algorithms.compression.v2.pytorch.pruning import ActivationAPoZRankPruner from nni.algorithms.compression.v2.pytorch.pruning import ActivationAPoZRankPruner
from nni.algorithms.compression.v2.pytorch.utils import trace_parameters
# make sure you have used nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize
traced_optimizer = trace_parameters(torch.optim.Adam)(model.parameters())
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }] config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
pruner = ActivationAPoZRankPruner(model, config_list, trainer, optimizer, criterion, training_batches=20) pruner = ActivationAPoZRankPruner(model, config_list, trainer, traced_optimizer, criterion, training_batches=20)
masked_model, masks = pruner.compress() masked_model, masks = pruner.compress()
For detailed example please refer to :githublink:`examples/model_compress/pruning/v2/activation_pruning_torch.py <examples/model_compress/pruning/v2/activation_pruning_torch.py>` For detailed example please refer to :githublink:`examples/model_compress/pruning/v2/activation_pruning_torch.py <examples/model_compress/pruning/v2/activation_pruning_torch.py>`
...@@ -215,8 +225,13 @@ Usage ...@@ -215,8 +225,13 @@ Usage
.. code-block:: python .. code-block:: python
from nni.algorithms.compression.v2.pytorch.pruning import ActivationMeanRankPruner from nni.algorithms.compression.v2.pytorch.pruning import ActivationMeanRankPruner
from nni.algorithms.compression.v2.pytorch.utils import trace_parameters
# make sure you have used nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize
traced_optimizer = trace_parameters(torch.optim.Adam)(model.parameters())
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }] config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
pruner = ActivationMeanRankPruner(model, config_list, trainer, optimizer, criterion, training_batches=20) pruner = ActivationMeanRankPruner(model, config_list, trainer, traced_optimizer, criterion, training_batches=20)
masked_model, masks = pruner.compress() masked_model, masks = pruner.compress()
For detailed example please refer to :githublink:`examples/model_compress/pruning/v2/activation_pruning_torch.py <examples/model_compress/pruning/v2/activation_pruning_torch.py>` For detailed example please refer to :githublink:`examples/model_compress/pruning/v2/activation_pruning_torch.py <examples/model_compress/pruning/v2/activation_pruning_torch.py>`
...@@ -247,8 +262,13 @@ Usage ...@@ -247,8 +262,13 @@ Usage
.. code-block:: python .. code-block:: python
from nni.algorithms.compression.v2.pytorch.pruning import TaylorFOWeightPruner from nni.algorithms.compression.v2.pytorch.pruning import TaylorFOWeightPruner
from nni.algorithms.compression.v2.pytorch.utils import trace_parameters
# make sure you have used nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize
traced_optimizer = trace_parameters(torch.optim.Adam)(model.parameters())
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }] config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
pruner = TaylorFOWeightPruner(model, config_list, trainer, optimizer, criterion, training_batches=20) pruner = TaylorFOWeightPruner(model, config_list, trainer, traced_optimizer, criterion, training_batches=20)
masked_model, masks = pruner.compress() masked_model, masks = pruner.compress()
For detailed example please refer to :githublink:`examples/model_compress/pruning/v2/taylorfo_pruning_torch.py <examples/model_compress/pruning/v2/taylorfo_pruning_torch.py>` For detailed example please refer to :githublink:`examples/model_compress/pruning/v2/taylorfo_pruning_torch.py <examples/model_compress/pruning/v2/taylorfo_pruning_torch.py>`
...@@ -280,8 +300,13 @@ Usage ...@@ -280,8 +300,13 @@ Usage
.. code-block:: python .. code-block:: python
from nni.algorithms.compression.v2.pytorch.pruning import ADMMPruner from nni.algorithms.compression.v2.pytorch.pruning import ADMMPruner
from nni.algorithms.compression.v2.pytorch.utils import trace_parameters
# make sure you have used nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize
traced_optimizer = trace_parameters(torch.optim.Adam)(model.parameters())
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }] config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
pruner = ADMMPruner(model, config_list, trainer, optimizer, criterion, iterations=10, training_epochs=1) pruner = ADMMPruner(model, config_list, trainer, traced_optimizer, criterion, iterations=10, training_epochs=1)
masked_model, masks = pruner.compress() masked_model, masks = pruner.compress()
For detailed example please refer to :githublink:`examples/model_compress/pruning/v2/admm_pruning_torch.py <examples/model_compress/pruning/v2/admm_pruning_torch.py>` For detailed example please refer to :githublink:`examples/model_compress/pruning/v2/admm_pruning_torch.py <examples/model_compress/pruning/v2/admm_pruning_torch.py>`
...@@ -316,8 +341,13 @@ Usage ...@@ -316,8 +341,13 @@ Usage
.. code-block:: python .. code-block:: python
from nni.algorithms.compression.v2.pytorch.pruning import MovementPruner from nni.algorithms.compression.v2.pytorch.pruning import MovementPruner
from nni.algorithms.compression.v2.pytorch.utils import trace_parameters
# make sure you have used nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize
traced_optimizer = trace_parameters(torch.optim.Adam)(model.parameters())
config_list = [{'op_types': ['Linear'], 'op_partial_names': ['bert.encoder'], 'sparsity': 0.9}] config_list = [{'op_types': ['Linear'], 'op_partial_names': ['bert.encoder'], 'sparsity': 0.9}]
pruner = MovementPruner(model, config_list, p_trainer, optimizer, criterion, 10, 3000, 27000) pruner = MovementPruner(model, config_list, trainer, traced_optimizer, criterion, 10, 3000, 27000)
masked_model, masks = pruner.compress() masked_model, masks = pruner.compress()
User configuration for Movement Pruner User configuration for Movement Pruner
...@@ -496,10 +526,15 @@ Usage ...@@ -496,10 +526,15 @@ Usage
.. code-block:: python .. code-block:: python
from nni.algorithms.compression.v2.pytorch.pruning import AutoCompressPruner from nni.algorithms.compression.v2.pytorch.pruning import AutoCompressPruner
from nni.algorithms.compression.v2.pytorch.utils import trace_parameters
# make sure you have used nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize
traced_optimizer = trace_parameters(torch.optim.Adam)(model.parameters())
config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }] config_list = [{ 'sparsity': 0.8, 'op_types': ['Conv2d'] }]
admm_params = { admm_params = {
'trainer': trainer, 'trainer': trainer,
'optimizer': optimizer, 'traced_optimizer': traced_optimizer,
'criterion': criterion, 'criterion': criterion,
'iterations': 10, 'iterations': 10,
'training_epochs': 1 'training_epochs': 1
......
...@@ -8,15 +8,19 @@ Note that pruners use masks to simulate the real pruning. In order to obtain a r ...@@ -8,15 +8,19 @@ Note that pruners use masks to simulate the real pruning. In order to obtain a r
''' '''
import argparse import argparse
import sys
import torch import torch
from torchvision import datasets, transforms from torchvision import datasets, transforms
from torch.optim.lr_scheduler import MultiStepLR from torch.optim.lr_scheduler import MultiStepLR
from nni.compression.pytorch import ModelSpeedup from nni.compression.pytorch import ModelSpeedup
from examples.model_compress.models.cifar10.vgg import VGG
from nni.compression.pytorch.utils.counter import count_flops_params from nni.compression.pytorch.utils.counter import count_flops_params
from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import ActivationAPoZRankPruner, ActivationMeanRankPruner from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import ActivationAPoZRankPruner, ActivationMeanRankPruner
from nni.algorithms.compression.v2.pytorch.utils import trace_parameters
sys.path.append('../../models')
from cifar10.vgg import VGG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
...@@ -105,14 +109,16 @@ if __name__ == '__main__': ...@@ -105,14 +109,16 @@ if __name__ == '__main__':
# Start to prune and speedup # Start to prune and speedup
print('\n' + '=' * 50 + ' START TO PRUNE THE BEST ACCURACY PRETRAINED MODEL ' + '=' * 50) print('\n' + '=' * 50 + ' START TO PRUNE THE BEST ACCURACY PRETRAINED MODEL ' + '=' * 50)
config_list = [{ config_list = [{
'total_sparsity': 0.5, 'total_sparsity': 0.5,
'op_types': ['Conv2d'], 'op_types': ['Conv2d'],
}] }]
optimizer, _ = optimizer_scheduler_generator(model)
# make sure you have used nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize
traced_optimizer = trace_parameters(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
if 'apoz' in args.pruner: if 'apoz' in args.pruner:
pruner = ActivationAPoZRankPruner(model, config_list, trainer, optimizer, criterion, training_batches=1) pruner = ActivationAPoZRankPruner(model, config_list, trainer, traced_optimizer, criterion, training_batches=1)
else: else:
pruner = ActivationMeanRankPruner(model, config_list, trainer, optimizer, criterion, training_batches=1) pruner = ActivationMeanRankPruner(model, config_list, trainer, traced_optimizer, criterion, training_batches=1)
_, masks = pruner.compress() _, masks = pruner.compress()
pruner.show_pruned_weights() pruner.show_pruned_weights()
pruner._unwrap_model() pruner._unwrap_model()
......
...@@ -8,14 +8,18 @@ Note that pruners use masks to simulate the real pruning. In order to obtain a r ...@@ -8,14 +8,18 @@ Note that pruners use masks to simulate the real pruning. In order to obtain a r
''' '''
import argparse import argparse
import sys
import torch import torch
from torchvision import datasets, transforms from torchvision import datasets, transforms
from torch.optim.lr_scheduler import MultiStepLR from torch.optim.lr_scheduler import MultiStepLR
from examples.model_compress.models.cifar10.vgg import VGG
from nni.compression.pytorch.utils.counter import count_flops_params from nni.compression.pytorch.utils.counter import count_flops_params
from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import ADMMPruner from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import ADMMPruner
from nni.algorithms.compression.v2.pytorch.utils import trace_parameters
sys.path.append('../../models')
from cifar10.vgg import VGG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
...@@ -107,8 +111,10 @@ if __name__ == '__main__': ...@@ -107,8 +111,10 @@ if __name__ == '__main__':
'sparsity': 0.92, 'sparsity': 0.92,
'op_types': ['Conv2d'], 'op_types': ['Conv2d'],
}] }]
optimizer, _ = optimizer_scheduler_generator(model)
pruner = ADMMPruner(model, config_list, trainer, optimizer, criterion, iterations=2, training_epochs=2) # make sure you have used nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize
traced_optimizer = trace_parameters(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
pruner = ADMMPruner(model, config_list, trainer, traced_optimizer, criterion, iterations=2, training_epochs=2)
_, masks = pruner.compress() _, masks = pruner.compress()
pruner.show_pruned_weights() pruner.show_pruned_weights()
......
import sys
from tqdm import tqdm from tqdm import tqdm
import torch import torch
from torchvision import datasets, transforms from torchvision import datasets, transforms
from nni.algorithms.compression.v2.pytorch.pruning import AutoCompressPruner from nni.algorithms.compression.v2.pytorch.pruning import AutoCompressPruner
from nni.algorithms.compression.v2.pytorch.utils import trace_parameters
from examples.model_compress.models.cifar10.vgg import VGG sys.path.append('../../models')
from cifar10.vgg import VGG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
...@@ -74,10 +76,11 @@ if __name__ == '__main__': ...@@ -74,10 +76,11 @@ if __name__ == '__main__':
config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.8}] config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.8}]
dummy_input = torch.rand(10, 3, 32, 32).to(device) dummy_input = torch.rand(10, 3, 32, 32).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) # make sure you have used nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize
traced_optimizer = trace_parameters(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
admm_params = { admm_params = {
'trainer': trainer, 'trainer': trainer,
'optimizer': optimizer, 'traced_optimizer': traced_optimizer,
'criterion': criterion, 'criterion': criterion,
'iterations': 10, 'iterations': 10,
'training_epochs': 1 'training_epochs': 1
......
...@@ -8,16 +8,19 @@ Note that pruners use masks to simulate the real pruning. In order to obtain a r ...@@ -8,16 +8,19 @@ Note that pruners use masks to simulate the real pruning. In order to obtain a r
''' '''
import argparse import argparse
import sys
import torch import torch
from torchvision import datasets, transforms from torchvision import datasets, transforms
from torch.optim.lr_scheduler import MultiStepLR from torch.optim.lr_scheduler import MultiStepLR
from nni.compression.pytorch import ModelSpeedup from nni.compression.pytorch import ModelSpeedup
from examples.model_compress.models.cifar10.vgg import VGG
from nni.compression.pytorch.utils.counter import count_flops_params from nni.compression.pytorch.utils.counter import count_flops_params
from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import FPGMPruner from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import FPGMPruner
sys.path.append('../../models')
from cifar10.vgg import VGG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
g_epoch = 0 g_epoch = 0
......
...@@ -16,8 +16,7 @@ from torchvision import datasets, transforms ...@@ -16,8 +16,7 @@ from torchvision import datasets, transforms
from nni.algorithms.compression.v2.pytorch.pruning import ( from nni.algorithms.compression.v2.pytorch.pruning import (
LinearPruner, LinearPruner,
AGPPruner, AGPPruner,
LotteryTicketPruner, LotteryTicketPruner
SimulatedAnnealingPruner
) )
sys.path.append('../../models') sys.path.append('../../models')
......
...@@ -8,15 +8,18 @@ Note that pruners use masks to simulate the real pruning. In order to obtain a r ...@@ -8,15 +8,18 @@ Note that pruners use masks to simulate the real pruning. In order to obtain a r
''' '''
import argparse import argparse
import sys
import torch import torch
from torchvision import datasets, transforms from torchvision import datasets, transforms
from torch.optim.lr_scheduler import MultiStepLR from torch.optim.lr_scheduler import MultiStepLR
from examples.model_compress.models.cifar10.vgg import VGG
from nni.compression.pytorch.utils.counter import count_flops_params from nni.compression.pytorch.utils.counter import count_flops_params
from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import LevelPruner from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import LevelPruner
sys.path.append('../../models')
from cifar10.vgg import VGG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
g_epoch = 0 g_epoch = 0
......
...@@ -14,6 +14,7 @@ from transformers import ( ...@@ -14,6 +14,7 @@ from transformers import (
) )
from nni.algorithms.compression.v2.pytorch.pruning import MovementPruner from nni.algorithms.compression.v2.pytorch.pruning import MovementPruner
from nni.algorithms.compression.v2.pytorch.utils import trace_parameters
task_to_keys = { task_to_keys = {
...@@ -108,8 +109,10 @@ if __name__ == '__main__': ...@@ -108,8 +109,10 @@ if __name__ == '__main__':
config_list = [{'op_types': ['Linear'], 'op_partial_names': ['bert.encoder'], 'sparsity': 0.9}] config_list = [{'op_types': ['Linear'], 'op_partial_names': ['bert.encoder'], 'sparsity': 0.9}]
p_trainer = functools.partial(trainer, train_dataloader=train_dataloader) p_trainer = functools.partial(trainer, train_dataloader=train_dataloader)
optimizer = Adam(model.parameters(), lr=2e-5)
pruner = MovementPruner(model, config_list, p_trainer, optimizer, criterion, training_epochs=10, # make sure you have used nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize
traced_optimizer = trace_parameters(Adam)(model.parameters(), lr=2e-5)
pruner = MovementPruner(model, config_list, p_trainer, traced_optimizer, criterion, training_epochs=10,
warm_up_step=3000, cool_down_beginning_step=27000) warm_up_step=3000, cool_down_beginning_step=27000)
_, masks = pruner.compress() _, masks = pruner.compress()
......
...@@ -8,16 +8,19 @@ Note that pruners use masks to simulate the real pruning. In order to obtain a r ...@@ -8,16 +8,19 @@ Note that pruners use masks to simulate the real pruning. In order to obtain a r
''' '''
import argparse import argparse
import sys
import torch import torch
from torchvision import datasets, transforms from torchvision import datasets, transforms
from torch.optim.lr_scheduler import MultiStepLR from torch.optim.lr_scheduler import MultiStepLR
from nni.compression.pytorch import ModelSpeedup from nni.compression.pytorch import ModelSpeedup
from examples.model_compress.models.cifar10.vgg import VGG
from nni.compression.pytorch.utils.counter import count_flops_params from nni.compression.pytorch.utils.counter import count_flops_params
from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import L1NormPruner, L2NormPruner from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import L1NormPruner, L2NormPruner
sys.path.append('../../models')
from cifar10.vgg import VGG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
g_epoch = 0 g_epoch = 0
......
...@@ -8,15 +8,19 @@ Note that pruners use masks to simulate the real pruning. In order to obtain a r ...@@ -8,15 +8,19 @@ Note that pruners use masks to simulate the real pruning. In order to obtain a r
''' '''
import argparse import argparse
import sys
import torch import torch
from torchvision import datasets, transforms from torchvision import datasets, transforms
from torch.optim.lr_scheduler import MultiStepLR from torch.optim.lr_scheduler import MultiStepLR
from nni.compression.pytorch import ModelSpeedup from nni.compression.pytorch import ModelSpeedup
from examples.model_compress.models.cifar10.vgg import VGG
from nni.compression.pytorch.utils.counter import count_flops_params from nni.compression.pytorch.utils.counter import count_flops_params
from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import SlimPruner from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import SlimPruner
from nni.algorithms.compression.v2.pytorch.utils import trace_parameters
sys.path.append('../../models')
from cifar10.vgg import VGG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
...@@ -107,8 +111,9 @@ if __name__ == '__main__': ...@@ -107,8 +111,9 @@ if __name__ == '__main__':
'max_sparsity_per_layer': 0.9 'max_sparsity_per_layer': 0.9
}] }]
optimizer, _ = optimizer_scheduler_generator(model) # make sure you have used nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize
pruner = SlimPruner(model, config_list, trainer, optimizer, criterion, training_epochs=1, scale=0.0001, mode='global') traced_optimizer = trace_parameters(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
pruner = SlimPruner(model, config_list, trainer, traced_optimizer, criterion, training_epochs=1, scale=0.0001, mode='global')
_, masks = pruner.compress() _, masks = pruner.compress()
pruner.show_pruned_weights() pruner.show_pruned_weights()
pruner._unwrap_model() pruner._unwrap_model()
......
...@@ -8,15 +8,19 @@ Note that pruners use masks to simulate the real pruning. In order to obtain a r ...@@ -8,15 +8,19 @@ Note that pruners use masks to simulate the real pruning. In order to obtain a r
''' '''
import argparse import argparse
import sys
import torch import torch
from torchvision import datasets, transforms from torchvision import datasets, transforms
from torch.optim.lr_scheduler import MultiStepLR from torch.optim.lr_scheduler import MultiStepLR
from nni.compression.pytorch import ModelSpeedup from nni.compression.pytorch import ModelSpeedup
from examples.model_compress.models.cifar10.vgg import VGG
from nni.compression.pytorch.utils.counter import count_flops_params from nni.compression.pytorch.utils.counter import count_flops_params
from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import TaylorFOWeightPruner from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import TaylorFOWeightPruner
from nni.algorithms.compression.v2.pytorch.utils import trace_parameters
sys.path.append('../../models')
from cifar10.vgg import VGG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
...@@ -102,11 +106,13 @@ if __name__ == '__main__': ...@@ -102,11 +106,13 @@ if __name__ == '__main__':
# Start to prune and speedup # Start to prune and speedup
print('\n' + '=' * 50 + ' START TO PRUNE THE BEST ACCURACY PRETRAINED MODEL ' + '=' * 50) print('\n' + '=' * 50 + ' START TO PRUNE THE BEST ACCURACY PRETRAINED MODEL ' + '=' * 50)
config_list = [{ config_list = [{
'total_sparsity': 0.5, 'total_sparsity': 0.5,
'op_types': ['Conv2d'], 'op_types': ['Conv2d'],
}] }]
optimizer, _ = optimizer_scheduler_generator(model)
pruner = TaylorFOWeightPruner(model, config_list, trainer, optimizer, criterion, training_batches=1) # make sure you have used nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize
traced_optimizer = trace_parameters(torch.optim.SGD)(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
pruner = TaylorFOWeightPruner(model, config_list, trainer, traced_optimizer, criterion, training_batches=1)
_, masks = pruner.compress() _, masks = pruner.compress()
pruner.show_pruned_weights() pruner.show_pruned_weights()
pruner._unwrap_model() pruner._unwrap_model()
......
...@@ -249,6 +249,24 @@ class Compressor: ...@@ -249,6 +249,24 @@ class Compressor:
self._wrap_model() self._wrap_model()
return module_groups return module_groups
def get_origin2wrapped_parameter_name_map(self) -> Dict[str, str]:
"""
Get the name mapping of parameters from original model to wrapped model.
Returns
-------
Dict[str, str]
Return a dict `{original_model_parameter_name: wrapped_model_parameter_name}`
"""
if self.is_wrapped:
wrapped_param_names = {id(param): name for name, param in self.bound_model.named_parameters()}
self._unwrap_model()
parameter_name_map = {name: wrapped_param_names[id(param)] for name, param in self.bound_model.named_parameters()}
self._wrap_model()
return parameter_name_map
else:
raise Exception('When only the model is wrapped can get the parameter_name_map.')
def _wrap_modules(self, layer: LayerInfo, config: Dict): def _wrap_modules(self, layer: LayerInfo, config: Dict):
""" """
This method is implemented in the subclasses, i.e., `Pruner` and `Quantizer` This method is implemented in the subclasses, i.e., `Pruner` and `Quantizer`
......
...@@ -7,6 +7,8 @@ from typing import Dict, List, Callable, Optional ...@@ -7,6 +7,8 @@ from typing import Dict, List, Callable, Optional
from torch import Tensor from torch import Tensor
from torch.nn import Module from torch.nn import Module
from nni.algorithms.compression.v2.pytorch.utils import OptimizerConstructHelper
from .basic_pruner import ADMMPruner from .basic_pruner import ADMMPruner
from .iterative_pruner import IterativePruner, SimulatedAnnealingPruner from .iterative_pruner import IterativePruner, SimulatedAnnealingPruner
from .tools import LotteryTicketTaskGenerator from .tools import LotteryTicketTaskGenerator
...@@ -56,9 +58,9 @@ class AutoCompressPruner(IterativePruner): ...@@ -56,9 +58,9 @@ class AutoCompressPruner(IterativePruner):
- trainer : Callable[[Module, Optimizer, Callable]. - trainer : Callable[[Module, Optimizer, Callable].
A callable function used to train model or just inference. Take model, optimizer, criterion as input. A callable function used to train model or just inference. Take model, optimizer, criterion as input.
The model will be trained or inferenced `training_epochs` epochs. The model will be trained or inferenced `training_epochs` epochs.
- optimizer : torch.optim.Optimizer. - traced_optimizer : nni.common.serializer.Traceable(torch.optim.Optimizer)
The optimizer instance used in trainer. Note that this optimizer might be patched during collect data, The traced optimizer instance which the optimizer class is wrapped by nni.algorithms.compression.v2.pytorch.utils.trace_parameters.
so do not use this optimizer in other places. E.g. traced_optimizer = nni.algorithms.compression.v2.pytorch.utils.trace_parameters(torch.nn.Adam)(model.parameters()).
- criterion : Callable[[Tensor, Tensor], Tensor]. - criterion : Callable[[Tensor, Tensor], Tensor].
The criterion function used in trainer. Take model output and target value as input, and return the loss. The criterion function used in trainer. Take model output and target value as input, and return the loss.
- iterations : int. - iterations : int.
...@@ -107,6 +109,8 @@ class AutoCompressPruner(IterativePruner): ...@@ -107,6 +109,8 @@ class AutoCompressPruner(IterativePruner):
sa_params=sa_params, sa_params=sa_params,
log_dir=log_dir, log_dir=log_dir,
keep_intermediate_result=keep_intermediate_result) keep_intermediate_result=keep_intermediate_result)
if 'traced_optimizer' in admm_params:
admm_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, admm_params['traced_optimizer'])
pruner = ADMMPruner(None, None, **admm_params) pruner = ADMMPruner(None, None, **admm_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input, super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False) evaluator=evaluator, reset_weight=False)
...@@ -12,8 +12,9 @@ import torch.nn as nn ...@@ -12,8 +12,9 @@ import torch.nn as nn
from torch.nn import Module from torch.nn import Module
from torch.optim import Optimizer from torch.optim import Optimizer
from nni.common.serializer import Traceable
from nni.algorithms.compression.v2.pytorch.base.pruner import Pruner from nni.algorithms.compression.v2.pytorch.base.pruner import Pruner
from nni.algorithms.compression.v2.pytorch.utils import CompressorSchema, config_list_canonical from nni.algorithms.compression.v2.pytorch.utils import CompressorSchema, config_list_canonical, OptimizerConstructHelper
from .tools import ( from .tools import (
DataCollector, DataCollector,
...@@ -371,9 +372,9 @@ class SlimPruner(BasicPruner): ...@@ -371,9 +372,9 @@ class SlimPruner(BasicPruner):
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False. # If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step() optimizer.step()
model.train(mode=training) model.train(mode=training)
optimizer : torch.optim.Optimizer traced_optimizer : nni.common.serializer.Traceable(torch.optim.Optimizer)
The optimizer instance used in trainer. Note that this optimizer might be patched during collect data, The traced optimizer instance which the optimizer class is wrapped by nni.algorithms.compression.v2.pytorch.utils.trace_parameters.
so do not use this optimizer in other places. E.g. traced_optimizer = nni.algorithms.compression.v2.pytorch.utils.trace_parameters(torch.nn.Adam)(model.parameters()).
criterion : Callable[[Tensor, Tensor], Tensor] criterion : Callable[[Tensor, Tensor], Tensor]
The criterion function used in trainer. Take model output and target value as input, and return the loss. The criterion function used in trainer. Take model output and target value as input, and return the loss.
training_epochs : int training_epochs : int
...@@ -388,11 +389,14 @@ class SlimPruner(BasicPruner): ...@@ -388,11 +389,14 @@ class SlimPruner(BasicPruner):
""" """
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None], def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], traced_optimizer: Traceable, criterion: Callable[[Tensor, Tensor], Tensor],
training_epochs: int, scale: float = 0.0001, mode='global'): training_epochs: int, scale: float = 0.0001, mode='global'):
self.mode = mode self.mode = mode
self.trainer = trainer self.trainer = trainer
self.optimizer = optimizer if isinstance(traced_optimizer, OptimizerConstructHelper):
self.optimizer_helper = traced_optimizer
else:
self.optimizer_helper = OptimizerConstructHelper.from_trace(model, traced_optimizer)
self.criterion = criterion self.criterion = criterion
self.training_epochs = training_epochs self.training_epochs = training_epochs
self._scale = scale self._scale = scale
...@@ -420,7 +424,7 @@ class SlimPruner(BasicPruner): ...@@ -420,7 +424,7 @@ class SlimPruner(BasicPruner):
def reset_tools(self): def reset_tools(self):
if self.data_collector is None: if self.data_collector is None:
self.data_collector = WeightTrainerBasedDataCollector(self, self.trainer, self.optimizer, self.criterion, self.data_collector = WeightTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion,
self.training_epochs, criterion_patch=self.criterion_patch) self.training_epochs, criterion_patch=self.criterion_patch)
else: else:
self.data_collector.reset() self.data_collector.reset()
...@@ -467,9 +471,9 @@ class ActivationPruner(BasicPruner): ...@@ -467,9 +471,9 @@ class ActivationPruner(BasicPruner):
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False. # If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step() optimizer.step()
model.train(mode=training) model.train(mode=training)
optimizer : torch.optim.Optimizer traced_optimizer : nni.common.serializer.Traceable(torch.optim.Optimizer)
The optimizer instance used in trainer. Note that this optimizer might be patched during collect data, The traced optimizer instance which the optimizer class is wrapped by nni.algorithms.compression.v2.pytorch.utils.trace_parameters.
so do not use this optimizer in other places. E.g. traced_optimizer = nni.algorithms.compression.v2.pytorch.utils.trace_parameters(torch.nn.Adam)(model.parameters()).
criterion : Callable[[Tensor, Tensor], Tensor] criterion : Callable[[Tensor, Tensor], Tensor]
The criterion function used in trainer. Take model output and target value as input, and return the loss. The criterion function used in trainer. Take model output and target value as input, and return the loss.
training_batches training_batches
...@@ -489,12 +493,15 @@ class ActivationPruner(BasicPruner): ...@@ -489,12 +493,15 @@ class ActivationPruner(BasicPruner):
""" """
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None], def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], training_batches: int, activation: str = 'relu', traced_optimizer: Traceable, criterion: Callable[[Tensor, Tensor], Tensor], training_batches: int, activation: str = 'relu',
mode: str = 'normal', dummy_input: Optional[Tensor] = None): mode: str = 'normal', dummy_input: Optional[Tensor] = None):
self.mode = mode self.mode = mode
self.dummy_input = dummy_input self.dummy_input = dummy_input
self.trainer = trainer self.trainer = trainer
self.optimizer = optimizer if isinstance(traced_optimizer, OptimizerConstructHelper):
self.optimizer_helper = traced_optimizer
else:
self.optimizer_helper = OptimizerConstructHelper.from_trace(model, traced_optimizer)
self.criterion = criterion self.criterion = criterion
self.training_batches = training_batches self.training_batches = training_batches
self._activation = self._choose_activation(activation) self._activation = self._choose_activation(activation)
...@@ -525,10 +532,10 @@ class ActivationPruner(BasicPruner): ...@@ -525,10 +532,10 @@ class ActivationPruner(BasicPruner):
def reset_tools(self): def reset_tools(self):
collector_info = HookCollectorInfo([layer_info for layer_info, _ in self._detect_modules_to_compress()], 'forward', self._collector) collector_info = HookCollectorInfo([layer_info for layer_info, _ in self._detect_modules_to_compress()], 'forward', self._collector)
if self.data_collector is None: if self.data_collector is None:
self.data_collector = SingleHookTrainerBasedDataCollector(self, self.trainer, self.optimizer, self.criterion, self.data_collector = SingleHookTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion,
1, collector_infos=[collector_info]) 1, collector_infos=[collector_info])
else: else:
self.data_collector.reset() self.data_collector.reset(collector_infos=[collector_info])
if self.metrics_calculator is None: if self.metrics_calculator is None:
self.metrics_calculator = self._get_metrics_calculator() self.metrics_calculator = self._get_metrics_calculator()
if self.sparsity_allocator is None: if self.sparsity_allocator is None:
...@@ -587,9 +594,9 @@ class TaylorFOWeightPruner(BasicPruner): ...@@ -587,9 +594,9 @@ class TaylorFOWeightPruner(BasicPruner):
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False. # If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step() optimizer.step()
model.train(mode=training) model.train(mode=training)
optimizer : torch.optim.Optimizer traced_optimizer : nni.common.serializer.Traceable(torch.optim.Optimizer)
The optimizer instance used in trainer. Note that this optimizer might be patched during collect data, The traced optimizer instance which the optimizer class is wrapped by nni.algorithms.compression.v2.pytorch.utils.trace_parameters.
so do not use this optimizer in other places. E.g. traced_optimizer = nni.algorithms.compression.v2.pytorch.utils.trace_parameters(torch.nn.Adam)(model.parameters()).
criterion : Callable[[Tensor, Tensor], Tensor] criterion : Callable[[Tensor, Tensor], Tensor]
The criterion function used in trainer. Take model output and target value as input, and return the loss. The criterion function used in trainer. Take model output and target value as input, and return the loss.
training_batches : int training_batches : int
...@@ -614,12 +621,15 @@ class TaylorFOWeightPruner(BasicPruner): ...@@ -614,12 +621,15 @@ class TaylorFOWeightPruner(BasicPruner):
""" """
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None], def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], training_batches: int, traced_optimizer: Traceable, criterion: Callable[[Tensor, Tensor], Tensor], training_batches: int,
mode: str = 'normal', dummy_input: Optional[Tensor] = None): mode: str = 'normal', dummy_input: Optional[Tensor] = None):
self.mode = mode self.mode = mode
self.dummy_input = dummy_input self.dummy_input = dummy_input
self.trainer = trainer self.trainer = trainer
self.optimizer = optimizer if isinstance(traced_optimizer, OptimizerConstructHelper):
self.optimizer_helper = traced_optimizer
else:
self.optimizer_helper = OptimizerConstructHelper.from_trace(model, traced_optimizer)
self.criterion = criterion self.criterion = criterion
self.training_batches = training_batches self.training_batches = training_batches
super().__init__(model, config_list) super().__init__(model, config_list)
...@@ -649,10 +659,10 @@ class TaylorFOWeightPruner(BasicPruner): ...@@ -649,10 +659,10 @@ class TaylorFOWeightPruner(BasicPruner):
hook_targets = {layer_info.name: layer_info.module.weight for layer_info, _ in self._detect_modules_to_compress()} hook_targets = {layer_info.name: layer_info.module.weight for layer_info, _ in self._detect_modules_to_compress()}
collector_info = HookCollectorInfo(hook_targets, 'tensor', self._collector) collector_info = HookCollectorInfo(hook_targets, 'tensor', self._collector)
if self.data_collector is None: if self.data_collector is None:
self.data_collector = SingleHookTrainerBasedDataCollector(self, self.trainer, self.optimizer, self.criterion, self.data_collector = SingleHookTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion,
1, collector_infos=[collector_info]) 1, collector_infos=[collector_info])
else: else:
self.data_collector.reset() self.data_collector.reset(collector_infos=[collector_info])
if self.metrics_calculator is None: if self.metrics_calculator is None:
self.metrics_calculator = MultiDataNormMetricsCalculator(p=1, dim=0) self.metrics_calculator = MultiDataNormMetricsCalculator(p=1, dim=0)
if self.sparsity_allocator is None: if self.sparsity_allocator is None:
...@@ -706,9 +716,9 @@ class ADMMPruner(BasicPruner): ...@@ -706,9 +716,9 @@ class ADMMPruner(BasicPruner):
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False. # If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step() optimizer.step()
model.train(mode=training) model.train(mode=training)
optimizer : torch.optim.Optimizer traced_optimizer : nni.common.serializer.Traceable(torch.optim.Optimizer)
The optimizer instance used in trainer. Note that this optimizer might be patched during collect data, The traced optimizer instance which the optimizer class is wrapped by nni.algorithms.compression.v2.pytorch.utils.trace_parameters.
so do not use this optimizer in other places. E.g. traced_optimizer = nni.algorithms.compression.v2.pytorch.utils.trace_parameters(torch.nn.Adam)(model.parameters()).
criterion : Callable[[Tensor, Tensor], Tensor] criterion : Callable[[Tensor, Tensor], Tensor]
The criterion function used in trainer. Take model output and target value as input, and return the loss. The criterion function used in trainer. Take model output and target value as input, and return the loss.
iterations : int iterations : int
...@@ -718,10 +728,12 @@ class ADMMPruner(BasicPruner): ...@@ -718,10 +728,12 @@ class ADMMPruner(BasicPruner):
""" """
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None], def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], iterations: int, training_epochs: int): traced_optimizer: Traceable, criterion: Callable[[Tensor, Tensor], Tensor], iterations: int, training_epochs: int):
self.trainer = trainer self.trainer = trainer
# TODO: handle optimizer here will case additional memory use, need improve, also in WeightTrainerBasedDataCollector if isinstance(traced_optimizer, OptimizerConstructHelper):
self.optimizer = optimizer self.optimizer_helper = traced_optimizer
else:
self.optimizer_helper = OptimizerConstructHelper.from_trace(model, traced_optimizer)
self.criterion = criterion self.criterion = criterion
self.iterations = iterations self.iterations = iterations
self.training_epochs = training_epochs self.training_epochs = training_epochs
...@@ -751,7 +763,7 @@ class ADMMPruner(BasicPruner): ...@@ -751,7 +763,7 @@ class ADMMPruner(BasicPruner):
def reset_tools(self): def reset_tools(self):
if self.data_collector is None: if self.data_collector is None:
self.data_collector = WeightTrainerBasedDataCollector(self, self.trainer, self.optimizer, self.criterion, self.data_collector = WeightTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion,
self.training_epochs, criterion_patch=self.criterion_patch) self.training_epochs, criterion_patch=self.criterion_patch)
else: else:
self.data_collector.reset() self.data_collector.reset()
......
...@@ -7,6 +7,8 @@ from typing import Dict, List, Callable, Optional ...@@ -7,6 +7,8 @@ from typing import Dict, List, Callable, Optional
from torch import Tensor from torch import Tensor
from torch.nn import Module from torch.nn import Module
from nni.algorithms.compression.v2.pytorch.utils import OptimizerConstructHelper
from .basic_pruner import ( from .basic_pruner import (
LevelPruner, LevelPruner,
L1NormPruner, L1NormPruner,
...@@ -107,6 +109,8 @@ class LinearPruner(IterativePruner): ...@@ -107,6 +109,8 @@ class LinearPruner(IterativePruner):
origin_config_list=config_list, origin_config_list=config_list,
log_dir=log_dir, log_dir=log_dir,
keep_intermediate_result=keep_intermediate_result) keep_intermediate_result=keep_intermediate_result)
if 'traced_optimizer' in pruning_params:
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer'])
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params) pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input, super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False) evaluator=evaluator, reset_weight=False)
...@@ -152,6 +156,8 @@ class AGPPruner(IterativePruner): ...@@ -152,6 +156,8 @@ class AGPPruner(IterativePruner):
origin_config_list=config_list, origin_config_list=config_list,
log_dir=log_dir, log_dir=log_dir,
keep_intermediate_result=keep_intermediate_result) keep_intermediate_result=keep_intermediate_result)
if 'traced_optimizer' in pruning_params:
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer'])
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params) pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input, super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False) evaluator=evaluator, reset_weight=False)
...@@ -200,6 +206,8 @@ class LotteryTicketPruner(IterativePruner): ...@@ -200,6 +206,8 @@ class LotteryTicketPruner(IterativePruner):
origin_config_list=config_list, origin_config_list=config_list,
log_dir=log_dir, log_dir=log_dir,
keep_intermediate_result=keep_intermediate_result) keep_intermediate_result=keep_intermediate_result)
if 'traced_optimizer' in pruning_params:
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer'])
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params) pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input, super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=reset_weight) evaluator=evaluator, reset_weight=reset_weight)
...@@ -252,6 +260,8 @@ class SimulatedAnnealingPruner(IterativePruner): ...@@ -252,6 +260,8 @@ class SimulatedAnnealingPruner(IterativePruner):
perturbation_magnitude=perturbation_magnitude, perturbation_magnitude=perturbation_magnitude,
log_dir=log_dir, log_dir=log_dir,
keep_intermediate_result=keep_intermediate_result) keep_intermediate_result=keep_intermediate_result)
if 'traced_optimizer' in pruning_params:
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer'])
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params) pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input, super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=speed_up, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False) evaluator=evaluator, reset_weight=False)
...@@ -12,7 +12,8 @@ from torch.optim import Optimizer, Adam ...@@ -12,7 +12,8 @@ from torch.optim import Optimizer, Adam
from nni.algorithms.compression.v2.pytorch.base.compressor import Compressor, _setattr, LayerInfo from nni.algorithms.compression.v2.pytorch.base.compressor import Compressor, _setattr, LayerInfo
from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import BasicPruner, NORMAL_SCHEMA, EXCLUDE_SCHEMA, INTERNAL_SCHEMA from nni.algorithms.compression.v2.pytorch.pruning.basic_pruner import BasicPruner, NORMAL_SCHEMA, EXCLUDE_SCHEMA, INTERNAL_SCHEMA
from nni.algorithms.compression.v2.pytorch.utils import CompressorSchema from nni.algorithms.compression.v2.pytorch.utils import CompressorSchema, OptimizerConstructHelper
from nni.common.serializer import Traceable
from .tools.base import TrainerBasedDataCollector from .tools.base import TrainerBasedDataCollector
...@@ -50,8 +51,6 @@ class PrunerScoredModuleWrapper(Module): ...@@ -50,8 +51,6 @@ class PrunerScoredModuleWrapper(Module):
self.pruner = pruner self.pruner = pruner
self.weight = Parameter(torch.empty(self.module.weight.size())) self.weight = Parameter(torch.empty(self.module.weight.size()))
self.weight.data = self.module.weight.data
self.weight_score = Parameter(torch.empty(self.weight.size())) self.weight_score = Parameter(torch.empty(self.weight.size()))
torch.nn.init.constant_(self.weight_score, val=0.0) torch.nn.init.constant_(self.weight_score, val=0.0)
...@@ -60,7 +59,6 @@ class PrunerScoredModuleWrapper(Module): ...@@ -60,7 +59,6 @@ class PrunerScoredModuleWrapper(Module):
if hasattr(self.module, 'bias') and self.module.bias is not None: if hasattr(self.module, 'bias') and self.module.bias is not None:
self.register_buffer("bias_mask", torch.ones(self.module.bias.shape)) self.register_buffer("bias_mask", torch.ones(self.module.bias.shape))
self.bias = Parameter(torch.empty(self.module.bias.size())) self.bias = Parameter(torch.empty(self.module.bias.size()))
self.bias.data = self.module.bias.data
else: else:
self.register_buffer("bias_mask", None) self.register_buffer("bias_mask", None)
...@@ -69,9 +67,11 @@ class PrunerScoredModuleWrapper(Module): ...@@ -69,9 +67,11 @@ class PrunerScoredModuleWrapper(Module):
When using this wrapper to inference, call `_weight2buffer()` to make original weight untrainable. When using this wrapper to inference, call `_weight2buffer()` to make original weight untrainable.
The best place to call this function is in `Pruner._wrap_model()`. The best place to call this function is in `Pruner._wrap_model()`.
""" """
self.weight.data = self.module.weight.data
delattr(self.module, 'weight') delattr(self.module, 'weight')
self.module.register_buffer('weight', self.weight.data) self.module.register_buffer('weight', self.weight.data)
if hasattr(self.module, 'bias') and self.module.bias is not None: if hasattr(self.module, 'bias') and self.module.bias is not None:
self.bias.data = self.module.bias.data
delattr(self.module, 'bias') delattr(self.module, 'bias')
self.module.register_buffer('bias', self.bias.data) self.module.register_buffer('bias', self.bias.data)
...@@ -113,22 +113,6 @@ class WeightScoreTrainerBasedDataCollector(TrainerBasedDataCollector): ...@@ -113,22 +113,6 @@ class WeightScoreTrainerBasedDataCollector(TrainerBasedDataCollector):
""" """
Collect all weight_score in wrappers as data used to calculate metrics. Collect all weight_score in wrappers as data used to calculate metrics.
""" """
def _reset_optimizer(self):
"""
Weed out the weight_score from the parameters passed to optimizer, guaranteed to load the optimizer state dict.
"""
if self._origin_optimizer_cls is not None:
optimizer_grouped_parameters = [{
"params": [p for n, p in self.compressor.bound_model.named_parameters() if "weight_score" not in n and p.requires_grad]
}]
if self._origin_optimizer_cls.__name__ == 'SGD':
self.optimizer = self._origin_optimizer_cls(optimizer_grouped_parameters, lr=0.001)
else:
self.optimizer = self._origin_optimizer_cls(optimizer_grouped_parameters)
self.optimizer.load_state_dict(self._origin_optimizer_state_dict)
else:
self.optimizer = None
def collect(self) -> Dict[str, Tensor]: def collect(self) -> Dict[str, Tensor]:
for _ in range(self.training_epochs): for _ in range(self.training_epochs):
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion) self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
...@@ -171,9 +155,9 @@ class MovementPruner(BasicPruner): ...@@ -171,9 +155,9 @@ class MovementPruner(BasicPruner):
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False. # If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step() optimizer.step()
model.train(mode=training) model.train(mode=training)
optimizer : torch.optim.Optimizer traced_optimizer : nni.common.serializer.Traceable(torch.optim.Optimizer)
The optimizer instance used in trainer. Note that this optimizer might be patched during collect data, The traced optimizer instance which the optimizer class is wrapped by nni.algorithms.compression.v2.pytorch.utils.trace_parameters.
so do not use this optimizer in other places. E.g. traced_optimizer = nni.algorithms.compression.v2.pytorch.utils.trace_parameters(torch.nn.Adam)(model.parameters()).
criterion : Callable[[Tensor, Tensor], Tensor] criterion : Callable[[Tensor, Tensor], Tensor]
The criterion function used in trainer. Take model output and target value as input, and return the loss. The criterion function used in trainer. Take model output and target value as input, and return the loss.
training_epochs : int training_epochs : int
...@@ -188,10 +172,13 @@ class MovementPruner(BasicPruner): ...@@ -188,10 +172,13 @@ class MovementPruner(BasicPruner):
total_sparsity * (1 - (1 - (current_step - warm_up_step) / (cool_down_beginning_step - warm_up_step)) ** 3). total_sparsity * (1 - (1 - (current_step - warm_up_step) / (cool_down_beginning_step - warm_up_step)) ** 3).
""" """
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None], def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None],
optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int, warm_up_step: int, traced_optimizer: Traceable, criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int, warm_up_step: int,
cool_down_beginning_step: int): cool_down_beginning_step: int):
self.trainer = trainer self.trainer = trainer
self.optimizer = optimizer if isinstance(traced_optimizer, OptimizerConstructHelper):
self.optimizer_helper = traced_optimizer
else:
self.optimizer_helper = OptimizerConstructHelper.from_trace(model, traced_optimizer)
self.criterion = criterion self.criterion = criterion
self.training_epochs = training_epochs self.training_epochs = training_epochs
self.warm_up_step = warm_up_step self.warm_up_step = warm_up_step
...@@ -238,7 +225,7 @@ class MovementPruner(BasicPruner): ...@@ -238,7 +225,7 @@ class MovementPruner(BasicPruner):
self.load_masks(masks) self.load_masks(masks)
if self.data_collector is None: if self.data_collector is None:
self.data_collector = WeightScoreTrainerBasedDataCollector(self, self.trainer, self.optimizer, self.criterion, self.training_epochs, opt_after_tasks=[_optimizer_patch]) self.data_collector = WeightScoreTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion, self.training_epochs, opt_after_tasks=[_optimizer_patch])
else: else:
self.data_collector.reset() self.data_collector.reset()
...@@ -283,6 +270,15 @@ class MovementPruner(BasicPruner): ...@@ -283,6 +270,15 @@ class MovementPruner(BasicPruner):
wrapper.to(layer.module.weight.device) wrapper.to(layer.module.weight.device)
return wrapper return wrapper
def get_origin2wrapped_parameter_name_map(self) -> Dict[str, str]:
if self.is_wrapped:
self._unwrap_model()
parameter_name_map = {name: name for name, _ in self.bound_model.named_parameters()}
self._wrap_model()
return parameter_name_map
else:
raise Exception('When only the model is wrapped can get the parameter_name_map.')
def compress(self) -> Tuple[Module, Dict]: def compress(self) -> Tuple[Module, Dict]:
# sparsity grow from 0 # sparsity grow from 0
for _, wrapper in self.get_modules_wrapper().items(): for _, wrapper in self.get_modules_wrapper().items():
......
...@@ -14,6 +14,7 @@ from torch.nn import Module ...@@ -14,6 +14,7 @@ from torch.nn import Module
from torch.optim import Optimizer from torch.optim import Optimizer
from nni.algorithms.compression.v2.pytorch.base import Compressor, LayerInfo, Task, TaskResult from nni.algorithms.compression.v2.pytorch.base import Compressor, LayerInfo, Task, TaskResult
from nni.algorithms.compression.v2.pytorch.utils import OptimizerConstructHelper
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -76,7 +77,7 @@ class TrainerBasedDataCollector(DataCollector): ...@@ -76,7 +77,7 @@ class TrainerBasedDataCollector(DataCollector):
This class includes some trainer based util functions, i.e., patch optimizer or criterion, add hooks. This class includes some trainer based util functions, i.e., patch optimizer or criterion, add hooks.
""" """
def __init__(self, compressor: Compressor, trainer: Callable[[Module, Optimizer, Callable], None], optimizer: Optimizer, def __init__(self, compressor: Compressor, trainer: Callable[[Module, Optimizer, Callable], None], optimizer_helper: OptimizerConstructHelper,
criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int, criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int,
opt_before_tasks: List = [], opt_after_tasks: List = [], opt_before_tasks: List = [], opt_after_tasks: List = [],
collector_infos: List[HookCollectorInfo] = [], criterion_patch: Callable[[Callable], Callable] = None): collector_infos: List[HookCollectorInfo] = [], criterion_patch: Callable[[Callable], Callable] = None):
...@@ -133,19 +134,16 @@ class TrainerBasedDataCollector(DataCollector): ...@@ -133,19 +134,16 @@ class TrainerBasedDataCollector(DataCollector):
super().__init__(compressor) super().__init__(compressor)
self.trainer = trainer self.trainer = trainer
self.training_epochs = training_epochs self.training_epochs = training_epochs
self._origin_optimizer_cls = optimizer.__class__ if optimizer is not None else None self.optimizer_helper = optimizer_helper
self._origin_optimizer_state_dict = optimizer.state_dict() if optimizer is not None else None
self._origin_criterion = criterion self._origin_criterion = criterion
self._opt_before_tasks = opt_before_tasks self._opt_before_tasks = opt_before_tasks
self._opt_after_tasks = opt_after_tasks self._opt_after_tasks = opt_after_tasks
self._collector_infos = collector_infos
self._criterion_patch = criterion_patch self._criterion_patch = criterion_patch
self.reset() self.reset(collector_infos)
def reset(self): def reset(self, collector_infos: List[HookCollectorInfo] = []):
# refresh optimizer and criterion # refresh optimizer and criterion
self._reset_optimizer() self._reset_optimizer()
...@@ -162,19 +160,13 @@ class TrainerBasedDataCollector(DataCollector): ...@@ -162,19 +160,13 @@ class TrainerBasedDataCollector(DataCollector):
self._hook_id = 0 self._hook_id = 0
self._hook_handles = {} self._hook_handles = {}
self._hook_buffer = {} self._hook_buffer = {}
self._collector_infos = collector_infos
self._add_all_hook() self._add_all_hook()
def _reset_optimizer(self): def _reset_optimizer(self):
self.compressor._unwrap_model() parameter_name_map = self.compressor.get_origin2wrapped_parameter_name_map()
if self._origin_optimizer_cls is not None: self.optimizer = self.optimizer_helper.call(self.compressor.bound_model, parameter_name_map)
if self._origin_optimizer_cls.__name__ == 'SGD':
self.optimizer = self._origin_optimizer_cls(self.compressor.bound_model.parameters(), lr=0.001)
else:
self.optimizer = self._origin_optimizer_cls(self.compressor.bound_model.parameters())
self.optimizer.load_state_dict(self._origin_optimizer_state_dict)
else:
self.optimizer = None
self.compressor._wrap_model()
def _patch_optimizer(self): def _patch_optimizer(self):
def patch_step(old_step): def patch_step(old_step):
...@@ -233,7 +225,7 @@ class TrainerBasedDataCollector(DataCollector): ...@@ -233,7 +225,7 @@ class TrainerBasedDataCollector(DataCollector):
def _remove_hook(self, hook_id: int): def _remove_hook(self, hook_id: int):
if hook_id not in self._hook_handles: if hook_id not in self._hook_handles:
raise ValueError("%s is not a valid collector id" % str(hook_id)) raise ValueError("%s is not a valid collector id" % str(hook_id))
for handle in self._hook_handles[hook_id]: for handle in self._hook_handles[hook_id].values():
handle.remove() handle.remove()
del self._hook_handles[hook_id] del self._hook_handles[hook_id]
......
...@@ -23,7 +23,7 @@ class NormalSparsityAllocator(SparsityAllocator): ...@@ -23,7 +23,7 @@ class NormalSparsityAllocator(SparsityAllocator):
for name, wrapper in self.pruner.get_modules_wrapper().items(): for name, wrapper in self.pruner.get_modules_wrapper().items():
sparsity_rate = wrapper.config['total_sparsity'] sparsity_rate = wrapper.config['total_sparsity']
assert name in metrics, 'Metric of %s is not calculated.' assert name in metrics, 'Metric of {} is not calculated.'.format(name)
# We assume the metric value are all positive right now. # We assume the metric value are all positive right now.
metric = metrics[name] metric = metrics[name]
......
...@@ -73,7 +73,7 @@ class FunctionBasedTaskGenerator(TaskGenerator): ...@@ -73,7 +73,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
# get current2origin_sparsity and compact2origin_sparsity # get current2origin_sparsity and compact2origin_sparsity
origin_model = torch.load(self._origin_model_path) origin_model = torch.load(self._origin_model_path)
current2origin_sparsity, compact2origin_sparsity, _ = compute_sparsity(origin_model, compact_model, compact_model_masks, self.target_sparsity) current2origin_sparsity, compact2origin_sparsity, _ = compute_sparsity(origin_model, compact_model, compact_model_masks, self.target_sparsity)
_logger.info('\nTask %s total real sparsity compared with original model is:\n%s', str(task_result.task_id), json_tricks.dumps(current2origin_sparsity, indent=4)) _logger.debug('\nTask %s total real sparsity compared with original model is:\n%s', str(task_result.task_id), json_tricks.dumps(current2origin_sparsity, indent=4))
if task_result.task_id != 'origin': if task_result.task_id != 'origin':
self._tasks[task_result.task_id].state['current2origin_sparsity'] = current2origin_sparsity self._tasks[task_result.task_id].state['current2origin_sparsity'] = current2origin_sparsity
......
...@@ -9,3 +9,4 @@ from .pruning import ( ...@@ -9,3 +9,4 @@ from .pruning import (
get_model_weights_numel, get_model_weights_numel,
get_module_by_name get_module_by_name
) )
from .constructor_helper import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment