Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
c717ce57
"...git@developer.sourcefind.cn:OpenDAS/opencompass.git" did not exist on "ad872a5dc2c397676b5ff3b2f4d577be02a7ebfe"
Unverified
Commit
c717ce57
authored
Jul 09, 2021
by
J-shang
Committed by
GitHub
Jul 09, 2021
Browse files
update pruning example (#3844)
parent
507595b0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
4 deletions
+9
-4
examples/model_compress/pruning/basic_pruners_torch.py
examples/model_compress/pruning/basic_pruners_torch.py
+9
-4
No files found.
examples/model_compress/pruning/basic_pruners_torch.py
View file @
c717ce57
...
@@ -143,7 +143,7 @@ def get_model_optimizer_scheduler(args, device, train_loader, test_loader, crite
...
@@ -143,7 +143,7 @@ def get_model_optimizer_scheduler(args, device, train_loader, test_loader, crite
model
.
load_state_dict
(
torch
.
load
(
args
.
pretrained_model_dir
))
model
.
load_state_dict
(
torch
.
load
(
args
.
pretrained_model_dir
))
best_acc
=
test
(
args
,
model
,
device
,
criterion
,
test_loader
)
best_acc
=
test
(
args
,
model
,
device
,
criterion
,
test_loader
)
# setup new opotimizer for
fine-t
uning
# setup new opotimizer for
pr
uning
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
scheduler
=
MultiStepLR
(
optimizer
,
milestones
=
[
int
(
args
.
pretrain_epochs
*
0.5
),
int
(
args
.
pretrain_epochs
*
0.75
)],
gamma
=
0.1
)
scheduler
=
MultiStepLR
(
optimizer
,
milestones
=
[
int
(
args
.
pretrain_epochs
*
0.5
),
int
(
args
.
pretrain_epochs
*
0.75
)],
gamma
=
0.1
)
...
@@ -192,10 +192,10 @@ def main(args):
...
@@ -192,10 +192,10 @@ def main(args):
# prepare model and data
# prepare model and data
train_loader
,
test_loader
,
criterion
=
get_data
(
args
.
dataset
,
args
.
data_dir
,
args
.
batch_size
,
args
.
test_batch_size
)
train_loader
,
test_loader
,
criterion
=
get_data
(
args
.
dataset
,
args
.
data_dir
,
args
.
batch_size
,
args
.
test_batch_size
)
model
,
optimizer
,
scheduler
=
get_model_optimizer_scheduler
(
args
,
device
,
train_loader
,
test_loader
,
criterion
)
model
,
optimizer
,
_
=
get_model_optimizer_scheduler
(
args
,
device
,
train_loader
,
test_loader
,
criterion
)
dummy_input
=
get_dummy_input
(
args
,
device
)
dummy_input
=
get_dummy_input
(
args
,
device
)
flops
,
params
,
results
=
count_flops_params
(
model
,
dummy_input
)
flops
,
params
,
_
=
count_flops_params
(
model
,
dummy_input
)
print
(
f
"FLOPs:
{
flops
}
, params:
{
params
}
"
)
print
(
f
"FLOPs:
{
flops
}
, params:
{
params
}
"
)
print
(
f
'start
{
args
.
pruner
}
pruning...'
)
print
(
f
'start
{
args
.
pruner
}
pruning...'
)
...
@@ -273,11 +273,16 @@ def main(args):
...
@@ -273,11 +273,16 @@ def main(args):
if
args
.
speed_up
:
if
args
.
speed_up
:
# Unwrap all modules to normal state
# Unwrap all modules to normal state
pruner
.
_unwrap_model
()
pruner
.
_unwrap_model
()
m_speedup
=
ModelSpeedup
(
model
,
dummy_input
,
mask_path
,
device
)
m_speedup
=
ModelSpeedup
(
model
,
dummy_input
,
mask_path
,
device
)
m_speedup
.
speedup_model
()
m_speedup
.
speedup_model
()
print
(
'start finetuning...'
)
print
(
'start finetuning...'
)
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
scheduler
=
MultiStepLR
(
optimizer
,
milestones
=
[
int
(
args
.
pretrain_epochs
*
0.5
),
int
(
args
.
pretrain_epochs
*
0.75
)],
gamma
=
0.1
)
best_top1
=
0
best_top1
=
0
save_path
=
os
.
path
.
join
(
args
.
experiment_data_dir
,
f
'finetuned.pth'
)
save_path
=
os
.
path
.
join
(
args
.
experiment_data_dir
,
f
'finetuned.pth'
)
for
epoch
in
range
(
args
.
fine_tune_epochs
):
for
epoch
in
range
(
args
.
fine_tune_epochs
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment