"docs/source/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "f6ec539457dad41c30789c9941c5905dee2ee5c0"
Unverified Commit e862d39e authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Refactor pruner examples (#2099)

parent c80f2cda
# Quick Start to Compress a Model
NNI provides very simple APIs for compressing a model. The compression includes pruning algorithms and quantization algorithms. The usage of them are the same, thus, here we use slim pruner as an example to show the usage. The complete code of this example can be found [here](https://github.com/microsoft/nni/blob/master/examples/model_compress/slim_torch_cifar10.py).
NNI provides very simple APIs for compressing a model. The compression includes pruning algorithms and quantization algorithms. The usage of them are the same, thus, here we use slim pruner as an example to show the usage.
## Write configuration
......@@ -34,6 +34,8 @@ After training, you get accuracy of the pruned model. You can export model weigh
pruner.export_model(model_path='pruned_vgg19_cifar10.pth', mask_path='mask_vgg19_cifar10.pth')
```
The complete code of model compression examples can be found [here](https://github.com/microsoft/nni/blob/master/examples/model_compress/model_prune_torch.py).
## Speed up the model
Masks do not provide real speedup of your model. The model should be speeded up based on the exported masks, thus, we provide an API to speed up your model as shown below. After invoking `apply_compression_results` on your model, your model becomes a smaller one with shorter inference latency.
......
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.compression.torch import ActivationAPoZRankFilterPruner
from models.cifar10.vgg import VGG
def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
acc = 100 * correct / len(test_loader.dataset)
print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, acc))
return acc
def main():
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=True, download=True,
transform=transforms.Compose([
transforms.Pad(4),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=200, shuffle=False)
model = VGG(depth=16)
model.to(device)
# Train the base VGG-16 model
print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 160, 0)
for epoch in range(160):
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
lr_scheduler.step(epoch)
torch.save(model.state_dict(), 'vgg16_cifar10.pth')
# Test base model accuracy
print('=' * 10 + 'Test on the original model' + '=' * 10)
model.load_state_dict(torch.load('vgg16_cifar10.pth'))
test(model, device, test_loader)
# top1 = 93.51%
# Pruning Configuration, in paper 'PRUNING FILTERS FOR EFFICIENT CONVNETS',
# Conv_1, Conv_8, Conv_9, Conv_10, Conv_11, Conv_12 are pruned with 50% sparsity, as 'VGG-16-pruned-A'
configure_list = [{
'sparsity': 0.5,
'op_types': ['default'],
'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37']
}]
# Prune model and test accuracy without fine tuning.
print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10)
pruner = ActivationAPoZRankFilterPruner(model, configure_list)
model = pruner.compress()
test(model, device, test_loader)
# top1 = 88.19%
# Fine tune the pruned model for 40 epochs and test accuracy
print('=' * 10 + 'Fine tuning' + '=' * 10)
optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
best_top1 = 0
for epoch in range(40):
pruner.update_epoch(epoch)
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer_finetune)
top1 = test(model, device, test_loader)
if top1 > best_top1:
best_top1 = top1
# Export the best model, 'model_path' stores state_dict of the pruned model,
# mask_path stores mask_dict of the pruned model
pruner.export_model(model_path='pruned_vgg16_cifar10.pth', mask_path='mask_vgg16_cifar10.pth')
# Test the exported model
print('=' * 10 + 'Test on the pruned model after fine tune' + '=' * 10)
new_model = VGG(depth=16)
new_model.to(device)
new_model.load_state_dict(torch.load('pruned_vgg16_cifar10.pth'))
test(new_model, device, test_loader)
# top1 = 93.53%
if __name__ == '__main__':
main()
import math
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.compression.torch import ActivationMeanRankFilterPruner
from models.cifar10.vgg import VGG
def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
acc = 100 * correct / len(test_loader.dataset)
print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, acc))
return acc
def main():
parser = argparse.ArgumentParser("multiple gpu with pruning")
parser.add_argument("--epochs", type=int, default=160)
parser.add_argument("--retrain", default=False, action="store_true")
parser.add_argument("--parallel", default=False, action="store_true")
args = parser.parse_args()
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=True, download=True,
transform=transforms.Compose([
transforms.Pad(4),
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('./data.cifar10', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])),
batch_size=200, shuffle=False)
model = VGG(depth=16)
model.to(device)
# Train the base VGG-16 model
if args.retrain:
print('=' * 10 + 'Train the unpruned base model' + '=' * 10)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 160, 0)
for epoch in range(args.epochs):
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
lr_scheduler.step(epoch)
torch.save(model.state_dict(), 'vgg16_cifar10.pth')
else:
assert os.path.isfile('vgg16_cifar10.pth'), "can not find checkpoint 'vgg16_cifar10.pth'"
model.load_state_dict(torch.load('vgg16_cifar10.pth'))
# Test base model accuracy
print('=' * 10 + 'Test on the original model' + '=' * 10)
test(model, device, test_loader)
# top1 = 93.51%
# Pruning Configuration, in paper 'PRUNING FILTERS FOR EFFICIENT CONVNETS',
# Conv_1, Conv_8, Conv_9, Conv_10, Conv_11, Conv_12 are pruned with 50% sparsity, as 'VGG-16-pruned-A'
configure_list = [{
'sparsity': 0.5,
'op_types': ['default'],
'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37']
}]
# Prune model and test accuracy without fine tuning.
print('=' * 10 + 'Test on the pruned model before fine tune' + '=' * 10)
pruner = ActivationMeanRankFilterPruner(model, configure_list)
model = pruner.compress()
if args.parallel:
if torch.cuda.device_count() > 1:
print("use {} gpus for pruning".format(torch.cuda.device_count()))
model = nn.DataParallel(model)
else:
print("only detect 1 gpu, fall back")
model.to(device)
test(model, device, test_loader)
# top1 = 88.19%
# Fine tune the pruned model for 40 epochs and test accuracy
print('=' * 10 + 'Fine tuning' + '=' * 10)
optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
best_top1 = 0
for epoch in range(40):
pruner.update_epoch(epoch)
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer_finetune)
top1 = test(model, device, test_loader)
if top1 > best_top1:
best_top1 = top1
# Export the best model, 'model_path' stores state_dict of the pruned model,
# mask_path stores mask_dict of the pruned model
pruner.export_model(model_path='pruned_vgg16_cifar10.pth', mask_path='mask_vgg16_cifar10.pth')
# Test the exported model
print('=' * 10 + 'Test on the pruned model after fine tune' + '=' * 10)
new_model = VGG(depth=16)
new_model.to(device)
new_model.load_state_dict(torch.load('pruned_vgg16_cifar10.pth'))
test(new_model, device, test_loader)
# top1 = 93.53%
if __name__ == '__main__':
main()
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from nni.compression.torch import FPGMPruner
class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4 * 4 * 50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def _get_conv_weight_sparsity(self, conv_layer):
num_zero_filters = (conv_layer.weight.data.sum((1, 2, 3)) == 0).sum()
num_filters = conv_layer.weight.data.size(0)
return num_zero_filters, num_filters, float(num_zero_filters)/num_filters
def print_conv_filter_sparsity(self):
if isinstance(self.conv1, nn.Conv2d):
conv1_data = self._get_conv_weight_sparsity(self.conv1)
conv2_data = self._get_conv_weight_sparsity(self.conv2)
else:
# self.conv1 is wrapped as PrunerModuleWrapper
conv1_data = self._get_conv_weight_sparsity(self.conv1.module)
conv2_data = self._get_conv_weight_sparsity(self.conv2.module)
print('conv1: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv1_data[0], conv1_data[1], conv1_data[2]))
print('conv2: num zero filters: {}, num filters: {}, sparsity: {:.4f}'.format(conv2_data[0], conv2_data[1], conv2_data[2]))
def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
if batch_idx % 100 == 0:
print('{:.2f}% Loss {:.4f}'.format(100 * batch_idx / len(train_loader), loss.item()))
if batch_idx == 0:
model.print_conv_filter_sparsity()
loss.backward()
optimizer.step()
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('Loss: {:.4f} Accuracy: {}%)\n'.format(
test_loss, 100 * correct / len(test_loader.dataset)))
def main():
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True, transform=trans),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=trans),
batch_size=1000, shuffle=True)
model = Mnist()
model.to(device)
model.print_conv_filter_sparsity()
configure_list = [{
'sparsity': 0.5,
'op_types': ['Conv2d']
}]
pruner = FPGMPruner(model, configure_list)
pruner.compress()
model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
for epoch in range(10):
pruner.update_epoch(epoch)
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
pruner.export_model('model.pth', 'mask.pth')
if __name__ == '__main__':
main()
from nni.compression.torch import AGP_Pruner
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
class Mnist(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)
self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)
self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)
self.fc2 = torch.nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, 100 * correct / len(test_loader.dataset)))
def main():
torch.manual_seed(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True, transform=trans),
batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=trans),
batch_size=1000, shuffle=True)
model = Mnist()
model = model.to(device)
'''you can change this to LevelPruner to implement it
pruner = LevelPruner(configure_list)
'''
configure_list = [{
'initial_sparsity': 0,
'final_sparsity': 0.8,
'start_epoch': 0,
'end_epoch': 10,
'frequency': 1,
'op_types': ['default']
}]
pruner = AGP_Pruner(model, configure_list)
model = pruner.compress()
model = model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
for epoch in range(10):
pruner.update_epoch(epoch)
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
pruner.export_model('model.pth', 'mask.pth', 'model.onnx', [1, 1, 28, 28], device)
if __name__ == '__main__':
main()
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from models.cifar10.vgg import VGG
import nni
from nni.compression.torch import LevelPruner, SlimPruner, FPGMPruner, L1FilterPruner, \
L2FilterPruner, AGP_Pruner, ActivationMeanRankFilterPruner, ActivationAPoZRankFilterPruner
prune_config = {
'level': {
'dataset_name': 'mnist',
'model_name': 'naive',
'pruner_class': LevelPruner,
'config_list': [{
'sparsity': 0.5,
'op_types': ['default'],
}]
},
'agp': {
'dataset_name': 'mnist',
'model_name': 'naive',
'pruner_class': AGP_Pruner,
'config_list': [{
'initial_sparsity': 0,
'final_sparsity': 0.8,
'start_epoch': 0,
'end_epoch': 10,
'frequency': 1,
'op_types': ['default']
}]
},
'slim': {
'dataset_name': 'cifar10',
'model_name': 'vgg19',
'pruner_class': SlimPruner,
'config_list': [{
'sparsity': 0.7,
'op_types': ['BatchNorm2d']
}]
},
'fpgm': {
'dataset_name': 'mnist',
'model_name': 'naive',
'pruner_class': FPGMPruner,
'config_list':[{
'sparsity': 0.5,
'op_types': ['Conv2d']
}]
},
'l1': {
'dataset_name': 'cifar10',
'model_name': 'vgg16',
'pruner_class': L1FilterPruner,
'config_list': [{
'sparsity': 0.5,
'op_types': ['default'],
'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37']
}]
},
'mean_activation': {
'dataset_name': 'cifar10',
'model_name': 'vgg16',
'pruner_class': ActivationMeanRankFilterPruner,
'configure_list': [{
'sparsity': 0.5,
'op_types': ['default'],
'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37']
}]
},
'apoz': {
'dataset_name': 'cifar10',
'model_name': 'vgg16',
'pruner_class': ActivationAPoZRankFilterPruner,
'config_list': [{
'sparsity': 0.5,
'op_types': ['default'],
'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37']
}]
}
}
def get_data_loaders(dataset_name='mnist', batch_size=128):
assert dataset_name in ['cifar10', 'mnist']
if dataset_name == 'cifar10':
ds_class = datasets.CIFAR10 if dataset_name == 'cifar10' else datasets.MNIST
MEAN, STD = (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
else:
ds_class = datasets.MNIST
MEAN, STD = (0.1307,), (0.3081,)
train_loader = DataLoader(
ds_class(
'./data', train=True, download=True,
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(MEAN, STD)])
),
batch_size=batch_size, shuffle=True
)
test_loader = DataLoader(
ds_class(
'./data', train=False, download=True,
transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize(MEAN, STD)])
),
batch_size=batch_size, shuffle=False
)
return train_loader, test_loader
class NaiveModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.bn1 = nn.BatchNorm2d(self.conv1.out_channels)
self.bn2 = nn.BatchNorm2d(self.conv2.out_channels)
self.fc1 = nn.Linear(4 * 4 * 50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.bn2(self.conv2(x)))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
def create_model(model_name='naive'):
assert model_name in ['naive', 'vgg16', 'vgg19']
if model_name == 'naive':
return NaiveModel()
elif model_name == 'vgg16':
return VGG(16)
else:
return VGG(19)
def create_pruner(model, pruner_name):
pruner_class = prune_config[pruner_name]['pruner_class']
config_list = prune_config[pruner_name]['config_list']
return pruner_class(model, config_list)
def train(model, device, train_loader, optimizer):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item()))
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.cross_entropy(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
acc = 100 * correct / len(test_loader.dataset)
print('Loss: {} Accuracy: {}%)\n'.format(
test_loss, acc))
return acc
def main(args):
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model_name = prune_config[args.pruner_name]['model_name']
dataset_name = prune_config[args.pruner_name]['dataset_name']
train_loader, test_loader = get_data_loaders(dataset_name, args.batch_size)
model = create_model(model_name).cuda()
if args.resume_from is not None and os.path.exists(args.resume_from):
print('loading checkpoint {} ...'.format(args.resume_from))
model.load_state_dict(torch.load(args.resume_from))
test(model, device, test_loader)
else:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
if args.multi_gpu and torch.cuda.device_count():
model = nn.DataParallel(model)
print('start training')
pretrain_model_path = os.path.join(
args.checkpoints_dir, 'pretrain_{}_{}_{}.pth'.format(model_name, dataset_name, args.pruner_name))
for epoch in range(args.pretrain_epochs):
train(model, device, train_loader, optimizer)
test(model, device, test_loader)
torch.save(model.state_dict(), pretrain_model_path)
print('start model pruning...')
if not os.path.exists(args.checkpoints_dir):
os.makedirs(args.checkpoints_dir)
model_path = os.path.join(args.checkpoints_dir, 'pruned_{}_{}_{}.pth'.format(model_name, dataset_name, args.pruner_name))
mask_path = os.path.join(args.checkpoints_dir, 'mask_{}_{}_{}.pth'.format(model_name, dataset_name, args.pruner_name))
# pruner needs to be initialized from a model not wrapped by DataParallel
if isinstance(model, nn.DataParallel):
model = model.module
optimizer_finetune = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)
best_top1 = 0
pruner = create_pruner(model, args.pruner_name)
model = pruner.compress()
if args.multi_gpu and torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
for epoch in range(args.prune_epochs):
pruner.update_epoch(epoch)
print('# Epoch {} #'.format(epoch))
train(model, device, train_loader, optimizer_finetune)
top1 = test(model, device, test_loader)
if top1 > best_top1:
best_top1 = top1
# Export the best model, 'model_path' stores state_dict of the pruned model,
# mask_path stores mask_dict of the pruned model
pruner.export_model(model_path=model_path, mask_path=mask_path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--pruner_name", type=str, default="level", help="pruner name")
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--pretrain_epochs", type=int, default=10, help="training epochs before model pruning")
parser.add_argument("--prune_epochs", type=int, default=10, help="training epochs for model pruning")
parser.add_argument("--checkpoints_dir", type=str, default="./checkpoints", help="checkpoints directory")
parser.add_argument("--resume_from", type=str, default=None, help="pretrained model weights")
parser.add_argument("--multi_gpu", action="store_true", help="Use multiple GPUs for training")
args = parser.parse_args()
main(args)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from nni.compression.torch import SlimPruner
class fc1(nn.Module):
def __init__(self, num_classes=10):
super(fc1, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.relu1 = nn.ReLU(inplace=True)
self.linear1 = nn.Linear(32*28*28, 300)
self.relu2 = nn.ReLU(inplace=True)
self.linear2 = nn.Linear(300, 100)
self.relu3 = nn.ReLU(inplace=True)
self.linear3 = nn.Linear(100, num_classes)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = torch.flatten(x,1)
x = self.linear1(x)
x = self.relu2(x)
x = self.linear2(x)
x = self.relu3(x)
x = self.linear3(x)
return x
def train(model, train_loader, optimizer, criterion, device):
model.train()
for imgs, targets in train_loader:
optimizer.zero_grad()
imgs, targets = imgs.to(device), targets.to(device)
output = model(imgs)
train_loss = criterion(output, targets)
train_loss.backward()
optimizer.step()
return train_loss.item()
def test(model, test_loader, criterion, device):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
correct += pred.eq(target.data.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
return accuracy
if __name__ == '__main__':
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
traindataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
testdataset = datasets.MNIST('./data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(traindataset, batch_size=60, shuffle=True, num_workers=10, drop_last=False)
test_loader = torch.utils.data.DataLoader(testdataset, batch_size=60, shuffle=False, num_workers=10, drop_last=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = fc1()
criterion = nn.CrossEntropyLoss()
configure_list = [{
'prune_iterations': 5,
'sparsity': 0.86,
'op_types': ['BatchNorm2d']
}]
pruner = SlimPruner(model, configure_list)
pruner.compress()
if torch.cuda.device_count()>1:
model = nn.DataParallel(model)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1.2e-3)
for name, par in model.named_parameters():
print(name)
# for i in pruner.get_prune_iterations():
# pruner.prune_iteration_start()
loss = 0
accuracy = 0
for epoch in range(10):
loss = train(model, train_loader, optimizer, criterion, device)
accuracy = test(model, test_loader, criterion, device)
print('current epoch: {0}, loss: {1}, accuracy: {2}'.format(epoch, loss, accuracy))
# print('prune iteration: {0}, loss: {1}, accuracy: {2}'.format(i, loss, accuracy))
pruner.export_model('model.pth', 'mask.pth')
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment