Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
6e643b00
Unverified
Commit
6e643b00
authored
Dec 28, 2021
by
J-shang
Committed by
GitHub
Dec 28, 2021
Browse files
fix pruning examples & pruner memory usage optimize (#4412)
parent
f46f0cf4
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
94 additions
and
72 deletions
+94
-72
examples/model_compress/pruning/v2/activation_pruning_torch.py
...les/model_compress/pruning/v2/activation_pruning_torch.py
+6
-6
examples/model_compress/pruning/v2/admm_pruning_torch.py
examples/model_compress/pruning/v2/admm_pruning_torch.py
+4
-4
examples/model_compress/pruning/v2/fpgm_pruning_torch.py
examples/model_compress/pruning/v2/fpgm_pruning_torch.py
+4
-4
examples/model_compress/pruning/v2/level_pruning_torch.py
examples/model_compress/pruning/v2/level_pruning_torch.py
+4
-4
examples/model_compress/pruning/v2/norm_pruning_torch.py
examples/model_compress/pruning/v2/norm_pruning_torch.py
+4
-4
examples/model_compress/pruning/v2/slim_pruning_torch.py
examples/model_compress/pruning/v2/slim_pruning_torch.py
+4
-4
examples/model_compress/pruning/v2/taylorfo_pruning_torch.py
examples/model_compress/pruning/v2/taylorfo_pruning_torch.py
+5
-5
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
...algorithms/compression/v2/pytorch/pruning/basic_pruner.py
+29
-4
nni/algorithms/compression/v2/pytorch/pruning/tools/metrics_calculator.py
...ompression/v2/pytorch/pruning/tools/metrics_calculator.py
+23
-26
test/ut/compression/v2/test_pruning_tools_torch.py
test/ut/compression/v2/test_pruning_tools_torch.py
+11
-11
No files found.
examples/model_compress/pruning/v2/activation_pruning_torch.py
View file @
6e643b00
...
@@ -72,9 +72,9 @@ def evaluator(model):
...
@@ -72,9 +72,9 @@ def evaluator(model):
print
(
'Accuracy: {}%
\n
'
.
format
(
acc
))
print
(
'Accuracy: {}%
\n
'
.
format
(
acc
))
return
acc
return
acc
def
optimizer_scheduler_generator
(
model
,
_lr
=
0.1
,
_momentum
=
0.9
,
_weight_decay
=
5e-4
):
def
optimizer_scheduler_generator
(
model
,
_lr
=
0.1
,
_momentum
=
0.9
,
_weight_decay
=
5e-4
,
total_epoch
=
160
):
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
_lr
,
momentum
=
_momentum
,
weight_decay
=
_weight_decay
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
_lr
,
momentum
=
_momentum
,
weight_decay
=
_weight_decay
)
scheduler
=
MultiStepLR
(
optimizer
,
milestones
=
[
int
(
args
.
pretrain
_epoch
s
*
0.5
),
int
(
args
.
pretrain
_epoch
s
*
0.75
)],
gamma
=
0.1
)
scheduler
=
MultiStepLR
(
optimizer
,
milestones
=
[
int
(
total
_epoch
*
0.5
),
int
(
total
_epoch
*
0.75
)],
gamma
=
0.1
)
return
optimizer
,
scheduler
return
optimizer
,
scheduler
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
@@ -90,7 +90,7 @@ if __name__ == '__main__':
...
@@ -90,7 +90,7 @@ if __name__ == '__main__':
print
(
'
\n
'
+
'='
*
50
+
' START TO TRAIN THE MODEL '
+
'='
*
50
)
print
(
'
\n
'
+
'='
*
50
+
' START TO TRAIN THE MODEL '
+
'='
*
50
)
model
=
VGG
().
to
(
device
)
model
=
VGG
().
to
(
device
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
,
total_epoch
=
args
.
pretrain_epochs
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
pre_best_acc
=
0.0
pre_best_acc
=
0.0
best_state_dict
=
None
best_state_dict
=
None
...
@@ -117,9 +117,9 @@ if __name__ == '__main__':
...
@@ -117,9 +117,9 @@ if __name__ == '__main__':
# make sure you have used nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize
# make sure you have used nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize
traced_optimizer
=
trace_parameters
(
torch
.
optim
.
SGD
)(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
traced_optimizer
=
trace_parameters
(
torch
.
optim
.
SGD
)(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
if
'apoz'
in
args
.
pruner
:
if
'apoz'
in
args
.
pruner
:
pruner
=
ActivationAPoZRankPruner
(
model
,
config_list
,
trainer
,
traced_optimizer
,
criterion
,
training_batches
=
1
)
pruner
=
ActivationAPoZRankPruner
(
model
,
config_list
,
trainer
,
traced_optimizer
,
criterion
,
training_batches
=
20
)
else
:
else
:
pruner
=
ActivationMeanRankPruner
(
model
,
config_list
,
trainer
,
traced_optimizer
,
criterion
,
training_batches
=
1
)
pruner
=
ActivationMeanRankPruner
(
model
,
config_list
,
trainer
,
traced_optimizer
,
criterion
,
training_batches
=
20
)
_
,
masks
=
pruner
.
compress
()
_
,
masks
=
pruner
.
compress
()
pruner
.
show_pruned_weights
()
pruner
.
show_pruned_weights
()
pruner
.
_unwrap_model
()
pruner
.
_unwrap_model
()
...
@@ -129,7 +129,7 @@ if __name__ == '__main__':
...
@@ -129,7 +129,7 @@ if __name__ == '__main__':
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
print
(
'
\n
'
+
'='
*
50
+
' START TO FINE TUNE THE MODEL '
+
'='
*
50
)
print
(
'
\n
'
+
'='
*
50
+
' START TO FINE TUNE THE MODEL '
+
'='
*
50
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
,
_lr
=
0.01
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
,
_lr
=
0.01
,
total_epoch
=
args
.
fine_tune_epochs
)
best_acc
=
0.0
best_acc
=
0.0
g_epoch
=
0
g_epoch
=
0
...
...
examples/model_compress/pruning/v2/admm_pruning_torch.py
View file @
6e643b00
...
@@ -71,9 +71,9 @@ def evaluator(model):
...
@@ -71,9 +71,9 @@ def evaluator(model):
print
(
'Accuracy: {}%
\n
'
.
format
(
acc
))
print
(
'Accuracy: {}%
\n
'
.
format
(
acc
))
return
acc
return
acc
def
optimizer_scheduler_generator
(
model
,
_lr
=
0.1
,
_momentum
=
0.9
,
_weight_decay
=
5e-4
):
def
optimizer_scheduler_generator
(
model
,
_lr
=
0.1
,
_momentum
=
0.9
,
_weight_decay
=
5e-4
,
total_epoch
=
160
):
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
_lr
,
momentum
=
_momentum
,
weight_decay
=
_weight_decay
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
_lr
,
momentum
=
_momentum
,
weight_decay
=
_weight_decay
)
scheduler
=
MultiStepLR
(
optimizer
,
milestones
=
[
int
(
args
.
pretrain
_epoch
s
*
0.5
),
int
(
args
.
pretrain
_epoch
s
*
0.75
)],
gamma
=
0.1
)
scheduler
=
MultiStepLR
(
optimizer
,
milestones
=
[
int
(
total
_epoch
*
0.5
),
int
(
total
_epoch
*
0.75
)],
gamma
=
0.1
)
return
optimizer
,
scheduler
return
optimizer
,
scheduler
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
@@ -86,7 +86,7 @@ if __name__ == '__main__':
...
@@ -86,7 +86,7 @@ if __name__ == '__main__':
print
(
'
\n
'
+
'='
*
50
+
' START TO TRAIN THE MODEL '
+
'='
*
50
)
print
(
'
\n
'
+
'='
*
50
+
' START TO TRAIN THE MODEL '
+
'='
*
50
)
model
=
VGG
().
to
(
device
)
model
=
VGG
().
to
(
device
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
,
total_epoch
=
args
.
pretrain_epochs
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
pre_best_acc
=
0.0
pre_best_acc
=
0.0
best_state_dict
=
None
best_state_dict
=
None
...
@@ -125,7 +125,7 @@ if __name__ == '__main__':
...
@@ -125,7 +125,7 @@ if __name__ == '__main__':
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
print
(
'
\n
'
+
'='
*
50
+
' START TO FINE TUNE THE MODEL '
+
'='
*
50
)
print
(
'
\n
'
+
'='
*
50
+
' START TO FINE TUNE THE MODEL '
+
'='
*
50
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
,
_lr
=
0.01
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
,
_lr
=
0.01
,
total_epoch
=
args
.
fine_tune_epochs
)
best_acc
=
0.0
best_acc
=
0.0
g_epoch
=
0
g_epoch
=
0
...
...
examples/model_compress/pruning/v2/fpgm_pruning_torch.py
View file @
6e643b00
...
@@ -71,9 +71,9 @@ def evaluator(model):
...
@@ -71,9 +71,9 @@ def evaluator(model):
print
(
'Accuracy: {}%
\n
'
.
format
(
acc
))
print
(
'Accuracy: {}%
\n
'
.
format
(
acc
))
return
acc
return
acc
def
optimizer_scheduler_generator
(
model
,
_lr
=
0.1
,
_momentum
=
0.9
,
_weight_decay
=
5e-4
):
def
optimizer_scheduler_generator
(
model
,
_lr
=
0.1
,
_momentum
=
0.9
,
_weight_decay
=
5e-4
,
total_epoch
=
160
):
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
_lr
,
momentum
=
_momentum
,
weight_decay
=
_weight_decay
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
_lr
,
momentum
=
_momentum
,
weight_decay
=
_weight_decay
)
scheduler
=
MultiStepLR
(
optimizer
,
milestones
=
[
int
(
args
.
pretrain
_epoch
s
*
0.5
),
int
(
args
.
pretrain
_epoch
s
*
0.75
)],
gamma
=
0.1
)
scheduler
=
MultiStepLR
(
optimizer
,
milestones
=
[
int
(
total
_epoch
*
0.5
),
int
(
total
_epoch
*
0.75
)],
gamma
=
0.1
)
return
optimizer
,
scheduler
return
optimizer
,
scheduler
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
@@ -86,7 +86,7 @@ if __name__ == '__main__':
...
@@ -86,7 +86,7 @@ if __name__ == '__main__':
print
(
'
\n
'
+
'='
*
50
+
' START TO TRAIN THE MODEL '
+
'='
*
50
)
print
(
'
\n
'
+
'='
*
50
+
' START TO TRAIN THE MODEL '
+
'='
*
50
)
model
=
VGG
().
to
(
device
)
model
=
VGG
().
to
(
device
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
,
total_epoch
=
args
.
pretrain_epochs
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
pre_best_acc
=
0.0
pre_best_acc
=
0.0
best_state_dict
=
None
best_state_dict
=
None
...
@@ -119,7 +119,7 @@ if __name__ == '__main__':
...
@@ -119,7 +119,7 @@ if __name__ == '__main__':
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
print
(
'
\n
'
+
'='
*
50
+
' START TO FINE TUNE THE MODEL '
+
'='
*
50
)
print
(
'
\n
'
+
'='
*
50
+
' START TO FINE TUNE THE MODEL '
+
'='
*
50
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
,
_lr
=
0.01
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
,
_lr
=
0.01
,
total_epoch
=
args
.
fine_tune_epochs
)
best_acc
=
0.0
best_acc
=
0.0
for
i
in
range
(
args
.
fine_tune_epochs
):
for
i
in
range
(
args
.
fine_tune_epochs
):
...
...
examples/model_compress/pruning/v2/level_pruning_torch.py
View file @
6e643b00
...
@@ -70,9 +70,9 @@ def evaluator(model):
...
@@ -70,9 +70,9 @@ def evaluator(model):
print
(
'Accuracy: {}%
\n
'
.
format
(
acc
))
print
(
'Accuracy: {}%
\n
'
.
format
(
acc
))
return
acc
return
acc
def
optimizer_scheduler_generator
(
model
,
_lr
=
0.1
,
_momentum
=
0.9
,
_weight_decay
=
5e-4
):
def
optimizer_scheduler_generator
(
model
,
_lr
=
0.1
,
_momentum
=
0.9
,
_weight_decay
=
5e-4
,
total_epoch
=
160
):
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
_lr
,
momentum
=
_momentum
,
weight_decay
=
_weight_decay
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
_lr
,
momentum
=
_momentum
,
weight_decay
=
_weight_decay
)
scheduler
=
MultiStepLR
(
optimizer
,
milestones
=
[
int
(
args
.
pretrain
_epoch
s
*
0.5
),
int
(
args
.
pretrain
_epoch
s
*
0.75
)],
gamma
=
0.1
)
scheduler
=
MultiStepLR
(
optimizer
,
milestones
=
[
int
(
total
_epoch
*
0.5
),
int
(
total
_epoch
*
0.75
)],
gamma
=
0.1
)
return
optimizer
,
scheduler
return
optimizer
,
scheduler
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
@@ -85,7 +85,7 @@ if __name__ == '__main__':
...
@@ -85,7 +85,7 @@ if __name__ == '__main__':
print
(
'
\n
'
+
'='
*
50
+
' START TO TRAIN THE MODEL '
+
'='
*
50
)
print
(
'
\n
'
+
'='
*
50
+
' START TO TRAIN THE MODEL '
+
'='
*
50
)
model
=
VGG
().
to
(
device
)
model
=
VGG
().
to
(
device
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
,
total_epoch
=
args
.
pretrain_epochs
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
pre_best_acc
=
0.0
pre_best_acc
=
0.0
best_state_dict
=
None
best_state_dict
=
None
...
@@ -117,7 +117,7 @@ if __name__ == '__main__':
...
@@ -117,7 +117,7 @@ if __name__ == '__main__':
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
print
(
'
\n
'
+
'='
*
50
+
' START TO FINE TUNE THE MODEL '
+
'='
*
50
)
print
(
'
\n
'
+
'='
*
50
+
' START TO FINE TUNE THE MODEL '
+
'='
*
50
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
,
_lr
=
0.01
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
,
_lr
=
0.01
,
total_epoch
=
args
.
fine_tune_epochs
)
best_acc
=
0.0
best_acc
=
0.0
g_epoch
=
0
g_epoch
=
0
...
...
examples/model_compress/pruning/v2/norm_pruning_torch.py
View file @
6e643b00
...
@@ -71,9 +71,9 @@ def evaluator(model):
...
@@ -71,9 +71,9 @@ def evaluator(model):
print
(
'Accuracy: {}%
\n
'
.
format
(
acc
))
print
(
'Accuracy: {}%
\n
'
.
format
(
acc
))
return
acc
return
acc
def
optimizer_scheduler_generator
(
model
,
_lr
=
0.1
,
_momentum
=
0.9
,
_weight_decay
=
5e-4
):
def
optimizer_scheduler_generator
(
model
,
_lr
=
0.1
,
_momentum
=
0.9
,
_weight_decay
=
5e-4
,
total_epoch
=
160
):
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
_lr
,
momentum
=
_momentum
,
weight_decay
=
_weight_decay
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
_lr
,
momentum
=
_momentum
,
weight_decay
=
_weight_decay
)
scheduler
=
MultiStepLR
(
optimizer
,
milestones
=
[
int
(
args
.
pretrain
_epoch
s
*
0.5
),
int
(
args
.
pretrain
_epoch
s
*
0.75
)],
gamma
=
0.1
)
scheduler
=
MultiStepLR
(
optimizer
,
milestones
=
[
int
(
total
_epoch
*
0.5
),
int
(
total
_epoch
*
0.75
)],
gamma
=
0.1
)
return
optimizer
,
scheduler
return
optimizer
,
scheduler
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
@@ -89,7 +89,7 @@ if __name__ == '__main__':
...
@@ -89,7 +89,7 @@ if __name__ == '__main__':
print
(
'
\n
'
+
'='
*
50
+
' START TO TRAIN THE MODEL '
+
'='
*
50
)
print
(
'
\n
'
+
'='
*
50
+
' START TO TRAIN THE MODEL '
+
'='
*
50
)
model
=
VGG
().
to
(
device
)
model
=
VGG
().
to
(
device
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
,
total_epoch
=
args
.
pretrain_epochs
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
pre_best_acc
=
0.0
pre_best_acc
=
0.0
best_state_dict
=
None
best_state_dict
=
None
...
@@ -125,7 +125,7 @@ if __name__ == '__main__':
...
@@ -125,7 +125,7 @@ if __name__ == '__main__':
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
print
(
'
\n
'
+
'='
*
50
+
' START TO FINE TUNE THE MODEL '
+
'='
*
50
)
print
(
'
\n
'
+
'='
*
50
+
' START TO FINE TUNE THE MODEL '
+
'='
*
50
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
,
_lr
=
0.01
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
,
_lr
=
0.01
,
total_epoch
=
args
.
fine_tune_epochs
)
best_acc
=
0.0
best_acc
=
0.0
for
i
in
range
(
args
.
fine_tune_epochs
):
for
i
in
range
(
args
.
fine_tune_epochs
):
...
...
examples/model_compress/pruning/v2/slim_pruning_torch.py
View file @
6e643b00
...
@@ -72,9 +72,9 @@ def evaluator(model):
...
@@ -72,9 +72,9 @@ def evaluator(model):
print
(
'Accuracy: {}%
\n
'
.
format
(
acc
))
print
(
'Accuracy: {}%
\n
'
.
format
(
acc
))
return
acc
return
acc
def
optimizer_scheduler_generator
(
model
,
_lr
=
0.1
,
_momentum
=
0.9
,
_weight_decay
=
5e-4
):
def
optimizer_scheduler_generator
(
model
,
_lr
=
0.1
,
_momentum
=
0.9
,
_weight_decay
=
5e-4
,
total_epoch
=
160
):
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
_lr
,
momentum
=
_momentum
,
weight_decay
=
_weight_decay
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
_lr
,
momentum
=
_momentum
,
weight_decay
=
_weight_decay
)
scheduler
=
MultiStepLR
(
optimizer
,
milestones
=
[
int
(
args
.
pretrain
_epoch
s
*
0.5
),
int
(
args
.
pretrain
_epoch
s
*
0.75
)],
gamma
=
0.1
)
scheduler
=
MultiStepLR
(
optimizer
,
milestones
=
[
int
(
total
_epoch
*
0.5
),
int
(
total
_epoch
*
0.75
)],
gamma
=
0.1
)
return
optimizer
,
scheduler
return
optimizer
,
scheduler
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
@@ -87,7 +87,7 @@ if __name__ == '__main__':
...
@@ -87,7 +87,7 @@ if __name__ == '__main__':
print
(
'
\n
'
+
'='
*
50
+
' START TO TRAIN THE MODEL '
+
'='
*
50
)
print
(
'
\n
'
+
'='
*
50
+
' START TO TRAIN THE MODEL '
+
'='
*
50
)
model
=
VGG
().
to
(
device
)
model
=
VGG
().
to
(
device
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
,
total_epoch
=
args
.
pretrain_epochs
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
pre_best_acc
=
0.0
pre_best_acc
=
0.0
best_state_dict
=
None
best_state_dict
=
None
...
@@ -124,7 +124,7 @@ if __name__ == '__main__':
...
@@ -124,7 +124,7 @@ if __name__ == '__main__':
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
print
(
'
\n
'
+
'='
*
50
+
' START TO FINE TUNE THE MODEL '
+
'='
*
50
)
print
(
'
\n
'
+
'='
*
50
+
' START TO FINE TUNE THE MODEL '
+
'='
*
50
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
,
_lr
=
0.01
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
,
_lr
=
0.01
,
total_epoch
=
args
.
fine_tune_epochs
)
best_acc
=
0.0
best_acc
=
0.0
g_epoch
=
0
g_epoch
=
0
for
i
in
range
(
args
.
fine_tune_epochs
):
for
i
in
range
(
args
.
fine_tune_epochs
):
...
...
examples/model_compress/pruning/v2/taylorfo_pruning_torch.py
View file @
6e643b00
...
@@ -72,9 +72,9 @@ def evaluator(model):
...
@@ -72,9 +72,9 @@ def evaluator(model):
print
(
'Accuracy: {}%
\n
'
.
format
(
acc
))
print
(
'Accuracy: {}%
\n
'
.
format
(
acc
))
return
acc
return
acc
def
optimizer_scheduler_generator
(
model
,
_lr
=
0.1
,
_momentum
=
0.9
,
_weight_decay
=
5e-4
):
def
optimizer_scheduler_generator
(
model
,
_lr
=
0.1
,
_momentum
=
0.9
,
_weight_decay
=
5e-4
,
total_epoch
=
160
):
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
_lr
,
momentum
=
_momentum
,
weight_decay
=
_weight_decay
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
_lr
,
momentum
=
_momentum
,
weight_decay
=
_weight_decay
)
scheduler
=
MultiStepLR
(
optimizer
,
milestones
=
[
int
(
args
.
pretrain
_epoch
s
*
0.5
),
int
(
args
.
pretrain
_epoch
s
*
0.75
)],
gamma
=
0.1
)
scheduler
=
MultiStepLR
(
optimizer
,
milestones
=
[
int
(
total
_epoch
*
0.5
),
int
(
total
_epoch
*
0.75
)],
gamma
=
0.1
)
return
optimizer
,
scheduler
return
optimizer
,
scheduler
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
@@ -87,7 +87,7 @@ if __name__ == '__main__':
...
@@ -87,7 +87,7 @@ if __name__ == '__main__':
print
(
'
\n
'
+
'='
*
50
+
' START TO TRAIN THE MODEL '
+
'='
*
50
)
print
(
'
\n
'
+
'='
*
50
+
' START TO TRAIN THE MODEL '
+
'='
*
50
)
model
=
VGG
().
to
(
device
)
model
=
VGG
().
to
(
device
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
,
total_epoch
=
args
.
pretrain_epochs
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
pre_best_acc
=
0.0
pre_best_acc
=
0.0
best_state_dict
=
None
best_state_dict
=
None
...
@@ -113,7 +113,7 @@ if __name__ == '__main__':
...
@@ -113,7 +113,7 @@ if __name__ == '__main__':
# make sure you have used nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize
# make sure you have used nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize
traced_optimizer
=
trace_parameters
(
torch
.
optim
.
SGD
)(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
traced_optimizer
=
trace_parameters
(
torch
.
optim
.
SGD
)(
model
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
pruner
=
TaylorFOWeightPruner
(
model
,
config_list
,
trainer
,
traced_optimizer
,
criterion
,
training_batches
=
1
)
pruner
=
TaylorFOWeightPruner
(
model
,
config_list
,
trainer
,
traced_optimizer
,
criterion
,
training_batches
=
20
)
_
,
masks
=
pruner
.
compress
()
_
,
masks
=
pruner
.
compress
()
pruner
.
show_pruned_weights
()
pruner
.
show_pruned_weights
()
pruner
.
_unwrap_model
()
pruner
.
_unwrap_model
()
...
@@ -123,7 +123,7 @@ if __name__ == '__main__':
...
@@ -123,7 +123,7 @@ if __name__ == '__main__':
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
# Optimizer used in the pruner might be patched, so recommend to new an optimizer for fine-tuning stage.
print
(
'
\n
'
+
'='
*
50
+
' START TO FINE TUNE THE MODEL '
+
'='
*
50
)
print
(
'
\n
'
+
'='
*
50
+
' START TO FINE TUNE THE MODEL '
+
'='
*
50
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
,
_lr
=
0.01
)
optimizer
,
scheduler
=
optimizer_scheduler_generator
(
model
,
_lr
=
0.01
,
total_epoch
=
args
.
fine_tune_epochs
)
best_acc
=
0.0
best_acc
=
0.0
g_epoch
=
0
g_epoch
=
0
...
...
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
View file @
6e643b00
...
@@ -524,11 +524,23 @@ class ActivationPruner(BasicPruner):
...
@@ -524,11 +524,23 @@ class ActivationPruner(BasicPruner):
raise
'Unsupported activatoin {}'
.
format
(
activation
)
raise
'Unsupported activatoin {}'
.
format
(
activation
)
def
_collector
(
self
,
buffer
:
List
)
->
Callable
[[
Module
,
Tensor
,
Tensor
],
None
]:
def
_collector
(
self
,
buffer
:
List
)
->
Callable
[[
Module
,
Tensor
,
Tensor
],
None
]:
assert
len
(
buffer
)
==
0
,
'Buffer pass to activation pruner collector is not empty.'
# The length of the buffer used in this pruner will always be 2.
# buffer[0] is the number of how many batches are counted in buffer[1].
# buffer[1] is a tensor and the size of buffer[1] is same as the activation.
buffer
.
append
(
0
)
def
collect_activation
(
_module
:
Module
,
_input
:
Tensor
,
output
:
Tensor
):
def
collect_activation
(
_module
:
Module
,
_input
:
Tensor
,
output
:
Tensor
):
if
len
(
buffer
)
<
self
.
training_batches
:
if
len
(
buffer
)
==
1
:
buffer
.
append
(
self
.
_activation
(
output
.
detach
()))
buffer
.
append
(
torch
.
zeros_like
(
output
))
if
buffer
[
0
]
<
self
.
training_batches
:
buffer
[
1
]
+=
self
.
_activation_trans
(
output
)
buffer
[
0
]
+=
1
return
collect_activation
return
collect_activation
def
_activation_trans
(
self
,
output
:
Tensor
)
->
Tensor
:
raise
NotImplementedError
()
def
reset_tools
(
self
):
def
reset_tools
(
self
):
collector_info
=
HookCollectorInfo
([
layer_info
for
layer_info
,
_
in
self
.
_detect_modules_to_compress
()],
'forward'
,
self
.
_collector
)
collector_info
=
HookCollectorInfo
([
layer_info
for
layer_info
,
_
in
self
.
_detect_modules_to_compress
()],
'forward'
,
self
.
_collector
)
if
self
.
data_collector
is
None
:
if
self
.
data_collector
is
None
:
...
@@ -551,11 +563,19 @@ class ActivationPruner(BasicPruner):
...
@@ -551,11 +563,19 @@ class ActivationPruner(BasicPruner):
class
ActivationAPoZRankPruner
(
ActivationPruner
):
class
ActivationAPoZRankPruner
(
ActivationPruner
):
def
_activation_trans
(
self
,
output
:
Tensor
)
->
Tensor
:
# return a matrix that the position of zero in `output` is one, others is zero.
return
torch
.
eq
(
self
.
_activation
(
output
.
detach
()),
torch
.
zeros_like
(
output
)).
type_as
(
output
)
def
_get_metrics_calculator
(
self
)
->
MetricsCalculator
:
def
_get_metrics_calculator
(
self
)
->
MetricsCalculator
:
return
APoZRankMetricsCalculator
(
dim
=
1
)
return
APoZRankMetricsCalculator
(
dim
=
1
)
class
ActivationMeanRankPruner
(
ActivationPruner
):
class
ActivationMeanRankPruner
(
ActivationPruner
):
def
_activation_trans
(
self
,
output
:
Tensor
)
->
Tensor
:
# return the activation of `output` directly.
return
self
.
_activation
(
output
.
detach
())
def
_get_metrics_calculator
(
self
)
->
MetricsCalculator
:
def
_get_metrics_calculator
(
self
)
->
MetricsCalculator
:
return
MeanRankMetricsCalculator
(
dim
=
1
)
return
MeanRankMetricsCalculator
(
dim
=
1
)
...
@@ -647,9 +667,14 @@ class TaylorFOWeightPruner(BasicPruner):
...
@@ -647,9 +667,14 @@ class TaylorFOWeightPruner(BasicPruner):
schema
.
validate
(
config_list
)
schema
.
validate
(
config_list
)
def
_collector
(
self
,
buffer
:
List
,
weight_tensor
:
Tensor
)
->
Callable
[[
Tensor
],
None
]:
def
_collector
(
self
,
buffer
:
List
,
weight_tensor
:
Tensor
)
->
Callable
[[
Tensor
],
None
]:
assert
len
(
buffer
)
==
0
,
'Buffer pass to taylor pruner collector is not empty.'
buffer
.
append
(
0
)
buffer
.
append
(
torch
.
zeros_like
(
weight_tensor
))
def
collect_taylor
(
grad
:
Tensor
):
def
collect_taylor
(
grad
:
Tensor
):
if
len
(
buffer
)
<
self
.
training_batches
:
if
buffer
[
0
]
<
self
.
training_batches
:
buffer
.
append
(
self
.
_calculate_taylor_expansion
(
weight_tensor
,
grad
))
buffer
[
1
]
+=
self
.
_calculate_taylor_expansion
(
weight_tensor
,
grad
)
buffer
[
0
]
+=
1
return
collect_taylor
return
collect_taylor
def
_calculate_taylor_expansion
(
self
,
weight_tensor
:
Tensor
,
grad
:
Tensor
)
->
Tensor
:
def
_calculate_taylor_expansion
(
self
,
weight_tensor
:
Tensor
,
grad
:
Tensor
)
->
Tensor
:
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/metrics_calculator.py
View file @
6e643b00
...
@@ -75,19 +75,20 @@ class NormMetricsCalculator(MetricsCalculator):
...
@@ -75,19 +75,20 @@ class NormMetricsCalculator(MetricsCalculator):
class
MultiDataNormMetricsCalculator
(
NormMetricsCalculator
):
class
MultiDataNormMetricsCalculator
(
NormMetricsCalculator
):
"""
"""
Sum each list of tensor in data at first, then calculate the specify norm for each sumed tensor.
The data value format is a two-element list [batch_number, cumulative_data].
TaylorFO pruner use this to calculate metric.
Directly use the cumulative_data as new_data to calculate norm metric.
TaylorFO pruner uses this to calculate metric.
"""
"""
def
calculate_metrics
(
self
,
data
:
Dict
[
str
,
List
[
Tensor
]])
->
Dict
[
str
,
Tensor
]:
def
calculate_metrics
(
self
,
data
:
Dict
[
str
,
List
[
Tensor
]])
->
Dict
[
str
,
Tensor
]:
new_data
=
{
name
:
sum
(
list_tensor
)
for
name
,
list_tenso
r
in
data
.
items
()}
new_data
=
{
name
:
buffer
[
1
]
for
name
,
buffe
r
in
data
.
items
()}
return
super
().
calculate_metrics
(
new_data
)
return
super
().
calculate_metrics
(
new_data
)
class
DistMetricsCalculator
(
MetricsCalculator
):
class
DistMetricsCalculator
(
MetricsCalculator
):
"""
"""
Calculate the sum of specify distance for each element with all other elements in specify `dim` in each tensor in data.
Calculate the sum of specify distance for each element with all other elements in specify `dim` in each tensor in data.
FPGM pruner use this to calculate metric.
FPGM pruner use
s
this to calculate metric.
"""
"""
def
__init__
(
self
,
p
:
float
,
dim
:
Union
[
int
,
List
[
int
]]):
def
__init__
(
self
,
p
:
float
,
dim
:
Union
[
int
,
List
[
int
]]):
...
@@ -153,26 +154,23 @@ class DistMetricsCalculator(MetricsCalculator):
...
@@ -153,26 +154,23 @@ class DistMetricsCalculator(MetricsCalculator):
class
APoZRankMetricsCalculator
(
MetricsCalculator
):
class
APoZRankMetricsCalculator
(
MetricsCalculator
):
"""
"""
Th
is metric counts the zero number at the same position in the tensor list in data,
Th
e data value format is a two-element list [batch_number, batch_wise_zeros_count_sum].
then
sum the zero number on `dim`
and
calculate the non-zero rate.
This metric
sum the zero number on `dim`
then devide the (batch_number * across_dim_size) to
calculate the non-zero rate.
Note that the metric we return is (1 - apoz), because we assume a higher metric value has higher importance.
Note that the metric we return is (1 - apoz), because we assume a higher metric value has higher importance.
APoZRank pruner use this to calculate metric.
APoZRank pruner use
s
this to calculate metric.
"""
"""
def
calculate_metrics
(
self
,
data
:
Dict
[
str
,
List
[
Tensor
]
])
->
Dict
[
str
,
Tensor
]:
def
calculate_metrics
(
self
,
data
:
Dict
[
str
,
List
])
->
Dict
[
str
,
Tensor
]:
metrics
=
{}
metrics
=
{}
for
name
,
tensor_list
in
data
.
items
():
for
name
,
(
num
,
zero_counts
)
in
data
.
items
():
# NOTE: dim=0 means the batch dim is 0
keeped_dim
=
list
(
range
(
len
(
zero_counts
.
size
())))
if
self
.
dim
is
None
else
self
.
dim
activations
=
torch
.
cat
(
tensor_list
,
dim
=
0
)
across_dim
=
list
(
range
(
len
(
zero_counts
.
size
())))
_eq_zero
=
torch
.
eq
(
activations
,
torch
.
zeros_like
(
activations
))
keeped_dim
=
list
(
range
(
len
(
_eq_zero
.
size
())))
if
self
.
dim
is
None
else
self
.
dim
across_dim
=
list
(
range
(
len
(
_eq_zero
.
size
())))
[
across_dim
.
pop
(
i
)
for
i
in
reversed
(
keeped_dim
)]
[
across_dim
.
pop
(
i
)
for
i
in
reversed
(
keeped_dim
)]
# The element number on each
[
keeped_dim
+ 1] in _eq_zero
# The element number on each keeped_dim
in zero_counts
total_size
=
1
total_size
=
num
for
dim
,
dim_size
in
enumerate
(
_eq_
zero
.
size
()):
for
dim
,
dim_size
in
enumerate
(
zero
_counts
.
size
()):
if
dim
not
in
keeped_dim
:
if
dim
not
in
keeped_dim
:
total_size
*=
dim_size
total_size
*=
dim_size
_apoz
=
torch
.
sum
(
_eq_
zero
,
dim
=
across_dim
).
type_as
(
activation
s
)
/
total_size
_apoz
=
torch
.
sum
(
zero
_counts
,
dim
=
across_dim
).
type_as
(
zero_count
s
)
/
total_size
# NOTE: the metric is (1 - apoz) because we assume the smaller metric value is more needed to be pruned.
# NOTE: the metric is (1 - apoz) because we assume the smaller metric value is more needed to be pruned.
metrics
[
name
]
=
torch
.
ones_like
(
_apoz
)
-
_apoz
metrics
[
name
]
=
torch
.
ones_like
(
_apoz
)
-
_apoz
return
metrics
return
metrics
...
@@ -180,16 +178,15 @@ class APoZRankMetricsCalculator(MetricsCalculator):
...
@@ -180,16 +178,15 @@ class APoZRankMetricsCalculator(MetricsCalculator):
class
MeanRankMetricsCalculator
(
MetricsCalculator
):
class
MeanRankMetricsCalculator
(
MetricsCalculator
):
"""
"""
This metric simply concat the list of tensor on dim 0, and average on `dim`.
The data value format is a two-element list [batch_number, batch_wise_activation_sum].
MeanRank pruner use this to calculate metric.
This metric simply calculate the average on `self.dim`, then divide by the batch_number.
MeanRank pruner uses this to calculate metric.
"""
"""
def
calculate_metrics
(
self
,
data
:
Dict
[
str
,
List
[
Tensor
]])
->
Dict
[
str
,
Tensor
]:
def
calculate_metrics
(
self
,
data
:
Dict
[
str
,
List
[
Tensor
]])
->
Dict
[
str
,
Tensor
]:
metrics
=
{}
metrics
=
{}
for
name
,
tensor_list
in
data
.
items
():
for
name
,
(
num
,
activation_sum
)
in
data
.
items
():
# NOTE: dim=0 means the batch dim is 0
keeped_dim
=
list
(
range
(
len
(
activation_sum
.
size
())))
if
self
.
dim
is
None
else
self
.
dim
activations
=
torch
.
cat
(
tensor_list
,
dim
=
0
)
across_dim
=
list
(
range
(
len
(
activation_sum
.
size
())))
keeped_dim
=
list
(
range
(
len
(
activations
.
size
())))
if
self
.
dim
is
None
else
self
.
dim
across_dim
=
list
(
range
(
len
(
activations
.
size
())))
[
across_dim
.
pop
(
i
)
for
i
in
reversed
(
keeped_dim
)]
[
across_dim
.
pop
(
i
)
for
i
in
reversed
(
keeped_dim
)]
metrics
[
name
]
=
torch
.
mean
(
activation
s
,
across_dim
)
metrics
[
name
]
=
torch
.
mean
(
activation
_sum
,
across_dim
)
/
num
return
metrics
return
metrics
test/ut/compression/v2/test_pruning_tools_torch.py
View file @
6e643b00
...
@@ -139,12 +139,12 @@ class PruningToolsTestCase(unittest.TestCase):
...
@@ -139,12 +139,12 @@ class PruningToolsTestCase(unittest.TestCase):
# Test MultiDataNormMetricsCalculator
# Test MultiDataNormMetricsCalculator
metrics_calculator
=
MultiDataNormMetricsCalculator
(
dim
=
0
,
p
=
1
)
metrics_calculator
=
MultiDataNormMetricsCalculator
(
dim
=
0
,
p
=
1
)
data
=
{
data
=
{
'1'
:
[
torch
.
ones
(
3
,
3
,
3
)
,
torch
.
ones
(
3
,
3
,
3
)
*
2
],
'1'
:
[
2
,
torch
.
ones
(
3
,
3
,
3
)
*
2
],
'2'
:
[
torch
.
ones
(
4
,
4
)
,
torch
.
ones
(
4
,
4
)
*
2
]
'2'
:
[
2
,
torch
.
ones
(
4
,
4
)
*
2
]
}
}
result
=
{
result
=
{
'1'
:
torch
.
ones
(
3
)
*
27
,
'1'
:
torch
.
ones
(
3
)
*
18
,
'2'
:
torch
.
ones
(
4
)
*
12
'2'
:
torch
.
ones
(
4
)
*
8
}
}
metrics
=
metrics_calculator
.
calculate_metrics
(
data
)
metrics
=
metrics_calculator
.
calculate_metrics
(
data
)
assert
all
(
torch
.
equal
(
result
[
k
],
v
)
for
k
,
v
in
metrics
.
items
())
assert
all
(
torch
.
equal
(
result
[
k
],
v
)
for
k
,
v
in
metrics
.
items
())
...
@@ -152,12 +152,12 @@ class PruningToolsTestCase(unittest.TestCase):
...
@@ -152,12 +152,12 @@ class PruningToolsTestCase(unittest.TestCase):
# Test APoZRankMetricsCalculator
# Test APoZRankMetricsCalculator
metrics_calculator
=
APoZRankMetricsCalculator
(
dim
=
1
)
metrics_calculator
=
APoZRankMetricsCalculator
(
dim
=
1
)
data
=
{
data
=
{
'1'
:
[
torch
.
tensor
([[
1
,
0
],
[
0
,
1
]],
dtype
=
torch
.
float32
)
,
torch
.
tensor
([[
0
,
1
],
[
1
,
0
]],
dtype
=
torch
.
float32
)],
'1'
:
[
2
,
torch
.
tensor
([[
1
,
1
],
[
1
,
1
]],
dtype
=
torch
.
float32
)],
'2'
:
[
torch
.
tensor
([[
1
,
0
,
1
],
[
0
,
1
,
0
]],
dtype
=
torch
.
float32
)
,
torch
.
tensor
([[
0
,
0
,
1
],
[
0
,
0
,
0
]],
dtype
=
torch
.
float32
)]
'2'
:
[
2
,
torch
.
tensor
([[
0
,
0
,
1
],
[
0
,
0
,
0
]],
dtype
=
torch
.
float32
)]
}
}
result
=
{
result
=
{
'1'
:
torch
.
tensor
([
0.5
,
0.5
],
dtype
=
torch
.
float32
),
'1'
:
torch
.
tensor
([
0.5
,
0.5
],
dtype
=
torch
.
float32
),
'2'
:
torch
.
tensor
([
0.25
,
0.25
,
0.5
],
dtype
=
torch
.
float32
)
'2'
:
torch
.
tensor
([
1
,
1
,
0.
7
5
],
dtype
=
torch
.
float32
)
}
}
metrics
=
metrics_calculator
.
calculate_metrics
(
data
)
metrics
=
metrics_calculator
.
calculate_metrics
(
data
)
assert
all
(
torch
.
equal
(
result
[
k
],
v
)
for
k
,
v
in
metrics
.
items
())
assert
all
(
torch
.
equal
(
result
[
k
],
v
)
for
k
,
v
in
metrics
.
items
())
...
@@ -165,12 +165,12 @@ class PruningToolsTestCase(unittest.TestCase):
...
@@ -165,12 +165,12 @@ class PruningToolsTestCase(unittest.TestCase):
# Test MeanRankMetricsCalculator
# Test MeanRankMetricsCalculator
metrics_calculator
=
MeanRankMetricsCalculator
(
dim
=
1
)
metrics_calculator
=
MeanRankMetricsCalculator
(
dim
=
1
)
data
=
{
data
=
{
'1'
:
[
torch
.
tensor
([[
1
,
0
],
[
0
,
1
]],
dtype
=
torch
.
float32
)
,
torch
.
tensor
([[
0
,
1
],
[
1
,
0
]],
dtype
=
torch
.
float32
)],
'1'
:
[
2
,
torch
.
tensor
([[
0
,
1
],
[
1
,
0
]],
dtype
=
torch
.
float32
)],
'2'
:
[
torch
.
tensor
([[
1
,
0
,
1
],
[
0
,
1
,
0
]],
dtype
=
torch
.
float32
)
,
torch
.
tensor
([[
0
,
0
,
1
],
[
0
,
0
,
0
]],
dtype
=
torch
.
float32
)]
'2'
:
[
2
,
torch
.
tensor
([[
0
,
0
,
1
],
[
0
,
0
,
0
]],
dtype
=
torch
.
float32
)]
}
}
result
=
{
result
=
{
'1'
:
torch
.
tensor
([
0.5
,
0.5
],
dtype
=
torch
.
float32
),
'1'
:
torch
.
tensor
([
0.
2
5
,
0.
2
5
],
dtype
=
torch
.
float32
),
'2'
:
torch
.
tensor
([
0
.25
,
0.25
,
0.5
],
dtype
=
torch
.
float32
)
'2'
:
torch
.
tensor
([
0
,
0
,
0.
2
5
],
dtype
=
torch
.
float32
)
}
}
metrics
=
metrics_calculator
.
calculate_metrics
(
data
)
metrics
=
metrics_calculator
.
calculate_metrics
(
data
)
assert
all
(
torch
.
equal
(
result
[
k
],
v
)
for
k
,
v
in
metrics
.
items
())
assert
all
(
torch
.
equal
(
result
[
k
],
v
)
for
k
,
v
in
metrics
.
items
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment