Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
e60e1838
Unverified
Commit
e60e1838
authored
Jun 29, 2020
by
chicm-ms
Committed by
GitHub
Jun 29, 2020
Browse files
Update lottery ticket example (#2559)
parent
b82bad0f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
62 additions
and
6 deletions
+62
-6
examples/model_compress/lottery_torch_mnist_fc.py
examples/model_compress/lottery_torch_mnist_fc.py
+58
-2
src/sdk/pynni/nni/compression/torch/pruning/lottery_ticket.py
...sdk/pynni/nni/compression/torch/pruning/lottery_ticket.py
+2
-2
test/scripts/model_compression.sh
test/scripts/model_compression.sh
+2
-2
No files found.
examples/model_compress/lottery_torch_mnist_fc.py
View file @
e60e1838
import
argparse
import
copy
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
...
...
@@ -53,6 +55,31 @@ def test(model, test_loader, criterion):
if
__name__
==
'__main__'
:
"""
THE LOTTERY TICKET HYPOTHESIS: FINDING SPARSE, TRAINABLE NEURAL NETWORKS (https://arxiv.org/pdf/1803.03635.pdf)
The Lottery Ticket Hypothesis. A randomly-initialized, dense neural network contains a subnetwork that is
initialized such that—when trained in isolation—it can match the test accuracy of the original network after
training for at most the same number of iterations.
Identifying winning tickets. We identify a winning ticket by training a network and pruning its
smallest-magnitude weights. The remaining, unpruned connections constitute the architecture of the
winning ticket. Unique to our work, each unpruned connection’s value is then reset to its initialization
from original network before it was trained. This forms our central experiment:
1. Randomly initialize a neural network f(x; θ0) (where θ0 ∼ Dθ).
2. Train the network for j iterations, arriving at parameters θj .
3. Prune p% of the parameters in θj , creating a mask m.
4. Reset the remaining parameters to their values in θ0, creating the winning ticket f(x; mθ0).
As described, this pruning approach is one-shot: the network is trained once, p% of weights are
pruned, and the surviving weights are reset. However, in this paper, we focus on iterative pruning,
which repeatedly trains, prunes, and resets the network over n rounds; each round prunes p**(1/n) % of
the weights that survive the previous round. Our results show that iterative pruning finds winning tickets
that match the accuracy of the original network at smaller sizes than does one-shot pruning.
"""
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--train_epochs"
,
type
=
int
,
default
=
10
,
help
=
"training epochs"
)
args
=
parser
.
parse_args
()
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
traindataset
=
datasets
.
MNIST
(
'./data'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
testdataset
=
datasets
.
MNIST
(
'./data'
,
train
=
False
,
transform
=
transform
)
...
...
@@ -63,6 +90,20 @@ if __name__ == '__main__':
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
1.2e-3
)
criterion
=
nn
.
CrossEntropyLoss
()
# Record the random intialized model weights
orig_state
=
copy
.
deepcopy
(
model
.
state_dict
())
# train the model to get unpruned metrics
for
epoch
in
range
(
args
.
train_epochs
):
train
(
model
,
train_loader
,
optimizer
,
criterion
)
orig_accuracy
=
test
(
model
,
test_loader
,
criterion
)
print
(
'unpruned model accuracy: {}'
.
format
(
orig_accuracy
))
# reset model weights and optimizer for pruning
model
.
load_state_dict
(
orig_state
)
optimizer
=
torch
.
optim
.
Adam
(
model
.
parameters
(),
lr
=
1.2e-3
)
# Prune the model to find a winning ticket
configure_list
=
[{
'prune_iterations'
:
5
,
'sparsity'
:
0.96
,
...
...
@@ -71,14 +112,29 @@ if __name__ == '__main__':
pruner
=
LotteryTicketPruner
(
model
,
configure_list
,
optimizer
)
pruner
.
compress
()
best_accuracy
=
0.
best_state_dict
=
None
for
i
in
pruner
.
get_prune_iterations
():
pruner
.
prune_iteration_start
()
loss
=
0
accuracy
=
0
for
epoch
in
range
(
10
):
for
epoch
in
range
(
args
.
train_epochs
):
loss
=
train
(
model
,
train_loader
,
optimizer
,
criterion
)
accuracy
=
test
(
model
,
test_loader
,
criterion
)
print
(
'current epoch: {0}, loss: {1}, accuracy: {2}'
.
format
(
epoch
,
loss
,
accuracy
))
if
accuracy
>
best_accuracy
:
best_accuracy
=
accuracy
# state dict of weights and masks
best_state_dict
=
copy
.
deepcopy
(
model
.
state_dict
())
print
(
'prune iteration: {0}, loss: {1}, accuracy: {2}'
.
format
(
i
,
loss
,
accuracy
))
pruner
.
export_model
(
'model.pth'
,
'mask.pth'
)
if
best_accuracy
>
orig_accuracy
:
# load weights and masks
pruner
.
bound_model
.
load_state_dict
(
best_state_dict
)
# reset weights to original untrained model and keep masks unchanged to export winning ticket
pruner
.
load_model_state_dict
(
orig_state
)
pruner
.
export_model
(
'model_winning_ticket.pth'
,
'mask_winning_ticket.pth'
)
print
(
'winning ticket has been saved: model_winning_ticket.pth, mask_winning_ticket.pth'
)
else
:
print
(
'winning ticket is not found in this run, you can run it again.'
)
src/sdk/pynni/nni/compression/torch/pruning/lottery_ticket.py
View file @
e60e1838
...
...
@@ -83,12 +83,12 @@ class LotteryTicketPruner(Pruner):
return
max
(
1
-
curr_keep_ratio
,
0
)
def
_calc_mask
(
self
,
wrapper
,
sparsity
):
weight
=
wrapper
.
weight
.
data
weight
=
wrapper
.
module
.
weight
.
data
if
self
.
curr_prune_iteration
==
0
:
mask
=
{
'weight_mask'
:
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)}
else
:
curr_sparsity
=
self
.
_calc_sparsity
(
sparsity
)
mask
=
self
.
masker
.
calc_mask
(
wrapper
,
curr_sparsity
)
mask
=
self
.
masker
.
calc_mask
(
sparsity
=
curr_sparsity
,
wrapper
=
wrapper
)
return
mask
def
calc_mask
(
self
,
wrapper
,
**
kwargs
):
...
...
test/scripts/model_compression.sh
View file @
e60e1838
...
...
@@ -28,8 +28,8 @@ python3 model_prune_torch.py --pruner_name agp --pretrain_epochs 1 --prune_epoch
echo
'testing mean_activation pruning'
python3 model_prune_torch.py
--pruner_name
mean_activation
--pretrain_epochs
1
--prune_epochs
1
#
echo "testing lottery ticket pruning..."
#
python3 lottery_torch_mnist_fc.py
echo
"testing lottery ticket pruning..."
python3 lottery_torch_mnist_fc.py
--train_epochs
1
echo
""
echo
"===========================Testing: quantizers==========================="
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment