Unverified Commit 6126960c authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

AMC supports resnet (#2876)

parent 392e55f3
...@@ -529,6 +529,16 @@ You can view [example](https://github.com/microsoft/nni/blob/master/examples/mod ...@@ -529,6 +529,16 @@ You can view [example](https://github.com/microsoft/nni/blob/master/examples/mod
.. autoclass:: nni.compression.torch.AMCPruner .. autoclass:: nni.compression.torch.AMCPruner
``` ```
### Reproduced Experiment
We implemented one of the experiments in [AMC: AutoML for Model Compression and Acceleration on Mobile Devices](https://arxiv.org/pdf/1802.03494.pdf), we pruned **MobileNet** to 50% FLOPS for ImageNet in the paper. Our experiments results are as follows:
| Model | Top 1 acc.(paper/ours) | Top 5 acc. (paper/ours) | FLOPS |
| ------------- | --------------| -------------- | ----- |
| MobileNet | 70.5% / 69.9% | 89.3% / 89.1% | 50% |
The experiments code can be found at [examples/model_compress]( https://github.com/microsoft/nni/tree/master/examples/model_compress/amc/)
## ADMM Pruner ## ADMM Pruner
Alternating Direction Method of Multipliers (ADMM) is a mathematical optimization technique, Alternating Direction Method of Multipliers (ADMM) is a mathematical optimization technique,
by decomposing the original nonconvex problem into two subproblems that can be solved iteratively. In weight pruning problem, these two subproblems are solved via 1) gradient descent algorithm and 2) Euclidean projection respectively. by decomposing the original nonconvex problem into two subproblems that can be solved iteratively. In weight pruning problem, these two subproblems are solved via 1) gradient descent algorithm and 2) Euclidean projection respectively.
......
...@@ -7,7 +7,7 @@ import time ...@@ -7,7 +7,7 @@ import time
import torch import torch
import torch.nn as nn import torch.nn as nn
from torchvision.models import resnet
from nni.compression.torch import AMCPruner from nni.compression.torch import AMCPruner
from data import get_split_dataset from data import get_split_dataset
from utils import AverageMeter, accuracy from utils import AverageMeter, accuracy
...@@ -16,7 +16,8 @@ sys.path.append('../models') ...@@ -16,7 +16,8 @@ sys.path.append('../models')
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='AMC search script') parser = argparse.ArgumentParser(description='AMC search script')
parser.add_argument('--model_type', default='mobilenet', type=str, choices=['mobilenet', 'mobilenetv2'], help='model to prune') parser.add_argument('--model_type', default='mobilenet', type=str, choices=['mobilenet', 'mobilenetv2', 'resnet18', 'resnet34', 'resnet50'],
help='model to prune')
parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'imagenet'], help='dataset to use (cifar/imagenet)') parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'imagenet'], help='dataset to use (cifar/imagenet)')
parser.add_argument('--batch_size', default=50, type=int, help='number of data batch size') parser.add_argument('--batch_size', default=50, type=int, help='number of data batch size')
parser.add_argument('--data_root', default='./cifar10', type=str, help='dataset path') parser.add_argument('--data_root', default='./cifar10', type=str, help='dataset path')
...@@ -28,27 +29,29 @@ def parse_args(): ...@@ -28,27 +29,29 @@ def parse_args():
parser.add_argument('--train_episode', default=800, type=int, help='number of training episode') parser.add_argument('--train_episode', default=800, type=int, help='number of training episode')
parser.add_argument('--n_gpu', default=1, type=int, help='number of gpu to use') parser.add_argument('--n_gpu', default=1, type=int, help='number of gpu to use')
parser.add_argument('--n_worker', default=16, type=int, help='number of data loader worker') parser.add_argument('--n_worker', default=16, type=int, help='number of data loader worker')
parser.add_argument('--job', default='train_export', type=str, choices=['train_export', 'export_only'], parser.add_argument('--suffix', default=None, type=str, help='suffix of auto-generated log directory')
help='search best pruning policy and export or just export model with searched policy')
parser.add_argument('--export_path', default=None, type=str, help='path for exporting models')
parser.add_argument('--searched_model_path', default=None, type=str, help='path for searched best wrapped model')
return parser.parse_args() return parser.parse_args()
def get_model_and_checkpoint(model, dataset, checkpoint_path, n_gpu=1): def get_model_and_checkpoint(model, dataset, checkpoint_path, n_gpu=1):
if model == 'mobilenet' and dataset == 'imagenet': if dataset == 'imagenet':
from mobilenet import MobileNet n_class = 1000
net = MobileNet(n_class=1000) elif dataset == 'cifar10':
elif model == 'mobilenetv2' and dataset == 'imagenet': n_class = 10
from mobilenet_v2 import MobileNetV2 else:
net = MobileNetV2(n_class=1000) raise ValueError('unsupported dataset')
elif model == 'mobilenet' and dataset == 'cifar10':
if model == 'mobilenet':
from mobilenet import MobileNet from mobilenet import MobileNet
net = MobileNet(n_class=10) net = MobileNet(n_class=n_class)
elif model == 'mobilenetv2' and dataset == 'cifar10': elif model == 'mobilenetv2':
from mobilenet_v2 import MobileNetV2 from mobilenet_v2 import MobileNetV2
net = MobileNetV2(n_class=10) net = MobileNetV2(n_class=n_class)
elif model.startswith('resnet'):
net = resnet.__dict__[model](pretrained=True)
in_features = net.fc.in_features
net.fc = nn.Linear(in_features, n_class)
else: else:
raise NotImplementedError raise NotImplementedError
if checkpoint_path: if checkpoint_path:
...@@ -130,7 +133,6 @@ if __name__ == "__main__": ...@@ -130,7 +133,6 @@ if __name__ == "__main__":
}] }]
pruner = AMCPruner( pruner = AMCPruner(
model, config_list, validate, val_loader, model_type=args.model_type, dataset=args.dataset, model, config_list, validate, val_loader, model_type=args.model_type, dataset=args.dataset,
train_episode=args.train_episode, job=args.job, export_path=args.export_path, train_episode=args.train_episode, flops_ratio=args.flops_ratio, lbound=args.lbound,
searched_model_path=args.searched_model_path, rbound=args.rbound, suffix=args.suffix)
flops_ratio=args.flops_ratio, lbound=args.lbound, rbound=args.rbound)
pruner.compress() pruner.compress()
...@@ -16,6 +16,7 @@ from tensorboardX import SummaryWriter ...@@ -16,6 +16,7 @@ from tensorboardX import SummaryWriter
from nni.compression.torch.pruning.amc.lib.net_measure import measure_model from nni.compression.torch.pruning.amc.lib.net_measure import measure_model
from nni.compression.torch.pruning.amc.lib.utils import get_output_folder from nni.compression.torch.pruning.amc.lib.utils import get_output_folder
from nni.compression.torch import ModelSpeedup
from data import get_dataset from data import get_dataset
from utils import AverageMeter, accuracy, progress_bar from utils import AverageMeter, accuracy, progress_bar
...@@ -28,17 +29,19 @@ def parse_args(): ...@@ -28,17 +29,19 @@ def parse_args():
parser = argparse.ArgumentParser(description='AMC train / fine-tune script') parser = argparse.ArgumentParser(description='AMC train / fine-tune script')
parser.add_argument('--model_type', default='mobilenet', type=str, help='name of the model to train') parser.add_argument('--model_type', default='mobilenet', type=str, help='name of the model to train')
parser.add_argument('--dataset', default='cifar10', type=str, help='name of the dataset to train') parser.add_argument('--dataset', default='cifar10', type=str, help='name of the dataset to train')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate') parser.add_argument('--lr', default=0.05, type=float, help='learning rate')
parser.add_argument('--n_gpu', default=1, type=int, help='number of GPUs to use') parser.add_argument('--n_gpu', default=4, type=int, help='number of GPUs to use')
parser.add_argument('--batch_size', default=128, type=int, help='batch size') parser.add_argument('--batch_size', default=256, type=int, help='batch size')
parser.add_argument('--n_worker', default=4, type=int, help='number of data loader worker') parser.add_argument('--n_worker', default=32, type=int, help='number of data loader worker')
parser.add_argument('--lr_type', default='exp', type=str, help='lr scheduler (exp/cos/step3/fixed)') parser.add_argument('--lr_type', default='cos', type=str, help='lr scheduler (exp/cos/step3/fixed)')
parser.add_argument('--n_epoch', default=50, type=int, help='number of epochs to train') parser.add_argument('--n_epoch', default=150, type=int, help='number of epochs to train')
parser.add_argument('--wd', default=4e-5, type=float, help='weight decay') parser.add_argument('--wd', default=4e-5, type=float, help='weight decay')
parser.add_argument('--seed', default=None, type=int, help='random seed to set') parser.add_argument('--seed', default=None, type=int, help='random seed to set')
parser.add_argument('--data_root', default='./data', type=str, help='dataset path') parser.add_argument('--data_root', default='./data', type=str, help='dataset path')
# resume # resume
parser.add_argument('--ckpt_path', default=None, type=str, help='checkpoint path to fine tune') parser.add_argument('--ckpt_path', default=None, type=str, help='checkpoint path to fine tune')
parser.add_argument('--mask_path', default=None, type=str, help='mask path for speedup')
# run eval # run eval
parser.add_argument('--eval', action='store_true', help='Simply run eval') parser.add_argument('--eval', action='store_true', help='Simply run eval')
parser.add_argument('--calc_flops', action='store_true', help='Calculate flops') parser.add_argument('--calc_flops', action='store_true', help='Calculate flops')
...@@ -56,20 +59,21 @@ def get_model(args): ...@@ -56,20 +59,21 @@ def get_model(args):
raise NotImplementedError raise NotImplementedError
if args.model_type == 'mobilenet': if args.model_type == 'mobilenet':
net = MobileNet(n_class=n_class).cuda() net = MobileNet(n_class=n_class)
elif args.model_type == 'mobilenetv2': elif args.model_type == 'mobilenetv2':
net = MobileNetV2(n_class=n_class).cuda() net = MobileNetV2(n_class=n_class)
else: else:
raise NotImplementedError raise NotImplementedError
if args.ckpt_path is not None: if args.ckpt_path is not None:
# the checkpoint can be a saved whole model object exported by amc_search.py, or a state_dict # the checkpoint can be state_dict exported by amc_search.py or saved by amc_train.py
print('=> Loading checkpoint {} ..'.format(args.ckpt_path)) print('=> Loading checkpoint {} ..'.format(args.ckpt_path))
ckpt = torch.load(args.ckpt_path) net.load_state_dict(torch.load(args.ckpt_path))
if type(ckpt) == dict: if args.mask_path is not None:
net.load_state_dict(ckpt['state_dict']) SZ = 224 if args.dataset == 'imagenet' else 32
else: data = torch.randn(2, 3, SZ, SZ)
net = ckpt ms = ModelSpeedup(net, data, args.mask_path)
ms.speedup_model()
net.to(args.device) net.to(args.device)
if torch.cuda.is_available() and args.n_gpu > 1: if torch.cuda.is_available() and args.n_gpu > 1:
...@@ -204,7 +208,7 @@ if __name__ == '__main__': ...@@ -204,7 +208,7 @@ if __name__ == '__main__':
if args.calc_flops: if args.calc_flops:
IMAGE_SIZE = 224 if args.dataset == 'imagenet' else 32 IMAGE_SIZE = 224 if args.dataset == 'imagenet' else 32
n_flops, n_params = measure_model(net, IMAGE_SIZE, IMAGE_SIZE) n_flops, n_params = measure_model(net, IMAGE_SIZE, IMAGE_SIZE, args.device)
print('=> Model Parameter: {:.3f} M, FLOPs: {:.3f}M'.format(n_params / 1e6, n_flops / 1e6)) print('=> Model Parameter: {:.3f} M, FLOPs: {:.3f}M'.format(n_params / 1e6, n_flops / 1e6))
exit(0) exit(0)
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import os import os
import logging
from copy import deepcopy from copy import deepcopy
from argparse import Namespace from argparse import Namespace
import numpy as np import numpy as np
...@@ -15,6 +16,8 @@ from .lib.utils import get_output_folder ...@@ -15,6 +16,8 @@ from .lib.utils import get_output_folder
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
_logger = logging.getLogger(__name__)
class AMCPruner(Pruner): class AMCPruner(Pruner):
""" """
A pytorch implementation of AMC: AutoML for Model Compression and Acceleration on Mobile Devices. A pytorch implementation of AMC: AutoML for Model Compression and Acceleration on Mobile Devices.
...@@ -38,13 +41,6 @@ class AMCPruner(Pruner): ...@@ -38,13 +41,6 @@ class AMCPruner(Pruner):
Data loader of validation dataset. Data loader of validation dataset.
suffix: str suffix: str
suffix to help you remember what experiment you ran. Default: None. suffix to help you remember what experiment you ran. Default: None.
job: str
train_export: search best pruned model and export after search.
export_only: export a searched model, searched_model_path and export_path must be specified.
searched_model_path: str
when job == export_only, use searched_model_path to specify the path of the searched model.
export_path: str
path for exporting models
# parameters for pruning environment # parameters for pruning environment
model_type: str model_type: str
...@@ -118,9 +114,6 @@ class AMCPruner(Pruner): ...@@ -118,9 +114,6 @@ class AMCPruner(Pruner):
evaluator, evaluator,
val_loader, val_loader,
suffix=None, suffix=None,
job='train_export',
export_path=None,
searched_model_path=None,
model_type='mobilenet', model_type='mobilenet',
dataset='cifar10', dataset='cifar10',
flops_ratio=0.5, flops_ratio=0.5,
...@@ -149,9 +142,8 @@ class AMCPruner(Pruner): ...@@ -149,9 +142,8 @@ class AMCPruner(Pruner):
epsilon=50000, epsilon=50000,
seed=None): seed=None):
self.job = job self.val_loader = val_loader
self.searched_model_path = searched_model_path self.evaluator = evaluator
self.export_path = export_path
if seed is not None: if seed is not None:
np.random.seed(seed) np.random.seed(seed)
...@@ -165,11 +157,9 @@ class AMCPruner(Pruner): ...@@ -165,11 +157,9 @@ class AMCPruner(Pruner):
# build folder and logs # build folder and logs
base_folder_name = '{}_{}_r{}_search'.format(model_type, dataset, flops_ratio) base_folder_name = '{}_{}_r{}_search'.format(model_type, dataset, flops_ratio)
if suffix is not None: if suffix is not None:
base_folder_name = base_folder_name + '_' + suffix self.output_dir = os.path.join(output_dir, base_folder_name + '-' + suffix)
self.output_dir = get_output_folder(output_dir, base_folder_name) else:
self.output_dir = get_output_folder(output_dir, base_folder_name)
if self.export_path is None:
self.export_path = os.path.join(self.output_dir, '{}_r{}_exported.pth'.format(model_type, flops_ratio))
self.env_args = Namespace( self.env_args = Namespace(
model_type=model_type, model_type=model_type,
...@@ -182,47 +172,42 @@ class AMCPruner(Pruner): ...@@ -182,47 +172,42 @@ class AMCPruner(Pruner):
channel_round=channel_round, channel_round=channel_round,
output=self.output_dir output=self.output_dir
) )
self.env = ChannelPruningEnv( self.env = ChannelPruningEnv(
self, evaluator, val_loader, checkpoint, args=self.env_args) self, evaluator, val_loader, checkpoint, args=self.env_args)
_logger.info('=> Saving logs to %s', self.output_dir)
if self.job == 'train_export': self.tfwriter = SummaryWriter(log_dir=self.output_dir)
print('=> Saving logs to {}'.format(self.output_dir)) self.text_writer = open(os.path.join(self.output_dir, 'log.txt'), 'w')
self.tfwriter = SummaryWriter(log_dir=self.output_dir) _logger.info('=> Output path: %s...', self.output_dir)
self.text_writer = open(os.path.join(self.output_dir, 'log.txt'), 'w')
print('=> Output path: {}...'.format(self.output_dir)) nb_states = self.env.layer_embedding.shape[1]
nb_actions = 1 # just 1 action here
nb_states = self.env.layer_embedding.shape[1]
nb_actions = 1 # just 1 action here rmsize = rmsize * len(self.env.prunable_idx) # for each layer
_logger.info('** Actual replay buffer size: %d', rmsize)
rmsize = rmsize * len(self.env.prunable_idx) # for each layer
print('** Actual replay buffer size: {}'.format(rmsize)) self.ddpg_args = Namespace(
hidden1=hidden1,
self.ddpg_args = Namespace( hidden2=hidden2,
hidden1=hidden1, lr_c=lr_c,
hidden2=hidden2, lr_a=lr_a,
lr_c=lr_c, warmup=warmup,
lr_a=lr_a, discount=discount,
warmup=warmup, bsize=bsize,
discount=discount, rmsize=rmsize,
bsize=bsize, window_length=window_length,
rmsize=rmsize, tau=tau,
window_length=window_length, init_delta=init_delta,
tau=tau, delta_decay=delta_decay,
init_delta=init_delta, max_episode_length=max_episode_length,
delta_decay=delta_decay, debug=debug,
max_episode_length=max_episode_length, train_episode=train_episode,
debug=debug, epsilon=epsilon
train_episode=train_episode, )
epsilon=epsilon self.agent = DDPG(nb_states, nb_actions, self.ddpg_args)
)
self.agent = DDPG(nb_states, nb_actions, self.ddpg_args)
def compress(self): def compress(self):
if self.job == 'train_export': self.train(self.ddpg_args.train_episode, self.agent, self.env, self.output_dir)
self.train(self.ddpg_args.train_episode, self.agent, self.env, self.output_dir)
self.export_pruned_model()
def train(self, num_episode, agent, env, output_dir): def train(self, num_episode, agent, env, output_dir):
agent.is_training = True agent.is_training = True
...@@ -263,12 +248,11 @@ class AMCPruner(Pruner): ...@@ -263,12 +248,11 @@ class AMCPruner(Pruner):
observation = deepcopy(observation2) observation = deepcopy(observation2)
if done: # end of episode if done: # end of episode
print( _logger.info(
'#{}: episode_reward:{:.4f} acc: {:.4f}, ratio: {:.4f}'.format( '#%d: episode_reward: %.4f acc: %.4f, ratio: %.4f',
episode, episode_reward, episode, episode_reward,
info['accuracy'], info['accuracy'],
info['compress_ratio'] info['compress_ratio']
)
) )
self.text_writer.write( self.text_writer.write(
'#{}: episode_reward:{:.4f} acc: {:.4f}, ratio: {:.4f}\n'.format( '#{}: episode_reward:{:.4f} acc: {:.4f}, ratio: {:.4f}\n'.format(
...@@ -310,19 +294,3 @@ class AMCPruner(Pruner): ...@@ -310,19 +294,3 @@ class AMCPruner(Pruner):
self.text_writer.write('best reward: {}\n'.format(env.best_reward)) self.text_writer.write('best reward: {}\n'.format(env.best_reward))
self.text_writer.write('best policy: {}\n'.format(env.best_strategy)) self.text_writer.write('best policy: {}\n'.format(env.best_strategy))
self.text_writer.close() self.text_writer.close()
def export_pruned_model(self):
if self.searched_model_path is None:
wrapper_model_ckpt = os.path.join(self.output_dir, 'best_wrapped_model.pth')
else:
wrapper_model_ckpt = self.searched_model_path
self.env.reset()
self.bound_model.load_state_dict(torch.load(wrapper_model_ckpt))
print('validate searched model:', self.env._validate(self.env._val_loader, self.env.model))
self.env.export_model()
self._unwrap_model()
print('validate exported model:', self.env._validate(self.env._val_loader, self.env.model))
torch.save(self.bound_model, self.export_path)
print('exported model saved to: {}'.format(self.export_path))
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import os import os
import logging
import time import time
import math import math
import copy import copy
...@@ -10,9 +11,10 @@ import torch ...@@ -10,9 +11,10 @@ import torch
import torch.nn as nn import torch.nn as nn
from nni.compression.torch.compressor import PrunerModuleWrapper from nni.compression.torch.compressor import PrunerModuleWrapper
from .lib.utils import prGreen
from .. import AMCWeightMasker from .. import AMCWeightMasker
_logger = logging.getLogger(__name__)
# for pruning # for pruning
def acc_reward(net, acc, flops): def acc_reward(net, acc, flops):
return acc * 0.01 return acc * 0.01
...@@ -139,13 +141,13 @@ class ChannelPruningEnv: ...@@ -139,13 +141,13 @@ class ChannelPruningEnv:
# build reward # build reward
self.reset() # restore weight self.reset() # restore weight
self.org_acc = self._validate(self._val_loader, self.model) self.org_acc = self._validate(self._val_loader, self.model)
print('=> original acc: {:.3f}%'.format(self.org_acc)) _logger.info('=> original acc: %.3f', self.org_acc)
self.org_model_size = sum(self.wsize_list) self.org_model_size = sum(self.wsize_list)
print('=> original weight size: {:.4f} M param'.format(self.org_model_size * 1. / 1e6)) _logger.info('=> original weight size: %.4f M param', self.org_model_size * 1. / 1e6)
self.org_flops = sum(self.flops_list) self.org_flops = sum(self.flops_list)
print('=> FLOPs:') _logger.info('=> FLOPs:')
print([self.layer_info_dict[idx]['flops']/1e6 for idx in sorted(self.layer_info_dict.keys())]) _logger.info([self.layer_info_dict[idx]['flops']/1e6 for idx in sorted(self.layer_info_dict.keys())])
print('=> original FLOPs: {:.4f} M'.format(self.org_flops * 1. / 1e6)) _logger.info('=> original FLOPs: %.4f M', self.org_flops * 1. / 1e6)
self.expected_preserve_computation = self.preserve_ratio * self.org_flops self.expected_preserve_computation = self.preserve_ratio * self.org_flops
...@@ -200,10 +202,12 @@ class ChannelPruningEnv: ...@@ -200,10 +202,12 @@ class ChannelPruningEnv:
self.best_reward = reward self.best_reward = reward
self.best_strategy = self.strategy.copy() self.best_strategy = self.strategy.copy()
self.best_d_prime_list = self.d_prime_list.copy() self.best_d_prime_list = self.d_prime_list.copy()
torch.save(self.model.state_dict(), os.path.join(self.args.output, 'best_wrapped_model.pth')) best_model = os.path.join(self.args.output, 'best_model.pth')
prGreen('New best reward: {:.4f}, acc: {:.4f}, compress: {:.4f}'.format(self.best_reward, acc, compress_ratio)) best_mask = os.path.join(self.args.output, 'best_mask.pth')
prGreen('New best policy: {}'.format(self.best_strategy)) self.pruner.export_model(model_path=best_model, mask_path=best_mask)
prGreen('New best d primes: {}'.format(self.best_d_prime_list)) _logger.info('New best reward: %.4f, acc: %.4f, compress: %.4f', self.best_reward, acc, compress_ratio)
_logger.info('New best policy: %s', self.best_strategy)
_logger.info('New best d primes: %s', self.best_d_prime_list)
obs = self.layer_embedding[self.cur_ind, :].copy() # actually the same as the last state obs = self.layer_embedding[self.cur_ind, :].copy() # actually the same as the last state
done = True done = True
return obs, reward, done, info_set return obs, reward, done, info_set
...@@ -242,9 +246,6 @@ class ChannelPruningEnv: ...@@ -242,9 +246,6 @@ class ChannelPruningEnv:
self.index_buffer = {} self.index_buffer = {}
return obs return obs
def set_export_path(self, path):
self.export_path = path
def prune_kernel(self, op_idx, preserve_ratio, preserve_idx=None): def prune_kernel(self, op_idx, preserve_ratio, preserve_idx=None):
m_list = list(self.model.modules()) m_list = list(self.model.modules())
op = m_list[op_idx] op = m_list[op_idx]
...@@ -273,66 +274,6 @@ class ChannelPruningEnv: ...@@ -273,66 +274,6 @@ class ChannelPruningEnv:
action = (m == 1).sum().item() / m.numel() action = (m == 1).sum().item() / m.numel()
return action, d_prime, preserve_idx return action, d_prime, preserve_idx
def export_model(self):
while True:
self.export_layer(self.prunable_idx[self.cur_ind])
if self._is_final_layer():
break
self.cur_ind += 1
#TODO replace this speedup implementation with nni.compression.torch.ModelSpeedup
def export_layer(self, op_idx):
m_list = list(self.model.modules())
op = m_list[op_idx]
assert type(op) == PrunerModuleWrapper
w = op.module.weight.cpu().data
m = op.weight_mask.cpu().data
if type(op.module) == nn.Linear:
w = w.unsqueeze(-1).unsqueeze(-1)
m = m.unsqueeze(-1).unsqueeze(-1)
d_prime = (m.sum((0, 2, 3)) > 0).sum().item()
preserve_idx = np.nonzero((m.sum((0, 2, 3)) > 0).numpy())[0]
assert d_prime <= w.size(1)
if d_prime == w.size(1):
return
mask = np.zeros(w.size(1), bool)
mask[preserve_idx] = True
rec_weight = torch.zeros((w.size(0), d_prime, w.size(2), w.size(3)))
rec_weight = w[:, preserve_idx, :, :]
if type(op.module) == nn.Linear:
rec_weight = rec_weight.squeeze()
# no need to provide bias mask for channel pruning
rec_mask = torch.ones_like(rec_weight)
# assign new weight and mask
device = op.module.weight.device
op.module.weight.data = rec_weight.to(device)
op.weight_mask = rec_mask.to(device)
if type(op.module) == nn.Conv2d:
op.module.in_channels = d_prime
else:
# Linear
op.module.in_features = d_prime
# export prev layers
prev_idx = self.prunable_idx[self.prunable_idx.index(op_idx) - 1]
for idx in range(prev_idx, op_idx):
m = m_list[idx]
if type(m) == nn.Conv2d: # depthwise
m.weight.data = m.weight.data[mask, :, :, :]
if m.groups == m.in_channels:
m.groups = int(np.sum(mask))
m.out_channels = d_prime
elif type(m) == nn.BatchNorm2d:
m.weight.data = m.weight.data[mask]
m.bias.data = m.bias.data[mask]
m.running_mean.data = m.running_mean.data[mask]
m.running_var.data = m.running_var.data[mask]
m.num_features = d_prime
def _is_final_layer(self): def _is_final_layer(self):
return self.cur_ind == len(self.prunable_idx) - 1 return self.cur_ind == len(self.prunable_idx) - 1
...@@ -456,7 +397,7 @@ class ChannelPruningEnv: ...@@ -456,7 +397,7 @@ class ChannelPruningEnv:
else: # same group else: # same group
share_group.append(c_idx) share_group.append(c_idx)
self.shared_idx.append(share_group) self.shared_idx.append(share_group)
print('=> Conv layers to share channels: {}'.format(self.shared_idx)) _logger.info('=> Conv layers to share channels: %s', self.shared_idx)
self.min_strategy_dict = copy.deepcopy(self.strategy_dict) self.min_strategy_dict = copy.deepcopy(self.strategy_dict)
...@@ -464,10 +405,10 @@ class ChannelPruningEnv: ...@@ -464,10 +405,10 @@ class ChannelPruningEnv:
for _, v in self.buffer_dict.items(): for _, v in self.buffer_dict.items():
self.buffer_idx += v self.buffer_idx += v
print('=> Prunable layer idx: {}'.format(self.prunable_idx)) _logger.info('=> Prunable layer idx: %s', self.prunable_idx)
print('=> Buffer layer idx: {}'.format(self.buffer_idx)) _logger.info('=> Buffer layer idx: %s', self.buffer_idx)
print('=> Shared idx: {}'.format(self.shared_idx)) _logger.info('=> Shared idx: %s', self.shared_idx)
print('=> Initial min strategy dict: {}'.format(self.min_strategy_dict)) _logger.info('=> Initial min strategy dict: %s', self.min_strategy_dict)
# added for supporting residual connections during pruning # added for supporting residual connections during pruning
self.visited = [False] * len(self.prunable_idx) self.visited = [False] * len(self.prunable_idx)
...@@ -504,7 +445,7 @@ class ChannelPruningEnv: ...@@ -504,7 +445,7 @@ class ChannelPruningEnv:
device = m.module.weight.device device = m.module.weight.device
# now let the image flow # now let the image flow
print('=> Extracting information...') _logger.info('=> Extracting information...')
with torch.no_grad(): with torch.no_grad():
for i_b, (inputs, target) in enumerate(self._val_loader): # use image from train set for i_b, (inputs, target) in enumerate(self._val_loader): # use image from train set
if i_b == self.n_calibration_batches: if i_b == self.n_calibration_batches:
...@@ -522,7 +463,7 @@ class ChannelPruningEnv: ...@@ -522,7 +463,7 @@ class ChannelPruningEnv:
self.layer_info_dict[idx]['flops'] = m_list[idx].flops self.layer_info_dict[idx]['flops'] = m_list[idx].flops
self.wsize_list.append(m_list[idx].params) self.wsize_list.append(m_list[idx].params)
self.flops_list.append(m_list[idx].flops) self.flops_list.append(m_list[idx].flops)
print('flops:', self.flops_list) _logger.info('flops: %s', self.flops_list)
for idx in self.prunable_idx: for idx in self.prunable_idx:
f_in_np = m_list[idx].input_feat.data.cpu().numpy() f_in_np = m_list[idx].input_feat.data.cpu().numpy()
f_out_np = m_list[idx].output_feat.data.cpu().numpy() f_out_np = m_list[idx].output_feat.data.cpu().numpy()
...@@ -559,7 +500,7 @@ class ChannelPruningEnv: ...@@ -559,7 +500,7 @@ class ChannelPruningEnv:
def _build_state_embedding(self): def _build_state_embedding(self):
# build the static part of the state embedding # build the static part of the state embedding
print('Building state embedding...') _logger.info('Building state embedding...')
layer_embedding = [] layer_embedding = []
module_list = list(self.model.modules()) module_list = list(self.model.modules())
for i, ind in enumerate(self.prunable_idx): for i, ind in enumerate(self.prunable_idx):
...@@ -590,7 +531,7 @@ class ChannelPruningEnv: ...@@ -590,7 +531,7 @@ class ChannelPruningEnv:
# normalize the state # normalize the state
layer_embedding = np.array(layer_embedding, 'float') layer_embedding = np.array(layer_embedding, 'float')
print('=> shape of embedding (n_layer * n_dim): {}'.format(layer_embedding.shape)) _logger.info('=> shape of embedding (n_layer * n_dim): %s', layer_embedding.shape)
assert len(layer_embedding.shape) == 2, layer_embedding.shape assert len(layer_embedding.shape) == 2, layer_embedding.shape
for i in range(layer_embedding.shape[1]): for i in range(layer_embedding.shape[1]):
fmin = min(layer_embedding[:, i]) fmin = min(layer_embedding[:, i])
......
...@@ -85,11 +85,11 @@ def measure_layer(layer, x): ...@@ -85,11 +85,11 @@ def measure_layer(layer, x):
return return
def measure_model(model, H, W): def measure_model(model, H, W, device):
global count_ops, count_params global count_ops, count_params
count_ops = 0 count_ops = 0
count_params = 0 count_params = 0
data = torch.zeros(2, 3, H, W).cuda() data = torch.zeros(2, 3, H, W).to(device)
def should_measure(x): def should_measure(x):
return is_leaf(x) return is_leaf(x)
......
...@@ -111,14 +111,3 @@ def get_output_folder(parent_dir, env_name): ...@@ -111,14 +111,3 @@ def get_output_folder(parent_dir, env_name):
parent_dir = parent_dir + '-run{}'.format(experiment_id) parent_dir = parent_dir + '-run{}'.format(experiment_id)
os.makedirs(parent_dir, exist_ok=True) os.makedirs(parent_dir, exist_ok=True)
return parent_dir return parent_dir
# logging
def prRed(prt): print("\033[91m {}\033[00m" .format(prt))
def prGreen(prt): print("\033[92m {}\033[00m" .format(prt))
def prYellow(prt): print("\033[93m {}\033[00m" .format(prt))
def prLightPurple(prt): print("\033[94m {}\033[00m" .format(prt))
def prPurple(prt): print("\033[95m {}\033[00m" .format(prt))
def prCyan(prt): print("\033[96m {}\033[00m" .format(prt))
def prLightGray(prt): print("\033[97m {}\033[00m" .format(prt))
def prBlack(prt): print("\033[98m {}\033[00m" .format(prt))
...@@ -812,25 +812,20 @@ class AMCWeightMasker(WeightMasker): ...@@ -812,25 +812,20 @@ class AMCWeightMasker(WeightMasker):
masked_X = X[:, mask] masked_X = X[:, mask]
if w.shape[2] == 1: # 1x1 conv or fc if w.shape[2] == 1: # 1x1 conv or fc
rec_weight = least_square_sklearn(X=masked_X, Y=Y) rec_weight = least_square_sklearn(X=masked_X, Y=Y)
# (C_out, K_h, K_w, C_in') rec_weight = rec_weight.reshape(-1, 1, 1, d_prime) # (C_out, K_h, K_w, C_in')
rec_weight = rec_weight.reshape(-1, 1, 1, d_prime) rec_weight = np.transpose(rec_weight, (0, 3, 1, 2)) # (C_out, C_in', K_h, K_w)
# (C_out, C_in', K_h, K_w)
rec_weight = np.transpose(rec_weight, (0, 3, 1, 2))
else:
raise NotImplementedError(
'Current code only supports 1x1 conv now!')
rec_weight_pad = np.zeros_like(w)
# pylint: disable=all
rec_weight_pad[:, mask, :, :] = rec_weight
rec_weight = rec_weight_pad
if wrapper.type == 'Linear': rec_weight_pad = np.zeros_like(w)
rec_weight = rec_weight.squeeze() # pylint: disable=all
assert len(rec_weight.shape) == 2 rec_weight_pad[:, mask, :, :] = rec_weight
rec_weight = rec_weight_pad
if wrapper.type == 'Linear':
rec_weight = rec_weight.squeeze()
assert len(rec_weight.shape) == 2
# now assign # now assign
wrapper.module.weight.data = torch.from_numpy( wrapper.module.weight.data = torch.from_numpy(rec_weight).to(weight.device)
rec_weight).to(weight.device)
mask_weight = torch.zeros_like(weight) mask_weight = torch.zeros_like(weight)
if wrapper.type == 'Linear': if wrapper.type == 'Linear':
......
...@@ -285,5 +285,11 @@ class PrunerTestCase(TestCase): ...@@ -285,5 +285,11 @@ class PrunerTestCase(TestCase):
pruner = AMCPruner(model, config_list, validate, val_loader, train_episode=1) pruner = AMCPruner(model, config_list, validate, val_loader, train_episode=1)
pruner.compress() pruner.compress()
pruner.export_model('./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth', input_shape=(2,3,32,32))
filePaths = ['./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth']
for f in filePaths:
if os.path.exists(f):
os.remove(f)
if __name__ == '__main__': if __name__ == '__main__':
main() main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment