"vscode:/vscode.git/clone" did not exist on "8c8d8ca23b08f146dae402dee7b4891e50aca16d"
Unverified Commit 6126960c authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

AMC supports resnet (#2876)

parent 392e55f3
......@@ -529,6 +529,16 @@ You can view [example](https://github.com/microsoft/nni/blob/master/examples/mod
.. autoclass:: nni.compression.torch.AMCPruner
```
### Reproduced Experiment
We implemented one of the experiments in [AMC: AutoML for Model Compression and Acceleration on Mobile Devices](https://arxiv.org/pdf/1802.03494.pdf), we pruned **MobileNet** to 50% FLOPS for ImageNet in the paper. Our experiments results are as follows:
| Model | Top 1 acc.(paper/ours) | Top 5 acc. (paper/ours) | FLOPS |
| ------------- | --------------| -------------- | ----- |
| MobileNet | 70.5% / 69.9% | 89.3% / 89.1% | 50% |
The experiments code can be found at [examples/model_compress]( https://github.com/microsoft/nni/tree/master/examples/model_compress/amc/)
## ADMM Pruner
Alternating Direction Method of Multipliers (ADMM) is a mathematical optimization technique,
by decomposing the original nonconvex problem into two subproblems that can be solved iteratively. In weight pruning problem, these two subproblems are solved via 1) gradient descent algorithm and 2) Euclidean projection respectively.
......
......@@ -7,7 +7,7 @@ import time
import torch
import torch.nn as nn
from torchvision.models import resnet
from nni.compression.torch import AMCPruner
from data import get_split_dataset
from utils import AverageMeter, accuracy
......@@ -16,7 +16,8 @@ sys.path.append('../models')
def parse_args():
parser = argparse.ArgumentParser(description='AMC search script')
parser.add_argument('--model_type', default='mobilenet', type=str, choices=['mobilenet', 'mobilenetv2'], help='model to prune')
parser.add_argument('--model_type', default='mobilenet', type=str, choices=['mobilenet', 'mobilenetv2', 'resnet18', 'resnet34', 'resnet50'],
help='model to prune')
parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'imagenet'], help='dataset to use (cifar/imagenet)')
parser.add_argument('--batch_size', default=50, type=int, help='number of data batch size')
parser.add_argument('--data_root', default='./cifar10', type=str, help='dataset path')
......@@ -28,27 +29,29 @@ def parse_args():
parser.add_argument('--train_episode', default=800, type=int, help='number of training episode')
parser.add_argument('--n_gpu', default=1, type=int, help='number of gpu to use')
parser.add_argument('--n_worker', default=16, type=int, help='number of data loader worker')
parser.add_argument('--job', default='train_export', type=str, choices=['train_export', 'export_only'],
help='search best pruning policy and export or just export model with searched policy')
parser.add_argument('--export_path', default=None, type=str, help='path for exporting models')
parser.add_argument('--searched_model_path', default=None, type=str, help='path for searched best wrapped model')
parser.add_argument('--suffix', default=None, type=str, help='suffix of auto-generated log directory')
return parser.parse_args()
def get_model_and_checkpoint(model, dataset, checkpoint_path, n_gpu=1):
if model == 'mobilenet' and dataset == 'imagenet':
from mobilenet import MobileNet
net = MobileNet(n_class=1000)
elif model == 'mobilenetv2' and dataset == 'imagenet':
from mobilenet_v2 import MobileNetV2
net = MobileNetV2(n_class=1000)
elif model == 'mobilenet' and dataset == 'cifar10':
if dataset == 'imagenet':
n_class = 1000
elif dataset == 'cifar10':
n_class = 10
else:
raise ValueError('unsupported dataset')
if model == 'mobilenet':
from mobilenet import MobileNet
net = MobileNet(n_class=10)
elif model == 'mobilenetv2' and dataset == 'cifar10':
net = MobileNet(n_class=n_class)
elif model == 'mobilenetv2':
from mobilenet_v2 import MobileNetV2
net = MobileNetV2(n_class=10)
net = MobileNetV2(n_class=n_class)
elif model.startswith('resnet'):
net = resnet.__dict__[model](pretrained=True)
in_features = net.fc.in_features
net.fc = nn.Linear(in_features, n_class)
else:
raise NotImplementedError
if checkpoint_path:
......@@ -130,7 +133,6 @@ if __name__ == "__main__":
}]
pruner = AMCPruner(
model, config_list, validate, val_loader, model_type=args.model_type, dataset=args.dataset,
train_episode=args.train_episode, job=args.job, export_path=args.export_path,
searched_model_path=args.searched_model_path,
flops_ratio=args.flops_ratio, lbound=args.lbound, rbound=args.rbound)
train_episode=args.train_episode, flops_ratio=args.flops_ratio, lbound=args.lbound,
rbound=args.rbound, suffix=args.suffix)
pruner.compress()
......@@ -16,6 +16,7 @@ from tensorboardX import SummaryWriter
from nni.compression.torch.pruning.amc.lib.net_measure import measure_model
from nni.compression.torch.pruning.amc.lib.utils import get_output_folder
from nni.compression.torch import ModelSpeedup
from data import get_dataset
from utils import AverageMeter, accuracy, progress_bar
......@@ -28,17 +29,19 @@ def parse_args():
parser = argparse.ArgumentParser(description='AMC train / fine-tune script')
parser.add_argument('--model_type', default='mobilenet', type=str, help='name of the model to train')
parser.add_argument('--dataset', default='cifar10', type=str, help='name of the dataset to train')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--n_gpu', default=1, type=int, help='number of GPUs to use')
parser.add_argument('--batch_size', default=128, type=int, help='batch size')
parser.add_argument('--n_worker', default=4, type=int, help='number of data loader worker')
parser.add_argument('--lr_type', default='exp', type=str, help='lr scheduler (exp/cos/step3/fixed)')
parser.add_argument('--n_epoch', default=50, type=int, help='number of epochs to train')
parser.add_argument('--lr', default=0.05, type=float, help='learning rate')
parser.add_argument('--n_gpu', default=4, type=int, help='number of GPUs to use')
parser.add_argument('--batch_size', default=256, type=int, help='batch size')
parser.add_argument('--n_worker', default=32, type=int, help='number of data loader worker')
parser.add_argument('--lr_type', default='cos', type=str, help='lr scheduler (exp/cos/step3/fixed)')
parser.add_argument('--n_epoch', default=150, type=int, help='number of epochs to train')
parser.add_argument('--wd', default=4e-5, type=float, help='weight decay')
parser.add_argument('--seed', default=None, type=int, help='random seed to set')
parser.add_argument('--data_root', default='./data', type=str, help='dataset path')
# resume
parser.add_argument('--ckpt_path', default=None, type=str, help='checkpoint path to fine tune')
parser.add_argument('--mask_path', default=None, type=str, help='mask path for speedup')
# run eval
parser.add_argument('--eval', action='store_true', help='Simply run eval')
parser.add_argument('--calc_flops', action='store_true', help='Calculate flops')
......@@ -56,20 +59,21 @@ def get_model(args):
raise NotImplementedError
if args.model_type == 'mobilenet':
net = MobileNet(n_class=n_class).cuda()
net = MobileNet(n_class=n_class)
elif args.model_type == 'mobilenetv2':
net = MobileNetV2(n_class=n_class).cuda()
net = MobileNetV2(n_class=n_class)
else:
raise NotImplementedError
if args.ckpt_path is not None:
# the checkpoint can be a saved whole model object exported by amc_search.py, or a state_dict
# the checkpoint can be state_dict exported by amc_search.py or saved by amc_train.py
print('=> Loading checkpoint {} ..'.format(args.ckpt_path))
ckpt = torch.load(args.ckpt_path)
if type(ckpt) == dict:
net.load_state_dict(ckpt['state_dict'])
else:
net = ckpt
net.load_state_dict(torch.load(args.ckpt_path))
if args.mask_path is not None:
SZ = 224 if args.dataset == 'imagenet' else 32
data = torch.randn(2, 3, SZ, SZ)
ms = ModelSpeedup(net, data, args.mask_path)
ms.speedup_model()
net.to(args.device)
if torch.cuda.is_available() and args.n_gpu > 1:
......@@ -204,7 +208,7 @@ if __name__ == '__main__':
if args.calc_flops:
IMAGE_SIZE = 224 if args.dataset == 'imagenet' else 32
n_flops, n_params = measure_model(net, IMAGE_SIZE, IMAGE_SIZE)
n_flops, n_params = measure_model(net, IMAGE_SIZE, IMAGE_SIZE, args.device)
print('=> Model Parameter: {:.3f} M, FLOPs: {:.3f}M'.format(n_params / 1e6, n_flops / 1e6))
exit(0)
......
......@@ -2,6 +2,7 @@
# Licensed under the MIT license.
import os
import logging
from copy import deepcopy
from argparse import Namespace
import numpy as np
......@@ -15,6 +16,8 @@ from .lib.utils import get_output_folder
torch.backends.cudnn.deterministic = True
_logger = logging.getLogger(__name__)
class AMCPruner(Pruner):
"""
A pytorch implementation of AMC: AutoML for Model Compression and Acceleration on Mobile Devices.
......@@ -38,13 +41,6 @@ class AMCPruner(Pruner):
Data loader of validation dataset.
suffix: str
suffix to help you remember what experiment you ran. Default: None.
job: str
train_export: search best pruned model and export after search.
export_only: export a searched model, searched_model_path and export_path must be specified.
searched_model_path: str
when job == export_only, use searched_model_path to specify the path of the searched model.
export_path: str
path for exporting models
# parameters for pruning environment
model_type: str
......@@ -118,9 +114,6 @@ class AMCPruner(Pruner):
evaluator,
val_loader,
suffix=None,
job='train_export',
export_path=None,
searched_model_path=None,
model_type='mobilenet',
dataset='cifar10',
flops_ratio=0.5,
......@@ -149,9 +142,8 @@ class AMCPruner(Pruner):
epsilon=50000,
seed=None):
self.job = job
self.searched_model_path = searched_model_path
self.export_path = export_path
self.val_loader = val_loader
self.evaluator = evaluator
if seed is not None:
np.random.seed(seed)
......@@ -165,11 +157,9 @@ class AMCPruner(Pruner):
# build folder and logs
base_folder_name = '{}_{}_r{}_search'.format(model_type, dataset, flops_ratio)
if suffix is not None:
base_folder_name = base_folder_name + '_' + suffix
self.output_dir = get_output_folder(output_dir, base_folder_name)
if self.export_path is None:
self.export_path = os.path.join(self.output_dir, '{}_r{}_exported.pth'.format(model_type, flops_ratio))
self.output_dir = os.path.join(output_dir, base_folder_name + '-' + suffix)
else:
self.output_dir = get_output_folder(output_dir, base_folder_name)
self.env_args = Namespace(
model_type=model_type,
......@@ -182,47 +172,42 @@ class AMCPruner(Pruner):
channel_round=channel_round,
output=self.output_dir
)
self.env = ChannelPruningEnv(
self, evaluator, val_loader, checkpoint, args=self.env_args)
if self.job == 'train_export':
print('=> Saving logs to {}'.format(self.output_dir))
self.tfwriter = SummaryWriter(log_dir=self.output_dir)
self.text_writer = open(os.path.join(self.output_dir, 'log.txt'), 'w')
print('=> Output path: {}...'.format(self.output_dir))
nb_states = self.env.layer_embedding.shape[1]
nb_actions = 1 # just 1 action here
rmsize = rmsize * len(self.env.prunable_idx) # for each layer
print('** Actual replay buffer size: {}'.format(rmsize))
self.ddpg_args = Namespace(
hidden1=hidden1,
hidden2=hidden2,
lr_c=lr_c,
lr_a=lr_a,
warmup=warmup,
discount=discount,
bsize=bsize,
rmsize=rmsize,
window_length=window_length,
tau=tau,
init_delta=init_delta,
delta_decay=delta_decay,
max_episode_length=max_episode_length,
debug=debug,
train_episode=train_episode,
epsilon=epsilon
)
self.agent = DDPG(nb_states, nb_actions, self.ddpg_args)
_logger.info('=> Saving logs to %s', self.output_dir)
self.tfwriter = SummaryWriter(log_dir=self.output_dir)
self.text_writer = open(os.path.join(self.output_dir, 'log.txt'), 'w')
_logger.info('=> Output path: %s...', self.output_dir)
nb_states = self.env.layer_embedding.shape[1]
nb_actions = 1 # just 1 action here
rmsize = rmsize * len(self.env.prunable_idx) # for each layer
_logger.info('** Actual replay buffer size: %d', rmsize)
self.ddpg_args = Namespace(
hidden1=hidden1,
hidden2=hidden2,
lr_c=lr_c,
lr_a=lr_a,
warmup=warmup,
discount=discount,
bsize=bsize,
rmsize=rmsize,
window_length=window_length,
tau=tau,
init_delta=init_delta,
delta_decay=delta_decay,
max_episode_length=max_episode_length,
debug=debug,
train_episode=train_episode,
epsilon=epsilon
)
self.agent = DDPG(nb_states, nb_actions, self.ddpg_args)
def compress(self):
if self.job == 'train_export':
self.train(self.ddpg_args.train_episode, self.agent, self.env, self.output_dir)
self.export_pruned_model()
self.train(self.ddpg_args.train_episode, self.agent, self.env, self.output_dir)
def train(self, num_episode, agent, env, output_dir):
agent.is_training = True
......@@ -263,12 +248,11 @@ class AMCPruner(Pruner):
observation = deepcopy(observation2)
if done: # end of episode
print(
'#{}: episode_reward:{:.4f} acc: {:.4f}, ratio: {:.4f}'.format(
_logger.info(
'#%d: episode_reward: %.4f acc: %.4f, ratio: %.4f',
episode, episode_reward,
info['accuracy'],
info['compress_ratio']
)
)
self.text_writer.write(
'#{}: episode_reward:{:.4f} acc: {:.4f}, ratio: {:.4f}\n'.format(
......@@ -310,19 +294,3 @@ class AMCPruner(Pruner):
self.text_writer.write('best reward: {}\n'.format(env.best_reward))
self.text_writer.write('best policy: {}\n'.format(env.best_strategy))
self.text_writer.close()
def export_pruned_model(self):
if self.searched_model_path is None:
wrapper_model_ckpt = os.path.join(self.output_dir, 'best_wrapped_model.pth')
else:
wrapper_model_ckpt = self.searched_model_path
self.env.reset()
self.bound_model.load_state_dict(torch.load(wrapper_model_ckpt))
print('validate searched model:', self.env._validate(self.env._val_loader, self.env.model))
self.env.export_model()
self._unwrap_model()
print('validate exported model:', self.env._validate(self.env._val_loader, self.env.model))
torch.save(self.bound_model, self.export_path)
print('exported model saved to: {}'.format(self.export_path))
......@@ -2,6 +2,7 @@
# Licensed under the MIT license.
import os
import logging
import time
import math
import copy
......@@ -10,9 +11,10 @@ import torch
import torch.nn as nn
from nni.compression.torch.compressor import PrunerModuleWrapper
from .lib.utils import prGreen
from .. import AMCWeightMasker
_logger = logging.getLogger(__name__)
# for pruning
def acc_reward(net, acc, flops):
return acc * 0.01
......@@ -139,13 +141,13 @@ class ChannelPruningEnv:
# build reward
self.reset() # restore weight
self.org_acc = self._validate(self._val_loader, self.model)
print('=> original acc: {:.3f}%'.format(self.org_acc))
_logger.info('=> original acc: %.3f', self.org_acc)
self.org_model_size = sum(self.wsize_list)
print('=> original weight size: {:.4f} M param'.format(self.org_model_size * 1. / 1e6))
_logger.info('=> original weight size: %.4f M param', self.org_model_size * 1. / 1e6)
self.org_flops = sum(self.flops_list)
print('=> FLOPs:')
print([self.layer_info_dict[idx]['flops']/1e6 for idx in sorted(self.layer_info_dict.keys())])
print('=> original FLOPs: {:.4f} M'.format(self.org_flops * 1. / 1e6))
_logger.info('=> FLOPs:')
_logger.info([self.layer_info_dict[idx]['flops']/1e6 for idx in sorted(self.layer_info_dict.keys())])
_logger.info('=> original FLOPs: %.4f M', self.org_flops * 1. / 1e6)
self.expected_preserve_computation = self.preserve_ratio * self.org_flops
......@@ -200,10 +202,12 @@ class ChannelPruningEnv:
self.best_reward = reward
self.best_strategy = self.strategy.copy()
self.best_d_prime_list = self.d_prime_list.copy()
torch.save(self.model.state_dict(), os.path.join(self.args.output, 'best_wrapped_model.pth'))
prGreen('New best reward: {:.4f}, acc: {:.4f}, compress: {:.4f}'.format(self.best_reward, acc, compress_ratio))
prGreen('New best policy: {}'.format(self.best_strategy))
prGreen('New best d primes: {}'.format(self.best_d_prime_list))
best_model = os.path.join(self.args.output, 'best_model.pth')
best_mask = os.path.join(self.args.output, 'best_mask.pth')
self.pruner.export_model(model_path=best_model, mask_path=best_mask)
_logger.info('New best reward: %.4f, acc: %.4f, compress: %.4f', self.best_reward, acc, compress_ratio)
_logger.info('New best policy: %s', self.best_strategy)
_logger.info('New best d primes: %s', self.best_d_prime_list)
obs = self.layer_embedding[self.cur_ind, :].copy() # actually the same as the last state
done = True
return obs, reward, done, info_set
......@@ -242,9 +246,6 @@ class ChannelPruningEnv:
self.index_buffer = {}
return obs
def set_export_path(self, path):
self.export_path = path
def prune_kernel(self, op_idx, preserve_ratio, preserve_idx=None):
m_list = list(self.model.modules())
op = m_list[op_idx]
......@@ -273,66 +274,6 @@ class ChannelPruningEnv:
action = (m == 1).sum().item() / m.numel()
return action, d_prime, preserve_idx
def export_model(self):
while True:
self.export_layer(self.prunable_idx[self.cur_ind])
if self._is_final_layer():
break
self.cur_ind += 1
#TODO replace this speedup implementation with nni.compression.torch.ModelSpeedup
def export_layer(self, op_idx):
m_list = list(self.model.modules())
op = m_list[op_idx]
assert type(op) == PrunerModuleWrapper
w = op.module.weight.cpu().data
m = op.weight_mask.cpu().data
if type(op.module) == nn.Linear:
w = w.unsqueeze(-1).unsqueeze(-1)
m = m.unsqueeze(-1).unsqueeze(-1)
d_prime = (m.sum((0, 2, 3)) > 0).sum().item()
preserve_idx = np.nonzero((m.sum((0, 2, 3)) > 0).numpy())[0]
assert d_prime <= w.size(1)
if d_prime == w.size(1):
return
mask = np.zeros(w.size(1), bool)
mask[preserve_idx] = True
rec_weight = torch.zeros((w.size(0), d_prime, w.size(2), w.size(3)))
rec_weight = w[:, preserve_idx, :, :]
if type(op.module) == nn.Linear:
rec_weight = rec_weight.squeeze()
# no need to provide bias mask for channel pruning
rec_mask = torch.ones_like(rec_weight)
# assign new weight and mask
device = op.module.weight.device
op.module.weight.data = rec_weight.to(device)
op.weight_mask = rec_mask.to(device)
if type(op.module) == nn.Conv2d:
op.module.in_channels = d_prime
else:
# Linear
op.module.in_features = d_prime
# export prev layers
prev_idx = self.prunable_idx[self.prunable_idx.index(op_idx) - 1]
for idx in range(prev_idx, op_idx):
m = m_list[idx]
if type(m) == nn.Conv2d: # depthwise
m.weight.data = m.weight.data[mask, :, :, :]
if m.groups == m.in_channels:
m.groups = int(np.sum(mask))
m.out_channels = d_prime
elif type(m) == nn.BatchNorm2d:
m.weight.data = m.weight.data[mask]
m.bias.data = m.bias.data[mask]
m.running_mean.data = m.running_mean.data[mask]
m.running_var.data = m.running_var.data[mask]
m.num_features = d_prime
def _is_final_layer(self):
return self.cur_ind == len(self.prunable_idx) - 1
......@@ -456,7 +397,7 @@ class ChannelPruningEnv:
else: # same group
share_group.append(c_idx)
self.shared_idx.append(share_group)
print('=> Conv layers to share channels: {}'.format(self.shared_idx))
_logger.info('=> Conv layers to share channels: %s', self.shared_idx)
self.min_strategy_dict = copy.deepcopy(self.strategy_dict)
......@@ -464,10 +405,10 @@ class ChannelPruningEnv:
for _, v in self.buffer_dict.items():
self.buffer_idx += v
print('=> Prunable layer idx: {}'.format(self.prunable_idx))
print('=> Buffer layer idx: {}'.format(self.buffer_idx))
print('=> Shared idx: {}'.format(self.shared_idx))
print('=> Initial min strategy dict: {}'.format(self.min_strategy_dict))
_logger.info('=> Prunable layer idx: %s', self.prunable_idx)
_logger.info('=> Buffer layer idx: %s', self.buffer_idx)
_logger.info('=> Shared idx: %s', self.shared_idx)
_logger.info('=> Initial min strategy dict: %s', self.min_strategy_dict)
# added for supporting residual connections during pruning
self.visited = [False] * len(self.prunable_idx)
......@@ -504,7 +445,7 @@ class ChannelPruningEnv:
device = m.module.weight.device
# now let the image flow
print('=> Extracting information...')
_logger.info('=> Extracting information...')
with torch.no_grad():
for i_b, (inputs, target) in enumerate(self._val_loader): # use image from train set
if i_b == self.n_calibration_batches:
......@@ -522,7 +463,7 @@ class ChannelPruningEnv:
self.layer_info_dict[idx]['flops'] = m_list[idx].flops
self.wsize_list.append(m_list[idx].params)
self.flops_list.append(m_list[idx].flops)
print('flops:', self.flops_list)
_logger.info('flops: %s', self.flops_list)
for idx in self.prunable_idx:
f_in_np = m_list[idx].input_feat.data.cpu().numpy()
f_out_np = m_list[idx].output_feat.data.cpu().numpy()
......@@ -559,7 +500,7 @@ class ChannelPruningEnv:
def _build_state_embedding(self):
# build the static part of the state embedding
print('Building state embedding...')
_logger.info('Building state embedding...')
layer_embedding = []
module_list = list(self.model.modules())
for i, ind in enumerate(self.prunable_idx):
......@@ -590,7 +531,7 @@ class ChannelPruningEnv:
# normalize the state
layer_embedding = np.array(layer_embedding, 'float')
print('=> shape of embedding (n_layer * n_dim): {}'.format(layer_embedding.shape))
_logger.info('=> shape of embedding (n_layer * n_dim): %s', layer_embedding.shape)
assert len(layer_embedding.shape) == 2, layer_embedding.shape
for i in range(layer_embedding.shape[1]):
fmin = min(layer_embedding[:, i])
......
......@@ -85,11 +85,11 @@ def measure_layer(layer, x):
return
def measure_model(model, H, W):
def measure_model(model, H, W, device):
global count_ops, count_params
count_ops = 0
count_params = 0
data = torch.zeros(2, 3, H, W).cuda()
data = torch.zeros(2, 3, H, W).to(device)
def should_measure(x):
return is_leaf(x)
......
......@@ -111,14 +111,3 @@ def get_output_folder(parent_dir, env_name):
parent_dir = parent_dir + '-run{}'.format(experiment_id)
os.makedirs(parent_dir, exist_ok=True)
return parent_dir
# logging
def prRed(prt): print("\033[91m {}\033[00m" .format(prt))
def prGreen(prt): print("\033[92m {}\033[00m" .format(prt))
def prYellow(prt): print("\033[93m {}\033[00m" .format(prt))
def prLightPurple(prt): print("\033[94m {}\033[00m" .format(prt))
def prPurple(prt): print("\033[95m {}\033[00m" .format(prt))
def prCyan(prt): print("\033[96m {}\033[00m" .format(prt))
def prLightGray(prt): print("\033[97m {}\033[00m" .format(prt))
def prBlack(prt): print("\033[98m {}\033[00m" .format(prt))
......@@ -812,25 +812,20 @@ class AMCWeightMasker(WeightMasker):
masked_X = X[:, mask]
if w.shape[2] == 1: # 1x1 conv or fc
rec_weight = least_square_sklearn(X=masked_X, Y=Y)
# (C_out, K_h, K_w, C_in')
rec_weight = rec_weight.reshape(-1, 1, 1, d_prime)
# (C_out, C_in', K_h, K_w)
rec_weight = np.transpose(rec_weight, (0, 3, 1, 2))
else:
raise NotImplementedError(
'Current code only supports 1x1 conv now!')
rec_weight_pad = np.zeros_like(w)
# pylint: disable=all
rec_weight_pad[:, mask, :, :] = rec_weight
rec_weight = rec_weight_pad
rec_weight = rec_weight.reshape(-1, 1, 1, d_prime) # (C_out, K_h, K_w, C_in')
rec_weight = np.transpose(rec_weight, (0, 3, 1, 2)) # (C_out, C_in', K_h, K_w)
if wrapper.type == 'Linear':
rec_weight = rec_weight.squeeze()
assert len(rec_weight.shape) == 2
rec_weight_pad = np.zeros_like(w)
# pylint: disable=all
rec_weight_pad[:, mask, :, :] = rec_weight
rec_weight = rec_weight_pad
if wrapper.type == 'Linear':
rec_weight = rec_weight.squeeze()
assert len(rec_weight.shape) == 2
# now assign
wrapper.module.weight.data = torch.from_numpy(
rec_weight).to(weight.device)
# now assign
wrapper.module.weight.data = torch.from_numpy(rec_weight).to(weight.device)
mask_weight = torch.zeros_like(weight)
if wrapper.type == 'Linear':
......
......@@ -285,5 +285,11 @@ class PrunerTestCase(TestCase):
pruner = AMCPruner(model, config_list, validate, val_loader, train_episode=1)
pruner.compress()
pruner.export_model('./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth', input_shape=(2,3,32,32))
filePaths = ['./model_tmp.pth', './mask_tmp.pth', './onnx_tmp.pth']
for f in filePaths:
if os.path.exists(f):
os.remove(f)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment