Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
e1cb8faa
Unverified
Commit
e1cb8faa
authored
Aug 10, 2021
by
J-shang
Committed by
GitHub
Aug 10, 2021
Browse files
update exclude example (#4031)
parent
26c58399
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
4 deletions
+16
-4
examples/model_compress/pruning/basic_pruners_torch.py
examples/model_compress/pruning/basic_pruners_torch.py
+16
-4
No files found.
examples/model_compress/pruning/basic_pruners_torch.py
View file @
e1cb8faa
...
...
@@ -20,6 +20,7 @@ from torchvision import datasets, transforms
sys
.
path
.
append
(
'../models'
)
from
mnist.lenet
import
LeNet
from
cifar10.vgg
import
VGG
from
cifar10.resnet
import
ResNet18
from
nni.compression.pytorch.utils.counter
import
count_flops_params
...
...
@@ -119,6 +120,12 @@ def get_model_optimizer_scheduler(args, device, train_loader, test_loader, crite
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
scheduler
=
MultiStepLR
(
optimizer
,
milestones
=
[
int
(
args
.
pretrain_epochs
*
0.5
),
int
(
args
.
pretrain_epochs
*
0.75
)],
gamma
=
0.1
)
elif
args
.
model
==
'resnet18'
:
model
=
ResNet18
().
to
(
device
)
if
args
.
pretrained_model_dir
is
None
:
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
scheduler
=
MultiStepLR
(
optimizer
,
milestones
=
[
int
(
args
.
pretrain_epochs
*
0.5
),
int
(
args
.
pretrain_epochs
*
0.75
)],
gamma
=
0.1
)
else
:
raise
ValueError
(
"model not recognized"
)
...
...
@@ -253,14 +260,19 @@ def main(args):
'sparsity'
:
args
.
sparsity
,
'op_types'
:
[
'BatchNorm2d'
],
}]
el
se
:
el
if
args
.
model
==
'resnet18'
:
config_list
=
[{
'sparsity'
:
args
.
sparsity
,
'op_types'
:
[
'Conv2d'
],
'op_names'
:
[
'feature.0'
,
'feature.10'
,
'feature.24'
,
'feature.27'
,
'feature.30'
,
'feature.34'
,
'feature.37'
]
'op_types'
:
[
'Conv2d'
]
},
{
'exclude'
:
True
,
'op_names'
:
[
'feature.10'
]
'op_names'
:
[
'layer1.0.conv1'
,
'layer1.0.conv2'
]
}]
else
:
config_list
=
[{
'sparsity'
:
args
.
sparsity
,
'op_types'
:
[
'Conv2d'
],
'op_names'
:
[
'feature.0'
,
'feature.24'
,
'feature.27'
,
'feature.30'
,
'feature.34'
,
'feature.37'
]
}]
pruner
=
pruner_cls
(
model
,
config_list
,
**
kw_args
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment