Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
decb78ee
"src/vscode:/vscode.git/clone" did not exist on "03eef73c5be07a1e02c090eacd24f0a9f6aa850e"
Unverified
Commit
decb78ee
authored
Jan 07, 2021
by
lin bin
Committed by
GitHub
Jan 07, 2021
Browse files
fix flops counter bug in auto_pruners_torch.py (#3265)
parent
99bc4594
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
examples/model_compress/pruning/auto_pruners_torch.py
examples/model_compress/pruning/auto_pruners_torch.py
+4
-4
No files found.
examples/model_compress/pruning/auto_pruners_torch.py
View file @
decb78ee
...
...
@@ -186,7 +186,7 @@ def get_trained_model_optimizer(args, device, train_loader, val_loader, criterio
if
args
.
save_model
:
torch
.
save
(
state_dict
,
os
.
path
.
join
(
args
.
experiment_data_dir
,
'model_trained.pth'
))
print
(
'Model trained saved to %s'
,
args
.
experiment_data_dir
)
print
(
'Model trained saved to %s'
%
args
.
experiment_data_dir
)
return
model
,
optimizer
...
...
@@ -312,7 +312,7 @@ def main(args):
if
args
.
save_model
:
pruner
.
export_model
(
os
.
path
.
join
(
args
.
experiment_data_dir
,
'model_masked.pth'
),
os
.
path
.
join
(
args
.
experiment_data_dir
,
'mask.pth'
))
print
(
'Masked model saved to %s'
,
args
.
experiment_data_dir
)
print
(
'Masked model saved to %s'
%
args
.
experiment_data_dir
)
# model speed up
if
args
.
speed_up
:
...
...
@@ -336,7 +336,7 @@ def main(args):
result
[
'performance'
][
'speedup'
]
=
evaluation_result
torch
.
save
(
model
.
state_dict
(),
os
.
path
.
join
(
args
.
experiment_data_dir
,
'model_speed_up.pth'
))
print
(
'Speed up model saved to %s'
,
args
.
experiment_data_dir
)
print
(
'Speed up model saved to %s'
%
args
.
experiment_data_dir
)
flops
,
params
,
_
=
count_flops_params
(
model
,
get_input_size
(
args
.
dataset
))
result
[
'flops'
][
'speedup'
]
=
flops
result
[
'params'
][
'speedup'
]
=
params
...
...
@@ -367,7 +367,7 @@ def main(args):
torch
.
save
(
model
.
state_dict
(),
os
.
path
.
join
(
args
.
experiment_data_dir
,
'model_fine_tuned.pth'
))
print
(
'Evaluation result (fine tuned): %s'
%
best_acc
)
print
(
'Fined tuned model saved to %s'
,
args
.
experiment_data_dir
)
print
(
'Fined tuned model saved to %s'
%
args
.
experiment_data_dir
)
result
[
'performance'
][
'finetuned'
]
=
best_acc
with
open
(
os
.
path
.
join
(
args
.
experiment_data_dir
,
'result.json'
),
'w+'
)
as
f
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment