Unverified Commit e0422c5e authored by Panacea's avatar Panacea Committed by GitHub
Browse files

[Compression V2] AMC Pruner (#4320)

parent 5fe29b06
......@@ -28,6 +28,7 @@ and how to schedule sparsity in each iteration are implemented as iterative prun
* `Lottery Ticket Pruner <#lottery-ticket-pruner>`__
* `Simulated Annealing Pruner <#simulated-annealing-pruner>`__
* `Auto Compress Pruner <#auto-compress-pruner>`__
* `AMC Pruner <#amc-pruner>`__
Level Pruner
------------
......@@ -554,3 +555,33 @@ User configuration for Auto Compress Pruner
**PyTorch**
.. autoclass:: nni.algorithms.compression.v2.pytorch.pruning.AutoCompressPruner
AMC Pruner
----------
AMC pruner leverages reinforcement learning to provide the model compression policy.
According to the author, this learning-based compression policy outperforms conventional rule-based compression policy by having a higher compression ratio,
better preserving the accuracy and freeing human labor.
For more details, please refer to `AMC: AutoML for Model Compression and Acceleration on Mobile Devices <https://arxiv.org/pdf/1802.03494.pdf>`__.
Usage
^^^^^
PyTorch code
.. code-block:: python
from nni.algorithms.compression.v2.pytorch.pruning import AMCPruner
config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.5, 'max_sparsity_per_layer': 0.8}]
pruner = AMCPruner(400, model, config_list, dummy_input, evaluator, finetuner=finetuner)
pruner.compress()
The full script can be found :githublink:`here <examples/model_compress/pruning/v2/amc_pruning_torch.py>`.
User configuration for AMC Pruner
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
**PyTorch**
.. autoclass:: nni.algorithms.compression.v2.pytorch.pruning.AMCPruner
import sys
from tqdm import tqdm
import torch
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import MultiStepLR
from nni.algorithms.compression.v2.pytorch.pruning import AMCPruner
from nni.compression.pytorch.utils.counter import count_flops_params
sys.path.append('../../models')
from cifar10.vgg import VGG
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=True, transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
normalize,
]), download=True),
batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
normalize,
])),
batch_size=128, shuffle=False)
criterion = torch.nn.CrossEntropyLoss()
def trainer(model, optimizer, criterion, epoch):
model.train()
for data, target in tqdm(iterable=train_loader, desc='Epoch {}'.format(epoch)):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
def finetuner(model):
model.train()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss()
for data, target in tqdm(iterable=train_loader, desc='Epoch PFs'):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
def evaluator(model):
model.eval()
correct = 0
with torch.no_grad():
for data, target in tqdm(iterable=test_loader, desc='Test'):
data, target = data.to(device), target.to(device)
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
acc = 100 * correct / len(test_loader.dataset)
print('Accuracy: {}%\n'.format(acc))
return acc
if __name__ == '__main__':
# model = MobileNetV2(n_class=10).to(device)
model = VGG().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = MultiStepLR(optimizer, milestones=[50, 75], gamma=0.1)
criterion = torch.nn.CrossEntropyLoss()
for i in range(100):
trainer(model, optimizer, criterion, i)
pre_best_acc = evaluator(model)
dummy_input = torch.rand(10, 3, 32, 32).to(device)
pre_flops, pre_params, _ = count_flops_params(model, dummy_input)
config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.5, 'max_sparsity_per_layer': 0.8}]
# if you just want to keep the final result as the best result, you can pass evaluator as None.
# or the result with the highest score (given by evaluator) will be the best result.
ddpg_params = {'hidden1': 300, 'hidden2': 300, 'lr_c': 1e-3, 'lr_a': 1e-4, 'warmup': 100, 'discount': 1., 'bsize': 64,
'rmsize': 100, 'window_length': 1, 'tau': 0.01, 'init_delta': 0.5, 'delta_decay': 0.99, 'max_episode_length': 1e9, 'epsilon': 50000}
pruner = AMCPruner(400, model, config_list, dummy_input, evaluator, finetuner=finetuner, ddpg_params=ddpg_params, target='flops')
pruner.compress()
_, model, masks, best_acc, _ = pruner.get_best_result()
flops, params, _ = count_flops_params(model, dummy_input)
print(f'Pretrained model FLOPs {pre_flops/1e6:.2f} M, #Params: {pre_params/1e6:.2f}M, Accuracy: {pre_best_acc: .2f}%')
print(f'Finetuned model FLOPs {flops/1e6:.2f} M, #Params: {params/1e6:.2f}M, Accuracy: {best_acc: .2f}%')
......@@ -19,7 +19,8 @@ class Task:
# NOTE: If we want to support multi-thread, this part need to refactor, maybe use file and lock to sync.
_reference_counter = {}
def __init__(self, task_id: int, model_path: str, masks_path: str, config_list_path: str) -> None:
def __init__(self, task_id: int, model_path: str, masks_path: str, config_list_path: str,
speed_up: Optional[bool] = True, finetune: Optional[bool] = True, evaluate: Optional[bool] = True):
"""
Parameters
----------
......@@ -31,12 +32,22 @@ class Task:
The path of the masks that applied on the model before pruning.
config_list_path
The path of the config list that used in this task.
speed_up
Control if this task needs speed up, True means use scheduler default value, False means no speed up.
finetune
Control if this task needs finetune, True means use scheduler default value, False means no finetune.
evaluate
Control if this task needs evaluate, True means use scheduler default value, False means no evaluate.
"""
self.task_id = task_id
self.model_path = model_path
self.masks_path = masks_path
self.config_list_path = config_list_path
self.speed_up = speed_up
self.finetune = finetune
self.evaluate = evaluate
self.status = 'Pending'
self.score: Optional[float] = None
......@@ -54,6 +65,9 @@ class Task:
'model_path': str(self.model_path),
'masks_path': str(self.masks_path),
'config_list_path': str(self.config_list_path),
'speed_up': self.speed_up,
'finetune': self.finetune,
'evaluate': self.evaluate,
'status': self.status,
'score': self.score,
'state': self.state
......
......@@ -3,3 +3,4 @@ from .basic_scheduler import PruningScheduler
from .iterative_pruner import *
from .movement_pruner import MovementPruner
from .auto_compress_pruner import AutoCompressPruner
from .amc_pruner import AMCPruner
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from copy import deepcopy
from pathlib import Path
from typing import Dict, List, Callable, Optional
import json_tricks
import torch
from torch import Tensor
from torch.nn import Module
from nni.algorithms.compression.v2.pytorch.base import Task, TaskResult
from nni.algorithms.compression.v2.pytorch.utils import compute_sparsity, config_list_canonical
from nni.compression.pytorch.utils.counter import count_flops_params
from .iterative_pruner import IterativePruner, PRUNER_DICT
from .tools import TaskGenerator
from .tools.rl_env import DDPG, AMCEnv
class AMCTaskGenerator(TaskGenerator):
"""
Parameters
----------
total_episode
The total episode number.
dummy_input
Use to inference and count the flops.
origin_model
The origin unwrapped pytorch model to be pruned.
origin_config_list
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
origin_masks
The pre masks on the origin model. This mask maybe user-defined or maybe generate by previous pruning.
log_dir
The log directory use to saving the task generator log.
keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
ddpg_params
The ddpg agent parameters.
target : str
'flops' or 'params'. Note that the sparsity in other pruners always means the parameters sparse, but in AMC, you can choose flops sparse.
This parameter is used to explain what the sparsity setting in config_list refers to.
"""
def __init__(self, total_episode: int, dummy_input: Tensor, origin_model: Module, origin_config_list: List[Dict],
origin_masks: Dict[str, Dict[str, Tensor]] = {}, log_dir: str = '.', keep_intermediate_result: bool = False,
ddpg_params: Dict = {}, target: str = 'flops'):
self.total_episode = total_episode
self.current_episode = 0
self.dummy_input = dummy_input
self.ddpg_params = ddpg_params
self.target = target
self.config_list_copy = deepcopy(origin_config_list)
super().__init__(origin_model=origin_model, origin_masks=origin_masks, origin_config_list=origin_config_list,
log_dir=log_dir, keep_intermediate_result=keep_intermediate_result)
def init_pending_tasks(self) -> List[Task]:
origin_model = torch.load(self._origin_model_path)
origin_masks = torch.load(self._origin_masks_path)
with open(self._origin_config_list_path, "r") as f:
origin_config_list = json_tricks.load(f)
self.T = []
self.action = None
self.observation = None
self.warmup_episode = self.ddpg_params['warmup'] if 'warmup' in self.ddpg_params.keys() else int(self.total_episode / 4)
config_list_copy = config_list_canonical(origin_model, origin_config_list)
total_sparsity = config_list_copy[0]['total_sparsity']
max_sparsity_per_layer = config_list_copy[0].get('max_sparsity_per_layer', 1.)
self.env = AMCEnv(origin_model, origin_config_list, self.dummy_input, total_sparsity, max_sparsity_per_layer, self.target)
self.agent = DDPG(len(self.env.state_feature), 1, self.ddpg_params)
self.agent.is_training = True
task_result = TaskResult('origin', origin_model, origin_masks, origin_masks, None)
return self.generate_tasks(task_result)
def generate_tasks(self, task_result: TaskResult) -> List[Task]:
# append experience & update agent policy
if task_result.task_id != 'origin':
action, reward, observation, done = self.env.step(self.action, task_result.compact_model)
self.T.append([reward, self.observation, observation, self.action, done])
self.observation = observation.copy()
if done:
final_reward = task_result.score - 1
# agent observe and update policy
for _, s_t, s_t1, a_t, d_t in self.T:
self.agent.observe(final_reward, s_t, s_t1, a_t, d_t)
if self.current_episode > self.warmup_episode:
self.agent.update_policy()
self.current_episode += 1
self.T = []
self.action = None
self.observation = None
# update current2origin_sparsity in log file
origin_model = torch.load(self._origin_model_path)
compact_model = task_result.compact_model
compact_model_masks = task_result.compact_model_masks
current2origin_sparsity, _, _ = compute_sparsity(origin_model, compact_model, compact_model_masks, self.temp_config_list)
self._tasks[task_result.task_id].state['current2origin_sparsity'] = current2origin_sparsity
current2origin_sparsity, _, _ = compute_sparsity(origin_model, compact_model, compact_model_masks, self.config_list_copy)
self._tasks[task_result.task_id].state['current_total_sparsity'] = current2origin_sparsity
flops, params, _ = count_flops_params(compact_model, self.dummy_input, verbose=False)
self._tasks[task_result.task_id].state['current_flops'] = '{:.2f} M'.format(flops / 1e6)
self._tasks[task_result.task_id].state['current_params'] = '{:.2f} M'.format(params / 1e6)
# generate new action
if self.current_episode < self.total_episode:
if self.observation is None:
self.observation = self.env.reset().copy()
self.temp_config_list = []
compact_model = torch.load(self._origin_model_path)
compact_model_masks = torch.load(self._origin_masks_path)
else:
compact_model = task_result.compact_model
compact_model_masks = task_result.compact_model_masks
if self.current_episode <= self.warmup_episode:
action = self.agent.random_action()
else:
action = self.agent.select_action(self.observation, episode=self.current_episode)
action = action.tolist()[0]
self.action = self.env.correct_action(action, compact_model)
sub_config_list = [{'op_names': [self.env.current_op_name], 'total_sparsity': self.action}]
self.temp_config_list.extend(sub_config_list)
task_id = self._task_id_candidate
if self.env.is_first_layer() or self.env.is_final_layer():
task_config_list = self.temp_config_list
else:
task_config_list = sub_config_list
config_list_path = Path(self._intermediate_result_dir, '{}_config_list.json'.format(task_id))
with Path(config_list_path).open('w') as f:
json_tricks.dump(task_config_list, f, indent=4)
model_path = Path(self._intermediate_result_dir, '{}_compact_model.pth'.format(task_result.task_id))
masks_path = Path(self._intermediate_result_dir, '{}_compact_model_masks.pth'.format(task_result.task_id))
torch.save(compact_model, model_path)
torch.save(compact_model_masks, masks_path)
task = Task(task_id, model_path, masks_path, config_list_path)
if not self.env.is_final_layer():
task.finetune = False
task.evaluate = False
self._tasks[task_id] = task
self._task_id_candidate += 1
return [task]
else:
return []
class AMCPruner(IterativePruner):
"""
A pytorch implementation of AMC: AutoML for Model Compression and Acceleration on Mobile Devices.
(https://arxiv.org/pdf/1802.03494.pdf)
Suggust config all `total_sparsity` in `config_list` a same value.
AMC pruner will treat the first sparsity in `config_list` as the global sparsity.
Parameters
----------
total_episode : int
The total episode number.
model : Module
The model to be pruned.
config_list : List[Dict]
Supported keys :
- total_sparsity : This is to specify the total sparsity for all layers in this config, each layer may have different sparsity.
- max_sparsity_per_layer : Always used with total_sparsity. Limit the max sparsity of each layer.
- op_types : Operation type to be pruned.
- op_names : Operation name to be pruned.
- exclude : Set True then the layers setting by op_types and op_names will be excluded from pruning.
dummy_input : torch.Tensor
`dummy_input` is required for speed-up and tracing the model in RL environment.
evaluator : Callable[[Module], float]
Evaluate the pruned model and give a score.
pruning_algorithm : str
Supported pruning algorithm ['l1', 'l2', 'fpgm', 'apoz', 'mean_activation', 'taylorfo'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
log_dir : str
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
ddpg_params : Dict
Configuration dict to configure the DDPG agent, any key unset will be set to default implicitly.
- hidden1: hidden num of first fully connect layer. Default: 300
- hidden2: hidden num of second fully connect layer. Default: 300
- lr_c: learning rate for critic. Default: 1e-3
- lr_a: learning rate for actor. Default: 1e-4
- warmup: number of episodes without training but only filling the replay memory. During warmup episodes, random actions ares used for pruning. Default: 100
- discount: next Q value discount for deep Q value target. Default: 0.99
- bsize: minibatch size for training DDPG agent. Default: 64
- rmsize: memory size for each layer. Default: 100
- window_length: replay buffer window length. Default: 1
- tau: moving average for target network being used by soft_update. Default: 0.99
- init_delta: initial variance of truncated normal distribution. Default: 0.5
- delta_decay: delta decay during exploration. Default: 0.99
# parameters for training ddpg agent
- max_episode_length: maximum episode length. Default: 1e9
- epsilon: linear decay of exploration policy. Default: 50000
pruning_params : Dict
If the pruner corresponding to the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
target : str
'flops' or 'params'. Note that the sparsity in other pruners always means the parameters sparse, but in AMC, you can choose flops sparse.
This parameter is used to explain what the sparsity setting in config_list refers to.
"""
def __init__(self, total_episode: int, model: Module, config_list: List[Dict], dummy_input: Tensor,
evaluator: Callable[[Module], float], pruning_algorithm: str = 'l1', log_dir: str = '.',
keep_intermediate_result: bool = False, finetuner: Optional[Callable[[Module], None]] = None,
ddpg_params: dict = {}, pruning_params: dict = {}, target: str = 'flops'):
assert pruning_algorithm in ['l1', 'l2', 'fpgm', 'apoz', 'mean_activation', 'taylorfo'], \
"Only support pruning_algorithm in ['l1', 'l2', 'fpgm', 'apoz', 'mean_activation', 'taylorfo']"
task_generator = AMCTaskGenerator(total_episode=total_episode,
dummy_input=dummy_input,
origin_model=model,
origin_config_list=config_list,
log_dir=log_dir,
keep_intermediate_result=keep_intermediate_result,
ddpg_params=ddpg_params,
target=target)
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speed_up=True, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False)
......@@ -73,12 +73,12 @@ class PruningScheduler(BasePruningScheduler):
self.pruner._unwrap_model()
# speed up
if self.speed_up:
if self.speed_up and task.speed_up:
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
compact_model_masks = {}
# finetune
if self.finetuner is not None:
if self.finetuner is not None and task.finetune:
if self.speed_up:
self.finetuner(compact_model)
else:
......@@ -87,7 +87,7 @@ class PruningScheduler(BasePruningScheduler):
self.pruner._unwrap_model()
# evaluate
if self.evaluator is not None:
if self.evaluator is not None and task.evaluate:
if self.speed_up:
score = self.evaluator(compact_model)
else:
......@@ -112,7 +112,7 @@ class PruningScheduler(BasePruningScheduler):
self.pruner.load_masks(masks)
# finetune
if self.finetuner is not None:
if self.finetuner is not None and task.finetune:
self.finetuner(model)
# pruning model
......@@ -127,12 +127,12 @@ class PruningScheduler(BasePruningScheduler):
compact_model.load_state_dict(checkpoint)
# speed up
if self.speed_up:
if self.speed_up and task.speed_up:
ModelSpeedup(compact_model, self.dummy_input, pruner_generated_masks).speedup_model()
compact_model_masks = {}
# evaluate
if self.evaluator is not None:
if self.evaluator is not None and task.evaluate:
if self.speed_up:
score = self.evaluator(compact_model)
else:
......
......@@ -513,7 +513,7 @@ class TaskGenerator:
task_id = task_result.task_id
task = self._tasks[task_id]
task.score = score
if self._best_score is None or score > self._best_score:
if self._best_score is None or (score is not None and score > self._best_score):
self._best_score = score
self._best_task_id = task_id
with Path(task.config_list_path).open('r') as fr:
......
from .agent import DDPG
from .amc_env import AMCEnv
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from .memory import SequentialMemory
criterion = nn.MSELoss()
USE_CUDA = torch.cuda.is_available()
def to_numpy(var):
use_cuda = torch.cuda.is_available()
return var.cpu().data.numpy() if use_cuda else var.data.numpy()
def to_tensor(ndarray, requires_grad=False): # return a float tensor by default
tensor = torch.from_numpy(ndarray).float() # by default does not require grad
if requires_grad:
tensor.requires_grad_()
return tensor.cuda() if torch.cuda.is_available() else tensor
class Actor(nn.Module):
def __init__(self, nb_states, nb_actions, hidden1=400, hidden2=300):
super(Actor, self).__init__()
self.fc1 = nn.Linear(nb_states, hidden1)
self.fc2 = nn.Linear(hidden1, hidden2)
self.fc3 = nn.Linear(hidden2, nb_actions)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
out = self.relu(out)
out = self.fc3(out)
out = self.sigmoid(out)
return out
class Critic(nn.Module):
def __init__(self, nb_states, nb_actions, hidden1=400, hidden2=300):
super(Critic, self).__init__()
self.fc11 = nn.Linear(nb_states, hidden1)
self.fc12 = nn.Linear(nb_actions, hidden1)
self.fc2 = nn.Linear(hidden1, hidden2)
self.fc3 = nn.Linear(hidden2, 1)
self.relu = nn.ReLU()
def forward(self, xs):
x, a = xs
out = self.fc11(x) + self.fc12(a)
out = self.relu(out)
out = self.fc2(out)
out = self.relu(out)
out = self.fc3(out)
return out
class DDPG(nn.Module):
def __init__(self, nb_states, nb_actions, args):
super(DDPG, self).__init__()
self.ddpg_params = {'hidden1': 300, 'hidden2': 300, 'lr_c': 1e-3, 'lr_a': 1e-4, 'warmup': 100, 'discount': 1., 'bsize': 64,
'rmsize': 100, 'window_length': 1, 'tau': 0.01, 'init_delta': 0.5, 'delta_decay': 0.99, 'max_episode_length': 1e9, 'epsilon': 50000}
for key in args:
assert key in self.ddpg_params.keys(), "Error! Illegal key: {}".format(key)
self.ddpg_params[key] = args[key]
self.nb_states = nb_states
self.nb_actions = nb_actions
# Create Actor and Critic Networks
net_cfg = {
'hidden1': self.ddpg_params['hidden1'],
'hidden2': self.ddpg_params['hidden2'],
# 'init_w': self.ddpg_params['init_w
}
self.actor = Actor(self.nb_states, self.nb_actions, **net_cfg)
self.actor_target = Actor(self.nb_states, self.nb_actions, **net_cfg)
self.actor_optim = Adam(self.actor.parameters(), lr=self.ddpg_params['lr_a'])
self.critic = Critic(self.nb_states, self.nb_actions, **net_cfg)
self.critic_target = Critic(self.nb_states, self.nb_actions, **net_cfg)
self.critic_optim = Adam(self.critic.parameters(), lr=self.ddpg_params['lr_c'])
self.hard_update(self.actor_target, self.actor) # Make sure target is with the same weight
self.hard_update(self.critic_target, self.critic)
# Create replay buffer
self.memory = SequentialMemory(limit=self.ddpg_params['rmsize'], window_length=self.ddpg_params['window_length'])
# Hyper-parameters
self.batch_size = self.ddpg_params['bsize']
self.tau = self.ddpg_params['tau']
self.discount = self.ddpg_params['discount']
self.depsilon = 1.0 / self.ddpg_params['epsilon']
self.lbound = 0. # self.ddpg_params['lbound']
self.rbound = 1. # self.ddpg_params['rbound']
# noise
self.init_delta = self.ddpg_params['init_delta']
self.delta_decay = self.ddpg_params['delta_decay']
self.warmup = self.ddpg_params['warmup']
self.epsilon = 1.0
# self.s_t = None # Most recent state
# self.a_t = None # Most recent action
self.is_training = True
#
if USE_CUDA: self.cuda()
# moving average baseline
self.moving_average = None
self.moving_alpha = 0.5 # based on batch, so small
def update_policy(self):
# Sample batch
state_batch, action_batch, reward_batch, \
next_state_batch, terminal_batch = self.memory.sample_and_split(self.batch_size)
# normalize the reward
batch_mean_reward = np.mean(reward_batch)
if self.moving_average is None:
self.moving_average = batch_mean_reward
else:
self.moving_average += self.moving_alpha * (batch_mean_reward - self.moving_average)
reward_batch -= self.moving_average
# Prepare for the target q batch
with torch.no_grad():
next_q_values = self.critic_target([
to_tensor(next_state_batch),
self.actor_target(to_tensor(next_state_batch)),
])
target_q_batch = to_tensor(reward_batch) + \
self.discount * to_tensor(terminal_batch.astype(np.float)) * next_q_values
# Critic update
self.critic.zero_grad()
q_batch = self.critic([to_tensor(state_batch), to_tensor(action_batch)])
value_loss = criterion(q_batch, target_q_batch)
value_loss.backward()
self.critic_optim.step()
# Actor update
self.actor.zero_grad()
policy_loss = -self.critic([ # pylint: disable=all
to_tensor(state_batch),
self.actor(to_tensor(state_batch))
])
policy_loss = policy_loss.mean()
policy_loss.backward()
self.actor_optim.step()
# Target update
self.soft_update(self.actor_target, self.actor)
self.soft_update(self.critic_target, self.critic)
def observe(self, r_t, s_t, s_t1, a_t, done):
if self.is_training:
self.memory.append(s_t, a_t, r_t, done) # save to memory
def random_action(self):
action = np.random.uniform(self.lbound, self.rbound, self.nb_actions)
# self.a_t = action
return action
def select_action(self, s_t, episode):
action = to_numpy(self.actor(to_tensor(np.array(s_t).reshape(1, -1)))).squeeze(0)
delta = self.init_delta * (self.delta_decay ** (episode - self.warmup))
# action += self.is_training * max(self.epsilon, 0) * self.random_process.sample()
action = self.sample_from_truncated_normal_distribution(lower=self.lbound, upper=self.rbound, mu=action, sigma=delta)
action = np.clip(action, self.lbound, self.rbound)
return action
def load_weights(self, output):
if output is None: return
self.actor.load_state_dict(
torch.load('{}/actor.pkl'.format(output))
)
self.critic.load_state_dict(
torch.load('{}/critic.pkl'.format(output))
)
def save_model(self, output):
torch.save(
self.actor.state_dict(),
'{}/actor.pkl'.format(output)
)
torch.save(
self.critic.state_dict(),
'{}/critic.pkl'.format(output)
)
def soft_update(self, target, source):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(
target_param.data * (1.0 - self.tau) + param.data * self.tau
)
def hard_update(self, target, source):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(param.data)
def sample_from_truncated_normal_distribution(self, lower, upper, mu, sigma, size=1):
from scipy import stats
return stats.truncnorm.rvs((lower-mu)/sigma, (upper-mu)/sigma, loc=mu, scale=sigma, size=size)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from collections import OrderedDict
from copy import Error
import logging
from typing import Dict, List
import numpy as np
from torch import Tensor
from torch.nn import Module
from nni.algorithms.compression.v2.pytorch.utils import config_list_canonical
from nni.compression.pytorch.utils.counter import count_flops_params
_logger = logging.getLogger(__name__)
class AMCEnv:
def __init__(self, model: Module, config_list: List[Dict], dummy_input: Tensor, total_sparsity: float, max_sparsity_per_layer: Dict[str, float], target: str = 'flops'):
pruning_op_names = []
[pruning_op_names.extend(config['op_names']) for config in config_list_canonical(model, config_list)]
self.pruning_ops = OrderedDict()
self.pruning_types = []
for i, (name, layer) in enumerate(model.named_modules()):
if name in pruning_op_names:
op_type = type(layer).__name__
stride = np.power(np.prod(layer.stride), 1 / len(layer.stride)) if hasattr(layer, 'stride') else 0
kernel_size = np.power(np.prod(layer.kernel_size), 1 / len(layer.kernel_size)) if hasattr(layer, 'kernel_size') else 1
self.pruning_ops[name] = (i, op_type, stride, kernel_size)
self.pruning_types.append(op_type)
self.pruning_types = list(set(self.pruning_types))
self.pruning_op_names = list(self.pruning_ops.keys())
self.dummy_input = dummy_input
self.total_sparsity = total_sparsity
self.max_sparsity_per_layer = max_sparsity_per_layer
assert target in ['flops', 'params']
self.target = target
self.origin_target, self.origin_params_num, self.origin_statistics = count_flops_params(model, dummy_input, verbose=False)
self.origin_statistics = {result['name']: result for result in self.origin_statistics}
self.under_pruning_target = sum([self.origin_statistics[name][self.target] for name in self.pruning_op_names])
self.excepted_pruning_target = self.total_sparsity * self.under_pruning_target
def reset(self):
self.ops_iter = iter(self.pruning_ops)
# build embedding (static part)
self._build_state_embedding(self.origin_statistics)
observation = self.layer_embedding[0].copy()
return observation
def correct_action(self, action: float, model: Module):
try:
op_name = next(self.ops_iter)
index = self.pruning_op_names.index(op_name)
_, _, current_statistics = count_flops_params(model, self.dummy_input, verbose=False)
current_statistics = {result['name']: result for result in current_statistics}
total_current_target = sum([current_statistics[name][self.target] for name in self.pruning_op_names])
previous_pruning_target = self.under_pruning_target - total_current_target
max_rest_pruning_target = sum([current_statistics[name][self.target] * self.max_sparsity_per_layer[name] for name in self.pruning_op_names[index + 1:]])
min_current_pruning_target = self.excepted_pruning_target - previous_pruning_target - max_rest_pruning_target
max_current_pruning_target_1 = self.origin_statistics[op_name][self.target] * self.max_sparsity_per_layer[op_name] - (self.origin_statistics[op_name][self.target] - current_statistics[op_name][self.target])
max_current_pruning_target_2 = self.excepted_pruning_target - previous_pruning_target
max_current_pruning_target = min(max_current_pruning_target_1, max_current_pruning_target_2)
min_action = min_current_pruning_target / current_statistics[op_name][self.target]
max_action = max_current_pruning_target / current_statistics[op_name][self.target]
if min_action > self.max_sparsity_per_layer[op_name]:
_logger.warning('[%s] min action > max sparsity per layer: %f > %f', op_name, min_action, self.max_sparsity_per_layer[op_name])
action = max(0., min(max_action, max(min_action, action)))
self.current_op_name = op_name
self.current_op_target = current_statistics[op_name][self.target]
except StopIteration:
raise Error('Something goes wrong, this should not happen.')
return action
def step(self, action: float, model: Module):
_, _, current_statistics = count_flops_params(model, self.dummy_input, verbose=False)
current_statistics = {result['name']: result for result in current_statistics}
index = self.pruning_op_names.index(self.current_op_name)
action = 1 - current_statistics[self.current_op_name][self.target] / self.current_op_target
total_current_target = sum([current_statistics[name][self.target] for name in self.pruning_op_names])
previous_pruning_target = self.under_pruning_target - total_current_target
rest_target = sum([current_statistics[name][self.target] for name in self.pruning_op_names[index + 1:]])
self.layer_embedding[index][-3] = previous_pruning_target / self.under_pruning_target # reduced
self.layer_embedding[index][-2] = rest_target / self.under_pruning_target # rest
self.layer_embedding[index][-1] = action # last action
observation = self.layer_embedding[index, :].copy()
return action, 0, observation, self.is_final_layer()
def is_first_layer(self):
return self.pruning_op_names.index(self.current_op_name) == 0
def is_final_layer(self):
return self.pruning_op_names.index(self.current_op_name) == len(self.pruning_op_names) - 1
@property
def state_feature(self):
return ['index', 'layer_type', 'input_size', 'output_size', 'stride', 'kernel_size', 'params_size', 'reduced', 'rest', 'a_{t-1}']
def _build_state_embedding(self, statistics: Dict[str, Dict]):
_logger.info('Building state embedding...')
layer_embedding = []
for name, (idx, op_type, stride, kernel_size) in self.pruning_ops.items():
state = []
state.append(idx) # index
state.append(self.pruning_types.index(op_type)) # layer type
state.append(np.prod(statistics[name]['input_size'])) # input size
state.append(np.prod(statistics[name]['output_size'])) # output size
state.append(stride) # stride
state.append(kernel_size) # kernel size
state.append(statistics[name]['params']) # params size
state.append(0.) # reduced
state.append(1.) # rest
state.append(0.) # a_{t-1}
layer_embedding.append(np.array(state))
layer_embedding = np.array(layer_embedding, 'float')
_logger.info('=> shape of embedding (n_layer * n_dim): %s', layer_embedding.shape)
assert len(layer_embedding.shape) == 2, layer_embedding.shape
# normalize the state
for i in range(layer_embedding.shape[1]):
fmin = min(layer_embedding[:, i])
fmax = max(layer_embedding[:, i])
if fmax - fmin > 0:
layer_embedding[:, i] = (layer_embedding[:, i] - fmin) / (fmax - fmin)
self.layer_embedding = layer_embedding
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import absolute_import
from collections import deque, namedtuple
import warnings
import random
import numpy as np
# [reference] https://github.com/matthiasplappert/keras-rl/blob/master/rl/memory.py
# This is to be understood as a transition: Given `state0`, performing `action`
# yields `reward` and results in `state1`, which might be `terminal`.
Experience = namedtuple('Experience', 'state0, action, reward, state1, terminal1')
def sample_batch_indexes(low, high, size):
if high - low >= size:
# We have enough data. Draw without replacement, that is each index is unique in the
# batch. We cannot use `np.random.choice` here because it is horribly inefficient as
# the memory grows. See https://github.com/numpy/numpy/issues/2764 for a discussion.
# `random.sample` does the same thing (drawing without replacement) and is way faster.
r = range(low, high)
batch_idxs = random.sample(r, size)
else:
# Not enough data. Help ourselves with sampling from the range, but the same index
# can occur multiple times. This is not good and should be avoided by picking a
# large enough warm-up phase.
warnings.warn(
'Not enough entries to sample without replacement. '
'Consider increasing your warm-up phase to avoid oversampling!')
batch_idxs = np.random.random_integers(low, high - 1, size=size)
assert len(batch_idxs) == size
return batch_idxs
class RingBuffer(object):
def __init__(self, maxlen):
self.maxlen = maxlen
self.start = 0
self.length = 0
self.data = [None for _ in range(maxlen)]
def __len__(self):
return self.length
def __getitem__(self, idx):
if idx < 0 or idx >= self.length:
raise KeyError()
return self.data[(self.start + idx) % self.maxlen]
def append(self, v):
if self.length < self.maxlen:
# We have space, simply increase the length.
self.length += 1
elif self.length == self.maxlen:
# No space, "remove" the first item.
self.start = (self.start + 1) % self.maxlen
else:
# This should never happen.
raise RuntimeError()
self.data[(self.start + self.length - 1) % self.maxlen] = v
def zeroed_observation(observation):
if hasattr(observation, 'shape'):
return np.zeros(observation.shape)
elif hasattr(observation, '__iter__'):
out = []
for x in observation:
out.append(zeroed_observation(x))
return out
else:
return 0.
class Memory(object):
def __init__(self, window_length, ignore_episode_boundaries=False):
self.window_length = window_length
self.ignore_episode_boundaries = ignore_episode_boundaries
self.recent_observations = deque(maxlen=window_length)
self.recent_terminals = deque(maxlen=window_length)
def sample(self, batch_size, batch_idxs=None):
raise NotImplementedError()
def append(self, observation, action, reward, terminal, training=True):
self.recent_observations.append(observation)
self.recent_terminals.append(terminal)
def get_recent_state(self, current_observation):
# This code is slightly complicated by the fact that subsequent observations might be
# from different episodes. We ensure that an experience never spans multiple episodes.
# This is probably not that important in practice but it seems cleaner.
state = [current_observation]
idx = len(self.recent_observations) - 1
for offset in range(0, self.window_length - 1):
current_idx = idx - offset
current_terminal = self.recent_terminals[current_idx - 1] if current_idx - 1 >= 0 else False
if current_idx < 0 or (not self.ignore_episode_boundaries and current_terminal):
# The previously handled observation was terminal, don't add the current one.
# Otherwise we would leak into a different episode.
break
state.insert(0, self.recent_observations[current_idx])
while len(state) < self.window_length:
state.insert(0, zeroed_observation(state[0]))
return state
def get_config(self):
config = {
'window_length': self.window_length,
'ignore_episode_boundaries': self.ignore_episode_boundaries,
}
return config
class SequentialMemory(Memory):
def __init__(self, limit, **kwargs):
super(SequentialMemory, self).__init__(**kwargs)
self.limit = limit
# Do not use deque to implement the memory. This data structure may seem convenient but
# it is way too slow on random access. Instead, we use our own ring buffer implementation.
self.actions = RingBuffer(limit)
self.rewards = RingBuffer(limit)
self.terminals = RingBuffer(limit)
self.observations = RingBuffer(limit)
def sample(self, batch_size, batch_idxs=None):
if batch_idxs is None:
# Draw random indexes such that we have at least a single entry before each
# index.
batch_idxs = sample_batch_indexes(0, self.nb_entries - 1, size=batch_size)
batch_idxs = np.array(batch_idxs) + 1
assert np.min(batch_idxs) >= 1
assert np.max(batch_idxs) < self.nb_entries
assert len(batch_idxs) == batch_size
# Create experiences
experiences = []
for idx in batch_idxs:
terminal0 = self.terminals[idx - 2] if idx >= 2 else False
while terminal0:
# Skip this transition because the environment was reset here. Select a new, random
# transition and use this instead. This may cause the batch to contain the same
# transition twice.
idx = sample_batch_indexes(1, self.nb_entries, size=1)[0]
terminal0 = self.terminals[idx - 2] if idx >= 2 else False
assert 1 <= idx < self.nb_entries
# This code is slightly complicated by the fact that subsequent observations might be
# from different episodes. We ensure that an experience never spans multiple episodes.
# This is probably not that important in practice but it seems cleaner.
state0 = [self.observations[idx - 1]]
for offset in range(0, self.window_length - 1):
current_idx = idx - 2 - offset
current_terminal = self.terminals[current_idx - 1] if current_idx - 1 > 0 else False
if current_idx < 0 or (not self.ignore_episode_boundaries and current_terminal):
# The previously handled observation was terminal, don't add the current one.
# Otherwise we would leak into a different episode.
break
state0.insert(0, self.observations[current_idx])
while len(state0) < self.window_length:
state0.insert(0, zeroed_observation(state0[0]))
action = self.actions[idx - 1]
reward = self.rewards[idx - 1]
terminal1 = self.terminals[idx - 1]
# Okay, now we need to create the follow-up state. This is state0 shifted on timestep
# to the right. Again, we need to be careful to not include an observation from the next
# episode if the last state is terminal.
state1 = [np.copy(x) for x in state0[1:]]
state1.append(self.observations[idx])
assert len(state0) == self.window_length
assert len(state1) == len(state0)
experiences.append(Experience(state0=state0, action=action, reward=reward,
state1=state1, terminal1=terminal1))
assert len(experiences) == batch_size
return experiences
def sample_and_split(self, batch_size, batch_idxs=None):
experiences = self.sample(batch_size, batch_idxs)
state0_batch = []
reward_batch = []
action_batch = []
terminal1_batch = []
state1_batch = []
for e in experiences:
state0_batch.append(e.state0)
state1_batch.append(e.state1)
reward_batch.append(e.reward)
action_batch.append(e.action)
terminal1_batch.append(0. if e.terminal1 else 1.)
# Prepare and validate parameters.
state0_batch = np.array(state0_batch, 'double').reshape(batch_size, -1)
state1_batch = np.array(state1_batch, 'double').reshape(batch_size, -1)
terminal1_batch = np.array(terminal1_batch, 'double').reshape(batch_size, -1)
reward_batch = np.array(reward_batch, 'double').reshape(batch_size, -1)
action_batch = np.array(action_batch, 'double').reshape(batch_size, -1)
return state0_batch, action_batch, reward_batch, state1_batch, terminal1_batch
def append(self, observation, action, reward, terminal, training=True):
super(SequentialMemory, self).append(observation, action, reward, terminal, training=training)
# This needs to be understood as follows: in `observation`, take `action`, obtain `reward`
# and weather the next state is `terminal` or not.
if training:
self.observations.append(observation)
self.actions.append(action)
self.rewards.append(reward)
self.terminals.append(terminal)
@property
def nb_entries(self):
return len(self.observations)
def get_config(self):
config = super(SequentialMemory, self).get_config()
config['limit'] = self.limit
return config
......@@ -37,6 +37,8 @@ def config_list_canonical(model: Module, config_list: List[Dict]) -> List[Dict]:
'layer2.0.conv1', 'layer2.1.conv1', 'layer3.0.conv1', 'layer3.1.conv1',
'layer4.0.conv1', 'layer4.1.conv1']}]
'''
config_list = deepcopy(config_list)
for config in config_list:
if 'sparsity' in config:
if 'sparsity_per_layer' in config:
......
......@@ -12,7 +12,8 @@ from nni.algorithms.compression.v2.pytorch.pruning import (
AGPPruner,
LotteryTicketPruner,
SimulatedAnnealingPruner,
AutoCompressPruner
AutoCompressPruner,
AMCPruner
)
from nni.algorithms.compression.v2.pytorch.utils import compute_sparsity_mask2compact, trace_parameters
......@@ -21,9 +22,9 @@ from nni.algorithms.compression.v2.pytorch.utils import compute_sparsity_mask2co
class TorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 5, 5, 1)
self.bn1 = torch.nn.BatchNorm2d(5)
self.conv2 = torch.nn.Conv2d(5, 10, 5, 1)
self.conv1 = torch.nn.Conv2d(1, 10, 5, 1)
self.bn1 = torch.nn.BatchNorm2d(10)
self.conv2 = torch.nn.Conv2d(10, 10, 5, 1)
self.bn2 = torch.nn.BatchNorm2d(10)
self.fc1 = torch.nn.Linear(4 * 4 * 10, 100)
self.fc2 = torch.nn.Linear(100, 10)
......@@ -33,7 +34,7 @@ class TorchModel(torch.nn.Module):
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 10)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
......@@ -62,6 +63,11 @@ def evaluator(model):
return random.random()
def finetuner(model):
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
trainer(model, optimizer, criterion)
class IterativePrunerTestCase(unittest.TestCase):
def test_linear_pruner(self):
model = TorchModel()
......@@ -120,5 +126,15 @@ class IterativePrunerTestCase(unittest.TestCase):
print(sparsity_list)
assert 0.78 < sparsity_list[0]['total_sparsity'] < 0.82
def test_amc_pruner(self):
model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'total_sparsity': 0.5, 'max_sparsity_per_layer': 0.8}]
dummy_input = torch.rand(10, 1, 28, 28)
ddpg_params = {'hidden1': 300, 'hidden2': 300, 'lr_c': 1e-3, 'lr_a': 1e-4, 'warmup': 5, 'discount': 1.,
'bsize': 64, 'rmsize': 100, 'window_length': 1, 'tau': 0.01, 'init_delta': 0.5, 'delta_decay': 0.99,
'max_episode_length': 1e9, 'epsilon': 50000}
pruner = AMCPruner(10, model, config_list, dummy_input, evaluator, finetuner=finetuner, ddpg_params=ddpg_params, target='flops', log_dir='../../../logs')
pruner.compress()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment