Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
0ab4916a
Unverified
Commit
0ab4916a
authored
May 20, 2022
by
Ming Yu
Committed by
GitHub
May 20, 2022
Browse files
[Model Compression] Add replace module (#4492)
parent
aa1f71c8
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
85 additions
and
0 deletions
+85
-0
nni/compression/pytorch/speedup/compress_modules.py
nni/compression/pytorch/speedup/compress_modules.py
+85
-0
No files found.
nni/compression/pytorch/speedup/compress_modules.py
View file @
0ab4916a
...
...
@@ -11,6 +11,7 @@ _logger = logging.getLogger(__name__)
replace_module
=
{
'BatchNorm2d'
:
lambda
module
,
masks
:
replace_batchnorm2d
(
module
,
masks
),
'BatchNorm1d'
:
lambda
module
,
masks
:
replace_batchnorm1d
(
module
,
masks
),
'InstanceNorm2d'
:
lambda
module
,
masks
:
replace_instancenorm2d
(
module
,
masks
),
'Conv2d'
:
lambda
module
,
masks
:
replace_conv2d
(
module
,
masks
),
'Linear'
:
lambda
module
,
masks
:
replace_linear
(
module
,
masks
),
'MaxPool2d'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
...
...
@@ -43,6 +44,7 @@ replace_module = {
'Upsample'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
),
'LayerNorm'
:
lambda
module
,
masks
:
replace_layernorm
(
module
,
masks
),
'ConvTranspose2d'
:
lambda
module
,
masks
:
replace_convtranspose2d
(
module
,
masks
),
'PixelShuffle'
:
lambda
module
,
masks
:
replace_pixelshuffle
(
module
,
masks
),
'Flatten'
:
lambda
module
,
masks
:
no_replace
(
module
,
masks
)
}
...
...
@@ -280,6 +282,51 @@ def replace_batchnorm2d(norm, masks):
return
new_norm
def
replace_instancenorm2d
(
norm
,
masks
):
"""
Parameters
----------
norm : torch.nn.InstanceNorm2d
The instancenorm module to be replace
masks : Tuple of the input masks, output masks and weight masks
Tuple of the masks, for example
([input_m1, input_m2], [output_m], {'weight':weight_m})
Returns
-------
torch.nn.InstanceNorm2d
The new instancenorm module
"""
in_masks
,
output_mask
,
_
=
masks
assert
isinstance
(
norm
,
nn
.
InstanceNorm2d
)
in_mask
=
in_masks
[
0
]
# N, C, H, W
_
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
1
)
_
,
remained_out
=
convert_to_coarse_mask
(
output_mask
,
1
)
if
remained_in
.
size
(
0
)
!=
remained_out
.
size
(
0
):
raise
ShapeMisMatchError
()
num_features
=
remained_in
.
size
(
0
)
_logger
.
info
(
"replace instancenorm2d with num_features: %d"
,
num_features
)
new_norm
=
torch
.
nn
.
InstanceNorm2d
(
num_features
=
num_features
,
eps
=
norm
.
eps
,
momentum
=
norm
.
momentum
,
affine
=
norm
.
affine
,
track_running_stats
=
norm
.
track_running_stats
)
# assign weights
if
norm
.
affine
:
new_norm
.
weight
.
data
=
torch
.
index_select
(
norm
.
weight
.
data
,
0
,
remained_in
)
new_norm
.
bias
.
data
=
torch
.
index_select
(
norm
.
bias
.
data
,
0
,
remained_in
)
if
norm
.
track_running_stats
:
new_norm
.
running_mean
.
data
=
torch
.
index_select
(
norm
.
running_mean
.
data
,
0
,
remained_in
)
new_norm
.
running_var
.
data
=
torch
.
index_select
(
norm
.
running_var
.
data
,
0
,
remained_in
)
return
new_norm
def
replace_conv2d
(
conv
,
masks
):
"""
Replace the original conv with a new one according to the infered
...
...
@@ -544,3 +591,41 @@ def replace_layernorm(layernorm, masks):
new_shape
.
append
(
n_remained
)
return
nn
.
LayerNorm
(
tuple
(
new_shape
),
layernorm
.
eps
,
layernorm
.
elementwise_affine
)
def
replace_pixelshuffle
(
pixelshuffle
,
masks
):
"""
Parameters
----------
norm : torch.nn.PixelShuffle
The pixelshuffle module to be replace
masks : Tuple of the input masks, output masks and weight masks
Tuple of the masks, for example
([input_m1, input_m2], [output_m], {'weight':weight_m})
Returns
-------
torch.nn.PixelShuffle
The new pixelshuffle module
"""
in_masks
,
output_mask
,
_
=
masks
assert
isinstance
(
pixelshuffle
,
torch
.
nn
.
PixelShuffle
)
if
len
(
in_masks
)
!=
1
:
raise
InputsNumberError
()
in_mask
=
in_masks
[
0
]
# N, C, H, W
_
,
remained_in
=
convert_to_coarse_mask
(
in_mask
,
1
)
_
,
remained_out
=
convert_to_coarse_mask
(
output_mask
,
1
)
upscale_factor
=
pixelshuffle
.
upscale_factor
if
remained_in
.
size
(
0
)
%
(
upscale_factor
*
upscale_factor
):
_logger
.
debug
(
"Shape mismatch, remained_in:%d upscale_factor:%d"
,
remained_in
.
size
(
0
),
remained_out
.
size
(
0
))
raise
ShapeMisMatchError
()
if
remained_out
.
size
(
0
)
*
upscale_factor
*
upscale_factor
!=
remained_in
:
raise
ShapeMisMatchError
()
new_pixelshuffle
=
torch
.
nn
.
PixelShuffle
(
upscale_factor
)
return
new_pixelshuffle
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment