Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
88925314
Unverified
Commit
88925314
authored
Oct 22, 2021
by
J-shang
Committed by
GitHub
Oct 22, 2021
Browse files
[Bugbash] fix bug in compression (#4259)
parent
a55a5559
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
19 additions
and
8 deletions
+19
-8
examples/model_compress/pruning/v2/iterative_pruning_torch.py
...ples/model_compress/pruning/v2/iterative_pruning_torch.py
+3
-1
examples/model_compress/pruning/v2/scheduler_torch.py
examples/model_compress/pruning/v2/scheduler_torch.py
+3
-1
examples/model_compress/pruning/v2/simple_pruning_torch.py
examples/model_compress/pruning/v2/simple_pruning_torch.py
+4
-2
nni/algorithms/compression/pytorch/pruning/iterative_pruner.py
...lgorithms/compression/pytorch/pruning/iterative_pruner.py
+1
-0
nni/compression/pytorch/speedup/compress_modules.py
nni/compression/pytorch/speedup/compress_modules.py
+8
-4
No files found.
examples/model_compress/pruning/v2/iterative_pruning_torch.py
View file @
88925314
import
sys
from
tqdm
import
tqdm
import
torch
...
...
@@ -5,7 +6,8 @@ from torchvision import datasets, transforms
from
nni.algorithms.compression.v2.pytorch.pruning
import
AGPPruner
from
examples.model_compress.models.cifar10.vgg
import
VGG
sys
.
path
.
append
(
'../../models'
)
from
cifar10.vgg
import
VGG
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
...
...
examples/model_compress/pruning/v2/scheduler_torch.py
View file @
88925314
import
sys
from
tqdm
import
tqdm
import
torch
...
...
@@ -7,7 +8,8 @@ from nni.algorithms.compression.v2.pytorch.pruning import L1NormPruner
from
nni.algorithms.compression.v2.pytorch.pruning.tools
import
AGPTaskGenerator
from
nni.algorithms.compression.v2.pytorch.pruning.basic_scheduler
import
PruningScheduler
from
examples.model_compress.models.cifar10.vgg
import
VGG
sys
.
path
.
append
(
'../../models'
)
from
cifar10.vgg
import
VGG
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
...
...
examples/model_compress/pruning/v2/simple_pruning_torch.py
View file @
88925314
import
sys
from
tqdm
import
tqdm
import
torch
...
...
@@ -6,7 +7,8 @@ from torchvision import datasets, transforms
from
nni.algorithms.compression.v2.pytorch.pruning
import
L1NormPruner
from
nni.compression.pytorch.speedup
import
ModelSpeedup
from
examples.model_compress.models.cifar10.vgg
import
VGG
sys
.
path
.
append
(
'../../models'
)
from
cifar10.vgg
import
VGG
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
...
...
@@ -72,7 +74,7 @@ if __name__ == '__main__':
evaluator
(
model
)
pruner
.
_unwrap_model
()
ModelSpeedup
(
model
,
dummy_input
=
torch
.
rand
(
10
,
3
,
32
,
32
).
to
(
device
),
masks_file
=
'simple_masks.pth'
).
speedup_model
()
ModelSpeedup
(
model
,
dummy_input
=
torch
.
rand
(
10
,
3
,
32
,
32
).
to
(
device
),
masks_file
=
masks
).
speedup_model
()
print
(
'
\n
The accuracy after speed up:'
)
evaluator
(
model
)
...
...
nni/algorithms/compression/pytorch/pruning/iterative_pruner.py
View file @
88925314
...
...
@@ -384,6 +384,7 @@ class ADMMPruner(IterativePruner):
for
i
,
wrapper
in
enumerate
(
self
.
get_modules_wrapper
()):
z
=
wrapper
.
module
.
weight
.
data
+
self
.
U
[
i
]
self
.
Z
[
i
]
=
self
.
_projection
(
z
,
wrapper
.
config
[
'sparsity'
],
wrapper
)
torch
.
cuda
.
empty_cache
()
self
.
U
[
i
]
=
self
.
U
[
i
]
+
wrapper
.
module
.
weight
.
data
-
self
.
Z
[
i
]
# apply prune
...
...
nni/compression/pytorch/speedup/compress_modules.py
View file @
88925314
...
...
@@ -110,6 +110,8 @@ def replace_prelu(prelu, masks):
in_mask
=
in_masks
[
0
]
weight_mask
=
weight_mask
[
'weight'
]
if
weight_mask
.
size
(
0
)
==
1
:
return
prelu
pruned_in
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
1
)
pruned_out
,
remained_out
=
convert_to_coarse_mask
(
output_mask
,
1
)
n_remained_in
=
weight_mask
.
size
(
0
)
-
pruned_in
.
size
(
0
)
...
...
@@ -221,8 +223,9 @@ def replace_batchnorm1d(norm, masks):
affine
=
norm
.
affine
,
track_running_stats
=
norm
.
track_running_stats
)
# assign weights
new_norm
.
weight
.
data
=
torch
.
index_select
(
norm
.
weight
.
data
,
0
,
remained_in
)
new_norm
.
bias
.
data
=
torch
.
index_select
(
norm
.
bias
.
data
,
0
,
remained_in
)
if
norm
.
affine
:
new_norm
.
weight
.
data
=
torch
.
index_select
(
norm
.
weight
.
data
,
0
,
remained_in
)
new_norm
.
bias
.
data
=
torch
.
index_select
(
norm
.
bias
.
data
,
0
,
remained_in
)
new_norm
.
running_mean
.
data
=
torch
.
index_select
(
norm
.
running_mean
.
data
,
0
,
remained_in
)
...
...
@@ -264,8 +267,9 @@ def replace_batchnorm2d(norm, masks):
affine
=
norm
.
affine
,
track_running_stats
=
norm
.
track_running_stats
)
# assign weights
new_norm
.
weight
.
data
=
torch
.
index_select
(
norm
.
weight
.
data
,
0
,
remained_in
)
new_norm
.
bias
.
data
=
torch
.
index_select
(
norm
.
bias
.
data
,
0
,
remained_in
)
if
norm
.
affine
:
new_norm
.
weight
.
data
=
torch
.
index_select
(
norm
.
weight
.
data
,
0
,
remained_in
)
new_norm
.
bias
.
data
=
torch
.
index_select
(
norm
.
bias
.
data
,
0
,
remained_in
)
new_norm
.
running_mean
.
data
=
torch
.
index_select
(
norm
.
running_mean
.
data
,
0
,
remained_in
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment