Unverified Commit e9f3cddf authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

AutoML for model compression (#2573)

parent 3757cf27
...@@ -28,6 +28,7 @@ jobs: ...@@ -28,6 +28,7 @@ jobs:
set -e set -e
sudo apt-get install -y pandoc sudo apt-get install -y pandoc
python3 -m pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --user python3 -m pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --user
python3 -m pip install tensorboardX==1.9
python3 -m pip install tensorflow==2.2.0 --user python3 -m pip install tensorflow==2.2.0 --user
python3 -m pip install keras==2.4.2 --user python3 -m pip install keras==2.4.2 --user
python3 -m pip install gym onnx peewee thop --user python3 -m pip install gym onnx peewee thop --user
...@@ -68,6 +69,7 @@ jobs: ...@@ -68,6 +69,7 @@ jobs:
- script: | - script: |
set -e set -e
python3 -m pip install torch==1.3.1+cpu torchvision==0.4.2+cpu -f https://download.pytorch.org/whl/torch_stable.html --user python3 -m pip install torch==1.3.1+cpu torchvision==0.4.2+cpu -f https://download.pytorch.org/whl/torch_stable.html --user
python3 -m pip install tensorboardX==1.9
python3 -m pip install tensorflow==1.15.2 --user python3 -m pip install tensorflow==1.15.2 --user
python3 -m pip install keras==2.1.6 --user python3 -m pip install keras==2.1.6 --user
python3 -m pip install gym onnx peewee --user python3 -m pip install gym onnx peewee --user
...@@ -117,6 +119,7 @@ jobs: ...@@ -117,6 +119,7 @@ jobs:
set -e set -e
# pytorch Mac binary does not support CUDA, default is cpu version # pytorch Mac binary does not support CUDA, default is cpu version
python3 -m pip install torchvision==0.6.0 torch==1.5.0 --user python3 -m pip install torchvision==0.6.0 torch==1.5.0 --user
python3 -m pip install tensorboardX==1.9
python3 -m pip install tensorflow==1.15.2 --user python3 -m pip install tensorflow==1.15.2 --user
brew install swig@3 brew install swig@3
rm -f /usr/local/bin/swig rm -f /usr/local/bin/swig
...@@ -144,6 +147,7 @@ jobs: ...@@ -144,6 +147,7 @@ jobs:
python -m pip install scikit-learn==0.23.2 --user python -m pip install scikit-learn==0.23.2 --user
python -m pip install keras==2.1.6 --user python -m pip install keras==2.1.6 --user
python -m pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --user python -m pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html --user
python -m pip install tensorboardX==1.9
python -m pip install tensorflow==1.15.2 --user python -m pip install tensorflow==1.15.2 --user
displayName: 'Install dependencies' displayName: 'Install dependencies'
- script: | - script: |
......
...@@ -20,6 +20,7 @@ We provide several pruning algorithms that support fine-grained weight pruning a ...@@ -20,6 +20,7 @@ We provide several pruning algorithms that support fine-grained weight pruning a
* [NetAdapt Pruner](#netadapt-pruner) * [NetAdapt Pruner](#netadapt-pruner)
* [SimulatedAnnealing Pruner](#simulatedannealing-pruner) * [SimulatedAnnealing Pruner](#simulatedannealing-pruner)
* [AutoCompress Pruner](#autocompress-pruner) * [AutoCompress Pruner](#autocompress-pruner)
* [AutoML for Model Compression Pruner](#automl-for-model-compression-pruner)
* [Sensitivity Pruner](#sensitivity-pruner) * [Sensitivity Pruner](#sensitivity-pruner)
**Others** **Others**
...@@ -476,6 +477,39 @@ You can view [example](https://github.com/microsoft/nni/blob/master/examples/mod ...@@ -476,6 +477,39 @@ You can view [example](https://github.com/microsoft/nni/blob/master/examples/mod
.. autoclass:: nni.compression.torch.AutoCompressPruner .. autoclass:: nni.compression.torch.AutoCompressPruner
``` ```
## AutoML for Model Compression Pruner
AutoML for Model Compression Pruner (AMCPruner) leverages reinforcement learning to provide the model compression policy.
This learning-based compression policy outperforms conventional rule-based compression policy by having higher compression ratio,
better preserving the accuracy and freeing human labor.
![](../../img/amc_pruner.jpg)
For more details, please refer to [AMC: AutoML for Model Compression and Acceleration on Mobile Devices](https://arxiv.org/pdf/1802.03494.pdf).
#### Usage
PyTorch code
```python
from nni.compression.torch import AMCPruner
config_list = [{
'op_types': ['Conv2d', 'Linear']
}]
pruner = AMCPruner(model, config_list, evaluator, val_loader, flops_ratio=0.5)
pruner.compress()
```
You can view [example](https://github.com/microsoft/nni/blob/master/examples/model_compress/amc/) for more information.
#### User configuration for AutoCompress Pruner
##### PyTorch
```eval_rst
.. autoclass:: nni.compression.torch.AMCPruner
```
## ADMM Pruner ## ADMM Pruner
Alternating Direction Method of Multipliers (ADMM) is a mathematical optimization technique, Alternating Direction Method of Multipliers (ADMM) is a mathematical optimization technique,
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import sys
import argparse
import time
import torch
import torch.nn as nn
from nni.compression.torch import AMCPruner
from data import get_split_dataset
from utils import AverageMeter, accuracy
sys.path.append('../models')
def parse_args():
parser = argparse.ArgumentParser(description='AMC search script')
parser.add_argument('--model_type', default='mobilenet', type=str, choices=['mobilenet', 'mobilenetv2'], help='model to prune')
parser.add_argument('--dataset', default='cifar10', type=str, choices=['cifar10', 'imagenet'], help='dataset to use (cifar/imagenet)')
parser.add_argument('--batch_size', default=50, type=int, help='number of data batch size')
parser.add_argument('--data_root', default='./cifar10', type=str, help='dataset path')
parser.add_argument('--flops_ratio', default=0.5, type=float, help='target flops ratio to preserve of the model')
parser.add_argument('--lbound', default=0.2, type=float, help='minimum sparsity')
parser.add_argument('--rbound', default=1., type=float, help='maximum sparsity')
parser.add_argument('--ckpt_path', default=None, type=str, help='manual path of checkpoint')
parser.add_argument('--train_episode', default=800, type=int, help='number of training episode')
parser.add_argument('--n_gpu', default=1, type=int, help='number of gpu to use')
parser.add_argument('--n_worker', default=16, type=int, help='number of data loader worker')
parser.add_argument('--job', default='train_export', type=str, choices=['train_export', 'export_only'],
help='search best pruning policy and export or just export model with searched policy')
parser.add_argument('--export_path', default=None, type=str, help='path for exporting models')
parser.add_argument('--searched_model_path', default=None, type=str, help='path for searched best wrapped model')
return parser.parse_args()
def get_model_and_checkpoint(model, dataset, checkpoint_path, n_gpu=1):
if model == 'mobilenet' and dataset == 'imagenet':
from mobilenet import MobileNet
net = MobileNet(n_class=1000)
elif model == 'mobilenetv2' and dataset == 'imagenet':
from mobilenet_v2 import MobileNetV2
net = MobileNetV2(n_class=1000)
elif model == 'mobilenet' and dataset == 'cifar10':
from mobilenet import MobileNet
net = MobileNet(n_class=10)
elif model == 'mobilenetv2' and dataset == 'cifar10':
from mobilenet_v2 import MobileNetV2
net = MobileNetV2(n_class=10)
else:
raise NotImplementedError
if checkpoint_path:
print('loading {}...'.format(checkpoint_path))
sd = torch.load(checkpoint_path, map_location=torch.device('cpu'))
if 'state_dict' in sd: # a checkpoint but not a state_dict
sd = sd['state_dict']
sd = {k.replace('module.', ''): v for k, v in sd.items()}
net.load_state_dict(sd)
if torch.cuda.is_available() and n_gpu > 0:
net = net.cuda()
if n_gpu > 1:
net = torch.nn.DataParallel(net, range(n_gpu))
return net
def init_data(args):
# split the train set into train + val
# for CIFAR, split 5k for val
# for ImageNet, split 3k for val
val_size = 5000 if 'cifar' in args.dataset else 3000
train_loader, val_loader, _ = get_split_dataset(
args.dataset, args.batch_size,
args.n_worker, val_size,
data_root=args.data_root,
shuffle=False
) # same sampling
return train_loader, val_loader
def validate(val_loader, model, verbose=False):
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
criterion = nn.CrossEntropyLoss().cuda()
# switch to evaluate mode
model.eval()
end = time.time()
t1 = time.time()
with torch.no_grad():
for i, (input, target) in enumerate(val_loader):
target = target.to(device)
input_var = torch.autograd.Variable(input).to(device)
target_var = torch.autograd.Variable(target).to(device)
# compute output
output = model(input_var)
loss = criterion(output, target_var)
# measure accuracy and record loss
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(prec1.item(), input.size(0))
top5.update(prec5.item(), input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
t2 = time.time()
if verbose:
print('* Test loss: %.3f top1: %.3f top5: %.3f time: %.3f' %
(losses.avg, top1.avg, top5.avg, t2 - t1))
return top5.avg
if __name__ == "__main__":
args = parse_args()
device = torch.device('cuda') if torch.cuda.is_available() and args.n_gpu > 0 else torch.device('cpu')
model = get_model_and_checkpoint(args.model_type, args.dataset, checkpoint_path=args.ckpt_path, n_gpu=args.n_gpu)
_, val_loader = init_data(args)
config_list = [{
'op_types': ['Conv2d', 'Linear']
}]
pruner = AMCPruner(
model, config_list, validate, val_loader, model_type=args.model_type, dataset=args.dataset,
train_episode=args.train_episode, job=args.job, export_path=args.export_path,
searched_model_path=args.searched_model_path,
flops_ratio=args.flops_ratio, lbound=args.lbound, rbound=args.rbound)
pruner.compress()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import sys
import os
import time
import argparse
import shutil
import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tensorboardX import SummaryWriter
from nni.compression.torch.pruning.amc.lib.net_measure import measure_model
from nni.compression.torch.pruning.amc.lib.utils import get_output_folder
from data import get_dataset
from utils import AverageMeter, accuracy, progress_bar
sys.path.append('../models')
from mobilenet import MobileNet
from mobilenet_v2 import MobileNetV2
def parse_args():
parser = argparse.ArgumentParser(description='AMC train / fine-tune script')
parser.add_argument('--model_type', default='mobilenet', type=str, help='name of the model to train')
parser.add_argument('--dataset', default='cifar10', type=str, help='name of the dataset to train')
parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
parser.add_argument('--n_gpu', default=1, type=int, help='number of GPUs to use')
parser.add_argument('--batch_size', default=128, type=int, help='batch size')
parser.add_argument('--n_worker', default=4, type=int, help='number of data loader worker')
parser.add_argument('--lr_type', default='exp', type=str, help='lr scheduler (exp/cos/step3/fixed)')
parser.add_argument('--n_epoch', default=50, type=int, help='number of epochs to train')
parser.add_argument('--wd', default=4e-5, type=float, help='weight decay')
parser.add_argument('--seed', default=None, type=int, help='random seed to set')
parser.add_argument('--data_root', default='./data', type=str, help='dataset path')
# resume
parser.add_argument('--ckpt_path', default=None, type=str, help='checkpoint path to fine tune')
# run eval
parser.add_argument('--eval', action='store_true', help='Simply run eval')
parser.add_argument('--calc_flops', action='store_true', help='Calculate flops')
return parser.parse_args()
def get_model(args):
print('=> Building model..')
if args.dataset == 'imagenet':
n_class = 1000
elif args.dataset == 'cifar10':
n_class = 10
else:
raise NotImplementedError
if args.model_type == 'mobilenet':
net = MobileNet(n_class=n_class).cuda()
elif args.model_type == 'mobilenetv2':
net = MobileNetV2(n_class=n_class).cuda()
else:
raise NotImplementedError
if args.ckpt_path is not None:
# the checkpoint can be a saved whole model object exported by amc_search.py, or a state_dict
print('=> Loading checkpoint {} ..'.format(args.ckpt_path))
ckpt = torch.load(args.ckpt_path)
if type(ckpt) == dict:
net.load_state_dict(ckpt['state_dict'])
else:
net = ckpt
net.to(args.device)
if torch.cuda.is_available() and args.n_gpu > 1:
net = torch.nn.DataParallel(net, list(range(args.n_gpu)))
return net
def train(epoch, train_loader, device):
print('\nEpoch: %d' % epoch)
net.train()
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
end = time.time()
for batch_idx, (inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# measure accuracy and record loss
prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(prec1.item(), inputs.size(0))
top5.update(prec5.item(), inputs.size(0))
# timing
batch_time.update(time.time() - end)
end = time.time()
progress_bar(batch_idx, len(train_loader), 'Loss: {:.3f} | Acc1: {:.3f}% | Acc5: {:.3f}%'
.format(losses.avg, top1.avg, top5.avg))
writer.add_scalar('loss/train', losses.avg, epoch)
writer.add_scalar('acc/train_top1', top1.avg, epoch)
writer.add_scalar('acc/train_top5', top5.avg, epoch)
def test(epoch, test_loader, device, save=True):
global best_acc
net.eval()
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
end = time.time()
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(test_loader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
loss = criterion(outputs, targets)
# measure accuracy and record loss
prec1, prec5 = accuracy(outputs.data, targets.data, topk=(1, 5))
losses.update(loss.item(), inputs.size(0))
top1.update(prec1.item(), inputs.size(0))
top5.update(prec5.item(), inputs.size(0))
# timing
batch_time.update(time.time() - end)
end = time.time()
progress_bar(batch_idx, len(test_loader), 'Loss: {:.3f} | Acc1: {:.3f}% | Acc5: {:.3f}%'
.format(losses.avg, top1.avg, top5.avg))
if save:
writer.add_scalar('loss/test', losses.avg, epoch)
writer.add_scalar('acc/test_top1', top1.avg, epoch)
writer.add_scalar('acc/test_top5', top5.avg, epoch)
is_best = False
if top1.avg > best_acc:
best_acc = top1.avg
is_best = True
print('Current best acc: {}'.format(best_acc))
save_checkpoint({
'epoch': epoch,
'model': args.model_type,
'dataset': args.dataset,
'state_dict': net.module.state_dict() if isinstance(net, nn.DataParallel) else net.state_dict(),
'acc': top1.avg,
'optimizer': optimizer.state_dict(),
}, is_best, checkpoint_dir=log_dir)
def adjust_learning_rate(optimizer, epoch):
if args.lr_type == 'cos': # cos without warm-up
lr = 0.5 * args.lr * (1 + math.cos(math.pi * epoch / args.n_epoch))
elif args.lr_type == 'exp':
step = 1
decay = 0.96
lr = args.lr * (decay ** (epoch // step))
elif args.lr_type == 'fixed':
lr = args.lr
else:
raise NotImplementedError
print('=> lr: {}'.format(lr))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
return lr
def save_checkpoint(state, is_best, checkpoint_dir='.'):
filename = os.path.join(checkpoint_dir, 'ckpt.pth.tar')
print('=> Saving checkpoint to {}'.format(filename))
torch.save(state, filename)
if is_best:
shutil.copyfile(filename, filename.replace('.pth.tar', '.best.pth.tar'))
if __name__ == '__main__':
args = parse_args()
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
args.device = torch.device('cuda') if torch.cuda.is_available() and args.n_gpu > 0 else torch.device('cpu')
best_acc = 0 # best test accuracy
start_epoch = 0 # start from epoch 0 or last checkpoint epoch
if args.seed is not None:
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
print('=> Preparing data..')
train_loader, val_loader, n_class = get_dataset(args.dataset, args.batch_size, args.n_worker,
data_root=args.data_root)
net = get_model(args) # for measure
if args.calc_flops:
IMAGE_SIZE = 224 if args.dataset == 'imagenet' else 32
n_flops, n_params = measure_model(net, IMAGE_SIZE, IMAGE_SIZE)
print('=> Model Parameter: {:.3f} M, FLOPs: {:.3f}M'.format(n_params / 1e6, n_flops / 1e6))
exit(0)
criterion = nn.CrossEntropyLoss()
print('Using SGD...')
print('weight decay = {}'.format(args.wd))
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.wd)
if args.eval: # just run eval
print('=> Start evaluation...')
test(0, val_loader, args.device, save=False)
else: # train
print('=> Start training...')
print('Training {} on {}...'.format(args.model_type, args.dataset))
train_type = 'train' if args.ckpt_path is None else 'finetune'
log_dir = get_output_folder('./logs', '{}_{}_{}'.format(args.model_type, args.dataset, train_type))
print('=> Saving logs to {}'.format(log_dir))
# tf writer
writer = SummaryWriter(logdir=log_dir)
for epoch in range(start_epoch, start_epoch + args.n_epoch):
lr = adjust_learning_rate(optimizer, epoch)
train(epoch, train_loader, args.device)
test(epoch, val_loader, args.device)
writer.close()
print('=> Best top-1 acc: {}%'.format(best_acc))
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
import os
def get_dataset(dset_name, batch_size, n_worker, data_root='../../data'):
cifar_tran_train = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
cifar_tran_test = [
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
]
print('=> Preparing data..')
if dset_name == 'cifar10':
transform_train = transforms.Compose(cifar_tran_train)
transform_test = transforms.Compose(cifar_tran_test)
trainset = torchvision.datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True, sampler=None)
testset = torchvision.datasets.CIFAR10(root=data_root, train=False, download=True, transform=transform_test)
val_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False,
num_workers=n_worker, pin_memory=True)
n_class = 10
elif dset_name == 'imagenet':
# get dir
traindir = os.path.join(data_root, 'train')
valdir = os.path.join(data_root, 'val')
# preprocessing
input_size = 224
imagenet_tran_train = [
transforms.RandomResizedCrop(input_size, scale=(0.2, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
imagenet_tran_test = [
transforms.Resize(int(input_size / 0.875)),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
train_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(traindir, transforms.Compose(imagenet_tran_train)),
batch_size=batch_size, shuffle=True,
num_workers=n_worker, pin_memory=True, sampler=None)
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose(imagenet_tran_test)),
batch_size=batch_size, shuffle=False,
num_workers=n_worker, pin_memory=True)
n_class = 1000
else:
raise NotImplementedError
return train_loader, val_loader, n_class
def get_split_dataset(dset_name, batch_size, n_worker, val_size, data_root='../data', shuffle=True):
'''
split the train set into train / val for rl search
'''
if shuffle:
index_sampler = SubsetRandomSampler
else: # every time we use the same order for the split subset
class SubsetSequentialSampler(SubsetRandomSampler):
def __iter__(self):
return (self.indices[i] for i in torch.arange(len(self.indices)).int())
index_sampler = SubsetSequentialSampler
print('=> Preparing data: {}...'.format(dset_name))
if dset_name == 'cifar10':
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR100(root=data_root, train=True, download=True, transform=transform_train)
valset = torchvision.datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform_test)
n_train = len(trainset)
indices = list(range(n_train))
# now shuffle the indices
#np.random.shuffle(indices)
assert val_size < n_train
train_idx, val_idx = indices[val_size:], indices[:val_size]
train_sampler = index_sampler(train_idx)
val_sampler = index_sampler(val_idx)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, sampler=train_sampler,
num_workers=n_worker, pin_memory=True)
val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size, shuffle=False, sampler=val_sampler,
num_workers=n_worker, pin_memory=True)
n_class = 10
elif dset_name == 'imagenet':
train_dir = os.path.join(data_root, 'train')
val_dir = os.path.join(data_root, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
input_size = 224
train_transform = transforms.Compose([
transforms.RandomResizedCrop(input_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([
transforms.Resize(int(input_size/0.875)),
transforms.CenterCrop(input_size),
transforms.ToTensor(),
normalize,
])
trainset = datasets.ImageFolder(train_dir, train_transform)
valset = datasets.ImageFolder(train_dir, test_transform)
n_train = len(trainset)
indices = list(range(n_train))
np.random.shuffle(indices)
assert val_size < n_train
train_idx, val_idx = indices[val_size:], indices[:val_size]
train_sampler = index_sampler(train_idx)
val_sampler = index_sampler(val_idx)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, sampler=train_sampler,
num_workers=n_worker, pin_memory=True)
val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size, sampler=val_sampler,
num_workers=n_worker, pin_memory=True)
n_class = 1000
else:
raise NotImplementedError
return train_loader, val_loader, n_class
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import sys
import os
import time
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
if self.count > 0:
self.avg = self.sum / self.count
def accumulate(self, val, n=1):
self.sum += val
self.count += n
if self.count > 0:
self.avg = self.sum / self.count
def accuracy(output, target, topk=(1, 5)):
"""Computes the precision@k for the specified values of k"""
batch_size = target.size(0)
num = output.size(1)
target_topk = []
appendices = []
for k in topk:
if k <= num:
target_topk.append(k)
else:
appendices.append([0.0])
topk = target_topk
maxk = max(topk)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res + appendices
# Custom progress bar
_, term_width = os.popen('stty size', 'r').read().split()
term_width = int(term_width)
TOTAL_BAR_LENGTH = 40.
last_time = time.time()
begin_time = last_time
def progress_bar(current, total, msg=None):
def format_time(seconds):
days = int(seconds / 3600 / 24)
seconds = seconds - days * 3600 * 24
hours = int(seconds / 3600)
seconds = seconds - hours * 3600
minutes = int(seconds / 60)
seconds = seconds - minutes * 60
secondsf = int(seconds)
seconds = seconds - secondsf
millis = int(seconds * 1000)
f = ''
i = 1
if days > 0:
f += str(days) + 'D'
i += 1
if hours > 0 and i <= 2:
f += str(hours) + 'h'
i += 1
if minutes > 0 and i <= 2:
f += str(minutes) + 'm'
i += 1
if secondsf > 0 and i <= 2:
f += str(secondsf) + 's'
i += 1
if millis > 0 and i <= 2:
f += str(millis) + 'ms'
i += 1
if f == '':
f = '0ms'
return f
global last_time, begin_time
if current == 0:
begin_time = time.time() # Reset for new bar.
cur_len = int(TOTAL_BAR_LENGTH*current/total)
rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1
sys.stdout.write(' [')
for i in range(cur_len):
sys.stdout.write('=')
sys.stdout.write('>')
for i in range(rest_len):
sys.stdout.write('.')
sys.stdout.write(']')
cur_time = time.time()
step_time = cur_time - last_time
last_time = cur_time
tot_time = cur_time - begin_time
L = []
L.append(' Step: %s' % format_time(step_time))
L.append(' | Tot: %s' % format_time(tot_time))
if msg:
L.append(' | ' + msg)
msg = ''.join(L)
sys.stdout.write(msg)
for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3):
sys.stdout.write(' ')
# Go back to the center of the bar.
for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2):
sys.stdout.write('\b')
sys.stdout.write(' %d/%d ' % (current+1, total))
if current < total-1:
sys.stdout.write('\r')
else:
sys.stdout.write('\n')
sys.stdout.flush()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch.nn as nn
import math
def conv_bn(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True)
)
def conv_dw(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
nn.BatchNorm2d(inp),
nn.ReLU(inplace=True),
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True),
)
class MobileNet(nn.Module):
def __init__(self, n_class, profile='normal'):
super(MobileNet, self).__init__()
# original
if profile == 'normal':
in_planes = 32
cfg = [64, (128, 2), 128, (256, 2), 256, (512, 2), 512, 512, 512, 512, 512, (1024, 2), 1024]
# 0.5 AMC
elif profile == '0.5flops':
in_planes = 24
cfg = [48, (96, 2), 80, (192, 2), 200, (328, 2), 352, 368, 360, 328, 400, (736, 2), 752]
else:
raise NotImplementedError
self.conv1 = conv_bn(3, in_planes, stride=2)
self.features = self._make_layers(in_planes, cfg, conv_dw)
self.classifier = nn.Sequential(
nn.Linear(cfg[-1], n_class),
)
self._initialize_weights()
def forward(self, x):
x = self.conv1(x)
x = self.features(x)
x = x.mean(3).mean(2) # global average pooling
x = self.classifier(x)
return x
def _make_layers(self, in_planes, cfg, layer):
layers = []
for x in cfg:
out_planes = x if isinstance(x, int) else x[0]
stride = 1 if isinstance(x, int) else x[1]
layers.append(layer(in_planes, out_planes, stride))
in_planes = out_planes
return nn.Sequential(*layers)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch.nn as nn
import math
def conv_bn(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6(inplace=True)
)
def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU6(inplace=True)
)
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, expand_ratio):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = round(inp * expand_ratio)
self.use_res_connect = self.stride == 1 and inp == oup
if expand_ratio == 1:
self.conv = nn.Sequential(
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
else:
self.conv = nn.Sequential(
# pw
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# dw
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
def forward(self, x):
if self.use_res_connect:
return x + self.conv(x)
else:
return self.conv(x)
class MobileNetV2(nn.Module):
def __init__(self, n_class=1000, input_size=224, width_mult=1.):
super(MobileNetV2, self).__init__()
block = InvertedResidual
input_channel = 32
last_channel = 1280
interverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
# building first layer
assert input_size % 32 == 0
input_channel = int(input_channel * width_mult)
self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
self.features = [conv_bn(3, input_channel, 2)]
# building inverted residual blocks
for t, c, n, s in interverted_residual_setting:
output_channel = int(c * width_mult)
for i in range(n):
if i == 0:
self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
else:
self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
input_channel = output_channel
# building last several layers
self.features.append(conv_1x1_bn(input_channel, self.last_channel))
# make it nn.Sequential
self.features = nn.Sequential(*self.features)
# building classifier
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(self.last_channel, n_class),
)
self._initialize_weights()
def forward(self, x):
x = self.features(x)
x = x.mean(3).mean(2)
x = self.classifier(x)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
...@@ -54,20 +54,34 @@ class Compressor: ...@@ -54,20 +54,34 @@ class Compressor:
self._fwd_hook_handles = {} self._fwd_hook_handles = {}
self._fwd_hook_id = 0 self._fwd_hook_id = 0
for layer, config in self._detect_modules_to_compress(): self.reset()
wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
if not self.modules_wrapper: if not self.modules_wrapper:
_logger.warning('Nothing is configured to compress, please check your model and config_list') _logger.warning('Nothing is configured to compress, please check your model and config_list')
self._wrap_model()
def validate_config(self, model, config_list): def validate_config(self, model, config_list):
""" """
subclass can optionally implement this method to check if config_list if valid subclass can optionally implement this method to check if config_list if valid
""" """
pass pass
def reset(self, checkpoint=None):
"""
reset model state dict and model wrapper
"""
self._unwrap_model()
if checkpoint is not None:
self.bound_model.load_state_dict(checkpoint)
self.modules_to_compress = None
self.modules_wrapper = []
for layer, config in self._detect_modules_to_compress():
wrapper = self._wrap_modules(layer, config)
self.modules_wrapper.append(wrapper)
self._wrap_model()
def _detect_modules_to_compress(self): def _detect_modules_to_compress(self):
""" """
detect all modules should be compressed, and save the result in `self.modules_to_compress`. detect all modules should be compressed, and save the result in `self.modules_to_compress`.
...@@ -346,7 +360,7 @@ class Pruner(Compressor): ...@@ -346,7 +360,7 @@ class Pruner(Compressor):
config : dict config : dict
the configuration for generating the mask the configuration for generating the mask
""" """
_logger.info("Module detected to compress : %s.", layer.name) _logger.debug("Module detected to compress : %s.", layer.name)
wrapper = PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self) wrapper = PrunerModuleWrapper(layer.module, layer.name, layer.type, config, self)
assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name
# move newly registered buffers to the same device of weight # move newly registered buffers to the same device of weight
...@@ -381,7 +395,7 @@ class Pruner(Compressor): ...@@ -381,7 +395,7 @@ class Pruner(Compressor):
if weight_mask is not None: if weight_mask is not None:
mask_sum = weight_mask.sum().item() mask_sum = weight_mask.sum().item()
mask_num = weight_mask.numel() mask_num = weight_mask.numel()
_logger.info('Layer: %s Sparsity: %.4f', wrapper.name, 1 - mask_sum / mask_num) _logger.debug('Layer: %s Sparsity: %.4f', wrapper.name, 1 - mask_sum / mask_num)
wrapper.module.weight.data = wrapper.module.weight.data.mul(weight_mask) wrapper.module.weight.data = wrapper.module.weight.data.mul(weight_mask)
if bias_mask is not None: if bias_mask is not None:
wrapper.module.bias.data = wrapper.module.bias.data.mul(bias_mask) wrapper.module.bias.data = wrapper.module.bias.data.mul(bias_mask)
......
...@@ -12,3 +12,5 @@ from .net_adapt_pruner import NetAdaptPruner ...@@ -12,3 +12,5 @@ from .net_adapt_pruner import NetAdaptPruner
from .admm_pruner import ADMMPruner from .admm_pruner import ADMMPruner
from .auto_compress_pruner import AutoCompressPruner from .auto_compress_pruner import AutoCompressPruner
from .sensitivity_pruner import SensitivityPruner from .sensitivity_pruner import SensitivityPruner
from .amc import AMCPruner
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .amc_pruner import AMCPruner
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
from copy import deepcopy
from argparse import Namespace
import numpy as np
import torch
from nni.compression.torch.compressor import Pruner
from .channel_pruning_env import ChannelPruningEnv
from .lib.agent import DDPG
from .lib.utils import get_output_folder
torch.backends.cudnn.deterministic = True
class AMCPruner(Pruner):
"""
A pytorch implementation of AMC: AutoML for Model Compression and Acceleration on Mobile Devices.
(https://arxiv.org/pdf/1802.03494.pdf)
Parameters:
model: nn.Module
The model to be pruned.
config_list: list
Configuration list to configure layer pruning.
Supported keys:
- op_types: operation type to be pruned
- op_names: operation name to be pruned
evaluator: function
function to evaluate the pruned model.
The prototype of the function:
>>> def evaluator(val_loader, model):
>>> ...
>>> return acc
val_loader: torch.utils.data.DataLoader
Data loader of validation dataset.
suffix: str
suffix to help you remember what experiment you ran. Default: None.
job: str
train_export: search best pruned model and export after search.
export_only: export a searched model, searched_model_path and export_path must be specified.
searched_model_path: str
when job == export_only, use searched_model_path to specify the path of the searched model.
export_path: str
path for exporting models
# parameters for pruning environment
model_type: str
model type to prune, currently 'mobilenet' and 'mobilenetv2' are supported. Default: mobilenet
flops_ratio: float
preserve flops ratio. Default: 0.5
lbound: float
minimum weight preserve ratio for each layer. Default: 0.2
rbound: float
maximum weight preserve ratio for each layer. Default: 1.0
reward: function
reward function type:
- acc_reward: accuracy * 0.01
- acc_flops_reward: - (100 - accuracy) * 0.01 * np.log(flops)
Default: acc_reward
# parameters for channel pruning
n_calibration_batches: int
number of batches to extract layer information. Default: 60
n_points_per_layer: int
number of feature points per layer. Default: 10
channel_round: int
round channel to multiple of channel_round. Default: 8
# parameters for ddpg agent
hidden1: int
hidden num of first fully connect layer. Default: 300
hidden2: int
hidden num of second fully connect layer. Default: 300
lr_c: float
learning rate for critic. Default: 1e-3
lr_a: float
learning rate for actor. Default: 1e-4
warmup: int
number of episodes without training but only filling the replay memory. During warmup episodes,
random actions ares used for pruning. Default: 100
discount: float
next Q value discount for deep Q value target. Default: 0.99
bsize: int
minibatch size for training DDPG agent. Default: 64
rmsize: int
memory size for each layer. Default: 100
window_length: int
replay buffer window length. Default: 1
tau: float
moving average for target network being used by soft_update. Default: 0.99
# noise
init_delta: float
initial variance of truncated normal distribution
delta_decay: float
delta decay during exploration
# parameters for training ddpg agent
max_episode_length: int
maximum episode length
output_dir: str
output directory to save log files and model files. Default: ./logs
debug: boolean
debug mode
train_episode: int
train iters each timestep. Default: 800
epsilon: int
linear decay of exploration policy. Default: 50000
seed: int
random seed to set for reproduce experiment. Default: None
"""
def __init__(
self,
model,
config_list,
evaluator,
val_loader,
suffix=None,
job='train_export',
export_path=None,
searched_model_path=None,
model_type='mobilenet',
dataset='cifar10',
flops_ratio=0.5,
lbound=0.2,
rbound=1.,
reward='acc_reward',
n_calibration_batches=60,
n_points_per_layer=10,
channel_round=8,
hidden1=300,
hidden2=300,
lr_c=1e-3,
lr_a=1e-4,
warmup=100,
discount=1.,
bsize=64,
rmsize=100,
window_length=1,
tau=0.01,
init_delta=0.5,
delta_decay=0.99,
max_episode_length=1e9,
output_dir='./logs',
debug=False,
train_episode=800,
epsilon=50000,
seed=None):
from tensorboardX import SummaryWriter
self.job = job
self.searched_model_path = searched_model_path
self.export_path = export_path
if seed is not None:
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
checkpoint = deepcopy(model.state_dict())
super().__init__(model, config_list, optimizer=None)
# build folder and logs
base_folder_name = '{}_{}_r{}_search'.format(model_type, dataset, flops_ratio)
if suffix is not None:
base_folder_name = base_folder_name + '_' + suffix
self.output_dir = get_output_folder(output_dir, base_folder_name)
if self.export_path is None:
self.export_path = os.path.join(self.output_dir, '{}_r{}_exported.pth'.format(model_type, flops_ratio))
self.env_args = Namespace(
model_type=model_type,
preserve_ratio=flops_ratio,
lbound=lbound,
rbound=rbound,
reward=reward,
n_calibration_batches=n_calibration_batches,
n_points_per_layer=n_points_per_layer,
channel_round=channel_round,
output=self.output_dir
)
self.env = ChannelPruningEnv(
self, evaluator, val_loader, checkpoint, args=self.env_args)
if self.job == 'train_export':
print('=> Saving logs to {}'.format(self.output_dir))
self.tfwriter = SummaryWriter(logdir=self.output_dir)
self.text_writer = open(os.path.join(self.output_dir, 'log.txt'), 'w')
print('=> Output path: {}...'.format(self.output_dir))
nb_states = self.env.layer_embedding.shape[1]
nb_actions = 1 # just 1 action here
rmsize = rmsize * len(self.env.prunable_idx) # for each layer
print('** Actual replay buffer size: {}'.format(rmsize))
self.ddpg_args = Namespace(
hidden1=hidden1,
hidden2=hidden2,
lr_c=lr_c,
lr_a=lr_a,
warmup=warmup,
discount=discount,
bsize=bsize,
rmsize=rmsize,
window_length=window_length,
tau=tau,
init_delta=init_delta,
delta_decay=delta_decay,
max_episode_length=max_episode_length,
debug=debug,
train_episode=train_episode,
epsilon=epsilon
)
self.agent = DDPG(nb_states, nb_actions, self.ddpg_args)
def compress(self):
if self.job == 'train_export':
self.train(self.ddpg_args.train_episode, self.agent, self.env, self.output_dir)
self.export_pruned_model()
def train(self, num_episode, agent, env, output_dir):
agent.is_training = True
step = episode = episode_steps = 0
episode_reward = 0.
observation = None
T = [] # trajectory
while episode < num_episode: # counting based on episode
# reset if it is the start of episode
if observation is None:
observation = deepcopy(env.reset())
agent.reset(observation)
# agent pick action ...
if episode <= self.ddpg_args.warmup:
action = agent.random_action()
# action = sample_from_truncated_normal_distribution(lower=0., upper=1., mu=env.preserve_ratio, sigma=0.5)
else:
action = agent.select_action(observation, episode=episode)
# env response with next_observation, reward, terminate_info
observation2, reward, done, info = env.step(action)
T.append([reward, deepcopy(observation), deepcopy(observation2), action, done])
# fix-length, never reach here
# if max_episode_length and episode_steps >= max_episode_length - 1:
# done = True
# [optional] save intermideate model
if num_episode / 3 <= 1 or episode % int(num_episode / 3) == 0:
agent.save_model(output_dir)
# update
step += 1
episode_steps += 1
episode_reward += reward
observation = deepcopy(observation2)
if done: # end of episode
print(
'#{}: episode_reward:{:.4f} acc: {:.4f}, ratio: {:.4f}'.format(
episode, episode_reward,
info['accuracy'],
info['compress_ratio']
)
)
self.text_writer.write(
'#{}: episode_reward:{:.4f} acc: {:.4f}, ratio: {:.4f}\n'.format(
episode, episode_reward,
info['accuracy'],
info['compress_ratio']
)
)
final_reward = T[-1][0]
# print('final_reward: {}'.format(final_reward))
# agent observe and update policy
for _, s_t, s_t1, a_t, done in T:
agent.observe(final_reward, s_t, s_t1, a_t, done)
if episode > self.ddpg_args.warmup:
agent.update_policy()
#agent.memory.append(
# observation,
# agent.select_action(observation, episode=episode),
# 0., False
#)
# reset
observation = None
episode_steps = 0
episode_reward = 0.
episode += 1
T = []
self.tfwriter.add_scalar('reward/last', final_reward, episode)
self.tfwriter.add_scalar('reward/best', env.best_reward, episode)
self.tfwriter.add_scalar('info/accuracy', info['accuracy'], episode)
self.tfwriter.add_scalar('info/compress_ratio', info['compress_ratio'], episode)
self.tfwriter.add_text('info/best_policy', str(env.best_strategy), episode)
# record the preserve rate for each layer
for i, preserve_rate in enumerate(env.strategy):
self.tfwriter.add_scalar('preserve_rate/{}'.format(i), preserve_rate, episode)
self.text_writer.write('best reward: {}\n'.format(env.best_reward))
self.text_writer.write('best policy: {}\n'.format(env.best_strategy))
self.text_writer.close()
def export_pruned_model(self):
if self.searched_model_path is None:
wrapper_model_ckpt = os.path.join(self.output_dir, 'best_wrapped_model.pth')
else:
wrapper_model_ckpt = self.searched_model_path
self.env.reset()
self.bound_model.load_state_dict(torch.load(wrapper_model_ckpt))
print('validate searched model:', self.env._validate(self.env._val_loader, self.env.model))
self.env.export_model()
self._unwrap_model()
print('validate exported model:', self.env._validate(self.env._val_loader, self.env.model))
torch.save(self.bound_model, self.export_path)
print('exported model saved to: {}'.format(self.export_path))
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from .memory import SequentialMemory
from .utils import to_numpy, to_tensor
criterion = nn.MSELoss()
USE_CUDA = torch.cuda.is_available()
class Actor(nn.Module):
def __init__(self, nb_states, nb_actions, hidden1=400, hidden2=300):
super(Actor, self).__init__()
self.fc1 = nn.Linear(nb_states, hidden1)
self.fc2 = nn.Linear(hidden1, hidden2)
self.fc3 = nn.Linear(hidden2, nb_actions)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
out = self.relu(out)
out = self.fc3(out)
out = self.sigmoid(out)
return out
class Critic(nn.Module):
def __init__(self, nb_states, nb_actions, hidden1=400, hidden2=300):
super(Critic, self).__init__()
self.fc11 = nn.Linear(nb_states, hidden1)
self.fc12 = nn.Linear(nb_actions, hidden1)
self.fc2 = nn.Linear(hidden1, hidden2)
self.fc3 = nn.Linear(hidden2, 1)
self.relu = nn.ReLU()
def forward(self, xs):
x, a = xs
out = self.fc11(x) + self.fc12(a)
out = self.relu(out)
out = self.fc2(out)
out = self.relu(out)
out = self.fc3(out)
return out
class DDPG(object):
def __init__(self, nb_states, nb_actions, args):
self.nb_states = nb_states
self.nb_actions = nb_actions
# Create Actor and Critic Network
net_cfg = {
'hidden1': args.hidden1,
'hidden2': args.hidden2,
# 'init_w': args.init_w
}
self.actor = Actor(self.nb_states, self.nb_actions, **net_cfg)
self.actor_target = Actor(self.nb_states, self.nb_actions, **net_cfg)
self.actor_optim = Adam(self.actor.parameters(), lr=args.lr_a)
self.critic = Critic(self.nb_states, self.nb_actions, **net_cfg)
self.critic_target = Critic(self.nb_states, self.nb_actions, **net_cfg)
self.critic_optim = Adam(self.critic.parameters(), lr=args.lr_c)
self.hard_update(self.actor_target, self.actor) # Make sure target is with the same weight
self.hard_update(self.critic_target, self.critic)
# Create replay buffer
self.memory = SequentialMemory(limit=args.rmsize, window_length=args.window_length)
# self.random_process = OrnsteinUhlenbeckProcess(size=nb_actions, theta=args.ou_theta, mu=args.ou_mu,
# sigma=args.ou_sigma)
# Hyper-parameters
self.batch_size = args.bsize
self.tau = args.tau
self.discount = args.discount
self.depsilon = 1.0 / args.epsilon
self.lbound = 0. # args.lbound
self.rbound = 1. # args.rbound
# noise
self.init_delta = args.init_delta
self.delta_decay = args.delta_decay
self.warmup = args.warmup
#
self.epsilon = 1.0
# self.s_t = None # Most recent state
# self.a_t = None # Most recent action
self.is_training = True
#
if USE_CUDA: self.cuda()
# moving average baseline
self.moving_average = None
self.moving_alpha = 0.5 # based on batch, so small
def update_policy(self):
# Sample batch
state_batch, action_batch, reward_batch, \
next_state_batch, terminal_batch = self.memory.sample_and_split(self.batch_size)
# normalize the reward
batch_mean_reward = np.mean(reward_batch)
if self.moving_average is None:
self.moving_average = batch_mean_reward
else:
self.moving_average += self.moving_alpha * (batch_mean_reward - self.moving_average)
reward_batch -= self.moving_average
# if reward_batch.std() > 0:
# reward_batch /= reward_batch.std()
# Prepare for the target q batch
with torch.no_grad():
next_q_values = self.critic_target([
to_tensor(next_state_batch),
self.actor_target(to_tensor(next_state_batch)),
])
target_q_batch = to_tensor(reward_batch) + \
self.discount * to_tensor(terminal_batch.astype(np.float)) * next_q_values
# Critic update
self.critic.zero_grad()
q_batch = self.critic([to_tensor(state_batch), to_tensor(action_batch)])
value_loss = criterion(q_batch, target_q_batch)
value_loss.backward()
self.critic_optim.step()
# Actor update
self.actor.zero_grad()
policy_loss = -self.critic([ # pylint: disable=all
to_tensor(state_batch),
self.actor(to_tensor(state_batch))
])
policy_loss = policy_loss.mean()
policy_loss.backward()
self.actor_optim.step()
# Target update
self.soft_update(self.actor_target, self.actor)
self.soft_update(self.critic_target, self.critic)
def eval(self):
self.actor.eval()
self.actor_target.eval()
self.critic.eval()
self.critic_target.eval()
def cuda(self):
self.actor.cuda()
self.actor_target.cuda()
self.critic.cuda()
self.critic_target.cuda()
def observe(self, r_t, s_t, s_t1, a_t, done):
if self.is_training:
self.memory.append(s_t, a_t, r_t, done) # save to memory
# self.s_t = s_t1
def random_action(self):
action = np.random.uniform(self.lbound, self.rbound, self.nb_actions)
# self.a_t = action
return action
def select_action(self, s_t, episode):
# assert episode >= self.warmup, 'Episode: {} warmup: {}'.format(episode, self.warmup)
action = to_numpy(self.actor(to_tensor(np.array(s_t).reshape(1, -1)))).squeeze(0)
delta = self.init_delta * (self.delta_decay ** (episode - self.warmup))
# action += self.is_training * max(self.epsilon, 0) * self.random_process.sample()
action = self.sample_from_truncated_normal_distribution(lower=self.lbound, upper=self.rbound, mu=action, sigma=delta)
action = np.clip(action, self.lbound, self.rbound)
# self.a_t = action
return action
def reset(self, obs):
pass
# self.s_t = obs
# self.random_process.reset_states()
def load_weights(self, output):
if output is None: return
self.actor.load_state_dict(
torch.load('{}/actor.pkl'.format(output))
)
self.critic.load_state_dict(
torch.load('{}/critic.pkl'.format(output))
)
def save_model(self, output):
torch.save(
self.actor.state_dict(),
'{}/actor.pkl'.format(output)
)
torch.save(
self.critic.state_dict(),
'{}/critic.pkl'.format(output)
)
def soft_update(self, target, source):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(
target_param.data * (1.0 - self.tau) + param.data * self.tau
)
def hard_update(self, target, source):
for target_param, param in zip(target.parameters(), source.parameters()):
target_param.data.copy_(param.data)
def sample_from_truncated_normal_distribution(self, lower, upper, mu, sigma, size=1):
from scipy import stats
return stats.truncnorm.rvs((lower-mu)/sigma, (upper-mu)/sigma, loc=mu, scale=sigma, size=size)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import absolute_import
from collections import deque, namedtuple
import warnings
import random
import numpy as np
# [reference] https://github.com/matthiasplappert/keras-rl/blob/master/rl/memory.py
# This is to be understood as a transition: Given `state0`, performing `action`
# yields `reward` and results in `state1`, which might be `terminal`.
Experience = namedtuple('Experience', 'state0, action, reward, state1, terminal1')
def sample_batch_indexes(low, high, size):
if high - low >= size:
# We have enough data. Draw without replacement, that is each index is unique in the
# batch. We cannot use `np.random.choice` here because it is horribly inefficient as
# the memory grows. See https://github.com/numpy/numpy/issues/2764 for a discussion.
# `random.sample` does the same thing (drawing without replacement) and is way faster.
r = range(low, high)
batch_idxs = random.sample(r, size)
else:
# Not enough data. Help ourselves with sampling from the range, but the same index
# can occur multiple times. This is not good and should be avoided by picking a
# large enough warm-up phase.
warnings.warn(
'Not enough entries to sample without replacement. '
'Consider increasing your warm-up phase to avoid oversampling!')
batch_idxs = np.random.random_integers(low, high - 1, size=size)
assert len(batch_idxs) == size
return batch_idxs
class RingBuffer(object):
def __init__(self, maxlen):
self.maxlen = maxlen
self.start = 0
self.length = 0
self.data = [None for _ in range(maxlen)]
def __len__(self):
return self.length
def __getitem__(self, idx):
if idx < 0 or idx >= self.length:
raise KeyError()
return self.data[(self.start + idx) % self.maxlen]
def append(self, v):
if self.length < self.maxlen:
# We have space, simply increase the length.
self.length += 1
elif self.length == self.maxlen:
# No space, "remove" the first item.
self.start = (self.start + 1) % self.maxlen
else:
# This should never happen.
raise RuntimeError()
self.data[(self.start + self.length - 1) % self.maxlen] = v
def zeroed_observation(observation):
if hasattr(observation, 'shape'):
return np.zeros(observation.shape)
elif hasattr(observation, '__iter__'):
out = []
for x in observation:
out.append(zeroed_observation(x))
return out
else:
return 0.
class Memory(object):
def __init__(self, window_length, ignore_episode_boundaries=False):
self.window_length = window_length
self.ignore_episode_boundaries = ignore_episode_boundaries
self.recent_observations = deque(maxlen=window_length)
self.recent_terminals = deque(maxlen=window_length)
def sample(self, batch_size, batch_idxs=None):
raise NotImplementedError()
def append(self, observation, action, reward, terminal, training=True):
self.recent_observations.append(observation)
self.recent_terminals.append(terminal)
def get_recent_state(self, current_observation):
# This code is slightly complicated by the fact that subsequent observations might be
# from different episodes. We ensure that an experience never spans multiple episodes.
# This is probably not that important in practice but it seems cleaner.
state = [current_observation]
idx = len(self.recent_observations) - 1
for offset in range(0, self.window_length - 1):
current_idx = idx - offset
current_terminal = self.recent_terminals[current_idx - 1] if current_idx - 1 >= 0 else False
if current_idx < 0 or (not self.ignore_episode_boundaries and current_terminal):
# The previously handled observation was terminal, don't add the current one.
# Otherwise we would leak into a different episode.
break
state.insert(0, self.recent_observations[current_idx])
while len(state) < self.window_length:
state.insert(0, zeroed_observation(state[0]))
return state
def get_config(self):
config = {
'window_length': self.window_length,
'ignore_episode_boundaries': self.ignore_episode_boundaries,
}
return config
class SequentialMemory(Memory):
def __init__(self, limit, **kwargs):
super(SequentialMemory, self).__init__(**kwargs)
self.limit = limit
# Do not use deque to implement the memory. This data structure may seem convenient but
# it is way too slow on random access. Instead, we use our own ring buffer implementation.
self.actions = RingBuffer(limit)
self.rewards = RingBuffer(limit)
self.terminals = RingBuffer(limit)
self.observations = RingBuffer(limit)
def sample(self, batch_size, batch_idxs=None):
if batch_idxs is None:
# Draw random indexes such that we have at least a single entry before each
# index.
batch_idxs = sample_batch_indexes(0, self.nb_entries - 1, size=batch_size)
batch_idxs = np.array(batch_idxs) + 1
assert np.min(batch_idxs) >= 1
assert np.max(batch_idxs) < self.nb_entries
assert len(batch_idxs) == batch_size
# Create experiences
experiences = []
for idx in batch_idxs:
terminal0 = self.terminals[idx - 2] if idx >= 2 else False
while terminal0:
# Skip this transition because the environment was reset here. Select a new, random
# transition and use this instead. This may cause the batch to contain the same
# transition twice.
idx = sample_batch_indexes(1, self.nb_entries, size=1)[0]
terminal0 = self.terminals[idx - 2] if idx >= 2 else False
assert 1 <= idx < self.nb_entries
# This code is slightly complicated by the fact that subsequent observations might be
# from different episodes. We ensure that an experience never spans multiple episodes.
# This is probably not that important in practice but it seems cleaner.
state0 = [self.observations[idx - 1]]
for offset in range(0, self.window_length - 1):
current_idx = idx - 2 - offset
current_terminal = self.terminals[current_idx - 1] if current_idx - 1 > 0 else False
if current_idx < 0 or (not self.ignore_episode_boundaries and current_terminal):
# The previously handled observation was terminal, don't add the current one.
# Otherwise we would leak into a different episode.
break
state0.insert(0, self.observations[current_idx])
while len(state0) < self.window_length:
state0.insert(0, zeroed_observation(state0[0]))
action = self.actions[idx - 1]
reward = self.rewards[idx - 1]
terminal1 = self.terminals[idx - 1]
# Okay, now we need to create the follow-up state. This is state0 shifted on timestep
# to the right. Again, we need to be careful to not include an observation from the next
# episode if the last state is terminal.
state1 = [np.copy(x) for x in state0[1:]]
state1.append(self.observations[idx])
assert len(state0) == self.window_length
assert len(state1) == len(state0)
experiences.append(Experience(state0=state0, action=action, reward=reward,
state1=state1, terminal1=terminal1))
assert len(experiences) == batch_size
return experiences
def sample_and_split(self, batch_size, batch_idxs=None):
experiences = self.sample(batch_size, batch_idxs)
state0_batch = []
reward_batch = []
action_batch = []
terminal1_batch = []
state1_batch = []
for e in experiences:
state0_batch.append(e.state0)
state1_batch.append(e.state1)
reward_batch.append(e.reward)
action_batch.append(e.action)
terminal1_batch.append(0. if e.terminal1 else 1.)
# Prepare and validate parameters.
state0_batch = np.array(state0_batch, 'double').reshape(batch_size, -1)
state1_batch = np.array(state1_batch, 'double').reshape(batch_size, -1)
terminal1_batch = np.array(terminal1_batch, 'double').reshape(batch_size, -1)
reward_batch = np.array(reward_batch, 'double').reshape(batch_size, -1)
action_batch = np.array(action_batch, 'double').reshape(batch_size, -1)
return state0_batch, action_batch, reward_batch, state1_batch, terminal1_batch
def append(self, observation, action, reward, terminal, training=True):
super(SequentialMemory, self).append(observation, action, reward, terminal, training=training)
# This needs to be understood as follows: in `observation`, take `action`, obtain `reward`
# and weather the next state is `terminal` or not.
if training:
self.observations.append(observation)
self.actions.append(action)
self.rewards.append(reward)
self.terminals.append(terminal)
@property
def nb_entries(self):
return len(self.observations)
def get_config(self):
config = super(SequentialMemory, self).get_config()
config['limit'] = self.limit
return config
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
# [reference] https://github.com/ShichenLiu/CondenseNet/blob/master/utils.py
def get_num_gen(gen):
return sum(1 for _ in gen)
def is_leaf(model):
return get_num_gen(model.children()) == 0
def get_layer_info(layer):
layer_str = str(layer)
type_name = layer_str[:layer_str.find('(')].strip()
return type_name
def get_layer_param(model):
import operator
import functools
return sum([functools.reduce(operator.mul, i.size(), 1) for i in model.parameters()])
count_ops = 0
count_params = 0
def measure_layer(layer, x):
global count_ops, count_params
delta_ops = 0
delta_params = 0
multi_add = 1
type_name = get_layer_info(layer)
# ops_conv
if type_name in ['Conv2d']:
out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) /
layer.stride[0] + 1)
out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) /
layer.stride[1] + 1)
delta_ops = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \
layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add
delta_params = get_layer_param(layer)
# ops_nonlinearity
elif type_name in ['ReLU']:
delta_ops = x.numel() / x.size(0)
delta_params = get_layer_param(layer)
# ops_pooling
elif type_name in ['AvgPool2d']:
in_w = x.size()[2]
kernel_ops = layer.kernel_size * layer.kernel_size
out_w = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1)
out_h = int((in_w + 2 * layer.padding - layer.kernel_size) / layer.stride + 1)
delta_ops = x.size()[1] * out_w * out_h * kernel_ops
delta_params = get_layer_param(layer)
elif type_name in ['AdaptiveAvgPool2d']:
delta_ops = x.size()[1] * x.size()[2] * x.size()[3]
delta_params = get_layer_param(layer)
# ops_linear
elif type_name in ['Linear']:
weight_ops = layer.weight.numel() * multi_add
bias_ops = layer.bias.numel()
delta_ops = weight_ops + bias_ops
delta_params = get_layer_param(layer)
# ops_nothing
elif type_name in ['BatchNorm2d', 'Dropout2d', 'DropChannel', 'Dropout']:
delta_params = get_layer_param(layer)
# unknown layer type
else:
delta_params = get_layer_param(layer)
count_ops += delta_ops
count_params += delta_params
return
def measure_model(model, H, W):
global count_ops, count_params
count_ops = 0
count_params = 0
data = torch.zeros(2, 3, H, W).cuda()
def should_measure(x):
return is_leaf(x)
def modify_forward(model):
for child in model.children():
if should_measure(child):
def new_forward(m):
def lambda_forward(x):
measure_layer(m, x)
return m.old_forward(x)
return lambda_forward
child.old_forward = child.forward
child.forward = new_forward(child)
else:
modify_forward(child)
def restore_forward(model):
for child in model.children():
# leaf node
if is_leaf(child) and hasattr(child, 'old_forward'):
child.forward = child.old_forward
child.old_forward = None
else:
restore_forward(child)
modify_forward(model)
model.forward(data)
restore_forward(model)
return count_ops, count_params
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import torch
class TextLogger(object):
"""Write log immediately to the disk"""
def __init__(self, filepath):
self.f = open(filepath, 'w')
self.fid = self.f.fileno()
self.filepath = filepath
def close(self):
self.f.close()
def write(self, content):
self.f.write(content)
self.f.flush()
os.fsync(self.fid)
def write_buf(self, content):
self.f.write(content)
def print_and_write(self, content):
print(content)
self.write(content+'\n')
def to_numpy(var):
use_cuda = torch.cuda.is_available()
return var.cpu().data.numpy() if use_cuda else var.data.numpy()
def to_tensor(ndarray, requires_grad=False): # return a float tensor by default
tensor = torch.from_numpy(ndarray).float() # by default does not require grad
if requires_grad:
tensor.requires_grad_()
return tensor.cuda() if torch.cuda.is_available() else tensor
def measure_layer_for_pruning(wrapper, x):
def get_layer_type(layer):
layer_str = str(layer)
return layer_str[:layer_str.find('(')].strip()
def get_layer_param(model):
import operator
import functools
return sum([functools.reduce(operator.mul, i.size(), 1) for i in model.parameters()])
multi_add = 1
layer = wrapper.module
type_name = get_layer_type(layer)
# ops_conv
if type_name in ['Conv2d']:
out_h = int((x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) /
layer.stride[0] + 1)
out_w = int((x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) /
layer.stride[1] + 1)
wrapper.flops = layer.in_channels * layer.out_channels * layer.kernel_size[0] * \
layer.kernel_size[1] * out_h * out_w / layer.groups * multi_add
wrapper.params = get_layer_param(layer)
# ops_linear
elif type_name in ['Linear']:
weight_ops = layer.weight.numel() * multi_add
bias_ops = layer.bias.numel()
wrapper.flops = weight_ops + bias_ops
wrapper.params = get_layer_param(layer)
return
def least_square_sklearn(X, Y):
from sklearn.linear_model import LinearRegression
reg = LinearRegression(fit_intercept=False)
reg.fit(X, Y)
return reg.coef_
def get_output_folder(parent_dir, env_name):
"""Return save folder.
Assumes folders in the parent_dir have suffix -run{run
number}. Finds the highest run number and sets the output folder
to that number + 1. This is just convenient so that if you run the
same script multiple times tensorboard can plot all of the results
on the same plots with different names.
Parameters
----------
parent_dir: str
Path of the directory containing all experiment runs.
Returns
-------
parent_dir/run_dir
Path to this run's save directory.
"""
os.makedirs(parent_dir, exist_ok=True)
experiment_id = 0
for folder_name in os.listdir(parent_dir):
if not os.path.isdir(os.path.join(parent_dir, folder_name)):
continue
try:
folder_name = int(folder_name.split('-run')[-1])
if folder_name > experiment_id:
experiment_id = folder_name
except:
pass
experiment_id += 1
parent_dir = os.path.join(parent_dir, env_name)
parent_dir = parent_dir + '-run{}'.format(experiment_id)
os.makedirs(parent_dir, exist_ok=True)
return parent_dir
# logging
def prRed(prt): print("\033[91m {}\033[00m" .format(prt))
def prGreen(prt): print("\033[92m {}\033[00m" .format(prt))
def prYellow(prt): print("\033[93m {}\033[00m" .format(prt))
def prLightPurple(prt): print("\033[94m {}\033[00m" .format(prt))
def prPurple(prt): print("\033[95m {}\033[00m" .format(prt))
def prCyan(prt): print("\033[96m {}\033[00m" .format(prt))
def prLightGray(prt): print("\033[97m {}\033[00m" .format(prt))
def prBlack(prt): print("\033[98m {}\033[00m" .format(prt))
...@@ -2,19 +2,40 @@ ...@@ -2,19 +2,40 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import logging import logging
import math
import numpy as np
import torch import torch
from .weight_masker import WeightMasker from .weight_masker import WeightMasker
__all__ = ['L1FilterPrunerMasker', 'L2FilterPrunerMasker', 'FPGMPrunerMasker', \ __all__ = ['L1FilterPrunerMasker', 'L2FilterPrunerMasker', 'FPGMPrunerMasker', \
'TaylorFOWeightFilterPrunerMasker', 'ActivationAPoZRankFilterPrunerMasker', \ 'TaylorFOWeightFilterPrunerMasker', 'ActivationAPoZRankFilterPrunerMasker', \
'ActivationMeanRankFilterPrunerMasker', 'SlimPrunerMasker'] 'ActivationMeanRankFilterPrunerMasker', 'SlimPrunerMasker', 'AMCWeightMasker']
logger = logging.getLogger('torch filter pruners') logger = logging.getLogger('torch filter pruners')
class StructuredWeightMasker(WeightMasker): class StructuredWeightMasker(WeightMasker):
""" """
A structured pruning masker base class that prunes convolutional layer filters. A structured pruning masker base class that prunes convolutional layer filters.
Parameters
----------
model: nn.Module
model to be pruned
pruner: Pruner
A Pruner instance used to prune the model
preserve_round: int
after pruning, preserve filters/channels round to `preserve_round`, for example:
for a Conv2d layer, output channel is 32, sparsity is 0.2, if preserve_round is
1 (no preserve round), then there will be int(32 * 0.2) = 6 filters pruned, and
32 - 6 = 26 filters are preserved. If preserve_round is 4, preserved filters will
be round up to 28 (which can be divided by 4) and only 4 filters are pruned.
""" """
def __init__(self, model, pruner, preserve_round=1):
self.model = model
self.pruner = pruner
self.preserve_round = preserve_round
def calc_mask(self, sparsity, wrapper, wrapper_idx=None): def calc_mask(self, sparsity, wrapper, wrapper_idx=None):
""" """
Calculate the mask of given layer. Calculate the mask of given layer.
...@@ -53,9 +74,16 @@ class StructuredWeightMasker(WeightMasker): ...@@ -53,9 +74,16 @@ class StructuredWeightMasker(WeightMasker):
mask_bias = None mask_bias = None
mask = {'weight_mask': mask_weight, 'bias_mask': mask_bias} mask = {'weight_mask': mask_weight, 'bias_mask': mask_bias}
filters = weight.size(0) num_total = weight.size(0)
num_prune = int(filters * sparsity) num_prune = int(num_total * sparsity)
if filters < 2 or num_prune < 1: if self.preserve_round > 1:
num_preserve = num_total - num_prune
num_preserve = int(math.ceil(num_preserve * 1. / self.preserve_round) * self.preserve_round)
if num_preserve > num_total:
num_preserve = int(math.floor(num_total * 1. / self.preserve_round) * self.preserve_round)
num_prune = num_total - num_preserve
if num_total < 2 or num_prune < 1:
return mask return mask
# weight*mask_weight: apply base mask for iterative pruning # weight*mask_weight: apply base mask for iterative pruning
return self.get_mask(mask, weight*mask_weight, num_prune, wrapper, wrapper_idx) return self.get_mask(mask, weight*mask_weight, num_prune, wrapper, wrapper_idx)
...@@ -365,3 +393,135 @@ class SlimPrunerMasker(WeightMasker): ...@@ -365,3 +393,135 @@ class SlimPrunerMasker(WeightMasker):
mask_bias = mask_weight.clone() mask_bias = mask_weight.clone()
mask = {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias.detach()} mask = {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias.detach()}
return mask return mask
def least_square_sklearn(X, Y):
from sklearn.linear_model import LinearRegression
reg = LinearRegression(fit_intercept=False)
reg.fit(X, Y)
return reg.coef_
class AMCWeightMasker(WeightMasker):
"""
Weight maskser class for AMC pruner. Currently, AMCPruner only supports pruning kernel
size 1x1 pointwise Conv2d layer. Before using this class to prune kernels, AMCPruner
collected input and output feature maps for each layer, the features maps are flattened
and save into wrapper.input_feat and wrapper.output_feat.
Parameters
----------
model: nn.Module
model to be pruned
pruner: Pruner
A Pruner instance used to prune the model
preserve_round: int
after pruning, preserve filters/channels round to `preserve_round`, for example:
for a Conv2d layer, output channel is 32, sparsity is 0.2, if preserve_round is
1 (no preserve round), then there will be int(32 * 0.2) = 6 filters pruned, and
32 - 6 = 26 filters are preserved. If preserve_round is 4, preserved filters will
be round up to 28 (which can be divided by 4) and only 4 filters are pruned.
"""
def __init__(self, model, pruner, preserve_round=1):
self.model = model
self.pruner = pruner
self.preserve_round = preserve_round
def calc_mask(self, sparsity, wrapper, wrapper_idx=None, preserve_idx=None):
"""
Calculate the mask of given layer.
Parameters
----------
sparsity: float
pruning ratio, preserved weight ratio is `1 - sparsity`
wrapper: PrunerModuleWrapper
layer wrapper of this layer
wrapper_idx: int
index of this wrapper in pruner's all wrappers
Returns
-------
dict
dictionary for storing masks, keys of the dict:
'weight_mask': weight mask tensor
'bias_mask': bias mask tensor (optional)
"""
msg = 'module type {} is not supported!'.format(wrapper.type)
assert wrapper.type in ['Conv2d', 'Linear'], msg
weight = wrapper.module.weight.data
bias = None
if hasattr(wrapper.module, 'bias') and wrapper.module.bias is not None:
bias = wrapper.module.bias.data
if wrapper.weight_mask is None:
mask_weight = torch.ones(weight.size()).type_as(weight).detach()
else:
mask_weight = wrapper.weight_mask.clone()
if bias is not None:
if wrapper.bias_mask is None:
mask_bias = torch.ones(bias.size()).type_as(bias).detach()
else:
mask_bias = wrapper.bias_mask.clone()
else:
mask_bias = None
mask = {'weight_mask': mask_weight, 'bias_mask': mask_bias}
num_total = weight.size(1)
num_prune = int(num_total * sparsity)
if self.preserve_round > 1:
num_preserve = num_total - num_prune
num_preserve = int(math.ceil(num_preserve * 1. / self.preserve_round) * self.preserve_round)
if num_preserve > num_total:
num_preserve = num_total
num_prune = num_total - num_preserve
if (num_total < 2 or num_prune < 1) and preserve_idx is None:
return mask
return self.get_mask(mask, weight, num_preserve, wrapper, wrapper_idx, preserve_idx)
def get_mask(self, base_mask, weight, num_preserve, wrapper, wrapper_idx, preserve_idx):
w = weight.data.cpu().numpy()
if wrapper.type == 'Linear':
w = w[:, :, None, None]
if preserve_idx is None:
importance = np.abs(w).sum((0, 2, 3))
sorted_idx = np.argsort(-importance) # sum magnitude along C_in, sort descend
d_prime = num_preserve
preserve_idx = sorted_idx[:d_prime] # to preserve index
else:
d_prime = len(preserve_idx)
assert len(preserve_idx) == d_prime
mask = np.zeros(w.shape[1], bool)
mask[preserve_idx] = True
# reconstruct, X, Y <= [N, C]
X, Y = wrapper.input_feat, wrapper.output_feat
masked_X = X[:, mask]
if w.shape[2] == 1: # 1x1 conv or fc
rec_weight = least_square_sklearn(X=masked_X, Y=Y)
rec_weight = rec_weight.reshape(-1, 1, 1, d_prime) # (C_out, K_h, K_w, C_in')
rec_weight = np.transpose(rec_weight, (0, 3, 1, 2)) # (C_out, C_in', K_h, K_w)
else:
raise NotImplementedError('Current code only supports 1x1 conv now!')
rec_weight_pad = np.zeros_like(w)
# pylint: disable=all
rec_weight_pad[:, mask, :, :] = rec_weight
rec_weight = rec_weight_pad
if wrapper.type == 'Linear':
rec_weight = rec_weight.squeeze()
assert len(rec_weight.shape) == 2
# now assign
wrapper.module.weight.data = torch.from_numpy(rec_weight).to(weight.device)
mask_weight = torch.zeros_like(weight)
if wrapper.type == 'Linear':
mask_weight[:, preserve_idx] = 1.
if base_mask['bias_mask'] is not None and wrapper.module.bias is not None:
mask_bias = torch.ones_like(wrapper.module.bias)
else:
mask_weight[:, preserve_idx, :, :] = 1.
mask_bias = None
return {'weight_mask': mask_weight.detach(), 'bias_mask': mask_bias}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment