Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
6dfdc546
Unverified
Commit
6dfdc546
authored
Dec 13, 2021
by
J-shang
Committed by
GitHub
Dec 13, 2021
Browse files
[Compression v2] Add optimizer & lr scheduler construct helper (#4332)
parent
7978c25a
Changes
25
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
146 additions
and
17 deletions
+146
-17
nni/algorithms/compression/v2/pytorch/utils/constructor_helper.py
...rithms/compression/v2/pytorch/utils/constructor_helper.py
+126
-0
nni/common/serializer.py
nni/common/serializer.py
+2
-2
test/ut/compression/v2/test_iterative_pruner_torch.py
test/ut/compression/v2/test_iterative_pruner_torch.py
+3
-3
test/ut/compression/v2/test_pruner_torch.py
test/ut/compression/v2/test_pruner_torch.py
+8
-8
test/ut/compression/v2/test_pruning_tools_torch.py
test/ut/compression/v2/test_pruning_tools_torch.py
+7
-4
No files found.
nni/algorithms/compression/v2/pytorch/utils/constructor_helper.py
0 → 100644
View file @
6dfdc546
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
copy
import
deepcopy
from
typing
import
Callable
,
Dict
,
List
,
Type
from
torch
import
Tensor
from
torch.nn
import
Module
from
torch.optim
import
Optimizer
from
torch.optim.lr_scheduler
import
_LRScheduler
from
nni.common.serializer
import
_trace_cls
from
nni.common.serializer
import
Traceable
__all__
=
[
'OptimizerConstructHelper'
,
'LRSchedulerConstructHelper'
,
'trace_parameters'
]
def
trace_parameters
(
base
,
kw_only
=
True
):
if
not
isinstance
(
base
,
type
):
raise
Exception
(
'Only class can be traced by this function.'
)
return
_trace_cls
(
base
,
kw_only
,
call_super
=
False
)
class
ConstructHelper
:
def
__init__
(
self
,
callable_obj
:
Callable
,
*
args
,
**
kwargs
):
assert
callable
(
callable_obj
),
'`callable_obj` must be a callable object.'
self
.
callable_obj
=
callable_obj
self
.
args
=
deepcopy
(
args
)
self
.
kwargs
=
deepcopy
(
kwargs
)
def
call
(
self
):
args
=
deepcopy
(
self
.
args
)
kwargs
=
deepcopy
(
self
.
kwargs
)
return
self
.
callable_obj
(
*
args
,
**
kwargs
)
class
OptimizerConstructHelper
(
ConstructHelper
):
def
__init__
(
self
,
model
:
Module
,
optimizer_class
:
Type
[
Optimizer
],
*
args
,
**
kwargs
):
assert
isinstance
(
model
,
Module
),
'Only support pytorch module.'
assert
issubclass
(
optimizer_class
,
Optimizer
),
'Only support pytorch optimizer'
args
=
list
(
args
)
if
'params'
in
kwargs
:
kwargs
[
'params'
]
=
self
.
params2names
(
model
,
kwargs
[
'params'
])
else
:
args
[
0
]
=
self
.
params2names
(
model
,
args
[
0
])
super
().
__init__
(
optimizer_class
,
*
args
,
**
kwargs
)
def
params2names
(
self
,
model
:
Module
,
params
:
List
)
->
List
[
Dict
]:
param_groups
=
list
(
params
)
assert
len
(
param_groups
)
>
0
if
not
isinstance
(
param_groups
[
0
],
dict
):
param_groups
=
[{
'params'
:
param_groups
}]
for
param_group
in
param_groups
:
params
=
param_group
[
'params'
]
if
isinstance
(
params
,
Tensor
):
params
=
[
params
]
elif
isinstance
(
params
,
set
):
raise
TypeError
(
'optimizer parameters need to be organized in ordered collections, but '
'the ordering of tensors in sets will change between runs. Please use a list instead.'
)
else
:
params
=
list
(
params
)
param_ids
=
[
id
(
p
)
for
p
in
params
]
param_group
[
'params'
]
=
[
name
for
name
,
p
in
model
.
named_parameters
()
if
id
(
p
)
in
param_ids
]
return
param_groups
def
names2params
(
self
,
wrapped_model
:
Module
,
origin2wrapped_name_map
:
Dict
,
params
:
List
[
Dict
])
->
List
[
Dict
]:
param_groups
=
deepcopy
(
params
)
for
param_group
in
param_groups
:
wrapped_names
=
[
origin2wrapped_name_map
.
get
(
name
,
name
)
for
name
in
param_group
[
'params'
]]
param_group
[
'params'
]
=
[
p
for
name
,
p
in
wrapped_model
.
named_parameters
()
if
name
in
wrapped_names
]
return
param_groups
def
call
(
self
,
wrapped_model
:
Module
,
origin2wrapped_name_map
:
Dict
)
->
Optimizer
:
args
=
deepcopy
(
self
.
args
)
kwargs
=
deepcopy
(
self
.
kwargs
)
if
'params'
in
kwargs
:
kwargs
[
'params'
]
=
self
.
names2params
(
wrapped_model
,
origin2wrapped_name_map
,
kwargs
[
'params'
])
else
:
args
[
0
]
=
self
.
names2params
(
wrapped_model
,
origin2wrapped_name_map
,
args
[
0
])
return
self
.
callable_obj
(
*
args
,
**
kwargs
)
@
staticmethod
def
from_trace
(
model
:
Module
,
optimizer_trace
:
Traceable
):
assert
isinstance
(
optimizer_trace
,
Traceable
),
\
'Please use nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize the optimizer.'
assert
isinstance
(
optimizer_trace
,
Optimizer
),
\
'It is not an instance of torch.nn.Optimizer.'
return
OptimizerConstructHelper
(
model
,
optimizer_trace
.
_get_nni_attr
(
'symbol'
),
*
optimizer_trace
.
_get_nni_attr
(
'args'
),
**
optimizer_trace
.
_get_nni_attr
(
'kwargs'
))
class
LRSchedulerConstructHelper
(
ConstructHelper
):
def
__init__
(
self
,
lr_scheduler_class
:
Type
[
_LRScheduler
],
*
args
,
**
kwargs
):
args
=
list
(
args
)
if
'optimizer'
in
kwargs
:
kwargs
[
'optimizer'
]
=
None
else
:
args
[
0
]
=
None
super
().
__init__
(
lr_scheduler_class
,
*
args
,
**
kwargs
)
def
call
(
self
,
optimizer
:
Optimizer
)
->
_LRScheduler
:
args
=
deepcopy
(
self
.
args
)
kwargs
=
deepcopy
(
self
.
kwargs
)
if
'optimizer'
in
kwargs
:
kwargs
[
'optimizer'
]
=
optimizer
else
:
args
[
0
]
=
optimizer
return
self
.
callable_obj
(
*
args
,
**
kwargs
)
@
staticmethod
def
from_trace
(
lr_scheduler_trace
:
Traceable
):
assert
isinstance
(
lr_scheduler_trace
,
Traceable
),
\
'Please use nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the lr scheduler class before initialize the scheduler.'
assert
isinstance
(
lr_scheduler_trace
,
_LRScheduler
),
\
'It is not an instance of torch.nn.lr_scheduler._LRScheduler.'
return
LRSchedulerConstructHelper
(
lr_scheduler_trace
.
trace_symbol
,
*
lr_scheduler_trace
.
trace_args
,
**
lr_scheduler_trace
.
trace_kwargs
)
nni/common/serializer.py
View file @
6dfdc546
...
@@ -344,7 +344,7 @@ def load(string: Optional[str] = None, *, fp: Optional[Any] = None, ignore_comme
...
@@ -344,7 +344,7 @@ def load(string: Optional[str] = None, *, fp: Optional[Any] = None, ignore_comme
return
json_tricks
.
load
(
fp
,
obj_pairs_hooks
=
hooks
,
**
json_tricks_kwargs
)
return
json_tricks
.
load
(
fp
,
obj_pairs_hooks
=
hooks
,
**
json_tricks_kwargs
)
def
_trace_cls
(
base
,
kw_only
):
def
_trace_cls
(
base
,
kw_only
,
call_super
=
True
):
# the implementation to trace a class is to store a copy of init arguments
# the implementation to trace a class is to store a copy of init arguments
# this won't support class that defines a customized new but should work for most cases
# this won't support class that defines a customized new but should work for most cases
...
@@ -354,7 +354,7 @@ def _trace_cls(base, kw_only):
...
@@ -354,7 +354,7 @@ def _trace_cls(base, kw_only):
args
,
kwargs
=
_formulate_arguments
(
base
.
__init__
,
args
,
kwargs
,
kw_only
,
is_class_init
=
True
)
args
,
kwargs
=
_formulate_arguments
(
base
.
__init__
,
args
,
kwargs
,
kw_only
,
is_class_init
=
True
)
# calling serializable object init to initialize the full object
# calling serializable object init to initialize the full object
super
().
__init__
(
symbol
=
base
,
args
=
args
,
kwargs
=
kwargs
,
call_super
=
True
)
super
().
__init__
(
symbol
=
base
,
args
=
args
,
kwargs
=
kwargs
,
call_super
=
call_super
)
_copy_class_wrapper_attributes
(
base
,
wrapper
)
_copy_class_wrapper_attributes
(
base
,
wrapper
)
...
...
test/ut/compression/v2/test_iterative_pruner_torch.py
View file @
6dfdc546
...
@@ -15,7 +15,7 @@ from nni.algorithms.compression.v2.pytorch.pruning import (
...
@@ -15,7 +15,7 @@ from nni.algorithms.compression.v2.pytorch.pruning import (
AutoCompressPruner
AutoCompressPruner
)
)
from
nni.algorithms.compression.v2.pytorch.utils
import
compute_sparsity_mask2compact
from
nni.algorithms.compression.v2.pytorch.utils
import
compute_sparsity_mask2compact
,
trace_parameters
class
TorchModel
(
torch
.
nn
.
Module
):
class
TorchModel
(
torch
.
nn
.
Module
):
...
@@ -52,7 +52,7 @@ def trainer(model, optimizer, criterion):
...
@@ -52,7 +52,7 @@ def trainer(model, optimizer, criterion):
def
get_optimizer
(
model
):
def
get_optimizer
(
model
):
return
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
return
trace_parameters
(
torch
.
optim
.
SGD
)
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
...
@@ -104,7 +104,7 @@ class IterativePrunerTestCase(unittest.TestCase):
...
@@ -104,7 +104,7 @@ class IterativePrunerTestCase(unittest.TestCase):
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
admm_params
=
{
admm_params
=
{
'trainer'
:
trainer
,
'trainer'
:
trainer
,
'optimizer'
:
get_optimizer
(
model
),
'
traced_
optimizer'
:
get_optimizer
(
model
),
'criterion'
:
criterion
,
'criterion'
:
criterion
,
'iterations'
:
10
,
'iterations'
:
10
,
'training_epochs'
:
1
'training_epochs'
:
1
...
...
test/ut/compression/v2/test_pruner_torch.py
View file @
6dfdc546
...
@@ -18,7 +18,7 @@ from nni.algorithms.compression.v2.pytorch.pruning import (
...
@@ -18,7 +18,7 @@ from nni.algorithms.compression.v2.pytorch.pruning import (
ADMMPruner
,
ADMMPruner
,
MovementPruner
MovementPruner
)
)
from
nni.algorithms.compression.v2.pytorch.utils
import
compute_sparsity_mask2compact
from
nni.algorithms.compression.v2.pytorch.utils
import
compute_sparsity_mask2compact
,
trace_parameters
class
TorchModel
(
torch
.
nn
.
Module
):
class
TorchModel
(
torch
.
nn
.
Module
):
...
@@ -55,7 +55,7 @@ def trainer(model, optimizer, criterion):
...
@@ -55,7 +55,7 @@ def trainer(model, optimizer, criterion):
def
get_optimizer
(
model
):
def
get_optimizer
(
model
):
return
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
return
trace_parameters
(
torch
.
optim
.
SGD
)
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
...
@@ -104,7 +104,7 @@ class PrunerTestCase(unittest.TestCase):
...
@@ -104,7 +104,7 @@ class PrunerTestCase(unittest.TestCase):
def
test_slim_pruner
(
self
):
def
test_slim_pruner
(
self
):
model
=
TorchModel
()
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'BatchNorm2d'
],
'total_sparsity'
:
0.8
}]
config_list
=
[{
'op_types'
:
[
'BatchNorm2d'
],
'total_sparsity'
:
0.8
}]
pruner
=
SlimPruner
(
model
=
model
,
config_list
=
config_list
,
trainer
=
trainer
,
optimizer
=
get_optimizer
(
model
),
pruner
=
SlimPruner
(
model
=
model
,
config_list
=
config_list
,
trainer
=
trainer
,
traced_
optimizer
=
get_optimizer
(
model
),
criterion
=
criterion
,
training_epochs
=
1
,
scale
=
0.001
,
mode
=
'global'
)
criterion
=
criterion
,
training_epochs
=
1
,
scale
=
0.001
,
mode
=
'global'
)
pruned_model
,
masks
=
pruner
.
compress
()
pruned_model
,
masks
=
pruner
.
compress
()
pruner
.
_unwrap_model
()
pruner
.
_unwrap_model
()
...
@@ -115,7 +115,7 @@ class PrunerTestCase(unittest.TestCase):
...
@@ -115,7 +115,7 @@ class PrunerTestCase(unittest.TestCase):
model
=
TorchModel
()
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
pruner
=
ActivationAPoZRankPruner
(
model
=
model
,
config_list
=
config_list
,
trainer
=
trainer
,
pruner
=
ActivationAPoZRankPruner
(
model
=
model
,
config_list
=
config_list
,
trainer
=
trainer
,
optimizer
=
get_optimizer
(
model
),
criterion
=
criterion
,
training_batches
=
1
,
traced_
optimizer
=
get_optimizer
(
model
),
criterion
=
criterion
,
training_batches
=
5
,
activation
=
'relu'
,
mode
=
'dependency_aware'
,
activation
=
'relu'
,
mode
=
'dependency_aware'
,
dummy_input
=
torch
.
rand
(
10
,
1
,
28
,
28
))
dummy_input
=
torch
.
rand
(
10
,
1
,
28
,
28
))
pruned_model
,
masks
=
pruner
.
compress
()
pruned_model
,
masks
=
pruner
.
compress
()
...
@@ -127,7 +127,7 @@ class PrunerTestCase(unittest.TestCase):
...
@@ -127,7 +127,7 @@ class PrunerTestCase(unittest.TestCase):
model
=
TorchModel
()
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
pruner
=
ActivationMeanRankPruner
(
model
=
model
,
config_list
=
config_list
,
trainer
=
trainer
,
pruner
=
ActivationMeanRankPruner
(
model
=
model
,
config_list
=
config_list
,
trainer
=
trainer
,
optimizer
=
get_optimizer
(
model
),
criterion
=
criterion
,
training_batches
=
1
,
traced_
optimizer
=
get_optimizer
(
model
),
criterion
=
criterion
,
training_batches
=
5
,
activation
=
'relu'
,
mode
=
'dependency_aware'
,
activation
=
'relu'
,
mode
=
'dependency_aware'
,
dummy_input
=
torch
.
rand
(
10
,
1
,
28
,
28
))
dummy_input
=
torch
.
rand
(
10
,
1
,
28
,
28
))
pruned_model
,
masks
=
pruner
.
compress
()
pruned_model
,
masks
=
pruner
.
compress
()
...
@@ -139,7 +139,7 @@ class PrunerTestCase(unittest.TestCase):
...
@@ -139,7 +139,7 @@ class PrunerTestCase(unittest.TestCase):
model
=
TorchModel
()
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
pruner
=
TaylorFOWeightPruner
(
model
=
model
,
config_list
=
config_list
,
trainer
=
trainer
,
pruner
=
TaylorFOWeightPruner
(
model
=
model
,
config_list
=
config_list
,
trainer
=
trainer
,
optimizer
=
get_optimizer
(
model
),
criterion
=
criterion
,
training_batches
=
1
,
traced_
optimizer
=
get_optimizer
(
model
),
criterion
=
criterion
,
training_batches
=
5
,
mode
=
'dependency_aware'
,
dummy_input
=
torch
.
rand
(
10
,
1
,
28
,
28
))
mode
=
'dependency_aware'
,
dummy_input
=
torch
.
rand
(
10
,
1
,
28
,
28
))
pruned_model
,
masks
=
pruner
.
compress
()
pruned_model
,
masks
=
pruner
.
compress
()
pruner
.
_unwrap_model
()
pruner
.
_unwrap_model
()
...
@@ -149,7 +149,7 @@ class PrunerTestCase(unittest.TestCase):
...
@@ -149,7 +149,7 @@ class PrunerTestCase(unittest.TestCase):
def
test_admm_pruner
(
self
):
def
test_admm_pruner
(
self
):
model
=
TorchModel
()
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
,
'rho'
:
1e-3
}]
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
,
'rho'
:
1e-3
}]
pruner
=
ADMMPruner
(
model
=
model
,
config_list
=
config_list
,
trainer
=
trainer
,
optimizer
=
get_optimizer
(
model
),
pruner
=
ADMMPruner
(
model
=
model
,
config_list
=
config_list
,
trainer
=
trainer
,
traced_
optimizer
=
get_optimizer
(
model
),
criterion
=
criterion
,
iterations
=
2
,
training_epochs
=
1
)
criterion
=
criterion
,
iterations
=
2
,
training_epochs
=
1
)
pruned_model
,
masks
=
pruner
.
compress
()
pruned_model
,
masks
=
pruner
.
compress
()
pruner
.
_unwrap_model
()
pruner
.
_unwrap_model
()
...
@@ -159,7 +159,7 @@ class PrunerTestCase(unittest.TestCase):
...
@@ -159,7 +159,7 @@ class PrunerTestCase(unittest.TestCase):
def
test_movement_pruner
(
self
):
def
test_movement_pruner
(
self
):
model
=
TorchModel
()
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
pruner
=
MovementPruner
(
model
=
model
,
config_list
=
config_list
,
trainer
=
trainer
,
optimizer
=
get_optimizer
(
model
),
pruner
=
MovementPruner
(
model
=
model
,
config_list
=
config_list
,
trainer
=
trainer
,
traced_
optimizer
=
get_optimizer
(
model
),
criterion
=
criterion
,
training_epochs
=
5
,
warm_up_step
=
0
,
cool_down_beginning_step
=
4
)
criterion
=
criterion
,
training_epochs
=
5
,
warm_up_step
=
0
,
cool_down_beginning_step
=
4
)
pruned_model
,
masks
=
pruner
.
compress
()
pruned_model
,
masks
=
pruner
.
compress
()
pruner
.
_unwrap_model
()
pruner
.
_unwrap_model
()
...
...
test/ut/compression/v2/test_pruning_tools_torch.py
View file @
6dfdc546
...
@@ -24,7 +24,8 @@ from nni.algorithms.compression.v2.pytorch.pruning.tools import (
...
@@ -24,7 +24,8 @@ from nni.algorithms.compression.v2.pytorch.pruning.tools import (
GlobalSparsityAllocator
GlobalSparsityAllocator
)
)
from
nni.algorithms.compression.v2.pytorch.pruning.tools.base
import
HookCollectorInfo
from
nni.algorithms.compression.v2.pytorch.pruning.tools.base
import
HookCollectorInfo
from
nni.algorithms.compression.v2.pytorch.utils
import
get_module_by_name
from
nni.algorithms.compression.v2.pytorch.utils
import
get_module_by_name
,
trace_parameters
from
nni.algorithms.compression.v2.pytorch.utils.constructor_helper
import
OptimizerConstructHelper
class
TorchModel
(
torch
.
nn
.
Module
):
class
TorchModel
(
torch
.
nn
.
Module
):
...
@@ -61,7 +62,7 @@ def trainer(model, optimizer, criterion):
...
@@ -61,7 +62,7 @@ def trainer(model, optimizer, criterion):
def
get_optimizer
(
model
):
def
get_optimizer
(
model
):
return
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
return
trace_parameters
(
torch
.
optim
.
SGD
)
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
...
@@ -88,7 +89,8 @@ class PruningToolsTestCase(unittest.TestCase):
...
@@ -88,7 +89,8 @@ class PruningToolsTestCase(unittest.TestCase):
model
.
conv1
.
module
.
weight
.
data
=
torch
.
ones
(
5
,
1
,
5
,
5
)
model
.
conv1
.
module
.
weight
.
data
=
torch
.
ones
(
5
,
1
,
5
,
5
)
model
.
conv2
.
module
.
weight
.
data
=
torch
.
ones
(
10
,
5
,
5
,
5
)
model
.
conv2
.
module
.
weight
.
data
=
torch
.
ones
(
10
,
5
,
5
,
5
)
data_collector
=
WeightTrainerBasedDataCollector
(
pruner
,
trainer
,
get_optimizer
(
model
),
criterion
,
1
,
opt_after_tasks
=
[
opt_after
])
optimizer_helper
=
OptimizerConstructHelper
.
from_trace
(
model
,
get_optimizer
(
model
))
data_collector
=
WeightTrainerBasedDataCollector
(
pruner
,
trainer
,
optimizer_helper
,
criterion
,
1
,
opt_after_tasks
=
[
opt_after
])
data
=
data_collector
.
collect
()
data
=
data_collector
.
collect
()
assert
all
(
torch
.
equal
(
get_module_by_name
(
model
,
module_name
)[
1
].
module
.
weight
.
data
,
data
[
module_name
])
for
module_name
in
[
'conv1'
,
'conv2'
])
assert
all
(
torch
.
equal
(
get_module_by_name
(
model
,
module_name
)[
1
].
module
.
weight
.
data
,
data
[
module_name
])
for
module_name
in
[
'conv1'
,
'conv2'
])
assert
all
(
t
.
numel
()
==
(
t
==
1
).
type_as
(
t
).
sum
().
item
()
for
t
in
data
.
values
())
assert
all
(
t
.
numel
()
==
(
t
==
1
).
type_as
(
t
).
sum
().
item
()
for
t
in
data
.
values
())
...
@@ -102,7 +104,8 @@ class PruningToolsTestCase(unittest.TestCase):
...
@@ -102,7 +104,8 @@ class PruningToolsTestCase(unittest.TestCase):
hook_targets
=
{
'conv1'
:
model
.
conv1
.
module
.
weight
,
'conv2'
:
model
.
conv2
.
module
.
weight
}
hook_targets
=
{
'conv1'
:
model
.
conv1
.
module
.
weight
,
'conv2'
:
model
.
conv2
.
module
.
weight
}
collector_info
=
HookCollectorInfo
(
hook_targets
,
'tensor'
,
_collector
)
collector_info
=
HookCollectorInfo
(
hook_targets
,
'tensor'
,
_collector
)
data_collector
=
SingleHookTrainerBasedDataCollector
(
pruner
,
trainer
,
get_optimizer
(
model
),
criterion
,
2
,
collector_infos
=
[
collector_info
])
optimizer_helper
=
OptimizerConstructHelper
.
from_trace
(
model
,
get_optimizer
(
model
))
data_collector
=
SingleHookTrainerBasedDataCollector
(
pruner
,
trainer
,
optimizer_helper
,
criterion
,
2
,
collector_infos
=
[
collector_info
])
data
=
data_collector
.
collect
()
data
=
data_collector
.
collect
()
assert
all
(
len
(
t
)
==
2
for
t
in
data
.
values
())
assert
all
(
len
(
t
)
==
2
for
t
in
data
.
values
())
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment