"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "d6f4774c1c66a7e72951ab60e2241aff14e5d688"
Unverified Commit e60e1838 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Update lottery ticket example (#2559)

parent b82bad0f
import argparse
import copy
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -53,6 +55,31 @@ def test(model, test_loader, criterion): ...@@ -53,6 +55,31 @@ def test(model, test_loader, criterion):
if __name__ == '__main__': if __name__ == '__main__':
"""
THE LOTTERY TICKET HYPOTHESIS: FINDING SPARSE, TRAINABLE NEURAL NETWORKS (https://arxiv.org/pdf/1803.03635.pdf)
The Lottery Ticket Hypothesis. A randomly-initialized, dense neural network contains a subnetwork that is
initialized such that—when trained in isolation—it can match the test accuracy of the original network after
training for at most the same number of iterations.
Identifying winning tickets. We identify a winning ticket by training a network and pruning its
smallest-magnitude weights. The remaining, unpruned connections constitute the architecture of the
winning ticket. Unique to our work, each unpruned connection’s value is then reset to its initialization
from original network before it was trained. This forms our central experiment:
1. Randomly initialize a neural network f(x; θ0) (where θ0 ∼ Dθ).
2. Train the network for j iterations, arriving at parameters θj .
3. Prune p% of the parameters in θj , creating a mask m.
4. Reset the remaining parameters to their values in θ0, creating the winning ticket f(x; m θ0).
As described, this pruning approach is one-shot: the network is trained once, p% of weights are
pruned, and the surviving weights are reset. However, in this paper, we focus on iterative pruning,
which repeatedly trains, prunes, and resets the network over n rounds; each round prunes p**(1/n) % of
the weights that survive the previous round. Our results show that iterative pruning finds winning tickets
that match the accuracy of the original network at smaller sizes than does one-shot pruning.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--train_epochs", type=int, default=10, help="training epochs")
args = parser.parse_args()
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
traindataset = datasets.MNIST('./data', train=True, download=True, transform=transform) traindataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
testdataset = datasets.MNIST('./data', train=False, transform=transform) testdataset = datasets.MNIST('./data', train=False, transform=transform)
...@@ -63,6 +90,20 @@ if __name__ == '__main__': ...@@ -63,6 +90,20 @@ if __name__ == '__main__':
optimizer = torch.optim.Adam(model.parameters(), lr=1.2e-3) optimizer = torch.optim.Adam(model.parameters(), lr=1.2e-3)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
# Record the random intialized model weights
orig_state = copy.deepcopy(model.state_dict())
# train the model to get unpruned metrics
for epoch in range(args.train_epochs):
train(model, train_loader, optimizer, criterion)
orig_accuracy = test(model, test_loader, criterion)
print('unpruned model accuracy: {}'.format(orig_accuracy))
# reset model weights and optimizer for pruning
model.load_state_dict(orig_state)
optimizer = torch.optim.Adam(model.parameters(), lr=1.2e-3)
# Prune the model to find a winning ticket
configure_list = [{ configure_list = [{
'prune_iterations': 5, 'prune_iterations': 5,
'sparsity': 0.96, 'sparsity': 0.96,
...@@ -71,14 +112,29 @@ if __name__ == '__main__': ...@@ -71,14 +112,29 @@ if __name__ == '__main__':
pruner = LotteryTicketPruner(model, configure_list, optimizer) pruner = LotteryTicketPruner(model, configure_list, optimizer)
pruner.compress() pruner.compress()
best_accuracy = 0.
best_state_dict = None
for i in pruner.get_prune_iterations(): for i in pruner.get_prune_iterations():
pruner.prune_iteration_start() pruner.prune_iteration_start()
loss = 0 loss = 0
accuracy = 0 accuracy = 0
for epoch in range(10): for epoch in range(args.train_epochs):
loss = train(model, train_loader, optimizer, criterion) loss = train(model, train_loader, optimizer, criterion)
accuracy = test(model, test_loader, criterion) accuracy = test(model, test_loader, criterion)
print('current epoch: {0}, loss: {1}, accuracy: {2}'.format(epoch, loss, accuracy)) print('current epoch: {0}, loss: {1}, accuracy: {2}'.format(epoch, loss, accuracy))
if accuracy > best_accuracy:
best_accuracy = accuracy
# state dict of weights and masks
best_state_dict = copy.deepcopy(model.state_dict())
print('prune iteration: {0}, loss: {1}, accuracy: {2}'.format(i, loss, accuracy)) print('prune iteration: {0}, loss: {1}, accuracy: {2}'.format(i, loss, accuracy))
pruner.export_model('model.pth', 'mask.pth')
if best_accuracy > orig_accuracy:
# load weights and masks
pruner.bound_model.load_state_dict(best_state_dict)
# reset weights to original untrained model and keep masks unchanged to export winning ticket
pruner.load_model_state_dict(orig_state)
pruner.export_model('model_winning_ticket.pth', 'mask_winning_ticket.pth')
print('winning ticket has been saved: model_winning_ticket.pth, mask_winning_ticket.pth')
else:
print('winning ticket is not found in this run, you can run it again.')
...@@ -83,12 +83,12 @@ class LotteryTicketPruner(Pruner): ...@@ -83,12 +83,12 @@ class LotteryTicketPruner(Pruner):
return max(1 - curr_keep_ratio, 0) return max(1 - curr_keep_ratio, 0)
def _calc_mask(self, wrapper, sparsity): def _calc_mask(self, wrapper, sparsity):
weight = wrapper.weight.data weight = wrapper.module.weight.data
if self.curr_prune_iteration == 0: if self.curr_prune_iteration == 0:
mask = {'weight_mask': torch.ones(weight.shape).type_as(weight)} mask = {'weight_mask': torch.ones(weight.shape).type_as(weight)}
else: else:
curr_sparsity = self._calc_sparsity(sparsity) curr_sparsity = self._calc_sparsity(sparsity)
mask = self.masker.calc_mask(wrapper, curr_sparsity) mask = self.masker.calc_mask(sparsity=curr_sparsity, wrapper=wrapper)
return mask return mask
def calc_mask(self, wrapper, **kwargs): def calc_mask(self, wrapper, **kwargs):
......
...@@ -28,8 +28,8 @@ python3 model_prune_torch.py --pruner_name agp --pretrain_epochs 1 --prune_epoch ...@@ -28,8 +28,8 @@ python3 model_prune_torch.py --pruner_name agp --pretrain_epochs 1 --prune_epoch
echo 'testing mean_activation pruning' echo 'testing mean_activation pruning'
python3 model_prune_torch.py --pruner_name mean_activation --pretrain_epochs 1 --prune_epochs 1 python3 model_prune_torch.py --pruner_name mean_activation --pretrain_epochs 1 --prune_epochs 1
#echo "testing lottery ticket pruning..." echo "testing lottery ticket pruning..."
#python3 lottery_torch_mnist_fc.py python3 lottery_torch_mnist_fc.py --train_epochs 1
echo "" echo ""
echo "===========================Testing: quantizers===========================" echo "===========================Testing: quantizers==========================="
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment