Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangsen
paddle_dbnet
Commits
76f40469
Commit
76f40469
authored
May 12, 2021
by
LDOUBLEV
Browse files
revert prune
parent
b79fee11
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
31 deletions
+17
-31
deploy/slim/prune/sensitivity_anal.py
deploy/slim/prune/sensitivity_anal.py
+17
-31
No files found.
deploy/slim/prune/sensitivity_anal.py
View file @
76f40469
...
@@ -24,14 +24,6 @@ sys.path.append(__dir__)
...
@@ -24,14 +24,6 @@ sys.path.append(__dir__)
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
,
'..'
,
'..'
))
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
,
'..'
,
'..'
))
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
,
'..'
,
'..'
,
'tools'
))
sys
.
path
.
append
(
os
.
path
.
join
(
__dir__
,
'..'
,
'..'
,
'..'
,
'tools'
))
import
json
import
cv2
import
paddle
from
paddle
import
fluid
import
paddleslim
as
slim
from
copy
import
deepcopy
from
tools
import
program
import
paddle
import
paddle
import
paddle.distributed
as
dist
import
paddle.distributed
as
dist
from
ppocr.data
import
build_dataloader
from
ppocr.data
import
build_dataloader
...
@@ -46,28 +38,14 @@ import tools.program as program
...
@@ -46,28 +38,14 @@ import tools.program as program
dist
.
get_world_size
()
dist
.
get_world_size
()
def
get_pruned_params
(
parameters
,
mode
=
"det"
):
def
get_pruned_params
(
parameters
):
if
mode
==
"det"
:
skip_prune_params
=
[
"conv2d_56.w_0"
,
"conv2d_54.w_0"
,
"conv2d_51.w_0"
,
"conv_last_weights"
,
"conv14_linear_weights"
,
"conv13_expand_weights"
,
"conv12_linear_weights"
,
"conv12_expand_weights"
,
"conv7_expand_weights"
,
"conv8_expand_weights"
,
"conv8_linear_weights"
,
"conv5_linear_weights"
,
"conv5_expand_weights"
,
"conv3_linear_weights"
]
skip_prune_params
=
skip_prune_params
+
[
'conv2d_53.w_0'
]
else
:
skip_prune_params
=
None
params
=
[]
params
=
[]
for
param
in
parameters
:
for
param
in
parameters
:
if
len
(
if
len
(
param
.
shape
param
.
shape
)
==
4
and
'depthwise'
not
in
param
.
name
and
'transpose'
not
in
param
.
name
and
"conv2d_57"
not
in
param
.
name
and
"conv2d_56"
not
in
param
.
name
:
)
==
4
and
'depthwise'
not
in
param
.
name
and
'transpose'
not
in
param
.
name
and
"conv2d_57"
not
in
param
.
name
and
"conv2d_56"
not
in
param
.
name
:
if
param
.
name
not
in
skip_prune_params
:
params
.
append
(
param
.
name
)
params
.
append
(
param
.
name
)
return
params
return
params
...
@@ -118,6 +96,11 @@ def main(config, device, logger, vdl_writer):
...
@@ -118,6 +96,11 @@ def main(config, device, logger, vdl_writer):
# load pretrain model
# load pretrain model
pre_best_model_dict
=
init_model
(
config
,
model
,
logger
,
optimizer
)
pre_best_model_dict
=
init_model
(
config
,
model
,
logger
,
optimizer
)
logger
.
info
(
'train dataloader has {} iters, valid dataloader has {} iters'
.
format
(
len
(
train_dataloader
),
len
(
valid_dataloader
)))
# build metric
eval_class
=
build_metric
(
config
[
'Metric'
])
logger
.
info
(
'train dataloader has {} iters, valid dataloader has {} iters'
.
logger
.
info
(
'train dataloader has {} iters, valid dataloader has {} iters'
.
format
(
len
(
train_dataloader
),
len
(
valid_dataloader
)))
format
(
len
(
train_dataloader
),
len
(
valid_dataloader
)))
...
@@ -127,20 +110,22 @@ def main(config, device, logger, vdl_writer):
...
@@ -127,20 +110,22 @@ def main(config, device, logger, vdl_writer):
logger
.
info
(
f
"metric['hmean']:
{
metric
[
'hmean'
]
}
"
)
logger
.
info
(
f
"metric['hmean']:
{
metric
[
'hmean'
]
}
"
)
return
metric
[
'hmean'
]
return
metric
[
'hmean'
]
pruner
.
sensitive
(
params_sensitive
=
pruner
.
sensitive
(
eval_func
=
eval_fn
,
eval_func
=
eval_fn
,
sen_file
=
"./sen.pickle"
,
sen_file
=
"./sen.pickle"
,
skip_vars
=
[
skip_vars
=
[
"conv2d_57.w_0"
,
"conv2d_transpose_2.w_0"
,
"conv2d_transpose_3.w_0"
"conv2d_57.w_0"
,
"conv2d_transpose_2.w_0"
,
"conv2d_transpose_3.w_0"
])
])
params
=
get_pruned_params
(
model
.
parameters
())
logger
.
info
(
ratios
=
{}
"The sensitivity analysis results of model parameters saved in sen.pickle"
# set the prune ratio is 0.2
)
for
param
in
params
:
# calculate pruned params's ratio
ratios
[
param
]
=
0.2
params_sensitive
=
pruner
.
_get_ratios_by_loss
(
params_sensitive
,
loss
=
0.02
)
for
key
in
params_sensitive
.
keys
():
logger
.
info
(
f
"
{
key
}
,
{
params_sensitive
[
key
]
}
"
)
plan
=
pruner
.
prune_vars
(
ratios
,
[
0
])
plan
=
pruner
.
prune_vars
(
params_sensitive
,
[
0
])
for
param
in
model
.
parameters
():
for
param
in
model
.
parameters
():
if
(
"weights"
in
param
.
name
and
"conv"
in
param
.
name
)
or
(
if
(
"weights"
in
param
.
name
and
"conv"
in
param
.
name
)
or
(
"w_0"
in
param
.
name
and
"conv2d"
in
param
.
name
):
"w_0"
in
param
.
name
and
"conv2d"
in
param
.
name
):
...
@@ -150,6 +135,7 @@ def main(config, device, logger, vdl_writer):
...
@@ -150,6 +135,7 @@ def main(config, device, logger, vdl_writer):
logger
.
info
(
f
"FLOPs after pruning:
{
flops
}
"
)
logger
.
info
(
f
"FLOPs after pruning:
{
flops
}
"
)
# start train
# start train
program
.
train
(
config
,
train_dataloader
,
valid_dataloader
,
device
,
model
,
program
.
train
(
config
,
train_dataloader
,
valid_dataloader
,
device
,
model
,
loss_class
,
optimizer
,
lr_scheduler
,
post_process_class
,
loss_class
,
optimizer
,
lr_scheduler
,
post_process_class
,
eval_class
,
pre_best_model_dict
,
logger
,
vdl_writer
)
eval_class
,
pre_best_model_dict
,
logger
,
vdl_writer
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment